In [1]:
import numpy as np 
from PIL import Image
import requests
import torch

In [2]:
from transformers import ViTModel, ViTImageProcessor

In [3]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

# 
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")
# 86 million model

Trainable parameters: 86,389,248


In [5]:
# Send to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# model.eval()

In [9]:
model.encoder.layer[0].attention.attention.query.in_features


768

In [10]:
for layer_i, block_i in enumerate(model.encoder.layer):
    print("block_i.attention.attention.query:",block_i.attention.attention.query)
    print("block_i.attention.attention.value:",block_i.attention.attention.value)
    break


block_i.attention.attention.query: Linear(in_features=768, out_features=768, bias=True)
block_i.attention.attention.value: Linear(in_features=768, out_features=768, bias=True)


This is the architecture of the vit model
```
model.encoder.layer[0]
└── attention
    └── attention
        ├── query: nn.Linear
        ├── key:   nn.Linear
        └── value: nn.Linear
```

# Things to do:


## Stage 0. We need to come up with some framework of the project and what exactly are we going to do

## Stage 1. Set up a Transformer based segmentation model using ViT+LoRA
0. Understand and play around with the models
1. Load the pre-trained model
    1.If required, we might have to switch to timm
2. Define the class for LoRA:
    2.1: Either can be set up using `peft` library
    2.2: Build a custom LoRA module using Pytorch.
3. Apply LoRA to attention layers
4. Define (code) and add a segmentation head
    1. Simple MLP or some more complicated architecture? We need to look into it.
5. Training Set-Up
    5.1. Loss for the segmentation task
    5.2 optimizer
6. Training:
    6.1 typical LoRA rank (r): 4 or 8 -  a good balance for fine tuning
    6.2 How many parameters to freeze

Obtain some simple acceptable results for this. 

## Stage 2. Try  1 more different versions of LoRA for the same task:
0. Serial LoRA for the ViT (recent paper)
1. Other  DoRA, etc. (check that review paper for other version)

## Stage 3: Try 1 other different fine tuning strategy, maybe some other adapter based approach/ IA3 etc. 




# LoRA implementation

In [26]:
import torch.nn as nn
import torch
from torch import Tensor
import math
from safetensors.torch import save_file, load_file

In [27]:
### implement the LoRAlayer
class LoraLayer(nn.Module):
    """
    This class implements the LoRA layer
    wt_linear: Weight (which would be left frozen)
    A,B: Lower matrices which constitute delta W
    rank_lora: Rank of A and B matrices
    alpha: some weighing factor
    """
    def __init__(self, wt: nn.Module, A: nn.Module, B: nn.Module, rank_lora: int, alpha: int):
        super().__init__()
        self.wt = wt
        self.A, self.B = A, B
        self.rank = rank_lora
        self.alpha = alpha

    def forward(self,x):

        x=self.wt(x) + (self.alpha / self.rank) * self.B(self.A(x))
        return x

In [None]:
## implement the LoRA VIT
### ? Things to check ?
## ? vit_model.transformer.blocks --> replaced with model.encoder.layer
## ? block.attn.proj_q/p ---> replace with block.attention.attention.query/value.
## ? vit_model.parameters() <--> Stays the same

