In [15]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from tqdm import tqdm
from typing import Optional

In [16]:
torch.manual_seed(0)

<torch._C.Generator at 0x7deb7c33e650>

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**What are we gonna do?**

We would train the model on MNIST dataset and then fine tune the model on a particular digit to compare the results from pre-trained model and fine-tuned model.

In [18]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=10)


In [19]:
class NeuralNetwork(nn.Module):
    def __init__ (self, h1: int=1000, h2: int=2000):
        super().__init__()
        self.linear1 = nn.Linear(28*28, h1)  # MNIST images are of size (28, 28) -> for Linear Layers -> flatten the images -> 28 * 28
        self.linear2 = nn.Linear(h1, h2)
        self.linear3 = nn.Linear(h2, 10)  # MNIST images have 10 classes
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = x.view(-1, 28*28)  # (batch_size, 28, 28) -> (batch_size, 784)
        out = self.relu(self.linear1(x))
        out = self.relu(self.linear2(out))
        out = self.linear3(out)
        return out

net = NeuralNetwork().to(device)

**Simulating pre-training of the model**

Here we would be using just 1 epoch (since the model is complex enough to learn the patterns).

In [20]:
def train(model, train_loader, epochs, total_iterations_limit: Optional[int]=None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0  # total number of batches of data we have traversed upon across all the epochs

    for epoch in range(epochs):
        model.train()
        data_iterator = tqdm(train_loader, desc=f'Epoch: {epoch+1}')

        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            x, y = data
            total_iterations += 1
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()

            output = model(x)

            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(net, train_loader, 1)

Epoch: 1: 100%|██████████| 6000/6000 [00:21<00:00, 279.37it/s]


**Original Weights of the model**

Keeping a copy of the original weights of the model will later help in comparison.

In [21]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

**Testing the model**

In [22]:
def test():
    correct = total = 0
    wrong_counts = [0] * 10
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)

            output = net(x)
            for idx, i in enumerate(output):
                pred = torch.argmax(i)
                actual = y[idx]
                if pred == actual:
                    correct += 1
                else:
                    wrong_counts[actual] += 1
                total += 1

    print(f'Accuracy: {round(correct/total, 2)}')
    for idx, out in enumerate(wrong_counts):
        print(f'Wrong counts for digit: {idx} is {out}')

test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 358.15it/s]

Accuracy: 0.96
Wrong counts for digit: 0 is 11
Wrong counts for digit: 1 is 20
Wrong counts for digit: 2 is 54
Wrong counts for digit: 3 is 68
Wrong counts for digit: 4 is 24
Wrong counts for digit: 5 is 18
Wrong counts for digit: 6 is 40
Wrong counts for digit: 7 is 55
Wrong counts for digit: 8 is 9
Wrong counts for digit: 9 is 87





**LoRA Parametrization**

In [23]:
class LoraParametrization(nn.Module):
    def __init__ (self, device, features_in, features_out, rank: int=1, alpha: int=1):
        super().__init__()
        self.lora_a = nn.Parameter(torch.zeros(rank, features_out).to(device))
        self.lora_b = nn.Parameter(torch.zeros(features_in, rank).to(device))

        self.lora_a = nn.init.normal_(self.lora_a)

        self.scale = alpha/rank
        self.lora_enabled = True

    def forward(self, orig_weights: torch.Tensor):
        if self.lora_enabled:
            return orig_weights + (torch.matmul(self.lora_b, self.lora_a).view(orig_weights.shape))*self.scale
        else:
            return orig_weights


**Adding the parametrization in our network**

In [24]:
import torch.nn.utils.parametrize as P

def linear_layer_parametrization(layer, rank: int=1, alpha: int=1):
    features_in, features_out = layer.weight.shape
    return LoraParametrization(device, features_in, features_out)

P.register_parametrization(net.linear1, 'weight', linear_layer_parametrization(net.linear1))
P.register_parametrization(net.linear2, 'weight', linear_layer_parametrization(net.linear2))
P.register_parametrization(net.linear3, 'weight', linear_layer_parametrization(net.linear3))

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoraParametrization()
    )
  )
)

In [25]:
def enable_disable_lora(flag:bool=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations['weight'][0].lora_enabled = flag

**Freezing the non-LoRA Parameters**

In [26]:
# Freeze the non-Lora parameters (which are self.lora_a and self.lora_b in our model)
for name, param in net.named_parameters():
    if 'lora' not in name:
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = MNIST(root='./data', train=True, download=True, transform=transform)

required_indices = mnist_trainset.targets == 9

mnist_trainset.data = mnist_trainset.data[required_indices]
mnist_trainset.targets = mnist_trainset.targets[required_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches
train(net, train_loader, epochs=1, total_iterations_limit=100)

Epoch: 1:  99%|█████████▉| 99/100 [00:00<00:00, 185.55it/s]


In [27]:
# Test with LoRA enabled
enable_disable_lora(True)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 290.68it/s]

Accuracy: 0.94
Wrong counts for digit: 0 is 11
Wrong counts for digit: 1 is 21
Wrong counts for digit: 2 is 68
Wrong counts for digit: 3 is 81
Wrong counts for digit: 4 is 155
Wrong counts for digit: 5 is 53
Wrong counts for digit: 6 is 49
Wrong counts for digit: 7 is 147
Wrong counts for digit: 8 is 37
Wrong counts for digit: 9 is 19





In [28]:
# Test with LoRA disabled
enable_disable_lora(False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 358.93it/s]

Accuracy: 0.96
Wrong counts for digit: 0 is 11
Wrong counts for digit: 1 is 20
Wrong counts for digit: 2 is 54
Wrong counts for digit: 3 is 68
Wrong counts for digit: 4 is 24
Wrong counts for digit: 5 is 18
Wrong counts for digit: 6 is 40
Wrong counts for digit: 7 is 55
Wrong counts for digit: 8 is 9
Wrong counts for digit: 9 is 87



