In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision.models import ViT_B_16_Weights
import copy

from torchvision.datasets import Flowers102
from torchvision import transforms

from adapters.LoRAs import *
from adapters.VeRAs import *
from adapters.IA3 import *
from adapters.PEFTclass import *

import os

In [2]:
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = torchvision.models.vit_b_16(weights=weights)

preprocess = weights.transforms()

In [3]:
model.encoder.layers[0]

EncoderBlock(
  (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (self_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLPBlock(
    (0): Linear(in_features=768, out_features=3072, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
)

In [4]:
nb_mlps = 0
nb_att = 0

for layer in model.encoder.layers: 
    nb_mlps += 1
    nb_att += layer.self_attention.num_heads

print(nb_mlps, nb_att)


12 144


In [5]:
modified_model = PEFTViT(model, nb_classes=256, method='lora', attention=True, qkv = [False, True, True])

in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144
in 768 out 3072
4
in 3072 out 768
4
6144


In [6]:
modified_model

PEFTViT(
  (model): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): LoRASelfAttention(
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): LoRALinear(
              (original_layer): Linear(in_features=768, out_features=3072, bias=True)
            )
            (

In [7]:
mlplayer = model.encoder.layers[0].self_attention
model.encoder.layers[0].mlp[3]

Linear(in_features=3072, out_features=768, bias=True)

In [13]:
369409-709

368700

In [9]:
r = 4
12*2*r*768*3

221184

In [None]:
def check_PEFTViT():
    print("--- 1. SETUP: Loading Base Model ---")
    base_model = torchvision.models.vit_b_16(weights=weights)

    
    print("\n--- 2. WRAP: Creating PEFTViT ---")
    # Wrap the model
    model = PEFTViT(base_model, nb_classes = 1, method='vera', r=4, attention= True, qkv = [False, True, True])

    
    print("\n--- 3. FREEZE: Setting Gradients ---")
    # Simulate the freezing logic: Freeze everything EXCEPT head and LoRA
    model.set_trainable_parameters()
            
    model.print_trainable_parameters()
    # Expected: Small % (only head + lora params)

    print("\n--- 4. MOCK TRAINING (Modifying Weights) ---")
    # We manually modify a weight to prove saving works
    with torch.no_grad():
        # Modify one of the LoRA parameters


        original_val = model.model.encoder.layers[0].mlp[0].vera_middle[0].item()
        model.model.encoder.layers[0].mlp[0].vera_middle[0] += 1.0
        new_val = model.model.encoder.layers[0].mlp[0].vera_middle[0].item()

    print(f"   Modified lora_A[0,0]: {original_val:.4f} -> {new_val:.4f}")

    print("\n--- 5. SAVE: Testing state_dict override ---")
    save_path = "temp_lora_checkpoint.pt"
    #save_path = "temp_vera_checkpoint.pt"

    # This calls our custom state_dict() method!
    torch.save(model.state_dict(), save_path)
    
    file_size = os.path.getsize(save_path) / 1024
    print(f"   Saved checkpoint size: {file_size:.2f} KB")
    # If this were the full model, it would be ~45MB. 
    # Since it is <100KB, we know it saved ONLY the adapter.

    print("\n--- 6. LOAD: Testing load_state_dict override ---")
    # Create a fresh model (original weights) to prove we are loading the changes
    fresh_base = torchvision.models.vit_b_16(weights=weights)
    fresh_base.head = nn.Linear(512, 100)
    
    # Wrap it
    new_model = PEFTViT(fresh_base, nb_classes=10, method='lora', r=4)

    # Verify it has the OLD value before loading
    print(f"   Value before load: {new_model.model.encoder.layers[0].mlp[0].lora_A[0, 0].item():.4f}")
    #print(f"   Value before load: {new_model.model.encoder.layers[0].mlp[0].vera_middle[ 0].item():.4f}")

    
    
    # Load the saved adapter
    # This calls our custom load_state_dict() method!
    saved_weights = torch.load(save_path)
    new_model.load_state_dict(saved_weights)
    
    print(f"   Value after load:  {new_model.model.encoder.layers[0].mlp[0].lora_A[0, 0].item():.4f}")
    #print(f"   Value after load:  {new_model.model.encoder.layers[0].mlp[0].vera_middle[0].item():.4f}")
    
    if abs(new_model.model.encoder.layers[0].mlp[0].lora_A[0, 0].item() - new_val) < 1e-5:
        print("\n✅ SUCCESS: Weights restored correctly.")
    else:
        print("\n❌ FAILURE: Weights do not match.")
        
    # Cleanup
    if os.path.exists(save_path):
        os.remove(save_path)

check_PEFTViT()

--- 1. SETUP: Loading Base Model ---

--- 2. WRAP: Creating PEFTViT ---

--- 3. FREEZE: Setting Gradients ---
[Info] Trainable: 65,377 | Total: 85,937,761 | %: 0.08%

--- 4. MOCK TRAINING (Modifying Weights) ---
   Modified lora_A[0,0]: 1.0000 -> 2.0000

--- 5. SAVE: Testing state_dict override ---
   Saved checkpoint size: 287.57 KB

--- 6. LOAD: Testing load_state_dict override ---
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
   Value before load: -0.0059


RuntimeError: Error(s) in loading state_dict for VisionTransformer:
	size mismatch for heads.weight: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).
	size mismatch for heads.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([10]).

In [23]:
check_PEFTViT()

--- 1. SETUP: Loading Base Model ---

--- 2. WRAP: Creating PEFTViT ---

--- 3. FREEZE: Setting Gradients ---
[Info] Trainable: 65,329 | Total: 85,900,849 | %: 0.08%

--- 4. MOCK TRAINING (Modifying Weights) ---
   Modified lora_A[0,0]: 1.0000 -> 2.0000

--- 5. SAVE: Testing state_dict override ---
   Saved checkpoint size: 287.57 KB

--- 6. LOAD: Testing load_state_dict override ---
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
in 768 out 3072
4
in 3072 out 768
4
   Value before load: -0.0000


RuntimeError: Error(s) in loading state_dict for VisionTransformer:
	size mismatch for heads.weight: copying a param with shape torch.Size([1, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).
	size mismatch for heads.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([10]).