# LoRA implementation modified

The original tutorial was very naive in the experiment design. Here I make some modifications that help to hopefully show better results on the same general problem

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
import tqdm as tqdm

In [2]:
# Make the model determinsitic
_ = torch.manual_seed(0)

In [6]:
# Let's train on only digits 0-7

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

mnist_train_filtered = [(x, y) for (x, y) in mnist_train if y < 8]
mnist_train_unused = [(x, y) for (x, y) in mnist_train if y >= 8]

In [34]:
train_loader = torch.utils.data.DataLoader(mnist_train_filtered, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=32, shuffle=False)

train_loader_unused = torch.utils.data.DataLoader(mnist_train_unused, batch_size=32, shuffle=True)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [9]:
# Use same network from original tutorial 
class Model(nn.Module):
    def __init__(self, hidden_size1=1000, hidden_size2=2000):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(28 * 28, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
model = Model().to(device)

In [10]:
# Train the network for 25 epochs to get decent results
def train(train_loader, model, epochs=25):
    cross_entropy_loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_losses = []

    for epoch in range(epochs):
        model.train()
        loss_sum = 0
        num_iterations = 0

        for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            num_iterations += 1

            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy_loss(outputs, labels)
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()

        avg_loss = loss_sum / num_iterations
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    return train_losses

In [11]:
train_losses = train(train_loader, model)

Epoch 1/25: 100%|██████████| 1507/1507 [00:05<00:00, 252.69it/s]


Epoch 1/25, Loss: 0.1509


Epoch 2/25: 100%|██████████| 1507/1507 [00:05<00:00, 298.19it/s]


Epoch 2/25, Loss: 0.0739


Epoch 3/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.34it/s]


Epoch 3/25, Loss: 0.0589


Epoch 4/25: 100%|██████████| 1507/1507 [00:05<00:00, 294.29it/s]


Epoch 4/25, Loss: 0.0439


Epoch 5/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.72it/s]


Epoch 5/25, Loss: 0.0388


Epoch 6/25: 100%|██████████| 1507/1507 [00:05<00:00, 291.00it/s]


Epoch 6/25, Loss: 0.0343


Epoch 7/25: 100%|██████████| 1507/1507 [00:05<00:00, 296.46it/s]


Epoch 7/25, Loss: 0.0346


Epoch 8/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.55it/s]


Epoch 8/25, Loss: 0.0307


Epoch 9/25: 100%|██████████| 1507/1507 [00:05<00:00, 298.61it/s]


Epoch 9/25, Loss: 0.0232


Epoch 10/25: 100%|██████████| 1507/1507 [00:05<00:00, 292.20it/s]


Epoch 10/25, Loss: 0.0320


Epoch 11/25: 100%|██████████| 1507/1507 [00:05<00:00, 296.26it/s]


Epoch 11/25, Loss: 0.0200


Epoch 12/25: 100%|██████████| 1507/1507 [00:05<00:00, 280.75it/s]


Epoch 12/25, Loss: 0.0261


Epoch 13/25: 100%|██████████| 1507/1507 [00:05<00:00, 278.22it/s]


Epoch 13/25, Loss: 0.0222


Epoch 14/25: 100%|██████████| 1507/1507 [00:05<00:00, 283.07it/s]


Epoch 14/25, Loss: 0.0196


Epoch 15/25: 100%|██████████| 1507/1507 [00:05<00:00, 281.58it/s]


Epoch 15/25, Loss: 0.0173


Epoch 16/25: 100%|██████████| 1507/1507 [00:05<00:00, 282.35it/s]


Epoch 16/25, Loss: 0.0273


Epoch 17/25: 100%|██████████| 1507/1507 [00:05<00:00, 282.66it/s]


Epoch 17/25, Loss: 0.0213


Epoch 18/25: 100%|██████████| 1507/1507 [00:05<00:00, 291.80it/s]


Epoch 18/25, Loss: 0.0175


Epoch 19/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.82it/s]


Epoch 19/25, Loss: 0.0250


Epoch 20/25: 100%|██████████| 1507/1507 [00:05<00:00, 293.32it/s]


Epoch 20/25, Loss: 0.0162


Epoch 21/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.43it/s]


