In [12]:
import numpy as np 
from PIL import Image
import requests
import torch

In [13]:
from transformers import ViTModel, ViTImageProcessor

In [14]:
## load the pre-trained ViT-model (86 Mil)
model_name = 'google/vit-base-patch16-224'

# 
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")
# 86 million model

Trainable parameters: 86,389,248


In [16]:
# Send to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# model.eval()

In [None]:
## I need to start working with the VIT model

## Things to do:

## Stage 0. We need to come up with some framework of the project and what exactly are we going to do

## Stage 1. Set up a Transformer based segmentation model using ViT+LoRA
0. Understand and play around with the models
1. Load the pre-trained model
    1.If required, we might have to switch to timm
2. Define the class for LoRA:
    2.1: Either can be set up using `peft` library
    2.2: Build a custom LoRA module using Pytorch.
3. Apply LoRA to attention layers
4. Define (code) and add a segmentation head
    1. Simple MLP or some more complicated architecture? We need to look into it.
5. Training Set-Up
    5.1. Loss for the segmentation task
    5.2 optimizer
6. Training:
    6.1 typical LoRA rank (r): 4 or 8 -  a good balance for fine tuning
    6.2 How many parameters to freeze

Obtain some simple acceptable results for this. 

## Stage 2. Try  1 more different versions of LoRA for the same task:
0. Serial LoRA for the ViT (recent paper)
1. Other  DoRA, etc. (check that review paper for other version)

## Stage 3: Try 1 other different fine tuning strategy, maybe some other adapter based approach/ IA3 etc. 




LoRA implementation

In [25]:
import torch.nn as nn
import torch
from torch import Tensor
import math
from safetensors.torch import save_file, load_file

In [None]:
### implement the LoRAlayer
class LoraLayer(nn.Module):
    """
    This class implements the LoRA layer
    wt_linear: Weight (which would be left frozen)
    A,B: Lower matrices which constitute delta W
    rank_lora: Rank of A and B matrices
    alpha: some weighing factor
    """
    def __init__(self, wt: nn.Module, A: nn.Module, B: nn.Module, rank_lora: int, alpha: int):
        super().__init__()
        self.wt = wt
        self.A, self.B = A, B
        self.rank = rank_lora
        self.alpha = alpha

    def forward(self,x):

        x=self.wt(x) + (self.alpha / self.rank) * self.B(self.A(x))
        return x

