In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision.models import ViT_B_16_Weights
import copy

from torchvision.datasets import Flowers102
from torchvision import transforms

from LoRAs import *
from VeRAs import *
from IA3 import *

In [2]:
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = torchvision.models.vit_b_16(weights=weights)

In [13]:
mlplayer = model.encoder.layers[0].self_attention
mlplayer.embed_dim

768

In [6]:
def run_tests():
    print("--- Starting LoRA Implementation Tests ---")
    
    # Load Model (dummy weights are fine, but we use ImageNet for realism)
    print("1. Loading ViT-B/16...")
    model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
    
    # Isolate one attention layer
    original_layer = model.encoder.layers[0].self_attention
    
    # Create dummy input (Batch=1, Seq=197, Dim=768)
    # 197 = 1 class token + (14x14) patches
    dummy_input = torch.randn(1, 197, 768)
    
    # ---------------------------------------------------------
    # Test A: Does the new layer produce the exact same output?
    # ---------------------------------------------------------
    print("\nTest A: Output Consistency (Original vs LoRA with B=0)")
    
    # Init LoRA layer
    lora_layer = LoRASelfAttention(original_layer)

    matrices = { 'A_q': torch.randn(( 4, 768)), 'B_q': torch.randn((768, 4))}
    vera_layer = VeRASelfAttention(original_layer, matrices=matrices)

    print('lora qkv', lora_layer.which_proj)
    print('vera qkv', vera_layer.which_proj)
    
    # Set to eval mode to disable dropout in both
    original_layer.eval()
    lora_layer.eval()
    vera_layer.eval()
    
    with torch.no_grad():
        # Original Forward (Pass x, x, x manually or rely on its forward)
        # torchvision ViT passes x, x, x internaly
        out_orig, _ = original_layer(dummy_input, dummy_input, dummy_input, need_weights=False)
        
        # LoRA Forward
        out_lora, _ = lora_layer(dummy_input, dummy_input, dummy_input)
        out_vera, _ = vera_layer(dummy_input, dummy_input, dummy_input)
        
    # Check difference
    diff = torch.abs(out_orig - out_lora).max().item()
    diff_vera = torch.abs(out_orig - out_vera).max().item()
    print(f"   Max absolute difference lora: {diff:.8f}")
    print(f"   Max absolute difference vera: {diff_vera:.8f}")
    
    if diff < 1e-5:
        print("   ✅ LORA SUCCESS: Outputs match closely.")
    else:
        print("   ❌ LORA FAILURE: Outputs diverge too much.")

    if diff_vera < 1e-5:
        print("   ✅ VERA SUCCESS: Outputs match closely.")
    else:
        print("   ❌ VERA FAILURE: Outputs diverge too much.")


    # ---------------------------------------------------------
    # Test B: Does LoRA actually modify the output when trained?
    # ---------------------------------------------------------
    print("\nTest B: LoRA Activation Check")
    
    # Manually set LoRA B matrix to something non-zero to simulate training
    with torch.no_grad():
        lora_layer.lora_B_q.fill_(1.0)
    
    with torch.no_grad():
        out_lora_active, _ = lora_layer(dummy_input, dummy_input, dummy_input)
        
    diff_active = torch.abs(out_orig - out_lora_active).max().item()
    print(f"   Diff after activating LoRA: {diff_active:.4f}")
    
    if diff_active > 1e-3:
        print("   ✅ SUCCESS: LoRA is modifying the output.")
    else:
        print("   ❌ FAILURE: LoRA parameters are not affecting the output.")

    print("\nTest B: vera Activation Check")
    


    # Manually set VeRA b vector to something non-zero to simulate training
    with torch.no_grad():
        vera_layer.vera_b_q.fill_(1.0)
    
    with torch.no_grad():
        out_vera_active, _ = vera_layer(dummy_input, dummy_input, dummy_input)
        
    diff_active_vera = torch.abs(out_orig - out_vera_active).max().item()
    print(f"   Diff after activating LoRA: {diff_active_vera:.4f}")
    
    if diff_active_vera > 1e-3:
        print("   ✅ SUCCESS: VeRA is modifying the output.")
    else:
        print("   ❌ FAILURE: VeRA parameters are not affecting the output.")

    # ---------------------------------------------------------
    # Test C: Full Model Integration (Swapping the layer)
    # ---------------------------------------------------------
    print("\nTest C: Full Model Integration")
    
    # Swap the layer in the model
    model.encoder.layers[0].self_attention = lora_layer
    
    # Run full model forward pass
    try:
        logits = model(torch.randn(1, 3, 224, 224)) # Standard ImageNet input
        print(f"   Output shape: {logits.shape}")
        print("   ✅ SUCCESS: Full model forward pass complete.")
    except Exception as e:
        print(f"   ❌ FAILURE: Model crashed with error: {e}")


    # Swap the layer in the model
    model.encoder.layers[0].self_attention = vera_layer
    
    # Run full model forward pass
    try:
        logits = model(torch.randn(1, 3, 224, 224)) # Standard ImageNet input
        print(f"   Output shape: {logits.shape}")
        print("   ✅ SUCCESS: Full model forward pass complete.")
    except Exception as e:
        print(f"   ❌ FAILURE: Model crashed with error: {e}")

    # ---------------------------------------------------------
    # Test D: Parameter Freeze Check
    # ---------------------------------------------------------
    print("\nTest D: Parameter Counting")
    
    model.encoder.layers[0].self_attention = lora_layer

    # Freeze base model
    for param in model.parameters():
        param.requires_grad = False
        
    # Unfreeze only LoRA params
    lora_params = 0
    for name, param in model.named_parameters():
        if "lora_" in name:
            param.requires_grad = True
            lora_params += param.numel()
            
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   Total trainable parameters: {trainable_params}")
    
    if trainable_params == lora_params and lora_params > 0:
        print("   ✅ SUCCESS: Only LoRA parameters are trainable.")
    else:
        print(f"   ❌ FAILURE: Trainable params ({trainable_params}) != LoRA params ({lora_params})")




    model.encoder.layers[0].self_attention = vera_layer
    
    # Freeze base model
    for param in model.parameters():
        param.requires_grad = False
        
    # Unfreeze only LoRA params
    vera_params = 0
    for name, param in model.named_parameters():
        if "vera_" in name:
            param.requires_grad = True
            vera_params += param.numel()
            
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"   Total trainable parameters: {trainable_params}")
    
    if trainable_params == vera_params and vera_params > 0:
        print("   ✅ SUCCESS: Only VeRA parameters are trainable.")
    else:
        print(f"   ❌ FAILURE: Trainable params ({trainable_params}) != VeRA params ({vera_params})")

