
# LoRA implementation

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
# make torch deterministic
_ = torch.manual_seed(0)

### we will be training a nn to classify MNIST digigts and then finetune on a particular digit on which it doesn't perform well

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)


# load the datasete
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)


# create dataloader
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


# load the datasete
mnist_testset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)


# create dataloader
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"

In [None]:
# create an overly expensive nn for classification


class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28 * 28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x


net = RichBoyNet().to(device)
net.modules

In [6]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28 * 28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if (
                total_iterations_limit is not None
                and total_iterations >= total_iterations_limit
            ):
                return


train(train_loader, net, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [03:58<00:00, 25.15it/s, loss=0.238]


In [7]:
list(net.named_parameters()),
[name for name, params in net.named_parameters()]

['linear1.weight',
 'linear1.bias',
 'linear2.weight',
 'linear2.bias',
 'linear3.weight',
 'linear3.bias']

### keep a copy of the original weights

In [8]:
original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

original_weights.keys()

dict_keys(['linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias', 'linear3.weight', 'linear3.bias'])

In [9]:
# The the performance of the pretrained network. As we can see, the network performs poorly on the digit 9. Let's fine-tune it on the digit 9


def test():
    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                else:
                    wrong_counts[y[idx]] += 1
                total += 1
    print(f"Accuracy: {round(correct/total, 3)}")
    for i in range(len(wrong_counts)):
        print(f"wrong counts for the digit {i}: {wrong_counts[i]}")


test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 421.72it/s]

Accuracy: 0.955
wrong counts for the digit 0: 12
wrong counts for the digit 1: 30
wrong counts for the digit 2: 44
wrong counts for the digit 3: 106
wrong counts for the digit 4: 28
wrong counts for the digit 5: 21
wrong counts for the digit 6: 32
wrong counts for the digit 7: 37
wrong counts for the digit 8: 9
wrong counts for the digit 9: 128





Let's visualize  how many parameters are in the original network, before introducing LoRA matrices.

In [10]:
# Print the size of the weights matrices of the network
# Save the count of the total number of parameters
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f"Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}")
print(f"Total number of parameters: {total_parameters_original:,}")

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10])
Total number of parameters: 2,807,010


Define the LoRA parameterization as described in the paper. 

The full detail on how PyTorch parameterizations work is here: https://pytorch.org/tutorials/intermediate/parametrizations.html

# We illustrate our reparametrization in Figure 1. We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training. We then scale ∆Wx by α/r , where α is a constant in r. When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. As a result, we simply set α to the first r we try and do not tune it. This scaling helps to reduce the need to retune hyperparameters when we vary r. 

Let's break down this explanation in the context of **LoRA (Low-Rank Adaptation)** for training large models more efficiently. The goal here is to modify the weights of a pretrained model (\( W \)) in a way that’s efficient and doesn’t require retraining all parameters.

---

### Key Terms in Context:
1. **\( \Delta W = BA \):**  
   - \( B \) and \( A \) are low-rank matrices that approximate the updates to \( W \).  
   - Instead of updating \( W \) directly, LoRA expresses changes (\( \Delta W \)) in terms of these smaller matrices.  
   - Initially, \( \Delta W \) is zero to avoid interfering with the pretrained model's performance before training begins.

2. **Random Gaussian Initialization for \( A \):**  
   - The matrix \( A \) is initialized with random values drawn from a Gaussian distribution. This provides a starting point for optimization.

3. **Zero Initialization for \( B \):**  
   - \( B \) starts as a zero matrix, meaning that the initial contribution of \( \Delta W \) is zero. This ensures the model behaves exactly like the pretrained version at the start.

4. **Scaling \( \Delta W \) by \( \alpha/r \):**  
   - **\( \alpha \):** A scaling constant to control the magnitude of \( \Delta W \).  
   - **\( r \):** The rank of the low-rank factorization, i.e., the number of columns in \( A \) and \( B \).  
   - Scaling by \( \alpha/r \) ensures that \( \Delta W \) doesn’t dominate the original \( W \) when \( r \) is small.  

---

### Why Scaling Matters
1. **Consistent Magnitude Across \( r \):**  
   Without scaling, \( \Delta W \) might become disproportionately large or small when \( r \) changes, which would require retuning hyperparameters (like the learning rate) every time \( r \) is adjusted.  
   Scaling by \( \alpha/r \) ensures that the overall contribution of \( \Delta W \) remains consistent, regardless of \( r \).

2. **Reduced Hyperparameter Tuning:**  
   - Scaling simplifies optimization by making the impact of \( r \) predictable.  
   - This is why \( \alpha \) can often be set to the same value for the first \( r \) tried, avoiding the need for extensive hyperparameter tuning.

---