Epoch 21/25, Loss: 0.0238


Epoch 22/25: 100%|██████████| 1507/1507 [00:05<00:00, 280.40it/s]


Epoch 22/25, Loss: 0.0219


Epoch 23/25: 100%|██████████| 1507/1507 [00:05<00:00, 292.83it/s]


Epoch 23/25, Loss: 0.0183


Epoch 24/25: 100%|██████████| 1507/1507 [00:05<00:00, 296.84it/s]


Epoch 24/25, Loss: 0.0195


Epoch 25/25: 100%|██████████| 1507/1507 [00:05<00:00, 295.37it/s]

Epoch 25/25, Loss: 0.0165





In [13]:
# Make a copy of original weights
original_weights = {}
for name, param in model.named_parameters():
    original_weights[name] = param.clone().detach()

In [14]:
# Test the performance 
def test():
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm.tqdm(test_loader, desc="Testing"):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            for idx, i in enumerate(outputs):
                if i.argmax() == labels[idx]:
                    correct += 1
                else:
                    wrong_counts[labels[idx]] += 1
            total += labels.size(0)
    accuracy = correct / total * 100
    print(f"Accuracy: {accuracy:.2f}%")
    for i in range(10):
        print(f"Digit {i}: {wrong_counts[i]} wrong")

In [15]:
test()

Testing: 100%|██████████| 313/313 [00:04<00:00, 74.29it/s]

Accuracy: 79.06%
Digit 0: 6 wrong
Digit 1: 9 wrong
Digit 2: 16 wrong
Digit 3: 15 wrong
Digit 4: 13 wrong
Digit 5: 14 wrong
Digit 6: 17 wrong
Digit 7: 21 wrong
Digit 8: 974 wrong
Digit 9: 1009 wrong





As expected, we perform well on digits 0-7 and fail on 8 and 9. Let's try to finetune using LoRA for only digits 8 and 9 now

In [18]:
total_parameters_original = 0
for idx, layer in enumerate([model.fc1, model.fc2, model.fc3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {idx+1} has {layer.weight.shape} weights and {layer.bias.shape} biases")
print(f"Total parameters: {total_parameters_original:,}")

Layer 1 has torch.Size([1000, 784]) weights and torch.Size([1000]) biases
Layer 2 has torch.Size([2000, 1000]) weights and torch.Size([2000]) biases
Layer 3 has torch.Size([10, 2000]) weights and torch.Size([10]) biases
Total parameters: 2,807,010


In [25]:
class LoRAParameterization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1.0, device="cpu"):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0.0, std=1)
        nn.init.zeros_(self.lora_B)

        self.scale = alpha / rank
        self.enabled = True
    
    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [26]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParameterization(features_in, features_out, rank=rank, alpha=lora_alpha, device=device)

In [27]:
parametrize.register_parametrization(model.fc1, "weight", linear_layer_parametrization(model.fc1, device))
parametrize.register_parametrization(model.fc2, "weight", linear_layer_parametrization(model.fc2, device))
parametrize.register_parametrization(model.fc3, "weight", linear_layer_parametrization(model.fc3, device))

ParametrizedLinear(
  in_features=2000, out_features=10, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRAParameterization()
    )
  )
)

In [28]:
model

