<a href="https://colab.research.google.com/github/kalyaannnn/TransforMER/blob/main/LoRAIntuition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
_ = torch.manual_seed(42)

In [2]:
d, k = 10, 10

w_rank = 2

W = torch.randn(d, w_rank) @ torch.randn(w_rank, k)
print(W)

tensor([[ 2.6746e+00, -1.4811e+00, -1.1993e+00, -4.3986e-01, -3.2132e+00,
         -9.8070e-01,  5.2384e+00,  1.1440e+00,  3.4292e-02, -8.0985e-01],
        [-2.6293e+00,  5.5638e-01,  2.8071e+00, -2.1880e+00,  1.5113e+00,
         -1.9586e+00, -8.4328e-01, -1.0369e+00,  1.2841e+00,  1.7830e+00],
        [ 1.2519e+00, -1.5893e-01, -1.5284e+00,  1.3505e+00, -5.2550e-01,
          1.2769e+00, -1.0580e-01,  4.8339e-01, -7.6663e-01, -9.6521e-01],
        [ 2.5833e+00, -1.0893e+00, -1.7760e+00,  5.6927e-01, -2.4785e+00,
          1.6158e-01,  3.4259e+00,  1.0717e+00, -4.6680e-01, -1.1566e+00],
        [-8.3761e-01,  1.5506e-01,  9.3442e-01, -7.6165e-01,  4.4083e-01,
         -6.9603e-01, -1.6246e-01, -3.2818e-01,  4.4156e-01,  5.9234e-01],
        [ 1.7718e-01,  1.0177e-01, -4.4121e-01,  5.5308e-01,  1.5319e-01,
          5.8442e-01, -6.0979e-01,  5.6301e-02, -2.9052e-01, -2.7291e-01],
        [-1.6017e-02,  7.9710e-02, -1.2103e-01,  2.0897e-01,  1.4897e-01,
          2.3602e-01, -3.7046e-0

In [3]:
import numpy as np
W_rank = np.linalg.matrix_rank(W)
print(f"Rank of W : {W_rank}")

Rank of W : 2


In [4]:
# Performing the SVD on W (W  = UxSxV^T)
U, S, V = torch.svd(W)

# For rank factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()

# Compute C = U_r * S_r and R = V_r
B = U_r @ S_r
A = V_r
print(f"Shape of B : {B.shape}")
print(f"Shape of A : {A.shape}")

Shape of B : torch.Size([10, 2])
Shape of A : torch.Size([2, 10])


In [5]:
bias = torch.randn(d)
x = torch.randn(d)
# Compute y = Wx + b
y = W @ x + bias

# Compute y' = CRx + b
y_prime = (B @ A) @ x + bias

print(f"Original y using W : \n", y)
print("")
print(f"y computed using BA : \n", y_prime)

Original y using W : 
 tensor([ 8.7740, -3.2261,  0.4063,  7.2494, -0.3587, -0.8469, -1.2226,  1.8885,
         0.1394,  0.2454])

y computed using BA : 
 tensor([ 8.7740, -3.2261,  0.4063,  7.2494, -0.3587, -0.8469, -1.2226,  1.8885,
         0.1394,  0.2454])


In [6]:
print("Total parameters of W :", W.nelement())
print("Total parameters of B and A :", B.nelement() + A.nelement())

Total parameters of W : 100
Total parameters of B and A : 40


In [7]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [8]:
_ = torch.manual_seed(0)

In [9]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST Dataset

mnist_trainset = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size = 10, shuffle = True)


mnist_testset = datasets.MNIST(root = './data', train = False, download = True, transform = transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size = 10, shuffle = True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 35507209.56it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1214502.50it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 10463688.79it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2589794.56it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






device(type='cuda')

In [10]:
class RichBoyNet(nn.Module):
  def __init__(self, hidden_size_1 = 1000, hidden_size_2 = 2000):
    super(RichBoyNet, self).__init__()
    self.linear1 = nn.Linear(28*28, hidden_size_1)
    self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
    self.linear3 = nn.Linear(hidden_size_2, 10)
    self.relu = nn.ReLU()

  def forward(self, img):
    x = img.view(-1, 28*28)
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    return x

net = RichBoyNet().to(device)

In [11]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:45<00:00, 130.59it/s, loss=0.234]


In [12]:
original_weights = {}
for name, param in net.named_parameters():
  original_weights[name] = param.clone().detach()

In [13]:
def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 295.41it/s]

Accuracy: 0.961
wrong counts for the digit 0: 11
wrong counts for the digit 1: 20
wrong counts for the digit 2: 54
wrong counts for the digit 3: 68
wrong counts for the digit 4: 24
wrong counts for the digit 5: 18
wrong counts for the digit 6: 40
wrong counts for the digit 7: 55
wrong counts for the digit 8: 9
wrong counts for the digit 9: 87





In [17]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


In [14]:
class LoRAParameterization(nn.Module):
  def __init__(self, features_in, features_out, rank = 1, alpha = 1, device = 'cpu'):
    super().__init__()
    # We use a random gaussian initalisation for A and zero for B so dW = BA is zero at the beginning of training
    self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
    self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
    nn.init.normal_(self.lora_A, mean = 0, std = 1)

    # We then scale dWx by alpha/rank, where alpha is constant in r

    # When optimising with Adam, tuning alpha is roughly the same as tuning the learning rate if we scale the initialisation appropriately.
    # As a result, we simply set alpha to the first r we try and do not tune it
    # This scaling helps to reduce the need to retune hyperparameters when we vary r
    self.scale = alpha / rank
    self.enabled = True

  def forward(self, original_weights):
    if self.enabled:
      # Return W + (B*A)*Scale
      return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
    else:
      return original_weights

In [15]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank = 1, lora_alpha = 1):
  # Only add the parametrization to the weight matrix ignore the bias

  features_in, features_out = layer.weight.shape
  return LoRAParameterization(
      features_in, features_out, rank = rank, alpha = lora_alpha, device = device
  )

parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parametrization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parametrization(net.linear3, device)
)

def enable_disable_lora(enabled = True):
  for layer in [net.linear1, net.linear2, net.linear3]:
    layer.parametrization["weight"][0].enabled = enabled

In [18]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


In [19]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 106.67it/s, loss=0.0961]
