# LoRA: Low-Rank Adaptation of Large Language Models

## Problem Statement

Implement LoRA (Low-Rank Adaptation), a parameter-efficient fine-tuning technique that adds trainable low-rank matrices to frozen pretrained weights.

## Background

### The Problem with Full Fine-Tuning

Fine-tuning all parameters of a large model is:
- **Memory intensive**: Need to store gradients for billions of parameters
- **Storage heavy**: Each task requires a full model copy
- **Slow**: Updating all parameters takes significant time

For a 7B parameter model with fp32:
- Model weights: 28 GB
- Gradients: 28 GB  
- Optimizer states (Adam): 56 GB
- **Total: ~112 GB just for training!**

### LoRA: The Key Insight

The weight updates during fine-tuning have **low intrinsic rank**. Instead of updating the full weight matrix $W \in \mathbb{R}^{d \times k}$, we can approximate the update as:

$$W' = W + \Delta W = W + BA$$

where:
- $B \in \mathbb{R}^{d \times r}$ (down-projection)
- $A \in \mathbb{R}^{r \times k}$ (up-projection)
- $r \ll \min(d, k)$ is the rank (typically 4-64)

### Benefits

1. **Fewer trainable parameters**: From $d \times k$ to $r \times (d + k)$
2. **No inference latency**: Merge $BA$ into $W$ at deployment
3. **Task switching**: Swap LoRA adapters without reloading base model
4. **Memory efficient**: Only store/update low-rank matrices

## Mathematical Formulation

### Standard Linear Layer

$$h = Wx$$

### With LoRA

$$h = Wx + \frac{\alpha}{r}BAx$$

where:
- $\alpha$ is a scaling factor (typically equals $r$)
- $\frac{\alpha}{r}$ normalizes the contribution
- $A$ is initialized from $\mathcal{N}(0, \sigma^2)$
- $B$ is initialized to zeros (so $\Delta W = 0$ initially)

### Parameter Count

For a weight matrix $W \in \mathbb{R}^{d \times k}$:
- Original: $d \times k$ parameters
- LoRA: $r \times (d + k)$ parameters
- Reduction: $\frac{r(d+k)}{dk} \approx \frac{r}{\min(d,k)}$ for large matrices

Example (GPT-3 175B attention):
- $d = k = 12288$, $r = 8$
- Original: 150M params per layer
- LoRA: 196K params per layer (0.13%)

## Learning Objectives

1. Understand why low-rank adaptation works
2. Implement a LoRA linear layer from scratch
3. Learn proper initialization (A ~ N(0,1), B = 0)
4. Understand merging for inference
5. Know common hyperparameter choices (rank, alpha, target modules)

## Requirements

1. `LoRALayer` class that wraps a linear layer with low-rank adapters
2. Proper initialization (B=0, A~N(0,1))
3. `merge()` and `unmerge()` methods for inference
4. Demonstration of parameter efficiency

## Hints

1. Initialize B to zeros so initial output equals pretrained model
2. Use `nn.Parameter` for trainable A and B matrices
3. Freeze the original weight with `requires_grad = False`
4. The scaling factor is typically $\alpha / r$ where $\alpha = r$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
import math

torch.manual_seed(42)

## Implementation

