#### LoRA
When fine-tuning large models, we need to update weight matrices W ∈ ℝ^(d×k). Full fine-tuning updates all parameters, which is:
- Memory intensive: Need to store gradients and optimizer states for all parameters
- Storage expensive: Each task needs a full copy of the model

Instead of updating W directly: W_new = W_old + ΔW
- LoRA decomposes ΔW into two low-rank matrices: W_new = W + ΔW = W + BA
    - B ∈ ℝ^(d×r): Low-rank "down-projection"
    - A ∈ ℝ^(r×k): Low-rank "up-projection"

Why This Works:
The number of trainable parameters goes from d×k to r×(d+k):
- Example: For a 4096×4096 weight matrix:

    - Full fine-tuning: 16,777,216 parameters
    - LoRA with r=8: 65,536 parameters (~0.4% of original!)

Key Mathematical Properties
- Forward pass: h = W₀x + BAx = W₀x + s·BAx
    - s is a scaling factor (often α/r where α is a hyperparameter)
- Merging: After training, can merge BA into W for zero inference overhead
    - W_merged = W₀ + BA
- Task switching: Keep W₀ frozen, swap different (B,A) pairs for different tasks

##### Practical Implementation Details
Which Layers to Apply LoRA?
In transformer models, you have multiple weight matrices. The original LoRA paper experiments show:
Typical choices (in order of importance):

- Query & Value projections (Wq, Wv): Most common, best bang for buck
- Query, Key, Value, Output (Wq, Wk, Wv, Wo): More expressive
- All linear layers: Including FFN layers (Wup, Wdown, Wgate in LLaMA)

Why Q and V are often preferred:
- Q controls what information to attend to
- V controls what information to extract
- K is often less critical (patterns emerge during pre-training)
- Trade-off: More layers = more parameters but better adaptation

##### Initialization Strategies
Critical for training stability!
- Matrix A (up-projection):
    - Initialized randomly: Usually Gaussian N(0, σ²) or Kaiming/Xavier
    - Provides the "diversity" in the low-rank space

- Matrix B (down-projection):
    - Initialized to zero: BA = 0 at start
    - Ensures LoRA starts as identity: W₀ + BA = W₀
    - Model begins exactly as pre-trained, then gradually adapts

##### The Scaling Factor (α/r)
Forward pass: h = W₀x + (α/r)·BAx
- Parameters:
    - α (alpha): Constant, often set to r or 2r
    - r: The rank we choose

Common strategies:
- α = r: Scale by 1, straightforward
- α = 2r or α = 16: Scale up LoRA contribution
- α = constant: Keeps LoRA magnitude constant when changing r

Why scale?
- Without scaling, LoRA updates might be too small relative to W₀
- Scaling ensures LoRA has sufficient learning signal
- Allows fair comparison between different ranks

##### Key insights:
- Higher rank: More expressive, more parameters, slower
- Lower rank: Faster, fewer parameters, may underfit
- Empirical finding: r=8 often sufficient for many tasks!

#### Compute/Memory Trade-offs:
For each LoRA-adapted layer:

- Extra memory: r×(d+k) parameters + gradients + optimizer states
- Extra compute: Two extra matrix multiplies per forward pass
    - Bx: (d×r) @ (r×batch) = O(d·r·batch)
    - A(Bx): (r×k) @ (k×batch) = O(r·k·batch)
- Still way cheaper than full fine-tuning!

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [3]:
class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha

        self.b = nn.Parameter(torch.zeros(in_features, rank))
        self.a = nn.Parameter(torch.zeros(rank, out_features))
        self.scaling = alpha / rank

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.b)
        nn.init.zeros_(self.a)

    def forward(self, x):
        out = torch.matmul(x, self.b)
        out = torch.matmul(out, self.a)
        return out * self.scaling

In [None]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear: nn.Linear, rank, alpha):
        super().__init__()

        self.linear = linear
        self.rank = rank
        self.alpha = alpha

        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.weight.requires_grad = False

        self.lora = LoRA(
            linear.in_features,
            linear.out_features,
            rank,
            alpha
        )

        self.merged = False

    def forward(self, x):
        if self.merged:
            return self.linear(x)
        else:
            return self.linear(x) + self.lora(x)
        
    def merge(self, x):
        if not self.merged:
            with torch.no_grad():
                delta_w = torch.matmul(self.lora.b, self.lora.a) * self.lora.scaling
                self.linear.weight += delta_w
            self.merged = True

    def unmerge(self):
        if self.merged:
            with torch.no_grad():
                delta_w = torch.matmul(self.lora.b, self.lora.a) * self.lora.scaling
                self.linear.weight.data -= delta_w
            self.merged = False

#### QLoRA (Quantized LoRA)
QLoRA is LoRA + 4-bit quantization of the base model. It's the technique that made fine-tuning 65B models possible on a single 48GB GPU!

The Core Idea
- LoRA problem:
    - Base model weights W still need to be loaded in memory (even if frozen)
    - 7B model in FP16 = 14GB just for weights
    - 65B model in FP16 = 130GB (won't fit on consumer GPUs!)
- QLoRA solution:
    - Store W in 4-bit (16x compression from FP32!)
    - Keep LoRA adapters (BA) in higher precision (FP16/BF16)
    - Dequantize W on-the-fly during forward/backward pass

In [4]:
class QLoRALinear(nn.Module):
    def __init__(self, weights, in_features, out_features, rank, alpha, bits):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        
        self.rank = rank
        self.alpha = alpha
        self.bits = bits

        self.quantized_weights, self.scale, self.zero_point = self._quantize_weights(weights)

        self.lora_b = nn.Parameter(torch.zeros(rank, in_features))
        self.lora_a = nn.Parameter(torch.zeros(out_features, rank))
        self.lora_scale = alpha / rank

        self._init_lora()

    def _init_lora(self):
        nn.init.kaiming_uniform_(self.lora_b)
        nn.init.zeros_(self.lora_a)

    def _quantize_weights(self, weights):
        q_max = 2 ** (self.bits - 1) - 1
        q_min = - q_max

        m_max = weights.max()
        m_min = weights.min()

        scale = (m_max - m_min) / (q_max - q_min)

        zero_point = q_min - torch.round(m_min / scale)

        quantized_linear = torch.clamp(
            torch.round(weights / scale) + zero_point, 
            q_min, q_max
        ).to(dtype=torch.int8)

        return quantized_linear, scale, zero_point
    
    def _dequantize_weights(self):
        return (self.quantized_weights.float() - self.zero_point) * self.scale
    
    def forward(self, x):
        W = self._dequantize_weights()
        
        base_out = x @ W

        lora_out = x @ self.lora_b.T
        lora_out = lora_out @ self.lora_a.T
        lora_out = lora_out * self.lora_scale

        return base_out + lora_out