In [7]:
run_tests()

--- Starting LoRA Implementation Tests ---
1. Loading ViT-B/16...

Test A: Output Consistency (Original vs LoRA with B=0)
lora qkv [True, False, False]
vera qkv [True, False, False]
   Max absolute difference lora: 0.00000222
   Max absolute difference vera: 0.00000222
   ✅ LORA SUCCESS: Outputs match closely.
   ✅ VERA SUCCESS: Outputs match closely.

Test B: LoRA Activation Check
   Diff after activating LoRA: 1.2421
   ✅ SUCCESS: LoRA is modifying the output.

Test B: vera Activation Check
   Diff after activating LoRA: 4.1646
   ✅ SUCCESS: VeRA is modifying the output.

Test C: Full Model Integration
   Output shape: torch.Size([1, 1000])
   ✅ SUCCESS: Full model forward pass complete.
   Output shape: torch.Size([1, 1000])
   ✅ SUCCESS: Full model forward pass complete.

Test D: Parameter Counting
   Total trainable parameters: 6144
   ✅ SUCCESS: Only LoRA parameters are trainable.
   Total trainable parameters: 772
   ✅ SUCCESS: Only VeRA parameters are trainable.


In [None]:
layers = model.encoder.layers


def apply_LoRA(model, r=4, mlps = True, mlpsblock=False, attention=False, qkv=[False, False, False]):
    new_model = model.copy()
    layers = new_model.encoder.layers
    for layer in layers:
        if mlps:
            layer.mlp[0] = LoRALinear(layer.mlp[0], r=r)
            layer.mlp[3] = LoRALinear(layer.mlp[3], r=r)
        if mlpsblock:
            layer.mlp = LoRALinear(layer.mlp, r=r)
        if attention:
            layer.self_attention = LoRASelfAttention(layer.self_attention, rank=r, q=qkv[0], k=qkv[1], v=qkv[2])
    return new_model


def apply_VeRA(model, r=4, mlps = True, mlpsblock=False, attention=False, qkv=[False, False, False], seed=0):

    new_model = model.copy()
    layers = new_model.encoder.layers
    
    if mlps:
        mlp_dim = layers.mlp[0].out_features
        A_mlp = torch.randn(r, mlp_dim)
        B_mlp = torch.randn(mlp_dim, r)

    if attention:
        attention_dim = layers.self_attention.embed_dim
        attention_matrices = {'A_q': torch.randn(r, attention_dim), 'A_k': torch.randn(r, attention_dim), 'A_v': torch.randn(r, attention_dim), 
                    'B_q': torch.randn(attention_dim, r), 'B_k':  torch.randn(attention_dim, r), 'B_v':  torch.randn(attention_dim, r)
                    }
        
    for layer in layers: 
        if mlps:
            layer.mlp[0] = VeRALinear(layer.mlp[0], A_mlp, B_mlp r=r)
            layer.mlp[3] = VeRALinear(layer.mlp[3], A_mlp, B_mlp, r=r)

        if attention:
            layer.self_attention = VeRASelfAttention(layer.self_attention, rank=r, q=qkv[0], k=qkv[1], v=qkv[2], matrices = attention_matrices)

    return new_model


def apply_IA3(model, mlps=True, attention=True, qkv = [False, True, True]):

    new_model = model.copy()
    layers = new_model.layers

    for layer in layers: 
        if mlps:
            layer.mlp[3] = IA3Linear(layer)
        
        if attention:
            layer.self_attention = IA3SelfAttention(layer.self.attention)

MLPBlock(
  (0): Linear(in_features=768, out_features=3072, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.0, inplace=False)
  (3): Linear(in_features=3072, out_features=768, bias=True)
  (4): Dropout(p=0.0, inplace=False)
)
MLPBlock(
  (0): Linear(in_features=768, out_features=3072, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.0, inplace=False)
  (3): Linear(in_features=3072, out_features=768, bias=True)
  (4): Dropout(p=0.0, inplace=False)
)
MLPBlock(
  (0): Linear(in_features=768, out_features=3072, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.0, inplace=False)
  (3): Linear(in_features=3072, out_features=768, bias=True)
  (4): Dropout(p=0.0, inplace=False)
)
MLPBlock(
  (0): Linear(in_features=768, out_features=3072, bias=True)
  (1): GELU(approximate='none')
  (2): Dropout(p=0.0, inplace=False)
  (3): Linear(in_features=3072, out_features=768, bias=True)
  (4): Dropout(p=0.0, inplace=False)
)
MLPBlock(
  (0): Linear(in_features=768,