In [1]:
# basic torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim

# hyperparameter optimization rtd
import optuna
import wandb

# os related
import os

# file handling

# segmentation model
from replora_vit import RepLoraVit
from segmentation_model import SegViT
from segmentation_head import CustomSegHead

# dataset class
from pet_dataset_class import PreprocessedPetDataset

# dataloaders
from create_dataloaders import get_pet_dataloaders

# trainer
from trainer import trainer

# loss and metrics
from loss_and_metrics_seg import * # idk what to import here tbh. Need to look into it

# data plotting
from data_plotting import plot_random_images_and_trimaps_2

# modules for loading the vit model
from transformers import ViTModel, ViTImageProcessor


In [2]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

image_processor = ViTImageProcessor.from_pretrained(model_name)
VIT_PRETRAINED = ViTModel.from_pretrained(model_name)

# instantiating the lora vit based backbone model
lora_vit_base = RepLoraVit(vit_model=VIT_PRETRAINED,
                        r=4,alpha = 16, lora_layers = [1,2,3,4])


# # instantiate the custom segmentation head
# check_seg_head_model = CustomSegHead(hidden_dim=768, num_classes=3, patch_size=16, image_size=224) #  do not need to do that, the SegViT model will do it automatically

# instantiate the segmentation model
vit_seg_model = SegViT(vit_model=lora_vit_base,image_size=224, patch_size=16
                    , dim= 768,
                    n_classes=3)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
#I just noticed this warning: "Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."

In [4]:
# get path of image and mask files
try:
    base_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
    # __file__ is not defined (e.g. in Jupyter notebook or interactive sessions apparently), fallback to cwd
    base_dir = os.getcwd()

# Suppose your dataset is in a folder named 'data' inside the project root:
data_dir = os.path.join(base_dir, 'data_oxford_iiit')

# # Then you can define image and trimap paths relative to that
image_folder = os.path.join(data_dir, 'resized_images')
trimap_folder = os.path.join(data_dir, 'resized_masks')

# create dataloaders
train_dl, val_dl, test_dl = get_pet_dataloaders(
    image_folder=image_folder,
    mask_folder=trimap_folder,
    DatasetClass=PreprocessedPetDataset,
    all_data=False,
    num_datapoints=80, # right now I am just testing whether the things would work or not. I will intentionally overtrain it.
    batch_size=32
) 

[INFO] Using only 80 datapoints out of 7390 total files.
train_size: 64, val_size: 8 and test_size: 8


In [5]:
# define the input parameters for the trainer
trainer_input_params = {
    "model": vit_seg_model , 
    "optimizer": torch.optim.Adam(params=vit_seg_model.parameters(), lr=5e-5),
#     torch.optim.Adam([
#     {"params": vit_seg_model.backbone_parameters, "lr": 1e-5},
#     {"params": vit_seg_model.head_parameters, "lr": 1e-4},
# ]),
    #"lr": 1e-4, # I do not think we need to have this learning rate seperately. need to fix it here :)
    "criterion": log_cosh_dice_loss,  # or log_cosh_dice_loss, whichever you want to use
    "num_epoch": 8,
    "dataloaders": {
        "train": train_dl,  # replace with your actual DataLoader
        "val": val_dl       # replace with your actual DataLoader
    },
    "use_trap_scheduler": False,             # or your scheduler instance if you use one
    "device": "cpu",#"cuda" if torch.cuda.is_available() else "cpu",
    #"model_kwargs": {},            # add any extra forward() kwargs if needed
    "criterion_kwargs": {
        "num_classes": 3,
        "epsilon": 1e-6,
        # "return_metrics": False    # usually False for training, True for validation if you want metrics ## WE DO NOT NEED THIS HERE
    }
    #,"want_backbone_frozen_initially": False,
    #"freeze_epochs":2
}


In [6]:
# instantiate the trainer:
trainer_seg_model = trainer(**trainer_input_params)

In [None]:
trainer_seg_model.train()

In [None]:
# yeah I had a feeling that we might get some error messahe here