# Why LoRA is Better for On-Device LLMs

*A companion to the [KV Cache Walkthrough](kv_cache_always_on.ipynb) — connecting what we built in Chapters 3-4 to real-world on-device deployment.*

In the KV cache notebook, we learned how to make **inference fast** by avoiding redundant computation. But there's another efficiency challenge: how do you **customize** a model for a specific task (like your personal writing style, a medical domain, or a specific language) without retraining all 124 million parameters?

This is where **LoRA (Low-Rank Adaptation)** comes in — and it's especially important for on-device deployment (phones, laptops, edge hardware).

> **How to use this notebook:** Run the cells top-to-bottom (`Shift+Enter`). We'll build a working LoRA implementation from scratch using the same GPT model from the KV cache notebook, then compare parameter counts and see LoRA in action.

## The Problem: Full Fine-Tuning is Expensive

When you fine-tune a model the traditional way, you update **every parameter** in every weight matrix:

```
GPT-2 124M has these weight matrices (among others):
  W_query:  (768 × 768) = 589,824 parameters    ← update ALL of these
  W_key:    (768 × 768) = 589,824 parameters    ← update ALL of these
  W_value:  (768 × 768) = 589,824 parameters    ← update ALL of these
  FeedForward layer 1: (768 × 3072) = 2,359,296 ← update ALL of these
  ...
  Total: ~124 million parameters to store, update, and deploy

For a 7B model: 7 BILLION parameters × 4 bytes each = 28 GB just for the weights!
```

For on-device deployment, this is a non-starter. You can't store a separate 28 GB model for every task — your phone only has 6-8 GB of RAM total.

## LoRA's Key Insight: Weight Changes Are Low-Rank

When you fine-tune a model, the **change** to each weight matrix (ΔW) turns out to be surprisingly low-rank. This means you can approximate ΔW as the product of two much smaller matrices:

```
Full fine-tuning:
  W_new = W_original + ΔW                    ← ΔW is (768 × 768) = 589,824 values

LoRA:
  W_new = W_original + (A × B)               ← A is (768 × r), B is (r × 768)
                                                 where r = "rank" (typically 4-16)

With rank r = 8:
  A: (768 × 8)  =  6,144 parameters
  B: (8 × 768)  =  6,144 parameters
  Total:           12,288 parameters          ← that's ~2% of the original 589,824!
```

### Why is ΔW low-rank?

Think of it this way: when you fine-tune a medical chatbot from a general-purpose model, you're not changing *everything* the model knows. You're mostly adjusting a specific *direction* in the weight space — "pay more attention to medical terminology" or "format answers like a doctor would." These task-specific adjustments tend to live in a low-dimensional subspace of the full weight matrix, which is exactly what the low-rank A × B decomposition captures.

## Visual: How LoRA Modifies a Linear Layer

```
                           Standard Fine-Tuning          LoRA
                           ────────────────────          ────

Input x ──→ [ W_original + ΔW ] ──→ output    Input x ──┬──→ [ W_original ] ──→  +  ──→ output
             (768 × 768)                                  │     (768 × 768)        ↑
             589,824 new params                           │     (FROZEN!)          │
                                                          │                        │
                                                          └──→ [ A ]──→[ B ]──────┘
                                                               (768×8)  (8×768)
                                                               12,288 new params
                                                               (TRAINABLE)
```

The original weight `W_original` stays **frozen** (unchanged). Only the tiny A and B matrices are trained. This is the magic of LoRA.

During inference, you can even **merge** A × B back into the original weight: `W_merged = W_original + A × B`. This means zero extra latency at inference time — the model runs at the exact same speed as the original, with no extra computation.

## Let's Build It

Here's a working LoRA layer that wraps any `nn.Linear`. We'll then apply it to our GPT model's attention layers.

In [None]:
import torch
import torch.nn as nn
import math

