# Import Libraries

In [1]:
import os
import copy
from tqdm.notebook import tqdm
import pprint
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import torch
from torch import nn
import os
import math
from src.models.utils import cal_l1_loss
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.core.lightning import LightningModule
from easydict import EasyDict as edict
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.data_loader.utils import get_train_val_split
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.callbacks.upload_comet_logs import UploadCometLogs
from src.models.simclr_model import SimCLR
from src.utils import get_console_logger, read_json
from torchvision import transforms
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

# Read data

In [2]:
train_param = edict(read_json(TRAINING_CONFIG_PATH))
model_param = edict(
        read_json(
            os.path.join(MASTER_THESIS_DIR, "src", "experiments", "simclr_config.json")
        )
)
train_param.augmentation_flags.resize=True
train_param.augmentation_flags.rotate=True
train_param.augmentation_flags.crop=True

In [4]:
train_param.num_workers=8

In [3]:
data = Data_Set(
        config=train_param,
        transform=transforms.Compose([transforms.ToTensor()]),
        train_set=True,
        experiment_type="supervised",
    )
model_param.num_samples= len(data)
model_param.alpha =5
model_param.gpu = True

In [4]:
train_data_loader, val_data_loader = get_train_val_split(
        data, batch_size=512, num_workers=train_param.num_workers
    )

In [5]:
class SupervisedHead(LightningModule):
    def __init__(self, simclr_config: edict, saved_simclr_model_path:str, config:edict):
        super().__init__()
        self.config = config
        self.encoder= self.get_simclr_model(simclr_config, saved_simclr_model_path)
        self.final_layers = nn.Sequential(
            nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 21 * 3)
        )
#         self.layer1 = nn.Linear(512,128)
#         self.batch_norm = nn.BatchNorm1d(128)
#         self.output_layer = nn.Linear(128,21*3)
        self.train_metrics_epoch = None
        self.train_metrics = None
        self.validation_metrics_epoch = None
        self.plot_params = None

    def get_simclr_model(self, simclr_config, saved_simclr_model_path):
        simclr_model = SimCLR(simclr_config)
        saved_model_state = torch.load(saved_simclr_model_path)["state_dict"]
        simclr_model.load_state_dict(saved_model_state)
        for param in simclr_model.parameters():
            param.requires_grad = False
        return simclr_model.encoder

    def forward(self, x):
        x = self.encoder(x)
        x = self.final_layers(x)
#         x = self.layer1(x)
#         x = self.batch_norm(x) 
#         x = self.output_layer(x)
        x = x.view(-1, 21, 3)
        return x

    def training_step(self, batch: dict, batch_idx: int):
        x, y = batch["image"], batch["joints"]
        with torch.cuda.amp.autocast():
            prediction = self(x)
            loss_2d, loss_z = cal_l1_loss(prediction, y)
            loss = loss_2d + self.config.alpha * loss_z
        metrics = {"loss": loss, "loss_z": loss_z, "loss_2d": loss_2d}
        self.train_metrics = metrics
        self.plot_params = {"prediction": prediction, "ground_truth": y, "input": x}
        return metrics
    
    def training_epoch_end(self, outputs):
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        loss_z = torch.stack([x["loss_z"] for x in outputs]).mean()
        loss_2d = torch.stack([x["loss_2d"] for x in outputs]).mean()
        self.train_metrics_epoch = {"loss": loss, "loss_z": loss_z, "loss_2d": loss_2d}

    def exclude_from_wt_decay(
        self, named_params, weight_decay, skip_list=["bias", "bn"]
    ):
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.append(param)

        return [
            {"params": params, "weight_decay": weight_decay},
            {"params": excluded_params, "weight_decay": 0.0},
        ]

    def setup(self, stage):
        global_batch_size = self.trainer.world_size * self.config.batch_size
        self.train_iters_per_epoch = self.config.num_samples // global_batch_size

    def configure_optimizers(self):
        parameters = self.exclude_from_wt_decay(
            self.named_parameters(), weight_decay=self.config.opt_weight_decay
        )
        optimizer = LARSWrapper(
            torch.optim.Adam(
                parameters, lr=self.config.lr * math.sqrt(self.config.batch_size)
            )
        )
        self.config.warmup_epochs = (
            self.config.warmup_epochs * self.train_iters_per_epoch
        )
        max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch

        linear_warmup_cosine_decay = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.config.warmup_epochs,
            max_epochs=max_epochs,
            warmup_start_lr=0,
            eta_min=0,
        )

        scheduler = {
            "scheduler": linear_warmup_cosine_decay,
            "interval": "step",
            "frequency": 1,
        }
#         return torch.optim.Adam(
#                 self.parameters(), lr=self.config.lr * math.sqrt(self.config.batch_size)
#             )
        return [optimizer], [scheduler]

    def validation_step(self, batch: dict, batch_idx: int) -> dict:
        x, y = batch["image"], batch["joints"]
        prediction = self(x)
        loss_2d, loss_z = cal_l1_loss(prediction, y)
        loss = loss_2d + self.config.alpha * loss_z
        metrics = {"loss": loss, "loss_z": loss_z, "loss_2d": loss_2d}
        self.plot_params = {"prediction": prediction, "ground_truth": y, "input": x}

        return metrics

    def validation_epoch_end(self, outputs) -> dict:
        loss = torch.stack([x["loss"] for x in outputs]).mean()
        loss_z = torch.stack([x["loss_z"] for x in outputs]).mean()
        loss_2d = torch.stack([x["loss_2d"] for x in outputs]).mean()
        metrics = {"loss": loss, "loss_z": loss_z, "loss_2d": loss_2d}
        self.validation_metrics_epoch = metrics

In [6]:
model_ssl = SupervisedHead(model_param, "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/a7c43b88d9c34332bf3c86eb81f5db7d/checkpoints/epoch=99.ckpt", model_param)

In [10]:
# comet_logger = CometLogger(
#     api_key=os.environ.get("COMET_API_KEY"),
#     project_name="master-thesis",
#     workspace="dahiyaaneesh",
#     save_dir=os.path.join(DATA_PATH, "models"),
# )
# upload_comet_logs = UploadCometLogs(
#     "step", get_console_logger("callback"), "supervised_ssl"
# )
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer(amp_level='O2',
    gpus=1,
    max_epochs=150)
#     logger=comet_logger,
#     callbacks=[upload_comet_logs,lr_monitor],
# )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [7]:
model_ssl

SupervisedHead(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [10]:
trainer = Trainer(amp_level='O2',
    gpus=1,
    max_epochs=150)
trainer.fit(model_ssl,train_data_loader, val_data_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | encoder      | ResNet     | 11 M  
1 | final_layers | Sequential | 74 K  


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…






1

In [13]:
del model_ssl

In [8]:
torch.clear_autocast_cache()

In [9]:
torch.cuda.empty_cache()