In [None]:
import torch 
import torch.nn as nn 
from tqdm import tqdm
import numpy as np 
from torchvision import datasets 
from torchvision.transforms import transforms 
from torch.utils.data import DataLoader

: 


	1.	Under-train full model → find weak class
	2.	Targeted LoRA fine-tune → fix only that class
	3.	Keep base weights unchanged → plug-and-play adaptability



In [6]:
import ssl
# Temporarily disable SSL certificate verification
ssl._create_default_https_context = ssl._create_unverified_context

transform = transforms.ToTensor() 
train_dataset = datasets.MNIST(root='./data',train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform,download=True)
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32,shuffle=False)

In [26]:
class BIGnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 1000)
        self.linear2 = nn.Linear(1000, 2000)
        self.linear3 = nn.Linear(2000, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

    

In [27]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device 

'mps'

In [31]:
model = BIGnet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)


In [88]:
def train(model, train_loader, device, num_epochs, total_iterations_limit, criterion):
    total_iterations = 0
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        
        total = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            total_iterations += 1
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
        train_loss /= len(train_loader)
        train_acc = 100. * correct / total
        print(f'Epoch: {epoch+1}/{num_epochs}, Loss: {train_loss:.3f}, Accuracy: {train_acc:.2f}%')
    return train_loss, train_acc



In [33]:


train(model = model,optimizer=optimizer,criterion=criterion,train_loader=train_loader,num_epochs=1,device=device)

Epoch: 1/1, Loss: 0.314, Accuracy: 91.16%


(0.3141168975392977, 91.15833333333333)

In [35]:
# copying the parameters from the original model which we will use later 
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [36]:
original_weights

{'linear.weight': tensor([[ 0.0279,  0.0322,  0.0298,  ..., -0.0158,  0.0351,  0.0303],
         [ 0.0206,  0.0348, -0.0100,  ...,  0.0064, -0.0168, -0.0191],
         [-0.0001, -0.0043, -0.0311,  ..., -0.0263,  0.0251,  0.0039],
         ...,
         [-0.0200, -0.0323, -0.0332,  ..., -0.0112,  0.0125,  0.0234],
         [-0.0290, -0.0022,  0.0127,  ..., -0.0173,  0.0035,  0.0124],
         [-0.0080, -0.0253,  0.0223,  ...,  0.0341, -0.0339,  0.0051]],
        device='mps:0'),
 'linear.bias': tensor([-1.2583e-02,  3.0831e-02,  4.5020e-02,  1.4790e-02,  3.0004e-03,
          3.2661e-02, -2.6347e-02, -8.0484e-03,  2.0411e-02,  3.8631e-03,
          4.0270e-02,  3.2066e-02, -2.3868e-02,  1.6084e-02, -9.5288e-03,
         -1.2936e-02,  3.6315e-02,  4.7068e-03,  1.3151e-03,  2.0537e-02,
         -8.4143e-03,  2.2198e-04,  3.4289e-02, -3.1297e-02, -1.5063e-02,
          1.8455e-02,  1.9217e-02,  3.8654e-03, -1.9805e-02, -1.5384e-02,
          2.0541e-02,  2.3218e-02,  3.3477e-02,  1.1731e-0

In [39]:
# testing the model to see on which digit the model is struggling
def test():
    correct = 0
    total = 0

    wrong_counts = [0]*10
    for i in range(10):
        wrong_counts[i] = 0

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 313/313 [00:05<00:00, 54.53it/s]

Accuracy: 0.954
wrong counts for the digit 0: 19
wrong counts for the digit 1: 15
wrong counts for the digit 2: 56
wrong counts for the digit 3: 39
wrong counts for the digit 4: 49
wrong counts for the digit 5: 54
wrong counts for the digit 6: 39
wrong counts for the digit 7: 58
wrong counts for the digit 8: 73
wrong counts for the digit 9: 58





our model is the most wrong for digit 8, so will fine tune our small model on digit 8

In [70]:
# total no of parameters in our model 
def get_num_params(model):
  return sum(p.numel() for p in model.parameters())
print('total no of parameters are:',get_num_params(model))

total_parameters_original = sum(p.numel() for p in model.parameters())

# no of parameters layer wise 
for index, layer in enumerate([model.linear, model.linear2, model.linear3]):
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')



total no of parameters are: 2813804
Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])


In [64]:
class LoRA(nn.Module):
    def __init__(self, features_in, features_out, r=1, alpha=1, device = device):
        super(LoRA, self).__init__()
        self.r = r
        self.alpha = alpha
        self.scale = self.alpha / self.r
        self.lora_A = nn.Parameter(torch.zeros((r,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, r)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1) 
        self.enabled = True 

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [65]:
# do the parameterization 
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias
    features_in, features_out = layer.weight.shape
    return LoRA(
        features_in, features_out, r=1, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    model.linear, "weight", linear_layer_parameterization(model.linear, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA()
    )
  )
)

In [73]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.linear, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
# assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [None]:
# freezing the non lora parameters as the paper says cause we no longer need them 
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# loading the dataseat again and now focusing only on the digit 8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 8
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 8)
print(train(train_loader=train_loader, model=model, num_epochs=1, device=device, total_iterations_limit=100, criterion=criterion))


Freezing non-LoRA parameter linear.bias
Freezing non-LoRA parameter linear.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original
None


In [None]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear.parametrizations.weight.original == original_weights['linear.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to model.linear.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(model.linear.weight, model.linear.parametrizations.weight.original + (model.linear.parametrizations.weight[0].lora_B @ model.linear.parametrizations.weight[0].lora_A) * model.linear.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(model.linear.weight, original_weights['linear1.weight'])

In [None]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

In [None]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()