### Connection to Adam Optimizer
- **Adam Optimizer:** Adjusts learning rates dynamically for each parameter.  
- By scaling \( \Delta W \) with \( \alpha/r \), the effective step size during optimization is kept consistent, making \( \alpha \) behave like a proxy for the learning rate.  
- Proper scaling ensures smooth training dynamics, even if \( r \) changes.

---

### Why \( A \) and \( B \) Initialization is Designed This Way:
1. **Zero Initialization for \( B \):**  
   - Prevents \( \Delta W \) from contributing anything initially. This avoids disturbing the pretrained model at the start of fine-tuning.  
   - \( B \) will learn to contribute meaningful changes during training.

2. **Random Gaussian Initialization for \( A \):**  
   - Introduces randomness to \( A \), which provides diversity in the directions that \( \Delta W \) can explore during training.  
   - This randomness allows \( \Delta W \) to adapt effectively as \( B \) updates.

---

### Practical Impact
- The combination of scaling and initialization strategies makes LoRA efficient and robust:  
  - Efficient because \( \Delta W \) is parameterized using smaller matrices \( A \) and \( B \).  
  - Robust because the scaling ensures that the magnitude of updates remains stable, minimizing the need for extensive hyperparameter tuning.

Would you like further clarification on any of these points or a more concrete example of how this works in practice?

Let’s work through an **example** to see how LoRA's initialization and scaling work in practice. We'll build up the intuition step by step.

---

### **Setup:**
We have a pretrained weight matrix \( W \) and we want to fine-tune it using LoRA. Instead of directly updating \( W \), we decompose the update as:

\[
W' = W + \Delta W \quad \text{where} \quad \Delta W = BA
\]

Here:
- \( W \): Original weights (e.g., \( 768 \times 768 \)).
- \( \Delta W \): Low-rank update matrix (\( 768 \times 768 \)).
- \( B \): Low-rank matrix (\( 768 \times r \)).
- \( A \): Low-rank matrix (\( r \times 768 \)).

---

### **Step 1: Initialization**
- **Initialize \( A \):**  
  \( A \) is randomly initialized using a Gaussian distribution (e.g., mean = 0, std = 0.02). This ensures that the directions of updates are diverse.

- **Initialize \( B \):**  
  \( B \) is initialized as a zero matrix. This ensures that \( \Delta W = 0 \) at the start, meaning no changes are applied to \( W \) initially.

---

### **Step 2: Scaling \( \Delta W \)**
We scale \( \Delta W \) by \( \alpha / r \), where:
- \( \alpha \): A constant controlling the magnitude of updates.
- \( r \): The rank of the decomposition.

This scaling ensures that the size of \( \Delta W \) doesn’t explode or vanish when \( r \) changes.

---

### **Code Example**
Let’s implement this step-by-step in Python.

```python
import torch

# Step 1: Initialize W (pretrained weights) and define dimensions
W = torch.randn(768, 768)  # Original weights
r = 4                      # Rank of decomposition
alpha = 16                 # Scaling constant

# Initialize A and B
A = torch.randn(r, 768) * 0.02  # Gaussian initialization for A
B = torch.zeros(768, r)         # Zero initialization for B

# Compute initial Delta W (should be zero)
Delta_W = B @ A  # Shape: (768, 768)
print(f"Initial Delta_W: {Delta_W.norm():.4f}")  # Should be close to 0

# Step 2: Scale Delta W by alpha / r
scaled_Delta_W = (alpha / r) * Delta_W
print(f"Scaled Delta_W: {scaled_Delta_W.norm():.4f}")  # Still 0, as B = 0
```

---

### **Step 3: Training**
During training, \( A \) and \( B \) are updated by the optimizer (e.g., Adam). Let’s simulate one training step to show how \( \Delta W \) evolves.

```python
# Define a simple optimizer
optimizer = torch.optim.Adam([A, B], lr=1e-3)

# Simulate a training step
loss = (W + (alpha / r) * (B @ A)).norm()  # Example loss
loss.backward()  # Compute gradients
optimizer.step()  # Update A and B

# Compute updated Delta W
Delta_W = B @ A
scaled_Delta_W = (alpha / r) * Delta_W
print(f"Updated Delta_W: {scaled_Delta_W.norm():.4f}")
```

---

### **Why Scaling Helps**
1. **Consistent Magnitude:**  
   The scaling ensures that \( \Delta W \)'s magnitude remains controlled regardless of \( r \). For example:
   - If \( r = 4 \): Scaling factor is \( \alpha / 4 = 4 \).
   - If \( r = 16 \): Scaling factor is \( \alpha / 16 = 1 \).  
   This prevents \( \Delta W \) from dominating \( W \) as \( r \) changes.

