In [ ]:
from stable_baselines3.common.callbacks import EveryNTimesteps, BaseCallback

import gymnasium as gym
from stable_baselines3 import A2C, PPO, DQN
from feature_extraction.wrappers.feature_extraction_observation_wrapper import FeatureExtractionObservationWrapper
import time
import wandb

# wandb API = cceb2653e8e4543a510e4c872213e68ea45cb706
# wandb.init(project="test-Block_based", 
#            config={"algorithm": "PPO", 
#                    "env": "ALE/Breakout-v5", 
#                    "feature_extractor": "Block-based", 
#                    "n_steps": 10_000})

class CustomCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        
    def _on_step(self) -> bool:
        wandb.log({"episode_reward": self.locals["infos"][0]["episode"]["r"]})
        return True
freq_checkpoint = EveryNTimesteps(n_steps=500, callback=CustomCallback())


env_id = "ALE/Breakout-v5"  # Adjusted to a single-environment ID
env = gym.make(env_id)
env = FeatureExtractionObservationWrapper(env)
print("Supposed observation space: ", env.observation_space)

model = PPO("MlpPolicy", env, verbose=1)
st = time.time()
model.learn(total_timesteps=10_000)
print("Time taken to train 10_000 timelaps: ", time.time() - st)

model.save("stage1_test_Block")
# 17 minutes to train 10_000 timesteps with PPO(MlpPolicy) and stage feature extractor

In [None]:
from torchinfo import torchinfo
import torch
import torch.nn as nn
from torchvision import models

class BlockFeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.block_num = 3
        stem_stages = 4
        stage_num = 4
        

        # Extract layers up to the desired stage
        stages = list(model.children())
        print("len stages: ", len(stages))
        self.stem = stages[:stem_stages]
        print("len stem: ", len(self.stem))
        self.custom_Stages = stages[stem_stages:stem_stages+stage_num]
        print("len custom stages", len(self.custom_Stages))
        self.custom_Stages[-1] = self.custom_Stages[-1][:self.block_num]
        
        print(self.custom_Stages)
        # # Access the desired stage
        # self.stage = self.stages[:(stem_stages+stage_num)]
        # 
        # self.stage[-1] = self.stage[-1][:self.block_num]
        # print("PRINTING MODEL")
        
        
        # Access up to the desired bottleneck block within the stage
        #self.blocks = nn.Sequential(*list(self.stage.children())[:block_num])
        #print(list(*self.blocks))

    def forward(self, x):
        for stages in self.stem:
            x = stages(x)
        
        for layer in self.custom_Stages:
            x = layer(x)
        return x

if __name__ == "__main__":
    model = models.resnet50(pretrained=True)
    block_num = 5 
    block_feature_extractor = BlockFeatureExtractor(model)
    rand_input = torch.rand(1, 3, 224, 224)
    block_output = block_feature_extractor.forward(rand_input)

    print("Shape of the output from block {}:".format(block_num), block_output.shape)


In [None]:
import torchvision.models as models
torchinfo.summary(models.resnet50(pretrained=True))

In [None]:
import torch
import torch.nn as nn
from torchinfo import torchinfo
from torchvision import models
from torchvision.transforms import transforms

from feature_extraction.feature_extractors.feature_extractor import FeatureExtractor