class LoraVit(nn.Module):
    """ 
    This class is to introduce LoRA layer to the model.
    vit_model: pre-trained vit model
    r: rank
    alpha: scaling strength for lora
    lora_layers: Layers we want to apply lora to
    """
    def __init__(self, vit_model, r:int, alpha:int, lora_layers=None):
        super().__init__()

        assert r>0, "r (rank of lora matrices) must be >0"
        assert alpha>0 , "alpha >0"

        if lora_layers:
            self.lora_layers = lora_layers
        else: ## apply lora to all
            ## ? here I need to see how will I check the number of transformer blocks
            self.lora_layers = list(range(len(vit_model.encoder.layer)))
        
        # Dimension of the input vector to the transformer
        dim = vit_model.encoder.layer[0].attention.attention.query.in_features 
        
        # freeze the parameters
        ## ? How can we invoke paramters in the vit_model
        for param in vit_model.parameters():
            param.requires_grad = False
        
        ## for storing the lora parameters
        self.list_q_As, self.list_q_Bs = [], []
        self.list_v_As, self.list_v_Bs = [], []

        # replace the normal q and V with LoRA layers
        for layer_i, block_i in enumerate(vit_model.encoder.layer):
            if layer_i not in self.lora_layers:
                continue # (next iteration)
            w_q_linear = block_i.attention.attention.query
            w_v_linear = block_i.attention.attention.value
            # Q and V layers' weights

            ## do I need to initialise weights here? or should I do it after this loop?
            a_linear_q = nn.Linear(dim, r, bias=False)
            b_linear_q = nn.Linear(r, dim, bias=False)
            a_linear_v = nn.Linear(dim, r, bias=False)
            b_linear_v = nn.Linear(r, dim, bias=False)

            # Append lora params to the list
            self.list_q_As.append(a_linear_q); self.list_q_Bs.append(b_linear_q)
            self.list_v_As.append(a_linear_v); self.list_v_Bs.append(b_linear_v) 

            # replace with LoRA layer
            block_i.attn.proj_q = LoraLayer(w_q_linear, a_linear_q, b_linear_q, r, alpha)
            block_i.attn.proj_v = LoraLayer(w_v_linear, a_linear_v, b_linear_v, r, alpha)

        self.init_lora_layers()# initialise the lora parameters
        self.vit_lora = vit_model

    def init_lora_layers(self) -> None:
        """
        Method to initialise the LoRA layers. A would be initalised using normal distribution and B as 0 i believe
        A initialized with small normal values, B to zeros
        """
        for A in self.list_q_As + self.list_v_As:
            nn.init.kaiming_uniform_(A.weight, a=math.sqrt(5))
            # if you want to use normal distn for initialisation: nn.init.normal_(A.weight, std=1e-3)
        for B in self.list_q_Bs + self.list_v_Bs:
            nn.init.zeros_(B.weight)

        
    def save_lora_params(self, filename:str): 
        """ 

        """
        assert filename.endswith(".safetensors"), "File name is required to have .safetensors extensions"

        # Create dict for safetensors, keys = str, values = tensors
        state_dict = {}

        # Save lora_layers as a tensor
        state_dict['lora_layers'] = torch.tensor(self.lora_layers, dtype=torch.int32)

        # Save all LoRA params with keys indicating their index and type
        for i, (a_q, b_q, a_v, b_v) in enumerate(zip(self.list_q_As, self.list_q_Bs,self.list_v_As, self.list_v_Bs)):
            state_dict[f'q_A_{i}'] = a_q.weight.data
            state_dict[f'q_B_{i}'] = b_q.weight.data
            state_dict[f'v_A_{i}'] = a_v.weight.data
            state_dict[f'v_B_{i}'] = b_v.weight.data

        save_file(state_dict, filename)
        print(f"Saved LoRA params and layers to {filename}")

    def load_lora_params(self, filename:str):
        """ 
        
        """

        assert filename.endswith(".safetensors"), "File name is required to have .safetensors extensions"
        loaded = load_file(filename)
        # Load lora_layers first (convert to list)
        loaded_layers = loaded['lora_layers'].tolist()

        # If current self.lora_layers differs, you might want to reset or warn
        if loaded_layers != self.lora_layers:
            print("Warning: loaded lora_layers differ from current model's layers. Adjusting...")
            ## maybe here I need to add assertion error so that there is not any major mistake later on
            self.lora_layers = loaded_layers
            # Optionally: re-initialize LoRA modules for these layers here

        # Now load weights into LoRA modules
        for i, (a_q, b_q, a_v, b_v) in enumerate(zip(self.list_q_As, self.list_q_Bs, self.list_v_As, self.list_v_Bs)):
            a_q.weight.data.copy_(loaded[f'q_A_{i}'])
            b_q.weight.data.copy_(loaded[f'q_B_{i}'])
            a_v.weight.data.copy_(loaded[f'v_A_{i}'])
            b_v.weight.data.copy_(loaded[f'v_B_{i}'])

        print(f"Loaded LoRA params and layers from {filename}")


    def forward(self, x:Tensor) -> Tensor:
        """ 
        run the LoRA vit
        """
        return self.vit_lora(x)