2. **No Hyperparameter Retuning:**  
   Since the effective step size is scaled consistently, you don’t need to retune the learning rate or other hyperparameters when adjusting \( r \).

---

### **Summary of Key Insights**
1. **Initialization:**  
   \( B = 0 \) ensures \( \Delta W = 0 \) initially, so the pretrained model starts unchanged.  
   \( A \) is randomly initialized to allow flexibility in updates.

2. **Scaling by \( \alpha / r \):**  
   Controls the magnitude of \( \Delta W \) and ensures stability across different \( r \).

3. **Efficiency:**  
   By keeping \( r \) small (e.g., 2–16), LoRA reduces the number of trainable parameters while still capturing meaningful updates to \( W \).

Would you like to dive deeper into how this impacts specific use cases, such as fine-tuning large language models?

Let’s clarify why **scaling by \( \alpha / r \)** ensures that \( \Delta W \)’s magnitude remains controlled, regardless of \( r \).

---

### **What is happening?**
- \( \Delta W = BA \): This is the low-rank update matrix.
- Scaling by \( \alpha / r \): This adjusts the overall magnitude of \( \Delta W \), depending on the rank \( r \).

Without scaling, \( \Delta W \)’s size could vary significantly as \( r \) changes. For example:
- A larger \( r \) (more components in \( B \) and \( A \)) would make \( \Delta W \) larger, potentially overwhelming \( W \).
- A smaller \( r \) would make \( \Delta W \) too small, reducing its impact.

By introducing the factor \( \alpha / r \), we ensure that \( \Delta W \)’s **overall contribution** is consistent.

---

### **Example with Numbers**
Let’s say:
- \( W \) is the original weight matrix, initialized with a norm of \( \|W\| = 10 \).
- \( \Delta W \) is the update matrix we compute using \( B \) and \( A \).
- \( \alpha = 16 \): A constant that controls the overall scaling.

#### Case 1: \( r = 4 \)
- Scaling factor: \( \alpha / r = 16 / 4 = 4 \).
- Without scaling, \( \| \Delta W \| \) might be, say, \( 0.5 \).
- After scaling:  
  \[
  \| \Delta W \| = 4 \times 0.5 = 2
  \]

#### Case 2: \( r = 16 \)
- Scaling factor: \( \alpha / r = 16 / 16 = 1 \).
- Without scaling, \( \| \Delta W \| \) might be larger due to the increased rank, say \( 2 \).
- After scaling:  
  \[
  \| \Delta W \| = 1 \times 2 = 2
  \]

---

### **Why This Matters**
The scaling ensures:
1. **Consistent Magnitude:**  
   Regardless of the rank \( r \), the effective size of \( \Delta W \) stays controlled. This prevents \( \Delta W \) from overshadowing \( W \) (or becoming negligible).

2. **Hyperparameter Stability:**  
   Since \( \Delta W \)’s magnitude doesn’t depend on \( r \), you don’t need to retune hyperparameters like the learning rate when changing \( r \).

---

### **Takeaway**
Scaling \( \Delta W \) by \( \alpha / r \) ensures that the low-rank updates are balanced and don’t overwhelm the original model’s weights. It’s a way to keep fine-tuning stable and predictable, regardless of how many low-rank components (\( r \)) you choose.

In [11]:
class LoRAParameterization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device="cpu"):
        super().__init__()
        # section 4.1 of the paper:
        #   we use a random Gaussian initialization for A and B, so ∆W = BA is zero at the beginning of training.
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out))).to(device)
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank))).to(device)

        nn.init.normal_(self.lora_A, mean=0, std=1)  # initialize gaussuan distribution

        # section 4.1 of the paper:
        #  we scale ∆Wx by α/r , where α is a constant in r.
        # when optimizing with adam, tuning α is roughly same as tuning the learning rate  if we scale the initilization appropriately
        # as a result, we simply set to the first r  and try to not tune it.
        # this scaling helps to reduce  the need to retune hyperparameters when we vary r
        self.scale = alpha / rank  # alpha is fixed and we only try rank
        self.enabled = True  # if lora is enabled.. we only run loRA on weights

    def forward(self, original_weights):
        if self.enabled:
            # return X + (B*A)*scale
            return (
                original_weights
                + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape)
                * self.scale
            )
        else:
            return original_weights

Add the parameterization to our network

In [12]:
import torch.nn.utils.parametrize as parametrize


def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    # only add the parameterization to the weight matrix, ignore the bias

    # from section 4.2
    # we limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules(as they are not trained in downstream taks)

    features_in, features_out = layer.weight.shape
    return LoRAParameterization(features_in, features_out, rank, lora_alpha, device)


