# Notebook 05: LoRA & Parameter-Efficient Fine-Tuning

## Learning Objectives
- Understand LoRA's mathematical foundation
- Implement LoRALayer from scratch
- Compare parameter counts (full vs LoRA)
- Demonstrate multi-agent LoRA with shared base

## LoRA Theory

Full fine-tuning updates every parameter: $W \leftarrow W_0 + \Delta W$

LoRA decomposes $\Delta W$ as a low-rank product:
$$\Delta W = B \cdot A \quad \text{where } B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k},\ r \ll \min(d, k)$$

Forward pass:
$$h = W_0 x + \frac{\alpha}{r} BAx$$

**Initialization:** $A \sim \mathcal{N}(0, \sigma^2)$, $B = 0$ (so $\Delta W = 0$ initially)

**Parameter reduction:**
- Full fine-tuning: $d \times k$ params
- LoRA (rank r): $r(d + k)$ params
- Example: $d=k=768$, $r=8$: $589{,}824 \to 12{,}288$ (**98% reduction**)

In [None]:
# !pip install torch matplotlib

In [None]:
import sys
sys.path.insert(0, '..')
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
print('Ready!')

## Step 1: Implement LoRA from Scratch

In [None]:
from src.training.lora_utils import LoRALayer, apply_lora_to_linear, count_parameters, freeze_base_model

# Create a LoRA layer
lora = LoRALayer(in_features=768, out_features=768, r=8, alpha=16.0)
print(f'LoRA layer created')
print(f'  in=768, out=768, r=8, alpha=16')
print(f'  Scaling factor: alpha/r = {lora.scaling:.2f}')
print(f'  A shape: {lora.lora_A.weight.shape}')
print(f'  B shape: {lora.lora_B.weight.shape}')
print(f'  Trainable params: {lora.trainable_parameters():,}')

In [None]:
# Compare parameter counts
d, k = 768, 768
print('Parameter count comparison:')
print(f'{"Rank":<8} {"LoRA Params":<15} {"Full Params":<15} {"Reduction":<12}')
print('-' * 50)
for r in [1, 2, 4, 8, 16, 32, 64]:
    lora_p = r * (d + k)
    full_p = d * k
    reduction = (1 - lora_p/full_p) * 100
    print(f'{r:<8} {lora_p:<15,} {full_p:<15,} {reduction:.1f}%')

## Step 2: Multi-Agent LoRA Architecture

The key insight: one frozen base model, multiple lightweight adapters.

```
GPU Memory:
  [Frozen Base Model — 1.5 GB for distilgpt2]
  + [Solver Adapter A_s, B_s — 0.05 MB for r=8]
  + [Critic Adapter A_c, B_c — 0.05 MB for r=8]
  + [Reviser Adapter A_r, B_r — 0.05 MB for r=8]
```

In [None]:
# Demonstrate: apply LoRA to a linear layer and verify forward pass
linear = nn.Linear(768, 768)
freeze_base_model(linear)
lora_layer = apply_lora_to_linear(linear, r=8, alpha=16.0)

x = torch.randn(4, 32, 768)  # batch=4, seq=32, hidden=768
with torch.no_grad():
    base_out = linear(x)
    lora_out = lora_layer(x)

print(f'Input shape:  {x.shape}')
print(f'Base output:  {base_out.shape}')
print(f'LoRA output:  {lora_out.shape}')
delta = (lora_out - base_out).abs().mean()
print(f'Delta W effect (should be small initially): {delta:.6f}')

---

## Exercises

1. **Rank sensitivity:** Try r=1, 4, 8, 64. How does generation quality vs param count trade off?
2. **Alpha tuning:** With r=8, try alpha=8 (scaling=1), 16 (scaling=2), 32 (scaling=4)
3. **Multi-agent demo:** Create solver/critic/reviser LoRA adapters with different random inits
4. **Weight merging:** Implement merge_lora_weights() and verify W_merged = W0 + scaling*B*A
5. **QLoRA:** Research QLoRA (quantized LoRA) — how does 4-bit quantization + LoRA work?