In [1]:
# basic torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim

# hyperparameter optimization rtd
import optuna
import wandb

# os related
import os

# file handling

# segmentation model
from transformers import ViTModel, ViTImageProcessor # modules for loading the vit model
from lora_vit import LoraVit
from segmentation_model import SegViT
from segmentation_head import CustomSegHead

# dataset class
from pet_dataset_class import PreprocessedPetDataset

# dataloaders
from create_dataloaders import get_pet_dataloaders

# trainer
from trainer import trainer

# loss and metrics
from loss_and_metrics_seg import * # idk what to import here tbh. Need to look into it

# data plotting
from data_plotting import plot_random_images_and_trimaps_2


In [2]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

In [3]:
# get path of image and mask files
try:
    base_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g. in Jupyter notebook or interactive sessions apparently), fallback to cwd
    base_dir = os.getcwd()

# Suppose your dataset is in a folder named 'data' inside the project root:
data_dir = os.path.join(base_dir, 'data_oxford_iiit')

# # Then you can define image and trimap paths relative to that
image_folder = os.path.join(data_dir, 'resized_images')
trimap_folder = os.path.join(data_dir, 'resized_masks')

In [4]:
wandb.login(key="b47c50d9d7a54018ff9133f43a7d0d5ce08cdb1e")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\NEELKANTH RAWAT\.netrc
[34m[1mwandb[0m: Currently logged in as: [33mneelkanth-rawat[0m ([33mnetwork-to-network[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
NUM_BLOCKS = 12 # len(vit_pretrained.encoder.layer)
# Deterministic LoRA layer options
LORA_LAYER_OPTIONS = [
    tuple(range(3)),                       # first 3 layers
    tuple(range(NUM_BLOCKS-3, NUM_BLOCKS)),# last 3 layers
    tuple(range(6)),                       # first 6 layers
    tuple(range(NUM_BLOCKS-6, NUM_BLOCKS)) # last 6 layers
]

def objective(trial, optimize_for="loss"):
    # Sample hyperparameters
    lora_rank = trial.suggest_categorical("lora_rank", [4, 8, 16])
    lora_alpha = trial.suggest_categorical("lora_alpha", [4, 8, 16, 32])
    
    lora_layers = trial.suggest_categorical("lora_layers", LORA_LAYER_OPTIONS)
    
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-2, log=True)# L2 decay
    optimizer_name = trial.suggest_categorical("optimizer", ["adamw", "adam"]) 
    
    batch_size = trial.suggest_categorical("batch_size", [8, 16,24])#, 24, 32])
    dropout_rate = trial.suggest_categorical("dropout_rate", [0.0, 0.15])
    use_bn = trial.suggest_categorical("use_bn", [True, False])

    # W&B setup
    wandb.init(
        project="Lora_vit_segmentation",
        config={
            "lora_rank": lora_rank,
            "lora_alpha": lora_alpha,
            "lora_layers": lora_layers,
            "lr": lr,
            "weight_decay": weight_decay,
            "optimizer": optimizer_name,
            "batch_size": batch_size,
            "dropout_rate": dropout_rate,
            "use_bn": use_bn,
            "num_epoch": 3
        },
        reinit='finish_previous' # This ensures that a new W&B run is created for each Optuna trial.
    )

    # Create model with current trial's hyperparameters
    #image_processor = ViTImageProcessor.from_pretrained(model_name)
    Vit_pretrained = ViTModel.from_pretrained(model_name)
    lora_vit_base = LoraVit(vit_model=Vit_pretrained,# this vit_pretrained is defined globally (in cell number 3)
                            r=lora_rank, alpha=lora_alpha,
                            lora_layers=lora_layers)
    
    seg_head = CustomSegHead(hidden_dim=768, num_classes=3,                 patch_size=16,image_size=224,
                            dropout_rate=dropout_rate, use_bn=use_bn)
    
    vit_seg_model = SegViT(vit_model=lora_vit_base,
                        image_size=224, patch_size=16,
                        dim=768, n_classes=3,
                        head=seg_head)

    # Create dataloaders with sampled batch size 
    train_dl, val_dl, _ = get_pet_dataloaders(
        image_folder=image_folder,
        mask_folder=trimap_folder,
        DatasetClass=PreprocessedPetDataset,
        all_data=False,
        num_datapoints=25,
        val_ratio=0.2,
        test_ratio=0.1,
        batch_size=batch_size
    )

    # Setup optimizer
    optimizer_cls = torch.optim.AdamW if optimizer_name == "adamw" else torch.optim.Adam
    optimizer = optimizer_cls(vit_seg_model.parameters(), lr=lr, weight_decay=weight_decay)

    # Trainer
    trainer_input_params = {
        "model": vit_seg_model,
        "optimizer": optimizer,
        "lr": lr,
        "criterion": log_cosh_dice_loss,
        "num_epoch": 3,
        "dataloaders": {"train": train_dl, "val": val_dl},
        "use_trap_scheduler":True,
        "device": "cpu",
        "criterion_kwargs": {"num_classes": 3, "epsilon": 1e-6}
    }

    trainer_seg_model = trainer(**trainer_input_params)
    # trainer_seg_model.train() # I had to comment it because for W&B, to log and plot curves in real time, I need to create loopwise functions here
    
    # Training loop with W&B logging
    for epoch in range(trainer_seg_model.num_epoch):
        #  train and validation step for each epoch
        avg_train_loss = trainer_seg_model.train_epoch(epoch)
        avg_val_loss, avg_val_dice, avg_val_iou = trainer_seg_model.val_epoch()# unpack all three values returned by val_epoch

        # accumulate losses
        trainer_seg_model.train_error_epoch_list.append(avg_train_loss)
        trainer_seg_model.val_error_epoch_list.append(avg_val_loss)
        # accumulate metrics
        trainer_seg_model.val_dice_epoch_list.append(avg_val_dice)
        trainer_seg_model.val_iou_epoch_list.append(avg_val_iou)

        # log to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            "val_loss": avg_val_loss,
            "val_dice": avg_val_dice,
            "val_iou": avg_val_iou,
            "lr": optimizer.param_groups[0]["lr"]
        })


    # after we train(), in the class trainer_seg_model, validation loss and validation metrics are stored for each epoch as class attributes
    if optimize_for == "loss":
        result = trainer_seg_model.val_error_epoch_list[-1]
    elif optimize_for == "dice":
        result = trainer_seg_model.val_dice_epoch_list[-1]
    elif optimize_for == "iou":
        result = trainer_seg_model.val_iou_epoch_list[-1]
    else:
        raise ValueError(f"Unknown optimize_for='{optimize_for}'")

    # finish W&B run
    wandb.finish() 

    return result

In [6]:
# Optimize for loss
N_TRIALS=2
study = optuna.create_study(direction="minimize")
study.optimize(lambda trial: objective(trial, optimize_for="loss"), n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params) 

[I 2025-08-19 15:55:59,799] A new study created in memory with name: no-name-3253dfb8-6e95-498e-b829-3556a665a47e


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO] Using only 25 datapoints out of 7390 total files.
train_size: 18, val_size: 5 and test_size: 2


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▅█
lr,▁▁▁
train_loss,█▃▁
val_dice,▁▅█
val_iou,▁▅█
val_loss,█▄▁

0,1
epoch,3.0
lr,7e-05
train_loss,0.15986
val_dice,0.37296
val_iou,0.24268
val_loss,0.18492


[I 2025-08-19 15:59:57,244] Trial 0 finished with value: 0.18491728603839874 and parameters: {'lora_rank': 16, 'lora_alpha': 32, 'lora_layers': (0, 1, 2), 'lr': 7.148500193053507e-05, 'weight_decay': 0.004557645356079437, 'optimizer': 'adam', 'batch_size': 16, 'dropout_rate': 0.0, 'use_bn': False}. Best is trial 0 with value: 0.18491728603839874.


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO] Using only 25 datapoints out of 7390 total files.
train_size: 18, val_size: 5 and test_size: 2


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▅█
lr,▁▁▁
train_loss,█▄▁
val_dice,▁▅█
val_iou,▁▅█
val_loss,█▄▁

0,1
epoch,3.0
lr,3e-05
train_loss,0.19302
val_dice,0.33663
val_iou,0.20878
val_loss,0.20556


[I 2025-08-19 16:03:16,471] Trial 1 finished with value: 0.20556245744228363 and parameters: {'lora_rank': 16, 'lora_alpha': 4, 'lora_layers': (9, 10, 11), 'lr': 3.127560309360304e-05, 'weight_decay': 0.0005006394791299488, 'optimizer': 'adamw', 'batch_size': 16, 'dropout_rate': 0.15, 'use_bn': False}. Best is trial 0 with value: 0.18491728603839874.


{'lora_rank': 16, 'lora_alpha': 32, 'lora_layers': (0, 1, 2), 'lr': 7.148500193053507e-05, 'weight_decay': 0.004557645356079437, 'optimizer': 'adam', 'batch_size': 16, 'dropout_rate': 0.0, 'use_bn': False}


In [7]:
print(study.best_params)

{'lora_rank': 16, 'lora_alpha': 32, 'lora_layers': (0, 1, 2), 'lr': 7.148500193053507e-05, 'weight_decay': 0.004557645356079437, 'optimizer': 'adam', 'batch_size': 16, 'dropout_rate': 0.0, 'use_bn': False}


In [None]:
# Optimize for dice
study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial, optimize_for="dice"), n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params)

In [None]:
# Optimize for IoU
study = optuna.create_study(direction="maximize")
study.optimize(lambda trial: objective(trial, optimize_for="iou"), 
            n_trials=N_TRIALS)
# Print best hyperparameters
print(study.best_params)