# Import Libraries

In [None]:
import os
import copy
from tqdm.notebook import tqdm
import pprint
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display
import torch
from torch import nn
import os
import math
from src.models.utils import cal_l1_loss
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.core.lightning import LightningModule
from easydict import EasyDict as edict
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from src.constants import DATA_PATH, MASTER_THESIS_DIR, TRAINING_CONFIG_PATH
from src.data_loader.data_set import Data_Set
from src.data_loader.utils import get_train_val_split
from src.experiments.utils import get_experiement_args, process_experiment_args
from src.models.callbacks.upload_comet_logs import UploadCometLogs
from src.models.simclr_model import SimCLR
from src.utils import get_console_logger, read_json
from torchvision import transforms
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from src.models.supervised_head_model import SupervisedHead

# Read data

In [None]:
train_param = edict(read_json(TRAINING_CONFIG_PATH))
model_param = edict(
        read_json(
            os.path.join(MASTER_THESIS_DIR, "src", "experiments", "simclr_config.json")
        )
)
train_param.augmentation_flags.resize=True
train_param.augmentation_flags.rotate=False
train_param.augmentation_flags.crop=True

In [None]:
train_param.num_workers=8
train_param.epochs = 150
model_param.warmup_epochs = 10
train_param, model_param

In [None]:
data = Data_Set(
        config=train_param,
        transform=transforms.Compose([transforms.ToTensor()]),
        train_set=True,
        experiment_type="supervised",
    )
model_param.num_samples= len(data)
model_param.alpha =5
model_param.gpu = True

In [None]:
train_data_loader, val_data_loader = get_train_val_split(
        data, batch_size=train_param.batch_size, num_workers=train_param.num_workers
    )

In [None]:
model_ssl = SupervisedHead(model_param, "/local/home/adahiya/Documents/master_thesis/data/models/master-thesis/14910753afb9499ab9229174d6377efb/checkpoints/epoch=999.ckpt", model_param)

In [None]:
model_ssl.config.warmup_epochs

In [None]:
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    project_name="master-thesis",
    workspace="dahiyaaneesh",
    save_dir=os.path.join(DATA_PATH, "models"),
)
upload_comet_logs = UploadCometLogs(
    "step", get_console_logger("callback"), "supervised_ssl"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer(precision=16,
    gpus=1,
    max_epochs=train_param.epochs,
    logger=comet_logger,
    callbacks=[upload_comet_logs,lr_monitor],
)

In [None]:
trainer.fit(model_ssl,train_data_loader, val_data_loader)