parametrize.register_parametrization(
    net.linear1,
    "weight",
    linear_layer_parametrization(
        net.linear1, device=device
    ),  # replace weight matrix of linear1 layer with `linear_layer_parametrization(net.linear1, device=device)` fn`
)  # we will get original weights and we will just alter them

parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parametrization(net.linear2, device=device)
)

parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parametrization(net.linear3, device=device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Display the number of parameters added by LoRA.



In [13]:
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += (
        layer.parametrizations["weight"][0].lora_A.nelement()
        + layer.parametrizations["weight"][0].lora_B.nelement()
    )
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f"Total number of parameters (original): {total_parameters_non_lora:,}")
print(
    f"Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}"
)
print(f"Parameters introduced by LoRA: {total_parameters_lora:,}")
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f"Parameters incremment: {parameters_incremment:.3f}%")

Layer 1: W: torch.Size([1000, 784]) + B: torch.Size([1000]) + Lora_A: torch.Size([1, 784]) + Lora_B: torch.Size([1000, 1])
Layer 2: W: torch.Size([2000, 1000]) + B: torch.Size([2000]) + Lora_A: torch.Size([1, 1000]) + Lora_B: torch.Size([2000, 1])
Layer 3: W: torch.Size([10, 2000]) + B: torch.Size([10]) + Lora_A: torch.Size([1, 2000]) + Lora_B: torch.Size([10, 1])
Total number of parameters (original): 2,807,010
Total number of parameters (original + LoRA): 2,813,804
Parameters introduced by LoRA: 6,794
Parameters incremment: 0.242%


Freeze all the parameters of the original network and only fine tuning the ones introduced by LoRA. Then fine-tune the model on the digit 9 and only for 100 batches.



In [14]:
[name for name, params in net.named_parameters()]

['linear1.bias',
 'linear1.parametrizations.weight.original',
 'linear1.parametrizations.weight.0.lora_A',
 'linear1.parametrizations.weight.0.lora_B',
 'linear2.bias',
 'linear2.parametrizations.weight.original',
 'linear2.parametrizations.weight.0.lora_A',
 'linear2.parametrizations.weight.0.lora_B',
 'linear3.bias',
 'linear3.parametrizations.weight.original',
 'linear3.parametrizations.weight.0.lora_A',
 'linear3.parametrizations.weight.0.lora_B']

In [15]:
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    print(f"{name=}")
    if "lora" not in name:
        print(f"Freezing non-LoRA parameter {name}")
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, net, epochs=1, total_iterations_limit=100)

name='linear1.bias'
Freezing non-LoRA parameter linear1.bias
name='linear1.parametrizations.weight.original'
Freezing non-LoRA parameter linear1.parametrizations.weight.original
name='linear1.parametrizations.weight.0.lora_A'
name='linear1.parametrizations.weight.0.lora_B'
name='linear2.bias'
Freezing non-LoRA parameter linear2.bias
name='linear2.parametrizations.weight.original'
Freezing non-LoRA parameter linear2.parametrizations.weight.original
name='linear2.parametrizations.weight.0.lora_A'
name='linear2.parametrizations.weight.0.lora_B'
name='linear3.bias'
Freezing non-LoRA parameter linear3.bias
name='linear3.parametrizations.weight.original'
Freezing non-LoRA parameter linear3.parametrizations.weight.original
name='linear3.parametrizations.weight.0.lora_A'
name='linear3.parametrizations.weight.0.lora_B'


Epoch 1:  99%|█████████▉| 99/100 [00:00<00:00, 121.39it/s, loss=0.11] 


Verify that the fine-tuning didn't alter the original weights, but only the ones introduced by LoRA.



In [16]:
# Check that the frozen parameters are still unchanged by the finetuning
assert torch.all(
    net.linear1.parametrizations.weight.original == original_weights["linear1.weight"]
)
assert torch.all(
    net.linear2.parametrizations.weight.original == original_weights["linear2.weight"]
)
assert torch.all(
    net.linear3.parametrizations.weight.original == original_weights["linear3.weight"]
)

enable_disable_lora(enabled=True)
# The new linear1.weight is obtained by the "forward" function of our LoRA parametrization
# The original weights have been moved to net.linear1.parametrizations.weight.original
# More info here: https://pytorch.org/tutorials/intermediate/parametrizations.html#inspecting-a-parametrized-module
assert torch.equal(
    net.linear1.weight,
    net.linear1.parametrizations.weight.original
    + (
        net.linear1.parametrizations.weight[0].lora_B
        @ net.linear1.parametrizations.weight[0].lora_A
    )
    * net.linear1.parametrizations.weight[0].scale,
)

enable_disable_lora(enabled=False)
# If we disable LoRA, the linear1.weight is the original one
assert torch.equal(net.linear1.weight, original_weights["linear1.weight"])