class BlockFeatureExtractor(nn.Module, FeatureExtractor):
    def __init__(self, model, num_blocks=1, num_stages=0):
        # super(BlockFeatureExtractor, self).__init__()
        nn.Module.__init__(self)
        self.num_blocks = num_blocks
        self.num_stages = num_stages

        self.Conv2d, self.BatchNorm2d, self.ReLU, self.MaxPool2d = list(model.children())[:4]
        # Collecting the layers up to the specified stage and block
        self.stages = self._get_block_features(model)
        self.stage1, self.stage2, self.stage3, self.stage4 = self.stages
        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.output_dim = self.output_dim()
        self.freeze_params()


    def freeze_params(self):
        for param in self.parameters():
            param.requires_grad = False

    def output_dim(self):
        dummy_input = torch.rand(1, 3, 224, 224)
        output = self.extract_features(dummy_input)
        return output.shape

    def process_image(self, image):
        image_processor = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return image_processor(image)

    def _get_block_features(self, fe_model):
        sequentials = [None] * 4
        blocks = []
        current_stages = 0
        total_blocks = 0

        # Traverse through the layers of the model until the specified stage and block
        for name, module in fe_model.named_children():
            if not self.num_blocks and not self.num_stages: break
            if isinstance(module, nn.Sequential):
                if current_stages < self.num_stages or len(module) <= self.num_blocks - total_blocks:
                    sequentials[current_stages] = module
                    current_stages += 1
                    total_blocks += len(module)
                    print(f"Adding a whole stage, stage-{current_stages}, totalblocks: {total_blocks}")
                    continue

                if self.num_blocks:
                    for idx, block in enumerate(module):
                        if total_blocks < self.num_blocks:
                            blocks.append(block)
                            total_blocks += 1
                            print(
                                f"Adding block-{idx + 1} from stage-{current_stages + 1}, totalblocks: {total_blocks}")

                    if len(blocks):
                        seq = nn.Sequential(*blocks)
                        sequentials[current_stages] = seq

                    current_stages += 1
                break
        return sequentials

    # def _get_block_features(self, fe_model):
    #     seqs = [None] * 4
    #     sequentials = nn.Sequential(*list(model.children())[4:-6+self.num_stages])  # first stage
    #     sequentials[-1] = nn.Sequential(*list(sequentials[-1].children())[:self.num_blocks])
    #     for idx, seq in enumerate(sequentials):
    #         seqs[idx] = seq
    #     return seqs

    def forward(self, x):
        with torch.no_grad():
            x = self.Conv2d(x)
            x = self.BatchNorm2d(x)
            x = self.ReLU(x)
            x = self.MaxPool2d(x)
            for stage in self.stages[:self.num_stages]:
                x = stage(x)
        return x

    def reduce_dim(self, features):
        reduced_features = self.adaptive_avg_pool(features)
        return reduced_features.view(features.size(0), -1)

    def extract_features(self, image):
        processed_image = self.process_image(image)
        feature_embeddings = self.forward(processed_image)
        reduced_dim = self.reduce_dim(feature_embeddings)
        return reduced_dim




if __name__ == "__main__":
    model = models.resnet50(weights='DEFAULT')
    rand_input = torch.rand(1, 3, 224, 224)

    custom_model = nn.Sequential(*list(model.children())[:-4]) # first stage
    feature_extractor = BlockFeatureExtractor(model, num_stages=2, num_blocks=7)

    output1 = feature_extractor.extract_features(rand_input)
    output2 = feature_extractor.reduce_dim(custom_model.forward(feature_extractor.process_image(rand_input)))
    # output2 = feature_extractor.extract_features(torch.rand(1, 3, 224, 224))

    print(output1.shape, output2.shape)
    print(feature_extractor.output_dim)
    torchinfo.summary(feature_extractor)

    # Check if the outputs are the same
    if torch.allclose(output1, output2):
        print("Outputs of the models are the same.")
    else:
        print("Outputs of the models are different.")


In [None]:
    model = models.resnet50(weights='DEFAULT')
    
    last_model = nn.Sequential(*list(model.children())[:-5]) # first stage
    last_model[-1] = nn.Sequential(*list(last_model[-1].children())[:1]) # two blocks
    torchinfo.summary(last_model)

In [None]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder, DummyVecEnv
from stable_baselines3 import PPO
from wandb.integration.sb3 import WandbCallback
import wandb

config = {
    "total_timesteps": int(10000),
    "num_envs": 4,
}
wandb.init(
    project="sb3",
    config=config,
    sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
    monitor_gym=True,  # auto-upload the videos of agents playing the game
    save_code=True,  # optional
)

env = make_atari_env('BreakoutNoFrameskip-v4', n_envs=config["num_envs"], seed=0)
env = VecFrameStack(env, n_stack=4)
env = VecVideoRecorder(env, "videos",
    record_video_trigger=lambda x: x % 1000 == 0, video_length=200)  # record videos
model = PPO(
    "CnnPolicy",
    env,
    n_steps=128,
    n_epochs=4,
    learning_rate=lambda progression: 2.5e-4 * progression,
    ent_coef=0.01,
    clip_range=lambda progression: 0.1 * progression,
    batch_size=256,
    verbose=1,
    tensorboard_log=f"runs"
)
model.learn(
    total_timesteps=int(10000),
    callback=WandbCallback(
        gradient_save_freq=500,
        model_save_freq=2000,
        model_save_path=f"models",
    ),
)