# LoRA implementation with PyTorch

This is the implementation from the tutorial. I will expand on this implementation in the other file lora2.ipynb

In [31]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import tqdm as tqdm

In [32]:
# Make the model deterministic
_ = torch.manual_seed(0)

We will be training a network to classify MNIST digits and then finetune the network on a particular digit on which it didn't perform well

In [33]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# load the mnist data
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=32, shuffle=True)

# load the mnist test set
mnist_testset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=32, shuffle=False)

# deifine the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


Create the neural network to classify the digits, making it overly complicated to better show the power of LORA

In [34]:
class Model(nn.Module):
    def __init__(self, hidden_size1=1000, hidden_size2=2000):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(28 * 28, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 10)
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Model().to(device)

Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [35]:
def train(train_loader, model, epochs=5):
    cross_entropy_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        loss_sum = 0
        num_iterations = 0

        for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            num_iterations += 1
            x, y = data
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_entropy_loss(output, y)
            loss_sum = loss_sum + loss.item()
            avg_loss = loss_sum / num_iterations
            loss.backward()
            optimizer.step()

train(train_loader, model, epochs=1)

Epoch 1/1: 100%|██████████| 1875/1875 [00:07<00:00, 237.28it/s]


Keep a copy of the original weights (clone them) so later we can prove that fine-tuning with LORA doesn't alter the original weights

In [36]:
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

Test the performance of the pretrained network. As we can see, the network performs poorly on the digit 9. Let's finetune it on the digit 9

In [37]:
def test():
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm.tqdm(test_loader, desc="Testing"):
            x, y = data
            x, y = x.to(device), y.to(device)
            output = model(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
            
    print(f"Accuracy: {correct / total * 100:.2f}%")
    for i in range(len(wrong_counts)):
        print(f"Class {i}: {wrong_counts[i]} wrong predictions")

test()

Testing: 100%|██████████| 313/313 [00:03<00:00, 88.56it/s]

Accuracy: 96.69%
Class 0: 9 wrong predictions
Class 1: 27 wrong predictions
Class 2: 40 wrong predictions
Class 3: 52 wrong predictions
Class 4: 21 wrong predictions
Class 5: 34 wrong predictions
Class 6: 33 wrong predictions
Class 7: 33 wrong predictions
Class 8: 14 wrong predictions
Class 9: 68 wrong predictions





Let's visualize how many parameters are in the original network, before introducing LoRA matricies

In [38]:
# print the size of the weight matrices of the network
# save the count of the total number of parameters
total_parameters_original = 0
for idx, layer in enumerate([model.fc1, model.fc2, model.fc3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {idx+1}: W: {layer.weight.shape}, b: {layer.bias.shape}")
print(f"Total parameters: {total_parameters_original:,}")

Layer 1: W: torch.Size([1000, 784]), b: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]), b: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]), b: torch.Size([10])
Total parameters: 2,807,010


Define the LoRA parameterization as described in the paper. The full details are on PyTorch

In [39]:
class LoRAParameterization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1.0, device="cpu"):
        super().__init__()
        # section 4.1 of paper: 
        # we use a random gaussian initialization for A and zero for B. So deltaW = BA is zero at the start of training
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0.0, std=1)
        nn.init.zeros_(self.lora_B)

        # Section 4.1 of the paper:
        # We then scale deltaW by alpha / rank, where alpha is a constant in r
        # When optimizing with Adam, tuning alpha is roughly the same as tuning the learning rate if we scale the init appropriately
        # As a result, we simply set alpha to the first r we try and dont tune it
        # This scaling helps to reduce the need to retune hyperparams when we vary r
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return X + (BA)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the parameterization to the network

In [40]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    # only add the parameterization to the weight matrix. ignore the bias

    # From section 4.2 of the paper:
    # We limit our study to only adapting the attention weights for downstream tasks and freeze the mlp modules (so they are not trained in downstream tasks) both for
    # ...
    # We leave the emperical investigation of [...] and biases to a future work
    features_in, features_out = layer.weight.shape
    return LoRAParameterization(features_in=features_in, features_out=features_out, rank=rank, alpha=lora_alpha, device=device)

parametrize.register_parametrization(
    model.fc1, "weight", linear_layer_parameterization(model.fc1, device)
)
parametrize.register_parametrization(
    model.fc2, "weight", linear_layer_parameterization(model.fc2, device)
)   
parametrize.register_parametrization(
    model.fc3, "weight", linear_layer_parameterization(model.fc3, device)
)


ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParameterization()
    )
  )
)