Test the network with LoRA enabled (the digit 9 should be classified better)



In [17]:
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 281.83it/s]

Accuracy: 0.941
wrong counts for the digit 0: 21
wrong counts for the digit 1: 33
wrong counts for the digit 2: 59
wrong counts for the digit 3: 145
wrong counts for the digit 4: 137
wrong counts for the digit 5: 29
wrong counts for the digit 6: 63
wrong counts for the digit 7: 63
wrong counts for the digit 8: 23
wrong counts for the digit 9: 14





Test the network with LoRA disabled (the accuracy and errors counts must be the same as the original network)



In [18]:
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 404.89it/s]

Accuracy: 0.955
wrong counts for the digit 0: 12
wrong counts for the digit 1: 30
wrong counts for the digit 2: 44
wrong counts for the digit 3: 106
wrong counts for the digit 4: 28
wrong counts for the digit 5: 21
wrong counts for the digit 6: 32
wrong counts for the digit 7: 37
wrong counts for the digit 8: 9
wrong counts for the digit 9: 128





In [19]:
net.modules

<bound method Module.modules of RichBoyNet(
  (linear1): ParametrizedLinear(
    in_features=784, out_features=1000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=1000, out_features=2000, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=2000, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParameterization()
      )
    )
  )
  (relu): ReLU()
)>

    Here’s a detailed, **command-by-command** explanation of the code, broken into its key sections:

---

### **1. Creating the Neural Network (`RichBoyNet`)**
```python
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)  # Input layer: 28x28 (image flattened) to hidden layer 1
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)  # Hidden layer 1 to hidden layer 2
        self.linear3 = nn.Linear(hidden_size_2, 10)  # Hidden layer 2 to output layer (10 classes)
        self.relu = nn.ReLU()  # ReLU activation function
```

- **Purpose:** Define an unnecessarily large neural network (hidden layers with 1000 and 2000 units) for classifying MNIST images into 10 classes (digits 0–9).
- `nn.Linear`: A fully connected (dense) layer.
- `ReLU`: Rectified Linear Unit activation introduces non-linearity.

---

```python
def forward(self, img):
    x = img.view(-1, 28 * 28)  # Flatten the input (batch_size, 28x28 -> batch_size, 784)
    x = self.relu(self.linear1(x))  # Apply linear1 and ReLU
    x = self.relu(self.linear2(x))  # Apply linear2 and ReLU
    x = self.linear3(x)  # Final layer (no activation since CrossEntropyLoss will handle softmax)
    return x
```

- **Purpose:** Define how the data flows through the network. Each layer applies its operation sequentially.
- `img.view(-1, 28 * 28)`: Flattens the image tensor for input to `linear1`.

---

### **2. Training the Model**
```python
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()  # Loss function: Cross-entropy for classification
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)  # Adam optimizer
```

- **Purpose:** Define a training loop.
- `nn.CrossEntropyLoss`: Combines `Softmax` and negative log-likelihood, suitable for multi-class classification.
- `torch.optim.Adam`: Adaptive learning rate optimizer.

---

```python
    for epoch in range(epochs):  # Iterate through epochs
        net.train()  # Set the model to training mode
        loss_sum = 0
        num_iterations = 0
        for data in tqdm(train_loader, desc=f"Epoch {epoch+1}"):  # Iterate over batches
            x, y = data  # Extract input (x) and labels (y)
            x, y = x.to(device), y.to(device)  # Move to device (CPU/GPU)
            optimizer.zero_grad()  # Clear gradients from previous step
            output = net(x.view(-1, 28 * 28))  # Forward pass
            loss = cross_el(output, y)  # Compute loss
            loss_sum += loss.item()  # Accumulate loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update parameters
```

- **Purpose:** Train the model over multiple epochs, calculate loss, and optimize weights using backpropagation.
- `tqdm`: Progress bar for visualization.
- `optimizer.zero_grad()`: Clears gradients to avoid accumulation.
- `loss.backward()`: Computes gradients for each parameter.
- `optimizer.step()`: Updates model parameters based on gradients.

---

### **3. Testing the Model**
```python
def test():
    correct = 0
    total = 0
    wrong_counts = [0 for i in range(10)]  # Track misclassifications for each digit
```

- **Purpose:** Evaluate model performance on test data and identify digits that are often misclassified.

---

```python
    with torch.no_grad():  # Disable gradient computation for inference
        for data in tqdm(test_loader, desc="Testing"):  # Iterate through test data
            x, y = data
            x, y = x.to(device), y.to(device)
            output = net(x.view(-1, 784))  # Forward pass
            for idx, i in enumerate(output):  # Iterate over predictions
                if torch.argmax(i) == y[idx]:  # Correct prediction
                    correct += 1
                else:  # Incorrect prediction
                    wrong_counts[y[idx]] += 1
                total += 1
```

