## LoRA implementation with Pytorch

In [1]:
import torch 
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# Make model determinstic
_ = torch.manual_seed(0)

We will be training a network to classify MNIST digits and then fine-tune the network on a particular digit on which it doesn't perform well

In [162]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create dataloader for training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Create an overly complicated Nueral net 

In [163]:
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
    
    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = RichBoyNet().to(device)

In [164]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()

            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
train(train_loader, net, epochs=1)

Epoch 1: 6000it [00:20, 290.85it/s, loss=0.239]  


Keep a copy of original weights

In [165]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

In [166]:
def test():
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]

    net.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 28*28))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 416.17it/s]

Accuracy: 0.957
wrong counts for the digit 0: 9
wrong counts for the digit 1: 14
wrong counts for the digit 2: 25
wrong counts for the digit 3: 29
wrong counts for the digit 4: 73
wrong counts for the digit 5: 21
wrong counts for the digit 6: 62
wrong counts for the digit 7: 46
wrong counts for the digit 8: 83
wrong counts for the digit 9: 68





In [167]:
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of parameters: {total_parameters_original: ,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters:  2,807,010


In [168]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out,  rank=1, alpha=1, device='cpu'):
        super().__init__()

        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [169]:
import torch.nn.utils.parametrize as parameterize

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)

parameterize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parameterize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parameterize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)
def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations['weight'][0].enabled = enabled


In [170]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations['weight'][0].lora_A.nelement() + layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()

    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations['weight'][0].lora_A.shape} + Lora_B: {layer.parametrizations['weight'][0].lora_B.shape}'
    )

assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')


Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [171]:
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-Lora parameter {name}')
        param.requires_grad = False
        
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

train(train_loader, net, epochs=1, total_iterations_limit=100)


Freezing non-Lora parameter linear1.bias
Freezing non-Lora parameter linear1.parametrizations.weight.original
Freezing non-Lora parameter linear2.bias
Freezing non-Lora parameter linear2.parametrizations.weight.original
Freezing non-Lora parameter linear3.bias
Freezing non-Lora parameter linear3.parametrizations.weight.original


Epoch 1:   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 177.68it/s, loss=0.103]


In [172]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(net.linear1.weight, (net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale)) 

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

In [173]:
net.linear1.weight

Parameter containing:
tensor([[-0.0031,  0.0518,  0.0425,  ..., -0.0131,  0.0051,  0.0466],
        [ 0.0914,  0.0878,  0.0352,  ...,  0.0819,  0.0483,  0.0522],
        [ 0.0049,  0.0116,  0.0557,  ...,  0.0619,  0.0172,  0.0201],
        ...,
        [ 0.0400,  0.0402,  0.0274,  ...,  0.0167,  0.0276,  0.0548],
        [ 0.0083,  0.0622,  0.0219,  ...,  0.0701,  0.0290,  0.0628],
        [-0.0040,  0.0552,  0.0392,  ...,  0.0431,  0.0068,  0.0124]],
       device='cuda:0')

In [174]:
net.linear1.parametrizations.weight.original + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A) * net.linear1.parametrizations.weight[0].scale

tensor([[-0.0079,  0.0308,  0.0341,  ...,  0.0016,  0.0039,  0.0547],
        [ 0.0914,  0.0878,  0.0352,  ...,  0.0819,  0.0483,  0.0522],
        [ 0.0005, -0.0079,  0.0479,  ...,  0.0757,  0.0161,  0.0277],
        ...,
        [ 0.0400,  0.0402,  0.0274,  ...,  0.0167,  0.0276,  0.0548],
        [ 0.0102,  0.0708,  0.0253,  ...,  0.0640,  0.0295,  0.0595],
        [-0.0040,  0.0552,  0.0392,  ...,  0.0431,  0.0068,  0.0124]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [175]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 339.79it/s]

Accuracy: 0.952
wrong counts for the digit 0: 9
wrong counts for the digit 1: 9
wrong counts for the digit 2: 25
wrong counts for the digit 3: 56
wrong counts for the digit 4: 123
wrong counts for the digit 5: 30
wrong counts for the digit 6: 56
wrong counts for the digit 7: 89
wrong counts for the digit 8: 65
wrong counts for the digit 9: 22





In [176]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 389.95it/s]

Accuracy: 0.957
wrong counts for the digit 0: 9
wrong counts for the digit 1: 14
wrong counts for the digit 2: 25
wrong counts for the digit 3: 29
wrong counts for the digit 4: 73
wrong counts for the digit 5: 21
wrong counts for the digit 6: 62
wrong counts for the digit 7: 46
wrong counts for the digit 8: 83
wrong counts for the digit 9: 68