In [None]:
## implement the LoRA VIT
### ? Things to check ?
## ? vit_model.transformer.blocks
## ? vit_model.parameters()
class LoraVit(nn.Module):
    """ 
    This class is to introduce LoRA layer to the model.
    vit_model: pre-trained vit model
    r: rank
    alpha: scaling strength for lora
    lora_layers: Layers we want to apply lora to
    """
    def __init__(self, vit_model, r:int, alpha:int, lora_layers=None):
        super().__init__()

        assert r>0, "r (rank of lora matrices) must be >0"
        assert alpha>0 , "alpha >0"

        if lora_layers:
            self.lora_layers = lora_layers
        else: ## apply lora to all
            ## ? here I need to see how will I check the number of transformer blocks
            self.lora_layers = list(range(len(vit_model.transformer.blocks)))
        
        dim = ... # Dimension of the input vector to the transformer
        # dim = vit_model.transformer.blocks[0].attn.proj_q.in_features
        
        # freeze the parameters
        ## ? How can we invoke paramters in the vit_model
        for param in vit_model.parameters():
            param.requires_grad = False
        
        ## for storing the lora parameters
        self.list_q_As, self.list_q_Bs = [], []
        self.list_v_As, self.list_v_Bs = [], []

        # replace the normal q and V with LoRA layers
        for layer_i, block_i in enumerate(vit_model.transformer.blocks):
            if layer_i not in self.lora_layers:
                continue # (next iteration)
            w_q_linear = block_i.attn.proj_q
            w_v_linear = block_i.attn.proj_v
            # Q and V layers' weights

            ## do I need to initialise weights here? or should I do it after this loop?
            a_linear_q = nn.Linear(dim, r, bias=False)
            b_linear_q = nn.Linear(r, dim, bias=False)
            a_linear_v = nn.Linear(dim, r, bias=False)
            b_linear_v = nn.Linear(r, dim, bias=False)

            # Append lora params to the list
            self.list_q_As.append(a_linear_q); self.list_q_Bs.append(b_linear_q)
            self.list_v_As.append(a_linear_v); self.list_v_Bs.append(b_linear_v) 

            # replace with LoRA layer
            block_i.attn.proj_q = LoraLayer(w_q_linear, a_linear_q, b_linear_q, r, alpha)
            block_i.attn.proj_v = LoraLayer(w_v_linear, a_linear_v, b_linear_v, r, alpha)

        self.init_lora_layers()# initialise the lora parameters
        self.vit_lora = vit_model

    def init_lora_layers(self) -> None:
        """
        Method to initialise the LoRA layers. A would be initalised using normal distribution and B as 0 i believe
        A initialized with small normal values, B to zeros
        """
        for A in self.list_q_As + self.list_v_As:
            nn.init.kaiming_uniform_(A.weight, a=math.sqrt(5))
            # if you want to use normal distn for initialisation: nn.init.normal_(A.weight, std=1e-3)
        for B in self.list_q_Bs + self.list_v_Bs:
            nn.init.zeros_(B.weight)

        
    def save_lora_params(self, filename:str):
        """ 

        """
        # Create dict for safetensors, keys = str, values = tensors
        state_dict = {}

        # Save lora_layers as a tensor
        state_dict['lora_layers'] = torch.tensor(self.lora_layers, dtype=torch.int32)

        # Save all LoRA params with keys indicating their index and type
        for i, (a_q, b_q, a_v, b_v) in enumerate(zip(self.list_q_As, self.list_q_Bs,self.list_v_As, self.list_v_Bs)):
            state_dict[f'q_A_{i}'] = a_q.weight.data
            state_dict[f'q_B_{i}'] = b_q.weight.data
            state_dict[f'v_A_{i}'] = a_v.weight.data
            state_dict[f'v_B_{i}'] = b_v.weight.data

        save_file(state_dict, filename)
        print(f"Saved LoRA params and layers to {filename}")

    def load_lora_params(self, filename:str):
        """ 
        
        """
        loaded = load_file(filename)
        # Load lora_layers first (convert to list)
        loaded_layers = loaded['lora_layers'].tolist()

        # If current self.lora_layers differs, you might want to reset or warn
        if loaded_layers != self.lora_layers:
            print("Warning: loaded lora_layers differ from current model's layers. Adjusting...")
            self.lora_layers = loaded_layers
            # Optionally: re-initialize LoRA modules for these layers here

        # Now load weights into LoRA modules
        for i, (a_q, b_q, a_v, b_v) in enumerate(zip(self.list_q_As, self.list_q_Bs, self.list_v_As, self.list_v_Bs)):
            a_q.weight.data.copy_(loaded[f'q_A_{i}'])
            b_q.weight.data.copy_(loaded[f'q_B_{i}'])
            a_v.weight.data.copy_(loaded[f'v_A_{i}'])
            b_v.weight.data.copy_(loaded[f'v_B_{i}'])

        print(f"Loaded LoRA params and layers from {filename}")


    def forward(self, x:Tensor) -> Tensor:
        """ 
        run the LoRA vit
        """
        return self.vit_lora(x)

In [None]:
# here we need to add code to introduce some model for the 
class SerialLoRALayer(nn.Module):
    pass

class SerialLoraVit():
    pass

In [None]:
# run the experiment and get it moving :)

In [None]:
# more advanced adapter based method
# what are the other methods that we can have for this work