In [None]:
import torch 
import torch.nn as nn 
from tqdm import tqdm
import numpy as np 
from torchvision import datasets 
from torchvision.transforms import transforms 
from torch.utils.data import DataLoader

1. Under-train full model → find weak class
2. Targeted LoRA fine-tune → fix only that class
3. Keep base weights unchanged → plug-and-play adaptability

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data',train=True,transform=transform,download=False)
test_dataset = datasets.MNIST(root='./data',train=False,transform=transform,download=False)
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=32,shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
class BIGnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 1000)
        self.linear2 = nn.Linear(1000, 2000)
        self.linear3 = nn.Linear(2000, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.linear(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

model = BIGnet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit

        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1

            x, y = data
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = net(x)
            loss = cross_el(output, y)
            loss_sum += loss.item()
            
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)

            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, model, epochs=1)

In [None]:
# copying the parameters from the original model which we will use later 
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [None]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0]*10
    for i in range(10):
        wrong_counts[i] = 0

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1

    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

In [None]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters())

print('total no of parameters are:', get_num_params(model))

total_parameters_original = sum(p.numel() for p in model.parameters())

# no of parameters layer wise 
for index, layer in enumerate([model.linear, model.linear2, model.linear3]):
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

In [None]:
class LoRA(nn.Module):
    def __init__(self, features_in, features_out, r=1, alpha=1, device='cpu'):
        super(LoRA, self).__init__()
        
        self.r = r
        self.alpha = alpha
        self.scale = self.alpha / self.r
        self.lora_A = nn.Parameter(torch.zeros((r,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, r)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [None]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias
    features_in, features_out = layer.weight.shape
    return LoRA(
        features_in, features_out, r=1, alpha=lora_alpha, device=device
    )

parametrize.register_parametrization(
    model.linear, "weight", linear_layer_parameterization(model.linear, device)
)
parametrize.register_parametrization(
    model.linear2, "weight", linear_layer_parameterization(model.linear2, device)
)
parametrize.register_parametrization(
    model.linear3, "weight", linear_layer_parameterization(model.linear3, device)
)

def enable_disable_lora(enabled=True):
    for layer in [model.linear, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [None]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([model.linear, model.linear2, model.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )

print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters increment: {parameters_increment:.3f}%')

In [None]:
# freezing the non lora parameters as the paper says cause we no longer need them 
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# loading the dataset again and now focusing only on the digit 8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 8
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# Create a dataloader for the training
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on digit 8 and only for 100 batches
train(train_loader, model, epochs=1, total_iterations_limit=100)

In [None]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(model.linear.parametrizations.weight.original == original_weights['linear.weight'])
assert torch.all(model.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(model.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)

# The new linear.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to model.linear.parametrizations.weight.original
assert torch.equal(model.linear.weight, model.linear.parametrizations.weight.original + (model.linear.parametrizations.weight[0].lora_B @ model.linear.parametrizations.weight[0].lora_A) * model.linear.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear.weight is the original one
assert torch.equal(model.linear.weight, original_weights['linear.weight'])

In [None]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()