## LoRA implementation with PyTorch

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import nn 
import matplotlib.pyplot as plt  
from tqdm import tqdm

In [2]:
# Make torch determinitic
_ = torch.manual_seed(0)

In [3]:
# Getting MNIST dataset for training our simple model

transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test dataset
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the devie
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Create the neural network to classify the digits, make the model little big to see the use of LoRA

In [4]:


class ClassifyNet(nn.Module):
    def __init__(self, hidden_size1 = 1000, hidden_size_2 = 2000):
        super(ClassifyNet, self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size1)
        self.linear2 = nn.Linear(hidden_size1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = ClassifyNet().to(device)

In [5]:
def train(train_loader, net, epochs = 5, total_iterations_limit = None):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

    total_iterations = 0 
    
    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iteration = tqdm(train_loader, desc = f'Epoch {epoch + 1}/{epochs}')
        
        if total_iterations_limit:
            data_iteration.total = total_iterations_limit
        
        for data in data_iteration:

            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = net(x)
            loss_value = loss_fn(output, y)
            loss_sum += loss_value.item()

            avg_loss = loss_sum / num_iterations

            data_iteration.set_postfix(loss = avg_loss)
            loss_value.backward()
            optimizer.step()

            if total_iterations_limit and total_iterations >= total_iterations_limit:
                return
            



In [6]:
train(train_loader, net, epochs = 1)

Epoch 1/1: 100%|██████████| 6000/6000 [00:31<00:00, 192.44it/s, loss=0.236]


### Keep a copy of original weights (clone them), so later we can confirm that fine tuning with LoRA doesn't alter the original weights

In [7]:
original_weights = {}

for name, param in net.named_parameters():
    original_weights[name] = param.data.clone().detach()

In [8]:
original_weights

{'linear1.weight': tensor([[ 0.0015,  0.0210, -0.0276,  ...,  0.0237,  0.0055,  0.0039],
         [ 0.0078,  0.0126,  0.0172,  ...,  0.0073,  0.0216, -0.0023],
         [ 0.0124,  0.0475, -0.0007,  ...,  0.0122,  0.0337,  0.0406],
         ...,
         [-0.0175,  0.0460,  0.0443,  ...,  0.0133,  0.0394, -0.0132],
         [ 0.0716,  0.0349,  0.0206,  ...,  0.0549,  0.0485,  0.0493],
         [ 0.0216, -0.0021,  0.0491,  ...,  0.0584,  0.0319,  0.0423]],
        device='cuda:0'),
 'linear1.bias': tensor([-2.6380e-02, -6.2766e-03, -4.0181e-02, -1.0200e-02,  6.3419e-03,
         -4.9645e-02, -4.1591e-02, -5.0680e-03, -5.5057e-02, -7.0120e-03,
          4.5819e-03, -1.1517e-02, -2.7861e-02, -4.0841e-02,  4.1142e-03,
         -5.0162e-02, -4.8437e-02, -4.1003e-02, -1.7643e-02, -5.5479e-02,
         -3.2971e-02, -4.6315e-02,  1.0228e-02, -2.2976e-02, -1.4506e-02,
         -4.3609e-02, -1.1930e-03, -2.7267e-02,  8.9650e-03, -2.3131e-02,
          2.4631e-03, -4.0215e-02, -2.1087e-02, -8.7772

In [9]:
def test():

    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        net.eval()
        for data in tqdm(test_loader, desc="Testing"):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = net(images)

            for idx, i in enumerate(outputs):
                if torch.argmax(i) == labels[idx]:
                    correct += 1
                else:
                    wrong_counts[torch.argmax(i)] += 1
                total += 1

    print(f'Accuracy: {round(correct/total, 4)}')

    for i in range(10):
        print(f'Wrong count for {i}: {wrong_counts[i]}')

In [10]:
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 377.18it/s]

Accuracy: 0.9663
Wrong count for 0: 13
Wrong count for 1: 35
Wrong count for 2: 36
Wrong count for 3: 30
Wrong count for 4: 43
Wrong count for 5: 20
Wrong count for 6: 16
Wrong count for 7: 29
Wrong count for 8: 90
Wrong count for 9: 25





#### Let's check how many parameters are in the original network, before introducing the LoRA matrices.

In [11]:
total_parameters_original = 0 
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}")

print(f"Total parameters: {total_parameters_original:,}")
    

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total parameters: 2,807,010


### LoRA parameterization as described in the paper.

it uses PyTorch parameterization: https://pytorch.org/tutorials/intermediate/parametrizations.html


In [12]:
class LoRAParameterization(nn.Module):
    def __init__(self, feature_in, feature_out, rank=1, alpha=1, device='cpu'):
        super(LoRAParameterization, self).__init__()

        self.lora_A = nn.Parameter(torch.zeros((rank, feature_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((feature_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        self.scale = alpha / rank 
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights


Add the parameterization to our network

In [13]:
from torch.nn.utils import parametrize 

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):

    feature_in, feature_out = layer.weight.shape

    return LoRAParameterization(feature_in, feature_out, rank=rank, alpha=lora_alpha, device=device)


In [14]:
parametrize.register_parametrization(net.linear1, "weight", linear_layer_parameterization(net.linear1, device))
parametrize.register_parametrization(net.linear2, "weight", linear_layer_parameterization(net.linear2, device))
parametrize.register_parametrization(net.linear3, "weight", linear_layer_parameterization(net.linear3, device))


ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParameterization()
    )
  )
)

In [15]:
net

ClassifyNet(
  (linear1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (relu): ReLU()
)

In [16]:
net.named_parameters

<bound method Module.named_parameters of ClassifyNet(
  (linear1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (relu): ReLU()
)>

In [17]:
total_parameters_lora = 0
total_parameters_non_lora = 0

for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations['weight'][0].lora_A.nelement() + layer.parametrizations['weight'][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()

    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}')
    
# Non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_original:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Paramters introduced by LoRA: {total_parameters_lora:,}')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameter increment: {parameters_increment:.2f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Paramters introduced by LoRA: 6,794
Parameter increment: 0.24%


In [18]:
net.linear1.parametrizations['weight'][0].lora_A.nelement()

784

Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine tune the model on the digit 4 and only for 100 batches.

In [19]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False
    else:
        print(f'LoRA paramerter {name}')

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
LoRA paramerter linear1.parametrizations.weight.0.lora_A
LoRA paramerter linear1.parametrizations.weight.0.lora_B
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
LoRA paramerter linear2.parametrizations.weight.0.lora_A
LoRA paramerter linear2.parametrizations.weight.0.lora_B
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original
LoRA paramerter linear3.parametrizations.weight.0.lora_A
LoRA paramerter linear3.parametrizations.weight.0.lora_B


Training just on digit 4 labels

In [20]:
# Load the MNIST dadtaset again, by keeping only the digit 4
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
exclude_indices = mnist_trainset.targets!= 4
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)



In [21]:
train(train_loader, net, epochs=1, total_iterations_limit=100)

Epoch 1/1:  99%|█████████▉| 99/100 [00:00<00:00, 147.70it/s, loss=0.117]


Verify that the fine tuning did alter the original weights, but only the ones introduced by LoRA

In [22]:
assert torch.all(net.linear1.parametrizations.weight.original == original_weights['linear1.weight'])
assert torch.all(net.linear2.parametrizations.weight.original == original_weights['linear2.weight'])
assert torch.all(net.linear3.parametrizations.weight.original == original_weights['linear3.weight'])


In [23]:
def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations['weight'][0].enabled = enabled

In [24]:
net.linear1.parametrizations.weight.original

Parameter containing:
tensor([[ 0.0015,  0.0210, -0.0276,  ...,  0.0237,  0.0055,  0.0039],
        [ 0.0078,  0.0126,  0.0172,  ...,  0.0073,  0.0216, -0.0023],
        [ 0.0124,  0.0475, -0.0007,  ...,  0.0122,  0.0337,  0.0406],
        ...,
        [-0.0175,  0.0460,  0.0443,  ...,  0.0133,  0.0394, -0.0132],
        [ 0.0716,  0.0349,  0.0206,  ...,  0.0549,  0.0485,  0.0493],
        [ 0.0216, -0.0021,  0.0491,  ...,  0.0584,  0.0319,  0.0423]],
       device='cuda:0')

In [25]:
enable_disable_lora(enabled=True)

assert torch.equal(net.linear1.weight, net.linear1.parametrizations.weight.original + \
                   net.linear1.parametrizations.weight[0].lora_B @ \
                    net.linear1.parametrizations.weight[0].lora_A * \
                    net.linear1.parametrizations.weight[0].scale 
                    )

In [26]:
enable_disable_lora(enabled=False)

assert torch.equal(net.linear1.weight, original_weights['linear1.weight'])

In [27]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 274.37it/s]

Accuracy: 0.9693
Wrong count for 0: 14
Wrong count for 1: 20
Wrong count for 2: 25
Wrong count for 3: 42
Wrong count for 4: 27
Wrong count for 5: 14
Wrong count for 6: 21
Wrong count for 7: 32
Wrong count for 8: 75
Wrong count for 9: 37





In [28]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 362.20it/s]

Accuracy: 0.9663
Wrong count for 0: 13
Wrong count for 1: 35
Wrong count for 2: 36
Wrong count for 3: 30
Wrong count for 4: 43
Wrong count for 5: 20
Wrong count for 6: 16
Wrong count for 7: 29
Wrong count for 8: 90
Wrong count for 9: 25



