**LoRA implementation with Pytorch**

In [2]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

_ = torch.manual_seed(0)

In [3]:
# Image preproccess pipeline
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Create the Neural Network to classify the digits, making it overly complicated to better show the power of LoRA

In [13]:
# Create an overly expensive neural network to classify MNIST digits
class RichNet(nn.Module):
    def __init__(self, hidden_size_1 = 1000, hidden_size_2=2000):
        super(RichNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28) # [Batch_size, 784]
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
net = RichNet().to(device)
print(net.linear1.weight.shape)

torch.Size([1000, 784])


Train the network only for 1 epoch to simulate a complete general pre-training on the data

In [5]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss() #Input: unnormalized logits; output: [Batch_size, Num_class]; y: [Batch_si、
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:   # iteration_num = len(dataset) / batch_size
            num_iterations += 1
            total_iterations += 1
            x, y = data # x: [Batch_size, Num_channel, H, W]; # y: [Batch_size]
            x = x.to(device) 
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y) # A scalar, average loss over the batch
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:16<00:00, 369.23it/s, loss=0.239]


Keep a copy of the original weights (cloning them) so later we can prove that a fine-tuning with LoRA doesn't alter the original weights

In [6]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()
# Make a clean numeric snapshot of each parameter:
# clone() creates a new tensor with its own memory (not a reference),
# detach() removes it from the computation graph (no gradients tracked).
# This ensures that even if the model parameters are updated later,
# the values stored in original_weights remain unchanged.

In [7]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            preds = torch.argmax(output, dim=1)  # [batch_size]
            correct_mask = (preds == y)
            correct = correct + correct_mask.sum().item()
            total += y.size(0)
            for label in range(10):
                wrong_counts[label] += ((y == label) & (preds != label)).sum().item()
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 498.97it/s]

Accuracy: 0.957
wrong counts for the digit 0: 14
wrong counts for the digit 1: 21
wrong counts for the digit 2: 65
wrong counts for the digit 3: 47
wrong counts for the digit 4: 37
wrong counts for the digit 5: 55
wrong counts for the digit 6: 23
wrong counts for the digit 7: 53
wrong counts for the digit 8: 10
wrong counts for the digit 9: 102





As we can see, the original richnet performs worst w.r.t digit 7, let's definr a lora fintuning on digit 9.

In [10]:
total_original_parameters = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_original_parameters += (layer.weight.nelement() + layer.bias.nelement())
    print(f"Layer {index+1}: shape of weight: {layer.weight.shape}; shape of bias: {layer.bias.shape}")
print(f"Total number of parameters: {total_original_parameters}")

Layer 1: shape of weight: torch.Size([1000, 784]); shape of bias: torch.Size([1000])
Layer 2: shape of weight: torch.Size([2000, 1000]); shape of bias: torch.Size([2000])
Layer 3: shape of weight: torch.Size([10, 2000]); shape of bias: torch.Size([10])
Total number of parameters: 2807010


In [None]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cuda'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)

        self.scale = alpha/rank
        self.enabled = True
    
    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Add the Parametrization to our network.

In [None]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    features_out, features_in = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