- **Purpose:** Loop through the test dataset, compute predictions, and count correct and incorrect classifications.
- `torch.no_grad`: Saves memory and speeds up computation by disabling gradient tracking.

---

### **4. LoRA Parameterization**
LoRA (Low-Rank Adaptation) introduces a parameter-efficient method to fine-tune large models by only modifying a subset of the model's weights.

```python
class LoRAParameterization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device="cpu"):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out))).to(device)
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank))).to(device)
        nn.init.normal_(self.lora_A, mean=0, std=1)  # Gaussian initialization
        self.scale = alpha / rank  # Scaling factor
        self.enabled = True  # Enable/disable LoRA
```

- **Purpose:** Define low-rank matrices (`lora_A`, `lora_B`) that adapt the original weight matrix during fine-tuning.
- `rank`: Controls the low-rank approximation. Lower ranks reduce memory usage.
- `alpha`: A scaling factor to adjust the update's magnitude.

---

```python
def forward(self, original_weights):
    if self.enabled:
        return (
            original_weights
            + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape)
            * self.scale
        )
    else:
        return original_weights
```

- **Purpose:** Add the low-rank update (`B @ A * scale`) to the original weights during forward passes when LoRA is enabled.

---

### **5. Registering LoRA Parameterization**
```python
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device=device)
)
```

- **Purpose:** Apply the LoRA parameterization to specific layers (`linear1`, `linear2`, `linear3`) by replacing their weights with the LoRA-adjusted weights.

---

### **6. Training with LoRA**
```python
for name, param in net.named_parameters():
    if "lora" not in name:
        param.requires_grad = False  # Freeze original weights
```

- **Purpose:** Fine-tune only the LoRA parameters (`lora_A`, `lora_B`) while freezing the original weights.

---

### **7. Results and Assertions**
```python
assert torch.all(
    net.linear1.weight == original_weights["linear1.weight"]
    + (net.linear1.parametrizations.weight[0].lora_B @ net.linear1.parametrizations.weight[0].lora_A)
    * net.linear1.parametrizations.weight[0].scale
)
```

- **Purpose:** Verify that the LoRA-adjusted weights match the expected values.

---

### **8. Summary**
This code:
1. Creates a large neural network (`RichBoyNet`).
2. Trains and tests the network on MNIST.
3. Introduces LoRA to efficiently fine-tune the model with minimal additional parameters.
4. Demonstrates LoRA's utility by training only on the digit "9" while freezing most parameters.

# 1. Base Neural Network (RichBoyNet)

```python
class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet, self).__init__()
        self.linear1 = nn.Linear(28 * 28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
```
- Inherits from `nn.Module`, PyTorch's base class for neural networks
- Takes two parameters for hidden layer sizes (default 1000 and 2000)
- Creates three linear layers:
  - Input layer: 784 (28×28 flattened MNIST image) → 1000 neurons
  - Hidden layer: 1000 → 2000 neurons
  - Output layer: 2000 → 10 neurons (one per digit)
- Uses ReLU activation function between layers

```python
def forward(self, img):
    x = img.view(-1, 28 * 28)  # Flatten image
    x = self.relu(self.linear1(x))  # First layer + ReLU
    x = self.relu(self.linear2(x))  # Second layer + ReLU
    x = self.linear3(x)  # Output layer (no ReLU)
    return x
```
- Defines forward pass through network
- Flattens input image to 784 dimensions
- Applies linear transformations with ReLU between layers
- Returns raw logits (no softmax needed with CrossEntropyLoss)

# 2. Training Function

```python
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
```
- Sets up CrossEntropyLoss for classification
- Uses Adam optimizer with learning rate 0.001
- Can limit total training iterations if specified

```python
    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
```
- Loops through specified number of epochs
- Sets network to training mode
- Creates progress bar with tqdm
- Handles iteration limit if specified

```python
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
```
- Processes each batch of data
- Moves data to specified device (CPU/GPU)

```python
            optimizer.zero_grad()
            output = net(x.view(-1, 28 * 28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
```
- Zeroes gradients before each backward pass
- Forward pass through network
- Calculates loss
- Updates progress bar with average loss

```python
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
```
- Backpropagates gradients
- Updates weights
- Checks iteration limit

# 3. LoRA Implementation

```python
class LoRAParameterization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device="cpu"):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank, features_out))).to(device)
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank))).to(device)
```
- Creates LoRA matrices A and B
- A is rank × output_features
- B is input_features × rank
- Both initialized as zero tensors