In [None]:
class LoRALayer(nn.Module):
    """
    LoRA (Low-Rank Adaptation) layer that wraps a linear layer.
    
    The forward pass computes:
        h = Wx + (alpha/r) * BAx
    
    where W is frozen and only A, B are trained.
    """
    
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        """
        Args:
            in_features: Input dimension
            out_features: Output dimension
            rank: Rank of the low-rank matrices (r)
            alpha: Scaling factor (typically set equal to rank)
            dropout: Dropout probability for LoRA path
        """
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        # Original frozen weight (pretrained)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        
        # LoRA matrices
        # A: down-projection (in_features -> rank)
        # B: up-projection (rank -> out_features)
        self.lora_A = nn.Parameter(torch.empty(rank, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        # Track if weights are merged
        self.merged = False
        
        # Initialize
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize parameters."""
        # Initialize main weight like nn.Linear
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        
        # LoRA initialization:
        # A: random Gaussian
        # B: zeros (so delta_W starts at 0)
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: h = Wx + (alpha/r) * B(Ax)
        """
        # Base linear transformation
        result = F.linear(x, self.weight, self.bias)
        
        if not self.merged:
            # LoRA path: x -> A -> dropout -> B -> scale
            lora_out = self.dropout(x)
            lora_out = F.linear(lora_out, self.lora_A)  # x @ A.T
            lora_out = F.linear(lora_out, self.lora_B)  # (xA.T) @ B.T = x @ (BA).T
            result = result + self.scaling * lora_out
        
        return result
    
    def merge(self):
        """Merge LoRA weights into main weight for inference."""
        if not self.merged:
            # W' = W + (alpha/r) * BA
            delta_w = self.scaling * (self.lora_B @ self.lora_A)
            self.weight.data += delta_w
            self.merged = True
    
    def unmerge(self):
        """Unmerge LoRA weights (restore original for training)."""
        if self.merged:
            delta_w = self.scaling * (self.lora_B @ self.lora_A)
            self.weight.data -= delta_w
            self.merged = False
    
    def lora_parameters(self):
        """Return only LoRA parameters for optimizer."""
        return [self.lora_A, self.lora_B]
    
    @property
    def num_lora_params(self) -> int:
        """Number of trainable LoRA parameters."""
        return self.lora_A.numel() + self.lora_B.numel()
    
    @property
    def num_base_params(self) -> int:
        """Number of base (frozen) parameters."""
        return self.weight.numel() + self.bias.numel()

In [None]:
def apply_lora_to_model(
    model: nn.Module,
    rank: int = 4,
    alpha: float = 1.0,
    target_modules: Optional[list] = None,
) -> nn.Module:
    """
    Apply LoRA to specific modules in a model.
    
    Args:
        model: The model to modify
        rank: LoRA rank
        alpha: LoRA alpha
        target_modules: List of module names to apply LoRA to
                       (default: all Linear layers)
    
    Returns:
        Modified model with LoRA layers
    """
    if target_modules is None:
        target_modules = []
    
    # Freeze all parameters first
    for param in model.parameters():
        param.requires_grad = False
    
    # Replace target linear layers with LoRA versions
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Check if this module should get LoRA
            if not target_modules or any(t in name for t in target_modules):
                # Create LoRA layer with same dimensions
                lora_layer = LoRALayer(
                    in_features=module.in_features,
                    out_features=module.out_features,
                    rank=rank,
                    alpha=alpha,
                )
                
                # Copy pretrained weights
                lora_layer.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    lora_layer.bias.data = module.bias.data.clone()
                
                # Freeze base weights
                lora_layer.weight.requires_grad = False
                lora_layer.bias.requires_grad = False
                
                # Replace in parent module
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                if parent_name:
                    parent = dict(model.named_modules())[parent_name]
                else:
                    parent = model
                setattr(parent, child_name, lora_layer)
    
    return model

In [None]:
def count_parameters(model: nn.Module) -> tuple:
    """Count trainable and total parameters."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable, total

## Testing

In [None]:
# Test basic LoRA layer
print("=== Testing LoRA Layer ===")

in_features = 768
out_features = 768
rank = 8
batch_size = 4

lora_layer = LoRALayer(in_features, out_features, rank=rank, alpha=rank)

print(f"Input dim: {in_features}, Output dim: {out_features}, Rank: {rank}")
print(f"Base parameters: {lora_layer.num_base_params:,}")
print(f"LoRA parameters: {lora_layer.num_lora_params:,}")
print(f"Parameter reduction: {lora_layer.num_lora_params / lora_layer.num_base_params * 100:.2f}%")

# Forward pass
x = torch.randn(batch_size, in_features)
y = lora_layer(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
assert y.shape == (batch_size, out_features)
print("LoRA layer forward pass: PASSED")

In [None]:
# Test that initial output matches base model (B=0)
print("\n=== Testing Initial Equivalence ===")

# Create fresh LoRA layer
lora_layer = LoRALayer(in_features, out_features, rank=rank, alpha=rank)

# Create equivalent standard linear layer
linear = nn.Linear(in_features, out_features)
linear.weight.data = lora_layer.weight.data.clone()
linear.bias.data = lora_layer.bias.data.clone()

x = torch.randn(batch_size, in_features)

with torch.no_grad():
    lora_out = lora_layer(x)
    linear_out = linear(x)

# Since B is initialized to 0, outputs should be identical
max_diff = (lora_out - linear_out).abs().max().item()
print(f"Max difference (should be ~0): {max_diff:.2e}")
assert max_diff < 1e-6, "Initial LoRA output should match linear!"
print("Initial equivalence: PASSED")

In [None]:
# Test merge/unmerge
print("\n=== Testing Merge/Unmerge ===")

lora_layer = LoRALayer(in_features, out_features, rank=rank, alpha=rank)

# Simulate some training by modifying LoRA weights
with torch.no_grad():
    lora_layer.lora_B.data = torch.randn_like(lora_layer.lora_B) * 0.01

x = torch.randn(batch_size, in_features)

# Get output before merge
with torch.no_grad():
    out_before = lora_layer(x)

# Merge
lora_layer.merge()
print(f"Merged: {lora_layer.merged}")

# Get output after merge
with torch.no_grad():
    out_merged = lora_layer(x)

# Outputs should be the same
max_diff = (out_before - out_merged).abs().max().item()
print(f"Max diff (before vs merged): {max_diff:.2e}")
assert max_diff < 1e-5, "Merged output should match!"

# Unmerge
lora_layer.unmerge()
print(f"Merged after unmerge: {lora_layer.merged}")

# Output should still match
with torch.no_grad():
    out_unmerged = lora_layer(x)

max_diff = (out_before - out_unmerged).abs().max().item()
print(f"Max diff (before vs unmerged): {max_diff:.2e}")
assert max_diff < 1e-5, "Unmerged output should match original!"
print("Merge/Unmerge: PASSED")

In [None]:
# Test applying LoRA to a model
print("\n=== Testing LoRA on a Model ===")

# Create a simple transformer-like block
class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model=512):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
        )
    
    def forward(self, x):
        # Simplified attention (just for testing)
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        attn = F.softmax(q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)), dim=-1)
        x = x + self.out_proj(attn @ v)
        x = x + self.mlp(x)
        return x

model = SimpleTransformerBlock(d_model=512)

# Count parameters before LoRA
trainable_before, total_before = count_parameters(model)
print(f"Before LoRA: {trainable_before:,} trainable / {total_before:,} total")

# Apply LoRA to q, k, v projections only
model = apply_lora_to_model(model, rank=8, alpha=8, target_modules=['q_proj', 'k_proj', 'v_proj'])

# Count parameters after LoRA
trainable_after, total_after = count_parameters(model)
print(f"After LoRA:  {trainable_after:,} trainable / {total_after:,} total")
print(f"Trainable parameter reduction: {(1 - trainable_after/trainable_before)*100:.1f}%")

# Verify forward pass works
x = torch.randn(2, 10, 512)  # batch=2, seq=10, dim=512
with torch.no_grad():
    y = model(x)
print(f"\nForward pass shape: {x.shape} -> {y.shape}")
assert y.shape == x.shape
print("LoRA on model: PASSED")

In [None]:
# Demonstrate training with LoRA
print("\n=== Training Demonstration ===")

# Create model with LoRA
model = SimpleTransformerBlock(d_model=256)
model = apply_lora_to_model(model, rank=4, alpha=4)

# Only optimize LoRA parameters
lora_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(lora_params, lr=1e-3)

print(f"Optimizing {len(lora_params)} LoRA parameter tensors")
print(f"Total LoRA params: {sum(p.numel() for p in lora_params):,}")

# Simple training loop
x = torch.randn(4, 8, 256)
target = torch.randn(4, 8, 256)

losses = []
for step in range(100):
    optimizer.zero_grad()
    output = model(x)
    loss = F.mse_loss(output, target)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

print(f"\nInitial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
assert losses[-1] < losses[0], "Loss should decrease!"
print("Training demonstration: PASSED")

In [None]:
# Show parameter efficiency at scale
print("\n=== Parameter Efficiency at Scale ===")

# Simulate different model sizes
model_sizes = [
    ("GPT-2 Small", 768, 12),
    ("GPT-2 Medium", 1024, 24),
    ("GPT-2 Large", 1280, 36),
    ("LLaMA-7B", 4096, 32),
    ("LLaMA-13B", 5120, 40),
]

rank = 8

print(f"{'Model':<15} {'Full Params':>15} {'LoRA Params':>15} {'Ratio':>10}")
print("-" * 60)

for name, d_model, n_layers in model_sizes:
    # 4 projection matrices per layer (q, k, v, o)
    full_params = 4 * d_model * d_model * n_layers
    lora_params = 4 * rank * (d_model + d_model) * n_layers
    ratio = lora_params / full_params * 100
    
    print(f"{name:<15} {full_params:>15,} {lora_params:>15,} {ratio:>9.2f}%")

print("\nLoRA achieves ~0.1-0.5% of full fine-tuning parameters!")

In [None]:
print("\n" + "=" * 50)
print("All LoRA tests passed!")
print("=" * 50)

## Summary

### Key Concepts

1. **Low-Rank Decomposition**: $\Delta W = BA$ where $r \ll \min(d, k)$
   - Reduces trainable parameters by orders of magnitude
   - Based on observation that updates have low intrinsic rank

2. **Initialization**:
   - $A \sim \mathcal{N}(0, \sigma^2)$
   - $B = 0$ (so initial $\Delta W = 0$)
   - This ensures model starts identical to pretrained

3. **Scaling Factor**: $\frac{\alpha}{r}$
   - Normalizes contribution regardless of rank
   - Typically $\alpha = r$ (scaling = 1)

4. **Merging for Inference**:
   - $W' = W + \frac{\alpha}{r}BA$
   - No additional latency at inference time
   - Can swap adapters by unmerging/remerging

### Common Hyperparameters

| Parameter | Typical Values | Notes |
|-----------|---------------|-------|
| rank | 4-64 | Higher = more capacity |
| alpha | = rank | Keeps scaling at 1 |
| target_modules | q, k, v, o | Attention projections |
| dropout | 0-0.1 | Regularization |

### When to Use LoRA

- Fine-tuning on limited compute
- Multiple task-specific adapters
- When full fine-tuning is too expensive
- Collaborative fine-tuning (share adapters)

## Interview Tips

1. **Why does LoRA work?** Fine-tuning updates are low-rank; we can approximate them efficiently

2. **Why initialize B=0?** So initial output equals pretrained model (no disruption)

3. **What is the rank?** Bottleneck dimension; controls capacity vs efficiency tradeoff

4. **Where to apply LoRA?** Typically attention projections (Q, K, V, O); sometimes MLP

5. **Inference overhead?** None after merging! $W' = W + BA$ is precomputed

6. **vs Full fine-tuning?** Slightly lower quality but 10-1000x fewer parameters

7. **QLoRA?** Quantized base model + LoRA adapters for even more efficiency

## References

1. [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) - Hu et al., 2021
2. [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) - Dettmers et al., 2023
3. [PEFT Library](https://github.com/huggingface/peft) - HuggingFace implementation
4. [The Practical Guides for Large Language Models](https://github.com/Mooler0410/LLMsPracticalGuide)