In [41]:
def enable_disable_lora(enabled=True):
    for layer in [model.fc1, model.fc2, model.fc3]:
        layer.parametrizations["weight"][0].enabled = enabled

Display the total number of parameters added by LoRA

In [42]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for idx, layer in enumerate([model.fc1, model.fc2, model.fc3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {idx+1}: W: {layer.weight.shape}, b: {layer.bias.shape}, Lora_A: {layer.parametrizations['weight'][0].lora_A.shape}, Lora_B: {layer.parametrizations['weight'][0].lora_B.shape}")

# The non-LoRA paramters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f"Total number of parameters (original): {total_parameters_original:,}")
print(f"Total number of parameters (original + LoRA): {total_parameters_original + total_parameters_lora:,}")
print(f"Total number of parameters (LoRA): {total_parameters_lora:,}")
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f"Parameters increment: {parameters_increment:.2f}%")

Layer 1: W: torch.Size([1000, 784]), b: torch.Size([1000]), Lora_A: torch.Size([1, 784]), Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]), b: torch.Size([2000]), Lora_A: torch.Size([1, 1000]), Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]), b: torch.Size([10]), Lora_A: torch.Size([1, 2000]), Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Total number of parameters (LoRA): 6,794
Parameters increment: 0.24%


In [43]:
# Freeze all the parameters of the original model and only fine tune the ones introduced by LoRA. Then finetune the model on the digit 9 and only for 100 batches.
# Freeze the non-LoRA parameters
for name, param in model.named_parameters():
    if "lora" not in name: 
        print(f"Freezing non-LoRA parameter {name}")
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# Create a dataloader for training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=16, shuffle=True)

# Train the network with LoRA only on the digit 9 and only 100 batches (hoping that it would improve the performace on the digit 9)
train(train_loader, model, epochs=1)

Freezing non-LoRA parameter fc1.bias
Freezing non-LoRA parameter fc1.parametrizations.weight.original
Freezing non-LoRA parameter fc2.bias
Freezing non-LoRA parameter fc2.parametrizations.weight.original
Freezing non-LoRA parameter fc3.bias
Freezing non-LoRA parameter fc3.parametrizations.weight.original


Epoch 1/1: 100%|██████████| 372/372 [00:01<00:00, 193.83it/s]


In [44]:
print(model)

Model(
  (fc1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (fc2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (fc3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (relu): ReLU()
)


In [45]:
# Verify that the finetuning didn't alter the original weights but only those introduced by LoRA
# Check that the frozen parameters are still unchanged by the finetuning 
assert torch.all(model.fc1.parametrizations.weight.original == original_weights["fc1.weight"])
assert torch.all(model.fc2.parametrizations.weight.original == original_weights["fc2.weight"])
assert torch.all(model.fc3.parametrizations.weight.original == original_weights["fc3.weight"])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" ffxn of our LoRA parameterization 
# The original weights have been moved to model.linear1.parametrizations.weight.original
# More information here: pytorch docs for inspecting a parametrized module 
assert torch.equal(model.fc1.weight, model.fc1.parametrizations.weight.original + (model.fc1.parametrizations.weight[0].lora_B @ model.fc1.parametrizations.weight[0].lora_A).view(model.fc1.weight.shape) * model.fc1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
# if we disable LoRA, the fc1.weight is the original one
assert torch.equal(model.fc1.weight, original_weights["fc1.weight"])

In [46]:
# Test the network with LoRA enabled (the digit 9 should be classified better)
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 313/313 [00:05<00:00, 59.21it/s]

Accuracy: 29.60%
Class 0: 741 wrong predictions
Class 1: 1059 wrong predictions
Class 2: 474 wrong predictions
Class 3: 759 wrong predictions
Class 4: 931 wrong predictions
Class 5: 687 wrong predictions
Class 6: 489 wrong predictions
Class 7: 1025 wrong predictions
Class 8: 875 wrong predictions
Class 9: 0 wrong predictions





In [47]:
# Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 313/313 [00:03<00:00, 92.03it/s]

Accuracy: 96.69%
Class 0: 9 wrong predictions
Class 1: 27 wrong predictions
Class 2: 40 wrong predictions
Class 3: 52 wrong predictions
Class 4: 21 wrong predictions
Class 5: 34 wrong predictions
Class 6: 33 wrong predictions
Class 7: 33 wrong predictions
Class 8: 14 wrong predictions
Class 9: 68 wrong predictions