In [None]:
class LoRALinear(nn.Module):
    """A Linear layer with a LoRA adapter attached.
    
    The original weight W is frozen. Two small matrices A and B are trainable.
    Output = W(x) + B(A(x)) * scaling_factor
    
    Args:
        original_linear: The nn.Linear layer to wrap (will be frozen)
        rank:            The LoRA rank (r) — smaller = fewer params, larger = more expressive
        alpha:           Scaling factor — controls how much LoRA affects the output
                         (the actual scale applied is alpha/rank)
    """
    def __init__(self, original_linear, rank=8, alpha=16):
        super().__init__()
        
        self.original = original_linear
        in_features = original_linear.in_features    # e.g., 768
        out_features = original_linear.out_features  # e.g., 768
        
        # Freeze the original weight — we don't want gradients flowing to it
        self.original.weight.requires_grad = False
        if self.original.bias is not None:
            self.original.bias.requires_grad = False
        
        # LoRA matrices:
        # A: projects DOWN from input dimension to rank  (768 → 8)
        # B: projects UP from rank to output dimension   (8 → 768)
        # Their product A × B approximates the weight change ΔW
        self.lora_A = nn.Linear(in_features, rank, bias=False)   # (768, 8)
        self.lora_B = nn.Linear(rank, out_features, bias=False)  # (8, 768)
        
        # Initialize A with small random values (Kaiming uniform, same as the LoRA paper)
        # Initialize B with zeros — so at the start, LoRA has NO effect (A×B = 0)
        # This means the model starts from its pre-trained behavior
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)
        
        # Scaling factor: alpha / rank
        # Higher alpha = LoRA has more influence on the output
        # Dividing by rank keeps the scale consistent regardless of rank choice
        self.scaling = alpha / rank
    
    def forward(self, x):
        # Original path: W(x) — frozen, no gradients
        original_output = self.original(x)
        
        # LoRA path: B(A(x)) * scaling — trainable!
        # x → A → (768→8) → B → (8→768) → scale
        lora_output = self.lora_B(self.lora_A(x)) * self.scaling
        
        # Combined: original + LoRA adjustment
        return original_output + lora_output
    
    def merge(self):
        """Merge LoRA weights into the original weight for zero-overhead inference.
        
        After merging: W_new = W_original + (B × A) * scaling
        The model runs at the same speed as the original — no extra computation.
        """
        with torch.no_grad():
            # A.weight: (rank, in_features)
            # B.weight: (out_features, rank)
            # B × A:    (out_features, in_features) — same shape as W!
            delta_w = self.lora_B.weight @ self.lora_A.weight * self.scaling
            self.original.weight.add_(delta_w)
        return self.original  # Return the merged layer (no LoRA overhead)


print("LoRALinear class defined!")
print("\nKey design choices:")
print("  - B initialized to zeros → LoRA starts with no effect")
print("  - scaling = alpha/rank → consistent impact regardless of rank")
print("  - merge() → fold LoRA into weights for zero-overhead inference")

## Parameter Count: Full Fine-Tuning vs LoRA

Let's see the numbers with a real GPT-2 sized layer.

In [None]:
# Create a standard Linear layer (like W_query in our GPT model)
d_model = 768  # GPT-2 embedding dimension
original_layer = nn.Linear(d_model, d_model, bias=False)

full_params = sum(p.numel() for p in original_layer.parameters())
print(f"Original W_query layer: {full_params:,} parameters")
print(f"  Shape: ({d_model} × {d_model})")
print()

# Now wrap it with LoRA at different ranks
for rank in [1, 4, 8, 16, 32, 64]:
    lora_layer = LoRALinear(nn.Linear(d_model, d_model, bias=False), rank=rank)
    
    # Count only the TRAINABLE parameters (A and B matrices)
    trainable = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
    total = sum(p.numel() for p in lora_layer.parameters())
    pct = trainable / total * 100
    
    print(f"  LoRA rank {rank:2d}: {trainable:>8,} trainable params "
          f"({pct:5.2f}% of total)  "
          f"[A: ({d_model}×{rank}), B: ({rank}×{d_model})]")

## Applying LoRA to a GPT Model's Attention Layers

In practice, LoRA is typically applied to the attention projection matrices (`W_query`, `W_key`, `W_value`, and sometimes `out_proj`). Let's see how many trainable parameters we'd have versus full fine-tuning.