In [None]:
# here we need to add code to introduce some model for the 
class SerialLoRALayer(nn.Module):
    pass

class SerialLoraVit():
    pass

### TO do:

1. ~~create a class for the segmentation task (or maybe look into some pre-defined architecture, maybe we might just need to use a simple mlp based architecture?):~~ 
2. ~~Look into the loss functions one can use for the task here.~~

In [13]:
import torch.nn.functional as F

In [None]:
class CustomSegHead(nn.Module):
    """ 
    Custom defined segmentation head. This module takes the patch embeddings from a ViT backbone and
    processes them into a full-resolution segmentation map

    Arguments:
    ----------
    hidden_dim : The dimension of the patch embeddings from the ViT (e.g., 768 for ViT-Base).
        
    num_classes : The number of output segmentation classes (e.g., 21 for PASCAL VOC, 10 for pets dataset).
        
    patch_size : The size of each image patch in pixels (e.g., 16 for 16x16 patches). This also determines how much to upsample the output feature map.

    image_size : The height/width of the input image in pixels (assumes square images). Used to calculate how many patches per spatial dimension.
    """
    def __init__(self, hidden_dim:int, num_classes:int, patch_size:int, image_size:int):
        super().__init__()
        
        # Store the patch size (e.g., 16 for 16x16 patches)
        self.patch_size = patch_size
        
        # Calculate the number of patches per spatial dimension (assuming square image and patches)
        self.num_patch_per_dim = image_size // patch_size

        # First conv layer: reduces channels from hidden_dim to half, with 3x3 kernel for local spatial context
        self.conv1 = nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=3, padding=1)
        
        # ReLU activation after conv1 for non-linearity
        self.relu = nn.ReLU()
        
        # Final conv layer: maps features to the number of classes with 1x1 convolution (pixel-wise classification)
        self.conv2 = nn.Conv2d(hidden_dim // 2, num_classes, kernel_size=1)

    def forward(self, x):  # x shape: (B, N, D)
        """
        Input x's shape: (B,N,D)
        where B= Extract batch size, N= number of patches, and D= embedding dimension.
        """
        # Extract batch size (B), number of patches (N), and embedding dimension (D)
        B, N, D = x.shape 
        
        # Calculate height and width of patch grid (assume square grid)
        H = W = self.num_patch_per_dim
        # Note: N= HXW
        
        # current shape is : (B,N,D)
        # Rearrange tensor to (B, D, H, W) so it can be processed by Conv2d layers
        # 1- permute swaps the dimensions so channels come before spatial dims
            # -- Swaps dimensions 1 and 2 -> new shape is: (B,D,N) 
        # 2- reshape organizes tokens into 2D spatial layout: (B,D,N)--> (B,D,H,W)
        x = x.permute(0, 2, 1).reshape(B, D, H, W)
        
        # Apply first convolution and ReLU activation to learn local spatial features
        x = self.relu(self.conv1(x))
        
        # Apply final 1x1 convolution to produce per-class scores for each spatial location
        x = self.conv2(x)
        
        # Upsample output to match original image resolution
        # scale_factor = patch size because each patch corresponds to patch_size x patch_size pixels
        x = F.interpolate(x, scale_factor=self.patch_size, mode='bilinear', align_corners=False)
        
        # Return segmentation logits of shape (B, num_classes, H_img, W_img)
        return x


define a vit-seg model: (club together vit-model-feature-extraction with Custom Seg Head)

In [None]:
# add the line below while creating a new file:
#   -from custom_seg_head import CustomSegHead  # Ensure this path is correct

class SegViT(nn.Module):
    """
    Wraps a ViT model and applies a custom segmentation head to the output.
    Converts patch embeddings into full-resolution segmentation masks.
    """

    def __init__(self,vit_model: nn.Module,
                    image_size: int,
                    patch_size: int,
                    dim: int,
                    n_classes: int,
                    ) -> None:
        super().__init__()
        
        self.vit = vit_model

        # Remove classification head if present
        if hasattr(self.vit, "fc"):
            del self.vit.fc
        elif hasattr(self.vit, "lora_vit") and hasattr(self.vit.lora_vit, "fc"):
            del self.vit.lora_vit.fc

        # Use custom segmentation head
        self.seg_head = CustomSegHead(
            hidden_dim=dim,
            num_classes=n_classes,
            patch_size=patch_size,
            image_size=image_size
        )

    def forward(self, x):
        """
        Forward pass:
        1. Get ViT patch embeddings (B, N+1, D)
        2. Remove class token (CLS) → (B, N, D)
        3. Feed to custom segmentation head
        """
        x = self.vit(x)  # (B, N+1, D)
        x = x[:, :-1, :]  # Remove CLS token → (B, N, D)
        x = self.seg_head(x)  # (B, num_classes, H, W)
        # 
        return x

task for today (31-july-2025):
0. Investigate the dataset
1. Check whether the models are working as intended before training.
2. ~~Define the loss functions.~~

    * ~~Cross entropy loss and weighted CE loss (-log loss).~~
        * `nn.CrossEntropyLoss`
        * In pytorch, input should be unnormalised logits.

    * ~~Dice loss.~~ 
    * ~~Log-Cosh Dice loss (finite and continuous gradients):~~
        * Original paper: https://arxiv.org/pdf/2006.14822
        * Review paper on loss: https://arxiv.org/html/2312.05391v1/#S3
3. Define the trainer class. 
4. test whether the code is working as intended for training.

Defining the loss functions for image segmentation tasks

1. Done
2. need to so some checks in the code below to confirm my implementation is indeed correct. 

In [None]:
import torch
import torch.nn.functional as F

In [None]:
def cross_entropy():
    pass

def dice_loss(logits, targets, num_classes, epsilon=1e-6):
    """
    Computes multi-class Dice loss.

    Args:
        logits: Tensor of shape (N, C, H, W) — raw model outputs
        targets: Tensor of shape (N, H, W) — ground truth class indices
        num_classes: int — number of classes
        epsilon: float — smoothing factor to avoid division by zero

    Returns:
        Scalar Dice loss
    """
    # Convert targets to one-hot encoding
    targets_one_hot = F.one_hot(targets, num_classes).permute(0, 3, 1, 2).float()  # (N, C, H, W)

    # Apply softmax to logits 
    probs = F.softmax(logits, dim=1)  # (N, C, H, W)

    # Calculate per-class Dice score
    ### ?? Check whether this work as intended or not ??
    dims = (0, 2, 3)  # sum over batch, height, width
    intersection = torch.sum(probs * targets_one_hot, dims)
    cardinality = torch.sum(probs + targets_one_hot, dims)

    dice_per_class = (2. * intersection + epsilon) / (cardinality + epsilon)
    
    #?? I need to test this mean because I am not sure whether the dimension wise it is correct or not ??
    # Handling binary vs multi-class
    if num_classes == 2:
        dice_loss_value = 1. - dice_per_class[1]  # Only use foreground class
    else:
        dice_loss_value = 1. - dice_per_class.mean()  # Average over all classes
        
    return dice_loss_value


def log_cosh_dice_loss(logits, targets, num_classes, epsilon=1e-6):
    """
    Computes log-cosh of the multi-class Dice loss.

    Args:
        logits: Tensor of shape (N, C, H, W)
        targets: Tensor of shape (N, H, W)
        num_classes: int
        epsilon: float

    Returns:
        Scalar log-cosh Dice loss
    """
    dice = dice_loss(logits, targets, num_classes, epsilon)
    return torch.log(torch.cosh(dice))


Download the relevant dataset.

In [30]:
# run the experiment and get it moving :)

In [31]:
# more advanced adapter based method
# what are the other methods that we can have for this work