```python
        nn.init.normal_(self.lora_A, mean=0, std=1)  # initialize gaussian distribution
        self.scale = alpha / rank
        self.enabled = True
```
- Initializes A matrix with normal distribution
- Sets scaling factor (alpha/rank as per LoRA paper)
- Enables LoRA by default

```python
    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights
```
- Implements LoRA update rule: W = W₀ + BA × (α/r)
- Returns original weights if LoRA disabled

# 4. LoRA Application

```python
def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParameterization(features_in, features_out, rank, lora_alpha, device)
```
- Creates LoRA parametrization for a linear layer
- Extracts input/output dimensions from layer
- Returns LoRA module with specified rank and alpha

```python
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device=device)
)
```
- Registers LoRA parametrization for each linear layer
- Uses PyTorch's parametrization system
- Applies to weight matrices only (not biases)

# 5. Parameter Management

```python
def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled
```
- Toggles LoRA on/off for all layers
- Allows switching between original and LoRA-modified weights

```python
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if "lora" not in name:
        param.requires_grad = False
```
- Freezes original network parameters
- Only LoRA parameters will be updated during training

# 6. Parameter Counting and Verification

```python
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_lora += (
        layer.parametrizations["weight"][0].lora_A.nelement() +
        layer.parametrizations["weight"][0].lora_B.nelement()
    )
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
```
- Counts parameters in original network and LoRA additions
- Verifies parameter counts match expectations
- Calculates parameter increase percentage

# 7. Fine-tuning Setup

```python
# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
```
- Creates dataset with only digit 9
- Used for specialized fine-tuning

# 8. Verification Steps

```python
assert torch.all(
    net.linear1.parametrizations.weight.original == original_weights["linear1.weight"]
)
```
- Verifies original weights unchanged after fine-tuning
- Checks LoRA weight computation correctness
- Confirms proper enabling/disabling of LoRA

This implementation demonstrates how LoRA can be used to efficiently fine-tune a neural network by:
1. Adding few trainable parameters (LoRA matrices)
2. Preserving original weights
3. Allowing easy switching between original and adapted behavior
4. Maintaining model performance while reducing memory requirements

The code follows the principles outlined in the LoRA paper while providing a practical PyTorch implementation.

This code snippet demonstrates how **LoRA (Low-Rank Adaptation)** can be implemented using PyTorch's parameterization utilities (`torch.nn.utils.parametrize`). Here's a detailed explanation:

---

### **1. What is happening here?**

- **LoRA Concept**: 
  Instead of updating the entire weight matrix during training, LoRA modifies only a low-rank update (\( \Delta W = BA \)) while keeping the original weights frozen. This reduces the number of trainable parameters and computational cost.

- **Parameterization**: 
  PyTorch's `parametrize` module lets you modify how parameters are represented without directly altering the original parameter. Here, the weight matrix of a linear layer is "parameterized" with LoRA, meaning it gets split into a low-rank approximation \( \Delta W \).

---

### **2. Code Breakdown**

#### **(a) `linear_layer_parametrization` function**
```python
def linear_layer_parametrization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParameterization(features_in, features_out, rank, lora_alpha, device)
```

- **Purpose**: This function creates a custom parameterization for a linear layer’s weight matrix using the `LoRAParameterization` class (assumed to be defined elsewhere).
- **Parameters**:
  - `rank`: The rank of the low-rank decomposition (\( \Delta W = BA \), where \( B \in \mathbb{R}^{m \times r} \) and \( A \in \mathbb{R}^{r \times n} \)).
  - `lora_alpha`: A scaling factor to control the magnitude of \( \Delta W \) relative to the original weight matrix.
  - `device`: Specifies where the parameterization will reside (e.g., GPU or CPU).

---

#### **(b) Registering Parameterizations**
```python
parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parametrization(net.linear1, device=device)
)
```

- **What happens here?**
  - The `weight` parameter of `net.linear1` is replaced with a parameterized version defined by `linear_layer_parametrization`.
  - The original `weight` is preserved, and any changes made by LoRA (\( \Delta W \)) are added to it during the forward pass.

- **Why?**
  - This approach allows you to train only the low-rank components (\( B \) and \( A \)) while freezing the original weights.
  - It’s efficient and keeps the changes modular.

---

#### **(c) Enabling/Disabling LoRA**
```python
def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled
```

- **Purpose**: Toggles LoRA updates for the layers (`linear1`, `linear2`, `linear3`).
- **How?**
  - Each parameterized weight has an `enabled` flag (defined in `LoRAParameterization`) that controls whether \( \Delta W \) is applied during the forward pass.
  - When `enabled=False`, the forward pass uses only the original weight \( W \), ignoring the LoRA updates.

---

### **3. Why Use `parametrize`?**