In [None]:
def count_params(model):
    """Count total and trainable parameters in a model."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


def apply_lora_to_attention(model, rank=8, alpha=16):
    """Replace W_query and W_value in every attention layer with LoRA-wrapped versions.
    
    Following the original LoRA paper, we apply LoRA to W_query and W_value.
    (Some implementations also include W_key and out_proj for more expressiveness.)
    """
    lora_count = 0
    for name, module in model.named_modules():
        # Find MultiHeadAttention modules
        if hasattr(module, 'W_query') and hasattr(module, 'W_value'):
            # Wrap W_query with LoRA
            module.W_query = LoRALinear(module.W_query, rank=rank, alpha=alpha)
            # Wrap W_value with LoRA  
            module.W_value = LoRALinear(module.W_value, rank=rank, alpha=alpha)
            lora_count += 2
    
    # Freeze ALL parameters first
    for param in model.parameters():
        param.requires_grad = False
    
    # Then unfreeze ONLY LoRA parameters
    for name, param in model.named_parameters():
        if 'lora_' in name:
            param.requires_grad = True
    
    return lora_count


print("Helper functions defined.")
print("apply_lora_to_attention() wraps W_query and W_value in every attention layer.")

In [None]:
# Build a GPT-2 124M model (same config as the KV cache notebook)
# We'll import the model class from the book's code
import sys
sys.path.insert(0, '../../ch04/01_main-chapter-code')

# Use a simpler model definition inline for self-containment
class SimpleMultiHeadAttention(nn.Module):
    """Simplified MHA without KV cache — just for demonstrating LoRA param counts."""
    def __init__(self, d_in, d_out, num_heads, dropout=0.0, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)

class SimpleGPT(nn.Module):
    """Minimal GPT structure for demonstrating LoRA parameter savings."""
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        
        # Build transformer blocks
        self.blocks = nn.ModuleList()
        for _ in range(cfg['n_layers']):
            block = nn.Module()
            block.att = SimpleMultiHeadAttention(
                d_in=cfg['emb_dim'], d_out=cfg['emb_dim'],
                num_heads=cfg['n_heads'], qkv_bias=cfg.get('qkv_bias', False)
            )
            block.ff = nn.Sequential(
                nn.Linear(cfg['emb_dim'], 4 * cfg['emb_dim']),
                nn.GELU(),
                nn.Linear(4 * cfg['emb_dim'], cfg['emb_dim'])
            )
            block.norm1 = nn.LayerNorm(cfg['emb_dim'])
            block.norm2 = nn.LayerNorm(cfg['emb_dim'])
            self.blocks.append(block)
        
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)


# GPT-2 124M config
GPT_CONFIG = {
    'vocab_size': 50257,
    'context_length': 1024,
    'emb_dim': 768,
    'n_heads': 12,
    'n_layers': 12,
    'qkv_bias': False,
}

torch.manual_seed(42)
model = SimpleGPT(GPT_CONFIG)

total_before, trainable_before = count_params(model)
print(f"=== Before LoRA (Full Fine-Tuning) ===")
print(f"  Total parameters:     {total_before:>12,}")
print(f"  Trainable parameters: {trainable_before:>12,}  (100%)")
print()

# Apply LoRA
num_lora = apply_lora_to_attention(model, rank=8, alpha=16)

total_after, trainable_after = count_params(model)
print(f"=== After LoRA (rank=8) ===")
print(f"  Total parameters:     {total_after:>12,}")
print(f"  Trainable parameters: {trainable_after:>12,}  ({trainable_after/total_after*100:.2f}%)")
print(f"  Frozen parameters:    {total_after - trainable_after:>12,}")
print(f"  LoRA layers added:    {num_lora} (W_query + W_value × {num_lora//2} blocks)")
print()
print(f"  Reduction: {trainable_before:,} → {trainable_after:,} trainable params")
print(f"  That's {trainable_after/trainable_before*100:.2f}% of the original — "
      f"a {trainable_before/trainable_after:.0f}x reduction!")
print()

# Storage comparison
full_mb = trainable_before * 4 / (1024**2)   # 4 bytes per float32 param
lora_mb = trainable_after * 4 / (1024**2)
print(f"  Storage for fine-tuned weights:")
print(f"    Full fine-tune: {full_mb:.1f} MB")
print(f"    LoRA adapter:   {lora_mb:.2f} MB")
print(f"    Savings:        {full_mb - lora_mb:.1f} MB per task")

## LoRA in Action: Before and After Merge

Let's verify that:
1. LoRA starts with zero effect (B is initialized to zeros)
2. After simulating some training, the output changes
3. After merging, the output is identical but with no LoRA overhead

In [None]:
# Demo with a single layer
torch.manual_seed(42)

# Create an original linear layer
original = nn.Linear(768, 768, bias=False)

# Create a test input
x = torch.randn(1, 4, 768)  # batch=1, seq_len=4, dim=768

# Output BEFORE LoRA
with torch.no_grad():
    out_original = original(x)

# Wrap with LoRA
lora_layer = LoRALinear(original, rank=8, alpha=16)

# Output immediately after wrapping (B=0, so LoRA has NO effect)
with torch.no_grad():
    out_with_lora_init = lora_layer(x)

# Verify: output should be IDENTICAL (LoRA starts at zero)
diff_init = (out_original - out_with_lora_init).abs().max().item()
print(f"Difference after LoRA init (should be ~0): {diff_init:.10f}")
print(f"  → LoRA has NO effect at initialization (B=0) ✓")
print()

# Simulate training: modify LoRA weights as if we trained them
with torch.no_grad():
    lora_layer.lora_A.weight.normal_(0, 0.01)
    lora_layer.lora_B.weight.normal_(0, 0.01)

# Output after "training" — should be DIFFERENT from original
with torch.no_grad():
    out_with_lora_trained = lora_layer(x)

diff_trained = (out_original - out_with_lora_trained).abs().max().item()
print(f"Difference after LoRA training: {diff_trained:.6f}")
print(f"  → LoRA is now modifying the output ✓")
print()

# Now MERGE LoRA into the original weight
merged_layer = lora_layer.merge()

# Output from merged layer — should be IDENTICAL to LoRA output
with torch.no_grad():
    out_merged = merged_layer(x)

diff_merged = (out_with_lora_trained - out_merged).abs().max().item()
print(f"Difference after merge (should be ~0): {diff_merged:.10f}")
print(f"  → Merged output matches LoRA output exactly ✓")
print()
print("After merging:")
print(f"  - The original layer now includes the LoRA adjustment")
print(f"  - No extra A, B matrices needed at inference time")
print(f"  - Zero additional computation or memory overhead")

## Why This Matters for On-Device Deployment

On a phone or laptop, you have severe constraints: limited RAM, limited storage, limited compute, and often no GPU. LoRA solves all of these:

```
                               Full Fine-Tune     LoRA (rank 8)
                               ──────────────     ─────────────
Storage per task (7B model):   28 GB              ~18 MB           ← 1500x smaller!
RAM during inference:          28 GB              28 GB base       ← same base model
                                                   + 18 MB adapter   shared across tasks
Can swap tasks instantly?      No (reload model)  Yes (swap adapter)
Can run multiple tasks?        No (one model)     Yes (one model + many adapters)
Training data needed:          Large              Small
Training time:                 Hours/days         Minutes/hours
```

## The Multi-Task Advantage

Here's where LoRA really shines on-device. Instead of storing separate full models for each task, you store **one base model** and multiple tiny LoRA adapters:

```
Traditional approach (one model per task):
  ┌─────────────────────┐
  │ Medical Model  28 GB │
  │ Legal Model    28 GB │
  │ Coding Model   28 GB │
  │ Chat Model     28 GB │
  └─────────────────────┘
  Total: 112 GB            ← impossible on a phone!

LoRA approach (one base + adapters):
  ┌─────────────────────┐
  │ Base Model     28 GB │ ← loaded once, shared
  │ Medical LoRA   18 MB │ ← swap in/out instantly
  │ Legal LoRA     18 MB │
  │ Coding LoRA    18 MB │
  │ Chat LoRA      18 MB │
  └─────────────────────┘
  Total: ~28.07 GB         ← fits on device!
```

Swapping between tasks is nearly instant — you just load a different set of A and B matrices and add them to the frozen base weights. No reloading the entire model.

## How LoRA Connects to the KV Cache

Look at the `W_query`, `W_key`, and `W_value` matrices in the `MultiHeadAttention` class from our [KV cache notebook](kv_cache_always_on.ipynb):

```python
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # (768, 768) = 589,824 params
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)  # (768, 768) = 589,824 params
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # (768, 768) = 589,824 params
```

With LoRA applied to W_query, the forward pass changes from:
```python
queries = self.W_query(x)                                    # original
```
to:
```python
queries = self.W_query(x) + self.lora_B(self.lora_A(x))     # with LoRA
#          ↑ frozen original    ↑ trainable LoRA correction
```

The KV cache works **exactly the same** with LoRA — the queries, keys, and values that go into the cache are just slightly adjusted by the adapter. Everything else we learned (prefill, decode, causal masking, sliding window) is unchanged.

And remember: after training, you can **merge** the LoRA weights back in, so at inference time the model is identical in structure to the original — the KV cache never even knows LoRA was involved.

## The Complete On-Device Stack: LoRA + KV Cache + Quantization

In production on-device LLMs (like Apple Intelligence, Google Gemini Nano, Samsung Galaxy AI), you'll typically see all three optimizations working together:

```
┌──────────────────────────────────────────────────────────────┐
│                    On-Device LLM Stack                       │
│                                                              │
│  1. QUANTIZATION (Ch 6+)                                     │
│     Shrink model from 28 GB → 3-4 GB                         │
│     by using 4-bit integers instead of 32-bit floats         │
│     (each parameter: 4 bytes → 0.5 bytes)                    │
│                                                              │
│  2. LoRA ADAPTERS (this notebook!)                           │
│     Customize for tasks without modifying the base model     │
│     18 MB per task instead of 3-4 GB                         │
│     Swap adapters instantly for different use cases           │
│                                                              │
│  3. KV CACHE (the KV cache notebook!)                        │
│     Generate tokens fast without recomputing old K,V         │
│     Sliding window to fit in limited device RAM              │
│     (our window_size parameter controls this)                │
│                                                              │
│  Result: A 7B-parameter model running on your phone,         │
│  customizable for different tasks, generating text           │
│  at interactive speeds — all in ~4 GB of RAM                 │
└──────────────────────────────────────────────────────────────┘
```

### Key Takeaway

LoRA is better for on-device because it solves the fundamental tension between **model quality** (bigger models are better) and **device constraints** (phones have limited RAM/storage). Instead of making the model smaller (which hurts quality), LoRA lets you keep the full model and make task-specific adjustments through tiny, swappable adapters. Combined with the KV cache we implemented in the companion notebook and quantization (covered later in the book), this forms the complete stack that makes modern on-device AI possible.