In [1]:
# basic torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim

# hyperparameter optimization rtd
import optuna
import wandb

# os related
import os

# file handling

# segmentation model
from localised_lora_vit import LocalizedLoraVit
from segmentation_model import SegViT
from segmentation_head import CustomSegHead

# dataset class
from pet_dataset_class import PreprocessedPetDataset

# dataloaders
from create_dataloaders import get_pet_dataloaders

# trainer
from trainer import trainer

# loss and metrics
from loss_and_metrics_seg import * # idk what to import here tbh. Need to look into it

# data plotting
from data_plotting import plot_random_images_and_trimaps_2

# modules for loading the vit model
from transformers import ViTModel, ViTImageProcessor


In [2]:
import math

In [3]:
rank_A_and_B=16
num_blocks=4;num_blocks_per_row=math.sqrt(num_blocks)
r_block = rank_A_and_B/(num_blocks*num_blocks_per_row)
print(f"r_block is: {r_block}")

r_block is: 2.0


In [4]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

image_processor = ViTImageProcessor.from_pretrained(model_name)
VIT_PRETRAINED = ViTModel.from_pretrained(model_name)

# instantiating the lora vit based backbone model
lora_vit_base = LocalizedLoraVit(vit_model=VIT_PRETRAINED,
                                r_block=int(r_block),
                                alpha=32,
                                num_blocks_per_row=int(num_blocks_per_row),# per
                                )


# # instantiate the custom segmentation head
# check_seg_head_model = CustomSegHead(hidden_dim=768, num_classes=3, patch_size=16, image_size=224) #  do not need to do that, the SegViT model will do it automatically

# instantiate the segmentation model
vit_seg_model = SegViT(vit_model=lora_vit_base,image_size=224, patch_size=16
                    , dim= 768,
                    n_classes=3)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
lora_vit_base.num_trainable_params

147456

In [6]:
from lora_vit import LoraVit

In [13]:
image_processor2 = ViTImageProcessor.from_pretrained(model_name)
VIT_PRETRAINED2 = ViTModel.from_pretrained(model_name)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
VIT_PRETRAINED2

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

In [5]:
#I just noticed this warning: "Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."

In [6]:
# get path of image and mask files
try:
    base_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g. in Jupyter notebook or interactive sessions apparently), fallback to cwd
    base_dir = os.getcwd()

# Suppose your dataset is in a folder named 'data' inside the project root:
data_dir = os.path.join(base_dir, 'data_oxford_iiit')

# # Then you can define image and trimap paths relative to that
image_folder = os.path.join(data_dir, 'resized_images')
trimap_folder = os.path.join(data_dir, 'resized_masks')

# create dataloaders
train_dl, val_dl, test_dl = get_pet_dataloaders(
    image_folder=image_folder,
    mask_folder=trimap_folder,
    DatasetClass=PreprocessedPetDataset,
    all_data=False,
    num_datapoints=40, # right now I am just testing whether the things would work or not. I will intentionally overtrain it.
    batch_size=34
) 

[INFO] Using only 40 datapoints out of 7390 total files.
train_size: 32, val_size: 4 and test_size: 4


In [7]:
# define the input parameters for the trainer
trainer_input_params = {
    "model": vit_seg_model , 
    "optimizer": torch.optim.Adam(params=vit_seg_model.parameters(), lr=5e-4),
#     torch.optim.Adam([
#     {"params": vit_seg_model.backbone_parameters, "lr": 1e-5},
#     {"params": vit_seg_model.head_parameters, "lr": 1e-4},
# ]),
    #"lr": 1e-4, # I do not think we need to have this learning rate seperately. need to fix it here :)
    "criterion": log_cosh_dice_loss,  # or log_cosh_dice_loss, whichever you want to use
    "num_epoch": 8,
    "dataloaders": {
        "train": train_dl,  # replace with your actual DataLoader
        "val": val_dl       # replace with your actual DataLoader
    },
    "use_trap_scheduler": False,             # or your scheduler instance if you use one
    "device": "cpu",#"cuda" if torch.cuda.is_available() else "cpu",
    #"model_kwargs": {},            # add any extra forward() kwargs if needed
    "criterion_kwargs": {
        "num_classes": 3,
        "epsilon": 1e-6,
        # "return_metrics": False    # usually False for training, True for validation if you want metrics ## WE DO NOT NEED THIS HERE
    }
    #,"want_backbone_frozen_initially": False,
    #"freeze_epochs":2
}


In [8]:
# instantiate the trainer:
trainer_seg_model = trainer(**trainer_input_params)

In [9]:
trainer_seg_model.train()

Epochs:  12%|█▎        | 1/8 [01:37<11:25, 97.91s/it]

Epoch [1/8] - Train Loss: 0.2279 | Val Loss: 0.1580 | Dice score: 0.42288434505462646 |IOU score: 0.3040440 


Epochs:  25%|██▌       | 2/8 [03:38<11:08, 111.46s/it]

Epoch [2/8] - Train Loss: 0.1375 | Val Loss: 0.1140 | Dice score: 0.5133711695671082 |IOU score: 0.4178324 


Epochs:  38%|███▊      | 3/8 [05:28<09:13, 110.70s/it]

Epoch [3/8] - Train Loss: 0.1106 | Val Loss: 0.0993 | Dice score: 0.5468373894691467 |IOU score: 0.4645833 


Epochs:  50%|█████     | 4/8 [07:06<07:02, 105.62s/it]

Epoch [4/8] - Train Loss: 0.0941 | Val Loss: 0.0917 | Dice score: 0.5650730729103088 |IOU score: 0.4912463 


Epochs:  62%|██████▎   | 5/8 [08:48<05:13, 104.48s/it]

Epoch [5/8] - Train Loss: 0.0792 | Val Loss: 0.0882 | Dice score: 0.5736867785453796 |IOU score: 0.5042699 


Epochs:  75%|███████▌  | 6/8 [10:36<03:31, 105.69s/it]

Epoch [6/8] - Train Loss: 0.0692 | Val Loss: 0.0865 | Dice score: 0.5781173706054688 |IOU score: 0.5110564 


Epochs:  88%|████████▊ | 7/8 [12:22<01:45, 105.49s/it]

Epoch [7/8] - Train Loss: 0.0661 | Val Loss: 0.0850 | Dice score: 0.5817290544509888 |IOU score: 0.5164741 


Epochs: 100%|██████████| 8/8 [13:54<00:00, 104.36s/it]

Epoch [8/8] - Train Loss: 0.0632 | Val Loss: 0.0835 | Dice score: 0.5855134129524231 |IOU score: 0.5219638 





In [None]:
# yeah I had a feeling that we might get some error messahe here