Model(
  (fc1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (fc2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (fc3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (relu): ReLU()
)

In [42]:
def enable_disable_lora(enabled=True):
    for layer in [model.fc1, model.fc2, model.fc3]:
        layer.parametrizations["weight"][0].enabled = enabled

In [33]:
# Find total number of parameters added by LoRA
total_parameters_lora = 0
total_parameters_non_lora = 0

for idx, layer in enumerate([model.fc1, model.fc2, model.fc3]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {idx+1} has {layer.weight.shape} weights and {layer.bias.shape} biases")

assert total_parameters_non_lora == total_parameters_original
print(f"Total parameters: {total_parameters_non_lora:,}")
print(f"Total LoRA parameters: {total_parameters_lora:,}")
print(f"Total non-LoRA parameters: {total_parameters_non_lora:,}")
print(f"Total parameters added by LoRA: {total_parameters_lora:,}")

Layer 1 has torch.Size([1000, 784]) weights and torch.Size([1000]) biases
Layer 2 has torch.Size([2000, 1000]) weights and torch.Size([2000]) biases
Layer 3 has torch.Size([10, 2000]) weights and torch.Size([10]) biases
Total parameters: 2,807,010
Total LoRA parameters: 6,794
Total non-LoRA parameters: 2,807,010
Total parameters added by LoRA: 6,794


In [39]:
# Freeze all the other layers
for name, param in model.named_parameters():
    if "lora" not in name:
        print(f"Freezing {name}")
        param.requires_grad = False

Freezing fc1.bias
Freezing fc1.parametrizations.weight.original
Freezing fc2.bias
Freezing fc2.parametrizations.weight.original
Freezing fc3.bias
Freezing fc3.parametrizations.weight.original


In [40]:
train(train_loader_unused, model, epochs=10)

Epoch 1/10: 100%|██████████| 369/369 [00:01<00:00, 231.61it/s]


Epoch 1/10, Loss: 8.5926


Epoch 2/10: 100%|██████████| 369/369 [00:01<00:00, 282.61it/s]


Epoch 2/10, Loss: 0.1853


Epoch 3/10: 100%|██████████| 369/369 [00:01<00:00, 294.08it/s]


Epoch 3/10, Loss: 0.1239


Epoch 4/10: 100%|██████████| 369/369 [00:01<00:00, 297.18it/s]


Epoch 4/10, Loss: 0.0875


Epoch 5/10: 100%|██████████| 369/369 [00:01<00:00, 287.81it/s]


Epoch 5/10, Loss: 0.0731


Epoch 6/10: 100%|██████████| 369/369 [00:01<00:00, 281.20it/s]


Epoch 6/10, Loss: 0.0593


Epoch 7/10: 100%|██████████| 369/369 [00:01<00:00, 285.95it/s]


Epoch 7/10, Loss: 0.0541


Epoch 8/10: 100%|██████████| 369/369 [00:01<00:00, 268.32it/s]


Epoch 8/10, Loss: 0.0500


Epoch 9/10: 100%|██████████| 369/369 [00:01<00:00, 290.70it/s]


Epoch 9/10, Loss: 0.0444


Epoch 10/10: 100%|██████████| 369/369 [00:01<00:00, 286.43it/s]

Epoch 10/10, Loss: 0.0413





[8.592565112320711,
 0.18527764489818718,
 0.12389598642968065,
 0.08751744613919639,
 0.07305905122441167,
 0.059264586043359865,
 0.054050454506423415,
 0.049986466438557475,
 0.044386885734022594,
 0.041255637037638467]

In [43]:
assert torch.all(model.fc1.parametrizations.weight.original == original_weights["fc1.weight"])
assert torch.all(model.fc2.parametrizations.weight.original == original_weights["fc2.weight"])
assert torch.all(model.fc3.parametrizations.weight.original == original_weights["fc3.weight"])

enable_disable_lora(enabled=True)
assert torch.equal(model.fc1.weight, model.fc1.parametrizations.weight.original + (model.fc1.parametrizations.weight[0].lora_B @ model.fc1.parametrizations.weight[0].lora_A).view(model.fc1.weight.shape) * model.fc1.parametrizations.weight[0].scale)

enable_disable_lora(enabled=False)
assert torch.equal(model.fc1.weight, original_weights["fc1.weight"])

In [44]:
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 313/313 [00:05<00:00, 57.85it/s]

Accuracy: 28.37%
Digit 0: 718 wrong
Digit 1: 1134 wrong
Digit 2: 865 wrong
Digit 3: 889 wrong
Digit 4: 977 wrong
Digit 5: 791 wrong
Digit 6: 829 wrong
Digit 7: 919 wrong
Digit 8: 35 wrong
Digit 9: 6 wrong





In [46]:
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 313/313 [00:03<00:00, 83.48it/s]

Accuracy: 79.06%
Digit 0: 6 wrong
Digit 1: 9 wrong
Digit 2: 16 wrong
Digit 3: 15 wrong
Digit 4: 13 wrong
Digit 5: 14 wrong
Digit 6: 17 wrong
Digit 7: 21 wrong
Digit 8: 974 wrong
Digit 9: 1009 wrong