- **Non-Intrusive Modifications**: The original weight matrix remains untouched. LoRA’s updates are modularly added.
- **Flexibility**: You can enable or disable LoRA dynamically, making it easier to experiment with fine-tuning strategies.
- **Scalability**: The same approach can be applied to multiple layers without modifying their original implementation.

---

### **4. Practical Intuition**

- **Think of the layers as a trainable filter**:  
  LoRA adds a small, tunable adjustment (\( \Delta W \)) to the pre-trained weights. You can "switch on" these adjustments to fine-tune for a specific task or "switch off" to revert to the original model.

- **Efficient Fine-Tuning**: Instead of retraining all parameters in the network, you only train the low-rank components, saving computation and memory.

---

### **5. Key Advantages**

1. **Low Memory Footprint**: By focusing only on \( B \) and \( A \), you train fewer parameters compared to full fine-tuning.
2. **Ease of Integration**: The `parametrize` module simplifies adding and managing LoRA updates.
3. **Dynamic Control**: The `enable_disable_lora` function provides a clean way to toggle LoRA for experimentation.

Would you like me to explain or expand on specific parts, such as how `LoRAParameterization` might work?

The `register_parametrization` function in PyTorch is a powerful tool for applying parameter transformations (e.g., custom modifications like **LoRA**) to layers in a model. Let’s break it down with a simple, **intuitive example** that explains its working.

---

### **Example: Applying a Scaling Parametrization**

Suppose you have a linear layer (`nn.Linear`) and you want to scale its weights dynamically by a learnable factor. You can use `register_parametrization` to apply this scaling without directly modifying the layer's original weights.

---

### **Step-by-Step Code Explanation**

#### 1. **Define a Parametrization Class**
The parametrization defines how the original parameter (weights) is transformed.

```python
import torch
import torch.nn as nn
from torch.nn.utils import parametrize

# Parametrization class: scales weights by a learnable factor
class ScaleParametrization(nn.Module):
    def __init__(self, scale_factor=1.0):
        super().__init__()
        # Learnable scaling factor
        self.scale = nn.Parameter(torch.tensor(scale_factor, dtype=torch.float32))

    def forward(self, weights):
        # Scale the weights dynamically
        return weights * self.scale
```

---

#### 2. **Create a Simple Model**
Define a model with a single `nn.Linear` layer.

```python
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(3, 2)  # A linear layer with input size 3 and output size 2

    def forward(self, x):
        return self.linear(x)
```

---

#### 3. **Register the Parametrization**
Apply the scaling parametrization to the weights of the `linear` layer.

```python
# Create the model
model = SimpleModel()

# Print original weights
print("Original weights:", model.linear.weight)

# Register the scaling parametrization
parametrize.register_parametrization(model.linear, "weight", ScaleParametrization(scale_factor=2.0))

# Print parametrized weights (scaled)
print("Parametrized weights:", model.linear.weight)
```

---

#### 4. **Observe the Transformation**
The `ScaleParametrization` scales the weights dynamically during forward passes, while the original weights remain unchanged.

---

#### 5. **Train the Model**
You can now train the model, and the scaling factor will be updated during optimization.

```python
# Define an optimizer and loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Dummy input and target
x = torch.tensor([[1.0, 2.0, 3.0]])
target = torch.tensor([[0.5, 1.5]])

# Training step
for _ in range(5):
    optimizer.zero_grad()
    output = model(x)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    # Observe changes
    print("Scaled weights:", model.linear.weight)
    print("Scaling factor:", model.linear.parametrizations.weight[0].scale)
```

---

### **Explanation of Key Concepts**

1. **Original vs. Parametrized Weights**:
   - The `parametrize.register_parametrization` applies a transformation (`ScaleParametrization`) to the `weight` tensor of the `linear` layer.
   - The original weights are preserved and can be accessed using:
     ```python
     model.linear.parametrizations.weight.original
     ```

2. **Dynamic Updates**:
   - During training, the scaling factor (`scale`) is updated, which changes the parametrized weights without modifying the original weights.

3. **Accessing Components**:
   - **Original weights**: `model.linear.parametrizations.weight.original`
   - **Parametrized weights**: `model.linear.weight`
   - **Scaling factor**: `model.linear.parametrizations.weight[0].scale`

---

### **Output Example**

After registering the parametrization:

- **Original weights**:  
  ```
  tensor([[0.5, -0.2, 0.1], [0.3, 0.4, -0.5]])
  ```
- **Parametrized weights (scaled)**:  
  ```
  tensor([[1.0, -0.4, 0.2], [0.6, 0.8, -1.0]])  # Original weights * scale_factor (2.0)
  ```

As training progresses, the scaling factor will adjust, influencing the scaled weights.

---

This demonstrates how `register_parametrization` allows you to dynamically transform parameters while keeping the original values intact.