In [None]:
# basic torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim

# hyperparameter optimization rtd
import optuna
import wandb

# os related
import os

# file handling

# segmentation model
from transformers import ViTModel, ViTImageProcessor # modules for loading the vit model
from lora_vit import LoraVit
from segmentation_model import SegViT
from serial_lora_vit import SerialLoraVit
from replora_vit import RepLoraVit
from localised_lora_vit import LocalizedLoraVit
from segmentation_head import CustomSegHead

# dataset class
from pet_dataset_class import PreprocessedPetDataset

# dataloaders
from create_dataloaders import get_pet_dataloaders

# trainer
from trainer import trainer

# loss and metrics
from loss_and_metrics_seg import * # idk what to import here tbh. Need to look into it

# data plotting
from data_plotting import plot_random_images_and_trimaps_2

#
from typing import Literal


In [2]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

In [3]:
# get path of image and mask files
try:
    base_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g. in Jupyter notebook or interactive sessions apparently), fallback to cwd
    base_dir = os.getcwd()

# Suppose your dataset is in a folder named 'data' inside the project root:
data_dir = os.path.join(base_dir, 'data_oxford_iiit')

# # Then you can define image and trimap paths relative to that
image_folder = os.path.join(data_dir, 'resized_images')
trimap_folder = os.path.join(data_dir, 'resized_masks')

