In [1]:
# basic torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim

# hyperparameter optimization rtd
import optuna
import wandb

# os related
import os

# file handling

# segmentation model
from lora_vit import LoraVit
from segmentation_model import SegViT
from segmentation_head import CustomSegHead

# dataset class
from pet_dataset_class import PreprocessedPetDataset

# dataloaders
from create_dataloaders import get_pet_dataloaders

# trainer
from trainer import trainer

# loss and metrics
from loss_and_metrics_seg import * # idk what to import here tbh. Need to look into it

# data plotting
from data_plotting import plot_random_images_and_trimaps_2

# modules for loading the vit model
from transformers import ViTModel, ViTImageProcessor


In [2]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

image_processor = ViTImageProcessor.from_pretrained(model_name)
VIT_PRETRAINED = ViTModel.from_pretrained(model_name)

# instantiating the lora vit based backbone model
lora_vit_base = LoraVit(vit_model=VIT_PRETRAINED,
                        r=4,alpha = 16, lora_layers = [1,2,3,4])


# # instantiate the custom segmentation head
# check_seg_head_model = CustomSegHead(hidden_dim=768, num_classes=3, patch_size=16, image_size=224) #  do not need to do that, the SegViT model will do it automatically

# instantiate the segmentation model
vit_seg_model = SegViT(vit_model=lora_vit_base,image_size=224, patch_size=16
                    , dim= 768,
                    n_classes=3)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
vit_seg_model.backbone_parameters # get backbone parameters

<generator object Module.parameters at 0x000001F59F1FB4C0>

In [4]:
vit_seg_model.head_parameters 

<generator object Module.parameters at 0x000001F59F1FB5A0>

In [5]:
# get path of image and mask files
try:
    base_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g. in Jupyter notebook or interactive sessions apparently), fallback to cwd
    base_dir = os.getcwd()

# Suppose your dataset is in a folder named 'data' inside the project root:
data_dir = os.path.join(base_dir, 'data_oxford_iiit')

# # Then you can define image and trimap paths relative to that
image_folder = os.path.join(data_dir, 'resized_images')
trimap_folder = os.path.join(data_dir, 'resized_masks')

# create dataloaders
train_dl, val_dl, test_dl = get_pet_dataloaders(
    image_folder=image_folder,
    mask_folder=trimap_folder,
    DatasetClass=PreprocessedPetDataset,
    all_data=False,
    num_datapoints=80, # right now I am just testing whether the things would work or not. I will intentionally overtrain it.
    batch_size=24
) 

[INFO] Using only 80 datapoints out of 7390 total files.
train_size: 64, val_size: 8 and test_size: 8


In [7]:
# define the input parameters for the trainer
trainer_input_params = {
    "model": vit_seg_model , 
    "optimizer": torch.optim.Adam([
    {"params": vit_seg_model.backbone_parameters, "lr": 1e-5},
    {"params": vit_seg_model.head_parameters, "lr": 1e-4},
]),
    "lr": 1e-4,
    "criterion": log_cosh_dice_loss,  # or log_cosh_dice_loss, whichever you want to use
    "num_epoch": 10,
    "dataloaders": {
        "train": train_dl,  # replace with your actual DataLoader
        "val": val_dl       # replace with your actual DataLoader
    },
    "use_trap_scheduler": False,             # or your scheduler instance if you use one
    "device": "cpu",#"cuda" if torch.cuda.is_available() else "cpu",
    #"model_kwargs": {},            # add any extra forward() kwargs if needed
    "criterion_kwargs": {
        "num_classes": 3,
        "epsilon": 1e-6,
        # "return_metrics": False    # usually False for training, True for validation if you want metrics ## WE DO NOT NEED THIS HERE
    }
}


In [8]:
# instantiate the trainer
trainer_seg_model = trainer(**trainer_input_params) 
trainer_seg_model.train() # need to look into this thing

Epochs:  10%|█         | 1/10 [02:59<26:54, 179.36s/it]

Epoch [1/10] - Train Loss: 0.2038 | Val Loss: 0.1671 | Dice score: 0.40576425194740295 |IOU score: 0.2783524 


Epochs:  20%|██        | 2/10 [05:17<20:38, 154.86s/it]

Epoch [2/10] - Train Loss: 0.1452 | Val Loss: 0.1326 | Dice score: 0.47354504466056824 |IOU score: 0.3533309 


Epochs:  30%|███       | 3/10 [07:12<15:57, 136.78s/it]

Epoch [3/10] - Train Loss: 0.1140 | Val Loss: 0.1115 | Dice score: 0.5189403891563416 |IOU score: 0.4039131 


Epochs:  40%|████      | 4/10 [09:04<12:41, 126.92s/it]

Epoch [4/10] - Train Loss: 0.0941 | Val Loss: 0.0958 | Dice score: 0.5551862120628357 |IOU score: 0.4378821 


Epochs:  50%|█████     | 5/10 [10:56<10:08, 121.76s/it]

Epoch [5/10] - Train Loss: 0.0791 | Val Loss: 0.0860 | Dice score: 0.5792495012283325 |IOU score: 0.4553901 


Epochs:  60%|██████    | 6/10 [12:47<07:51, 117.96s/it]

Epoch [6/10] - Train Loss: 0.0714 | Val Loss: 0.0810 | Dice score: 0.5920671224594116 |IOU score: 0.4649752 


Epochs:  70%|███████   | 7/10 [14:49<05:57, 119.28s/it]

Epoch [7/10] - Train Loss: 0.0642 | Val Loss: 0.0732 | Dice score: 0.6127362847328186 |IOU score: 0.4891889 


Epochs:  80%|████████  | 8/10 [16:56<04:03, 121.90s/it]

Epoch [8/10] - Train Loss: 0.0577 | Val Loss: 0.0685 | Dice score: 0.6257516741752625 |IOU score: 0.5063803 


Epochs:  90%|█████████ | 9/10 [19:11<02:05, 125.86s/it]

Epoch [9/10] - Train Loss: 0.0538 | Val Loss: 0.0665 | Dice score: 0.6313516497612 |IOU score: 0.5125546 


Epochs: 100%|██████████| 10/10 [21:16<00:00, 127.67s/it]

Epoch [10/10] - Train Loss: 0.0500 | Val Loss: 0.0660 | Dice score: 0.6327686905860901 |IOU score: 0.5116880 



