In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# make torch deterministic
_ = torch.manual_seed(0)

In [2]:
# (0.1307,), (0.3081,) mean and std computed on the training set of MNIST.
transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
# dataloader for the training
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
test_loader = DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps")


In [3]:
# Create the Neural Network to classify the digits

class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
        

net = RichBoyNet().to(device=device)

In [4]:
# train the net on ly for 1 epoch

def train(train_loader: DataLoader, net: RichBoyNet, epochs: int=5, total_iteration_limit: int=None, device: type = "mps"):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
        if total_iteration_limit is not None:
            data_iterator.total = total_iteration_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            ouput = net(x.view(-1, 28*28))
            loss = cross_el(ouput, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iteration_limit is not None and total_iterations >= total_iteration_limit:
                return
            
train(train_loader=train_loader, net=net, epochs=5, device=device)


Epoch 1: 100%|██████████| 6000/6000 [00:46<00:00, 130.38it/s, loss=0.235]
Epoch 2: 100%|██████████| 6000/6000 [00:45<00:00, 130.61it/s, loss=0.134]
Epoch 3: 100%|██████████| 6000/6000 [00:46<00:00, 129.20it/s, loss=0.101] 
Epoch 4: 100%|██████████| 6000/6000 [00:46<00:00, 129.54it/s, loss=0.0907]
Epoch 5: 100%|██████████| 6000/6000 [00:47<00:00, 127.20it/s, loss=0.0805]


In [5]:
# keep a copy of the oroginal weghts () LoRA doesn't touch this originals.
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()
    

In [6]:
def test(device: type = "mps"):
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc="I'm testing"):
            x, y, = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1

    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digits {i}: {wrong_counts[i]}')

test(device=device)


I'm testing: 100%|██████████| 1000/1000 [00:05<00:00, 177.47it/s]

Accuracy: 0.964
wrong counts for the digits 0: 5
wrong counts for the digits 1: 8
wrong counts for the digits 2: 44
wrong counts for the digits 3: 82
wrong counts for the digits 4: 27
wrong counts for the digits 5: 11
wrong counts for the digits 6: 46
wrong counts for the digits 7: 19
wrong counts for the digits 8: 61
wrong counts for the digits 9: 53





In [7]:
# Lets visualize how many parameters are in the original network (before LoRA)

total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index +1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')


Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


#### LoRA made in Pytorch

In [8]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in: int, features_out: int, rank: int = 1, alpha: int = 1, device: type = "mps"):
        super().__init__()
        # Section 4.1 of the paper:
        #  We use a random gaussian initialization for A and zero for B, so ∆W = BA is zero at the beggining of the training.
        self.lora_A = nn.Parameter(torch.zeros(rank, features_out).to(device=device))
        self.lora_B = nn.Parameter(torch.zeros(features_in, rank).to(device=device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        # Section 4.1 of the paper:
        #   We then scale ∆Wx by a/r, where a is a constant in r.
        #   When optimizing Adam, tuning a is roughly the same as tuning the learning rate if we scale the initialization appropiatelly.
        #   As a result, we simply set a to the first r we try and do not tune it
        #   this scaling helps to reduce the need to retune hyperparamters when we vary r
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return X + (B*A) * scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights
        
    

In [9]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # Only add the parametrization to the weight matrix, ignoring Bias.

    # From section 4.2 of the paper:
    #    We limit our study to only adapting the attention weigths for downstram task and freeze the MLP modules...
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)

parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)

parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled



In [10]:
# Displya the numbers of parameters added by LoRA technique

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index + 1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}')
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,} <--- We ONLY train this added matrix')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameter increment: {parameters_increment:.3f}%')


Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794 <--- We ONLY train this added matrix
Parameter increment: 0.242%


In [16]:
# freeze all non-Lora parameters

for name, params in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit # 9
mnist_trainset = datasets.MNIST(root = './data', train=True, download=True, transform=transforms)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# create a dataloader for the trainning
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# train the network with LoRA only on the digit 7 and only for 100 batches (hopping inprove)
train(train_loader, net, epochs=5,)


Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:   0%|          | 0/595 [00:00<?, ?it/s, loss=0]

Epoch 1: 100%|██████████| 595/595 [00:07<00:00, 80.79it/s, loss=1e-10]   
Epoch 2: 100%|██████████| 595/595 [00:07<00:00, 84.39it/s, loss=0]
Epoch 3: 100%|██████████| 595/595 [00:07<00:00, 84.56it/s, loss=0]
Epoch 4: 100%|██████████| 595/595 [00:07<00:00, 80.15it/s, loss=0]
Epoch 5: 100%|██████████| 595/595 [00:07<00:00, 81.02it/s, loss=0]


In [None]:
# # Check that the frozen parameters are still unchanged by the finetuning
# assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
# assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
# assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])


enable_disable_lora(enabled=True)
# assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
# assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

In [17]:
enable_disable_lora(enabled=True)
test(device=device)


I'm testing: 100%|██████████| 1000/1000 [00:08<00:00, 121.68it/s]

Accuracy: 0.428
wrong counts for the digits 0: 389
wrong counts for the digits 1: 765
wrong counts for the digits 2: 180
wrong counts for the digits 3: 700
wrong counts for the digits 4: 926
wrong counts for the digits 5: 589
wrong counts for the digits 6: 233
wrong counts for the digits 7: 1012
wrong counts for the digits 8: 921
wrong counts for the digits 9: 0





In [19]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test(device=device)


I'm testing: 100%|██████████| 1000/1000 [00:07<00:00, 136.93it/s]

Accuracy: 0.552
wrong counts for the digits 0: 201
wrong counts for the digits 1: 390
wrong counts for the digits 2: 125
wrong counts for the digits 3: 643
wrong counts for the digits 4: 778
wrong counts for the digits 5: 470
wrong counts for the digits 6: 135
wrong counts for the digits 7: 937
wrong counts for the digits 8: 796
wrong counts for the digits 9: 0