In [4]:
wandb.login(key="b47c50d9d7a54018ff9133f43a7d0d5ce08cdb1e")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\NEELKANTH RAWAT\.netrc
[34m[1mwandb[0m: Currently logged in as: [33mneelkanth-rawat[0m ([33mnetwork-to-network[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
NUM_BLOCKS = 12 # len(vit_pretrained.encoder.layer)
# Deterministic LoRA layer options
# LORA_LAYER_OPTIONS = [
#     tuple(range(3)),                       # first 3 layers
#     tuple(range(NUM_BLOCKS-3, NUM_BLOCKS)),# last 3 layers
#     tuple(range(6)),                       # first 6 layers
#     tuple(range(NUM_BLOCKS-6, NUM_BLOCKS)) # last 6 layers
# ]

def objective(
    trial, 
    lora_type: Literal["lora", "serial_lora", "replora", "localised_lora"] = "lora", 
    optimize_for: Literal["loss", "dice", "iou"] = "loss",
    separate_lr = False,
    want_backbone_frozen_initially=False
):
    # Sample hyperparameters
    lora_rank = trial.suggest_categorical("lora_rank", [4, 8, 16])
    lora_alpha = trial.suggest_categorical("lora_alpha", [4, 8, 16, 32])
    if separate_lr:
        lr_backbone = trial.suggest_float("lr_vit_backbone", 1e-5, 1e-3, log=True)
        lr_head = trial.suggest_float("lr_seg_head", 1e-4, 1e-3, log=True)
    else:
        lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    #weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-2, log=True)# L2 decay
    optimizer_name = trial.suggest_categorical("optimizer", ["adamw", "adam"]) 
    batch_size = trial.suggest_categorical("batch_size", [24, 32])
    #dropout_rate = trial.suggest_categorical("dropout_rate", [0.0, 0.15])
    #use_bn = trial.suggest_categorical("use_bn", [True, False])
    if lora_type=="localised_lora":
        # add rblock and num_blocks as parameters here
        r_block = trial.suggest_categorical("r_block",[2,4,8,16])
        num_blocks = trial.suggest_categorical("num_blocks", [2,4,8,16])
    
    if want_backbone_frozen_initially:
        freeze_epochs=trial.suggest_categorical("freeze_epochs",[2,5,10])

    # W&B SETUP
    # #W&B config setup
    wandb_config = {
        "lora_type": lora_type,
        "lora_rank": lora_rank,
        "lora_alpha": lora_alpha,
        "optimizer": optimizer_name,
        "batch_size": batch_size,
        "separate_lr": separate_lr,
        "want_backbone_frozen_initially": want_backbone_frozen_initially,
    }
    if lora_type == "localised_lora":
        wandb_config.update({
            "r_block": r_block,
            "num_blocks_per_row": num_blocks
        })
    if want_backbone_frozen_initially: # I am not quite sure if we want to do it orn
        wandb_config.update({
            "freeze_epochs": freeze_epochs
        })
    if separate_lr:
        wandb_config.update({
            "lr_vit_backbone": lr_backbone,
            "lr_seg_head": lr_head
        })
    else:
        wandb_config.update({"lr": lr})
    ## Initialize W&B
    wandb.init(
        project="Lora_vit_segmentation",
        config=wandb_config,
        reinit='finish_previous'
    )

    # some other parameters which I have decided to keep fixed
    weight_decay=0.0005
    use_bn=True
    dropout_rate=0.1
    num_epoch=50
    if want_backbone_frozen_initially: # I am not quite sure if we want to do it orn
        wandb_config.update({
            "freeze_epochs": freeze_epochs
        })

    # Create model with current trial's hyperparameters
    Vit_pretrained = ViTModel.from_pretrained(model_name)
    if lora_type == "lora":
        lora_vit_base = LoraVit(vit_model=Vit_pretrained, r=lora_rank, alpha=lora_alpha)
    elif lora_type == "serial_lora":
        lora_vit_base = SerialLoraVit(vit_model=Vit_pretrained, r=lora_rank)
    elif lora_type == "replora":
        lora_vit_base = RepLoraVit(vit_model=Vit_pretrained, r=lora_rank, alpha=lora_alpha)
    elif lora_type == "localised_lora":
        # i still need to calculate r_block and num_blocks properly
        lora_vit_base = LocalizedLoraVit(vit_model=Vit_pretrained,
                                        r_block=r_block,
                                        alpha=lora_alpha,
                                        num_blocks_per_row=num_blocks)
    
    seg_head = CustomSegHead(hidden_dim=768, num_classes=3,
                            patch_size=16,image_size=224,
                            dropout_rate=dropout_rate, 
                            use_bn=use_bn)
    
    vit_seg_model = SegViT(vit_model=lora_vit_base,
                        image_size=224, patch_size=16,
                        dim=768, n_classes=3,
                        head=seg_head)

    # Create dataloaders with sampled batch size 
    train_dl, val_dl, _ = get_pet_dataloaders(
        image_folder=image_folder,
        mask_folder=trimap_folder,
        DatasetClass=PreprocessedPetDataset,
        all_data=False,
        num_datapoints=25,
        val_ratio=0.2,
        test_ratio=0.1,
        batch_size=batch_size
    )

    # Setup optimizer
    optimizer_cls = torch.optim.AdamW if optimizer_name == "adamw" else torch.optim.Adam
    if separate_lr:
        optimizer = optimizer_cls([
            {"params": vit_seg_model.backbone_parameters, "lr": lr_backbone},
            {"params": vit_seg_model.head_parameters, "lr": lr_head},
        ], weight_decay=weight_decay)
    else:
        optimizer = optimizer_cls(vit_seg_model.parameters(), lr=lr, weight_decay=weight_decay)

    # Trainer
    trainer_input_params = {
        "model": vit_seg_model,
        "optimizer": optimizer,
        "criterion": log_cosh_dice_loss,
        "num_epoch": num_epoch,
        "dataloaders": {"train": train_dl, "val": val_dl},
        "use_trap_scheduler":True,
        "device": "cpu",
        "criterion_kwargs": {"num_classes": 3, "epsilon": 1e-6},
        "want_backbone_frozen_initially":want_backbone_frozen_initially,
        "freeze_epochs":freeze_epochs # I have fixed it to 3
    }

    trainer_seg_model = trainer(**trainer_input_params)
    # trainer_seg_model.train() # I had to comment it because for W&B, to log and plot curves in real time, I need to create loopwise functions here
    
    # Training loop with W&B logging
    for epoch in range(trainer_seg_model.num_epoch):
        #  train and validation step for each epoch
        avg_train_loss = trainer_seg_model.train_epoch(epoch)
        avg_val_loss, avg_val_dice, avg_val_iou = trainer_seg_model.val_epoch()# unpack all three values returned by val_epoch

        # accumulate losses
        trainer_seg_model.train_error_epoch_list.append(avg_train_loss)
        trainer_seg_model.val_error_epoch_list.append(avg_val_loss)
        # accumulate metrics
        trainer_seg_model.val_dice_epoch_list.append(avg_val_dice)
        trainer_seg_model.val_iou_epoch_list.append(avg_val_iou)

        # log to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "val_dice": avg_val_dice,
            "val_iou": avg_val_iou,
            "lr": optimizer.param_groups[0]["lr"]
        })


    # after we train(), in the class trainer_seg_model, validation loss and validation metrics are stored for each epoch as class attributes
    if optimize_for == "loss":
        result = trainer_seg_model.val_error_epoch_list[-1]
    elif optimize_for == "dice":
        result = trainer_seg_model.val_dice_epoch_list[-1]
    elif optimize_for == "iou":
        result = trainer_seg_model.val_iou_epoch_list[-1]
    else:
        raise ValueError(f"Unknown optimize_for='{optimize_for}'")

    # finish W&B run
    wandb.finish() 

    return result

In [None]:
# Optimize for loss
N_TRIALS=2
study = optuna.create_study(direction="minimize")
study.optimize(lambda trial: objective(trial, optimize_for="loss"), n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params) 

In [7]:
print(study.best_params)

{'lora_rank': 16, 'lora_alpha': 32, 'lora_layers': (0, 1, 2), 'lr': 7.148500193053507e-05, 'weight_decay': 0.004557645356079437, 'optimizer': 'adam', 'batch_size': 16, 'dropout_rate': 0.0, 'use_bn': False}


In [None]:
# Optimize for dice
study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial, optimize_for="dice"), n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params)

In [None]:
# Optimize for IoU
study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial, optimize_for="iou"), 
            n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params)