In [None]:
# LoRA

## Context

[BERT](https://arxiv.org/abs/1810.04805) and [GPT](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) popularized the approach of pre-training on large amounts of internet data, then fine-tuning on smaller datasets to perform specialized tasks. This approach of first teaching a model a broad skill (e.g. language modeling), and then re-using that model to learn a more specific task (e.g. sentiment analysis) is called transfer learning.

Initially, transfer learning required creating an entirely new model for each new task during fine-tuning. This approach becomes computationally expensive and memory-intensive for larger models. 

### Enter Adapters

[Adapters](https://arxiv.org/abs/1902.00751) introduced a more efficient approach by inserting small bottleneck layers between existing transformer layers. These adapter modules add only a few trainable parameters per task while keeping the original pre-trained weights completely frozen.

Adapters work by:
- Keeping the original pre-trained weights **frozen**
- Adding small bottleneck layers between existing layers
- Only training these new adapter parameters
### Other Parameter-Efficient Approaches

Other approaches like prefix tuning emerged, which optimize trainable prompt tokens instead of model weights.

These methods were more efficient than full fine-tuning, but they had limitations:
1. **Inference Latency**: Adapters add extra layers, increasing computational overhead
2. **Worse performance**: These methods sometimes underperform full fine-tuning

Ideally, we'd like to maintain the same model architecture as the original, as to avoid the latency problem. We'd also like to avoid updating the entire model, as to avoid the high computational cost. 


In [None]:
## Enter LoRA 

**LoRA (Low-Rank Adaptation)** ([Hu et al., 2021](https://arxiv.org/abs/2106.09685)) provides a solution to this problem.

When we fine-tune a model, we freese the original model and apply weight updates $\Delta W$. It turns out that $\Delta W$ is much lower rank than the original weight matrix $W$ - meaning you can express the change without updating every single weight. This is the insight behind LoRA. 

Instead of updating the full weight matrix:
$$W_{new} = W_0 + \Delta W$$

LoRA decomposes the update into two low-rank matrices:
$$W_{new} = W_0 + \Delta W = W_0 + BA$$

Where:
- $W_0 \in \mathbb{R}^{d \times k}$ is the original frozen weight matrix
- $B \in \mathbb{R}^{d \times r}$ and $A \in \mathbb{R}^{r \times k}$ are trainable low-rank matrices. These are initialized to zero.
- $r \ll \min(d, k)$ is the rank (typically 1-64)

### Parameter Reduction

For a weight matrix of size $d \times k$:
- **Full fine-tuning**: $d \times k$ parameters
- **LoRA**: $r \times (d + k)$ parameters

**Reduction factor**: $\frac{d \times k}{r \times (d + k)}$

For large matrices, this can be **100x-1000x fewer parameters**!


In [None]:
## Build: Sentiment Analysis with LoRA

We'll implement LoRA for sentiment analysis using the IMDB movie reviews dataset. We'll adapt a pre-trained DistilBERT model (66M params). 

1. **Baseline**: Test DistilBERT's zero-shot performance on IMDB
2. **Resource Analysis**: Examine DistilBERT's architecture and parameter count
3. **Implement LoRA**: Build LoRA from scratch and see the math in action
4. **Training & Comparison**: Compare LoRA vs full fine-tuning.


In [None]:
%pip install transformers datasets torch scikit-learn matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    DistilBertTokenizer, 
    DistilBertForSequenceClassification,
    DistilBertModel,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report
import time
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')



In [None]:
# Load IMDB dataset
print("Loading IMDB dataset...")
dataset = load_dataset("imdb")

# using subset for faster experimentation (5k samples)
SUBSET_SIZE = 5000
train_dataset = dataset["train"].shuffle(seed=42).select(range(SUBSET_SIZE))
test_dataset = dataset["test"].shuffle(seed=42).select(range(1000))  # 1k for testing

print(f"loaded {len(train_dataset)} training samples and {len(test_dataset)} test samples")

# look at a few examples
print("\nSample reviews:")
for i in range(2):
    review = train_dataset[i]
    label = "Positive" if review["label"] == 1 else "Negative"
    text_preview = review["text"][:100] + "..." if len(review["text"]) > 100 else review["text"]
    print(f"\n{label}: {text_preview}")


In [None]:
# Load pre-trained DistilBERT model and tokenizer
print("Loading DistilBERT...")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')

# For zero-shot sentiment analysis, we'll use a simple approach:
# Look at the model's hidden representations and see if we can distinguish sentiment
def simple_sentiment_baseline(texts, sample_size=100):
    """
    Simple baseline: use DistilBERT's [CLS] token representation 
    to see if there's any inherent sentiment signal
    """
    model.eval()
    correct = 0
    total = 0
    
    print(f"Testing baseline on {sample_size} samples...")
    
    with torch.no_grad():
        for i in tqdm(range(min(sample_size, len(texts)))):
            text = texts[i]["text"]
            true_label = texts[i]["label"]
            
            # Tokenize
            inputs = tokenizer(text, return_tensors="pt", truncation=True, 
                             max_length=128, padding=True)
            
            # Get DistilBERT representation
            outputs = model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :]  # [CLS] token
            
            # Simple heuristic: sum of embedding values
            # (This is obviously not a good approach, but shows the baseline)
            sentiment_score = cls_embedding.sum().item()
            predicted_label = 1 if sentiment_score > 0 else 0
            
            if predicted_label == true_label:
                correct += 1
            total += 1
    
    accuracy = correct / total
    return accuracy

# Test baseline performance
baseline_acc = simple_sentiment_baseline(test_dataset, sample_size=200)
print(f"\nBaseline accuracy: {baseline_acc:.3f}")
print("(This is essentially random - DistilBERT wasn't trained for sentiment analysis!)")


In [None]:
# Analyze DistilBERT's parameters
def analyze_model_parameters(model):
    """Analyze the parameter distribution in DistilBERT"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"DistilBERT Parameter Analysis:")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: ~{total_params * 4 / 1024**2:.1f} MB (float32)")
    
    # Look at specific layer types
    layer_counts = {}
    attention_params = 0
    
    for name, param in model.named_parameters():
        layer_type = name.split('.')[0] if '.' in name else name
        
        if layer_type not in layer_counts:
            layer_counts[layer_type] = 0
        layer_counts[layer_type] += param.numel()
        
        # Count attention parameters specifically
        if 'attention' in name and any(x in name for x in ['q_lin', 'k_lin', 'v_lin', 'out_lin']):
            attention_params += param.numel()
    
    print(f"\nParameter breakdown by component:")
    for layer, count in sorted(layer_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"  {layer}: {count:,} parameters ({count/total_params*100:.1f}%)")
    
    print(f"\nAttention mechanism parameters: {attention_params:,} ({attention_params/total_params*100:.1f}%)")
    
    return total_params, attention_params

# Analyze our model
total_params, attention_params = analyze_model_parameters(model)

# Look at a specific attention layer to understand the matrices we'll modify
print(f"\nExamining a single attention layer:")
first_attn = model.transformer.layer[0].attention
print(f"Query matrix shape: {first_attn.q_lin.weight.shape}")
print(f"Key matrix shape: {first_attn.k_lin.weight.shape}")  
print(f"Value matrix shape: {first_attn.v_lin.weight.shape}")
print(f"Output matrix shape: {first_attn.out_lin.weight.shape}")

# This is where LoRA will make a big difference
single_attn_params = sum(p.numel() for p in first_attn.parameters())
print(f"Parameters in one attention layer: {single_attn_params:,}")
print(f"DistilBERT has 6 such layers = {single_attn_params * 6:,} attention parameters total")


In [None]:
# Full Fine-tuning Implementation (OPTIONAL - takes longer)
# Set this to True only if you want to run full fine-tuning
RUN_FULL_FINETUNING = False  # Change to True to run

if RUN_FULL_FINETUNING:
    print("Starting full fine-tuning...")
    
    # Create a classification model
    classification_model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased', 
        num_labels=2
    )
    
    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(examples['text'], truncation=True, padding=True, max_length=128)
    
    tokenized_train = train_dataset.map(tokenize_function, batched=True)
    tokenized_test = test_dataset.map(tokenize_function, batched=True)
    
    # Training arguments - optimized for speed
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=2,  # Reduced for time
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=50,
        evaluation_strategy="steps",
        eval_steps=200,
        save_strategy="no",  # Don't save to reduce overhead
    )
    
    # Create trainer
    trainer = Trainer(
        model=classification_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
        compute_metrics=lambda eval_pred: {
            'accuracy': accuracy_score(eval_pred.label_ids, eval_pred.predictions.argmax(-1))
        }
    )
    
    # Train
    start_time = time.time()
    trainer.train()
    full_finetuning_time = time.time() - start_time
    
    # Evaluate
    full_finetuning_results = trainer.evaluate()
    print(f"Full fine-tuning completed in {full_finetuning_time:.1f} seconds")
    print(f"Full fine-tuning accuracy: {full_finetuning_results['eval_accuracy']:.3f}")
    
else:
    # Simulated results for comparison (typical performance)
    print("Skipping full fine-tuning (RUN_FULL_FINETUNING=False)")
    print("Typical full fine-tuning results on this task:")
    print("   - Accuracy: ~0.85-0.90")
    print("   - Training time: 20-30 min (GPU) / 2-3 hours (CPU)")
    print("   - Parameters updated: 66,955,010 (100%)")
    
    # Store simulated results for later comparison
    full_finetuning_time = 1800  # 30 minutes
    full_finetuning_accuracy = 0.87


In [None]:
# Calculate parameter savings for LoRA
d, k = 768, 768  # DistilBERT attention matrix dimensions
r = 16  # LoRA rank

# Full matrix update
full_params = d * k
print(f"Parameter Comparison:")
print(f"Full fine-tuning: {full_params:,} parameters per matrix")

# LoRA update  
lora_params = r * (d + k)
print(f"LoRA (rank {r}): {lora_params:,} parameters per matrix")

# Reduction factor
reduction = full_params / lora_params
print(f"Reduction factor: {reduction:.1f}x fewer parameters")

# For all attention matrices in DistilBERT
num_attention_matrices = 4 * 6  # 4 matrices (Q,K,V,O) × 6 layers
total_full = full_params * num_attention_matrices
total_lora = lora_params * num_attention_matrices

print(f"\nFor all DistilBERT attention layers:")
print(f"Full fine-tuning: {total_full:,} parameters")
print(f"LoRA: {total_lora:,} parameters")
print(f"Total reduction: {total_full / total_lora:.1f}x")

# Memory calculation
print(f"\nMemory impact:")
print(f"LoRA adds only: {total_lora * 4 / 1024**2:.2f} MB")


In [None]:
class LoRALayer(nn.Module):
    """
    LoRA (Low-Rank Adaptation) layer
    
    Implements: output = input @ (W_0 + B @ A)
    Where W_0 is frozen, and B, A are trainable low-rank matrices
    """
    def __init__(self, original_layer, rank=16, alpha=16):
        super().__init__()
        
        # Store the original layer (frozen)
        self.original_layer = original_layer
        self.original_layer.requires_grad_(False)  # Freeze original weights
        
        # Get dimensions
        if hasattr(original_layer, 'in_features'):  # Linear layer
            in_features = original_layer.in_features
            out_features = original_layer.out_features
        else:
            raise ValueError("LoRA currently supports nn.Linear layers")
        
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank  # Scaling factor
        
        # Initialize LoRA matrices
        # A: Normal initialization (like original paper)
        # B: Zero initialization (so ΔW = B@A starts at zero)
        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.1)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
        
        print(f"Created LoRA layer: {out_features}×{in_features} → rank {rank}")
        print(f"   Original params: {in_features * out_features:,}")
        print(f"   LoRA params: {rank * (in_features + out_features):,}")
        print(f"   Reduction: {(in_features * out_features) / (rank * (in_features + out_features)):.1f}x")
    
    def forward(self, x):
        # Original computation: x @ W_0^T + b_0
        original_output = self.original_layer(x)
        
        # LoRA computation: x @ (B @ A)^T = x @ A^T @ B^T
        # Note: we need to transpose because PyTorch Linear does x @ W^T
        lora_output = x @ self.lora_A.T @ self.lora_B.T * self.scaling
        
        # Combine: W_new = W_0 + B @ A
        return original_output + lora_output
    
    def get_lora_parameters(self):
        """Return only the LoRA parameters for optimization"""
        return [self.lora_A, self.lora_B]

# Test our LoRA layer
print("Testing LoRA layer...")
test_linear = nn.Linear(768, 768)
test_input = torch.randn(1, 10, 768)  # Batch size 1, sequence length 10

# Create LoRA version
lora_layer = LoRALayer(test_linear, rank=16)

# Test forward pass
with torch.no_grad():
    original_output = test_linear(test_input)
    lora_output = lora_layer(test_input)
    
print(f"Forward pass successful!")
print(f"Output shapes match: {original_output.shape == lora_output.shape}")
print(f"Difference (should be ~0 initially): {(original_output - lora_output).abs().max().item():.6f}")


In [None]:
def apply_lora_to_distilbert(model, rank=16):
    """Apply LoRA to all attention layers in DistilBERT"""
    
    print(f"Applying LoRA (rank={rank}) to DistilBERT attention layers...")
    
    lora_params = []
    original_params = 0
    lora_param_count = 0
    
    # Apply LoRA to each transformer layer's attention
    for layer_idx, transformer_layer in enumerate(model.transformer.layer):
        attention = transformer_layer.attention
        
        print(f"\nLayer {layer_idx}:")
        
        # Apply to Query, Key, Value, and Output projections
        for name, linear_layer in [
            ('q_lin', attention.q_lin),
            ('k_lin', attention.k_lin), 
            ('v_lin', attention.v_lin),
            ('out_lin', attention.out_lin)
        ]:
            # Count original parameters
            orig_params = sum(p.numel() for p in linear_layer.parameters())
            original_params += orig_params
            
            # Replace with LoRA layer
            lora_layer = LoRALayer(linear_layer, rank=rank)
            setattr(attention, name, lora_layer)
            
            # Collect LoRA parameters
            lora_params.extend(lora_layer.get_lora_parameters())
            lora_param_count += len(lora_layer.get_lora_parameters()[0]) + len(lora_layer.get_lora_parameters()[1])
            
            print(f"  {name}: LoRA applied")
    
    print(f"\nLoRA Application Summary:")
    print(f"Original attention parameters: {original_params:,}")
    print(f"LoRA parameters added: {lora_param_count:,}")
    print(f"Reduction factor: {original_params / lora_param_count:.1f}x")
    
    return lora_params

# Create a fresh DistilBERT model for LoRA
print("Loading fresh DistilBERT for LoRA...")
lora_model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased', 
    num_labels=2
)

# Apply LoRA
lora_parameters = apply_lora_to_distilbert(lora_model, rank=16)

print(f"\nReady for LoRA training with {len(lora_parameters):,} trainable parameters!")


In [None]:
# Prepare data for LoRA training
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True, max_length=128)

# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)

# LoRA Training setup
print("Setting up LoRA training...")

# Only optimize LoRA parameters + classifier head
optimizer = torch.optim.AdamW([
    *lora_parameters,  # LoRA parameters
    *lora_model.classifier.parameters()  # Classification head
], lr=5e-4, weight_decay=0.01)

# Simple training loop (faster than Trainer for this demo)
def train_lora_model(model, train_data, test_data, epochs=2):
    model.train()
    
    train_losses = []
    test_accuracies = []
    
    print(f"Training LoRA for {epochs} epochs...")
    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_loss = 0
        num_batches = 0
        
        # Training
        for i in tqdm(range(0, len(train_data), 16), desc=f"Epoch {epoch+1}"):
            batch = train_data[i:i+16]
            
            # Prepare batch
            texts = [item['text'] for item in batch]
            labels = torch.tensor([item['label'] for item in batch])
            
            inputs = tokenizer(texts, return_tensors="pt", truncation=True, 
                             max_length=128, padding=True)
            
            # Forward pass
            outputs = model(**inputs, labels=labels)
            loss = outputs.loss
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)
        
        # Evaluation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for i in range(0, min(500, len(test_data)), 16):  # Quick eval
                batch = test_data[i:i+16]
                texts = [item['text'] for item in batch]
                labels = [item['label'] for item in batch]
                
                inputs = tokenizer(texts, return_tensors="pt", truncation=True,
                                 max_length=128, padding=True)
                
                outputs = model(**inputs)
                predictions = outputs.logits.argmax(-1)
                
                correct += (predictions == torch.tensor(labels)).sum().item()
                total += len(labels)
        
        accuracy = correct / total
        test_accuracies.append(accuracy)
        
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.3f}")
        model.train()
    
    training_time = time.time() - start_time
    return train_losses, test_accuracies, training_time

# Train LoRA model
losses, accuracies, lora_time = train_lora_model(
    lora_model, tokenized_train, tokenized_test, epochs=2
)

# Final evaluation
lora_model.eval()
final_accuracy = accuracies[-1] if accuracies else 0

print(f"\nLoRA Training Complete!")
print(f"Training time: {lora_time:.1f} seconds")
print(f"Final accuracy: {final_accuracy:.3f}")

# Compare with full fine-tuning
print(f"\nCOMPARISON SUMMARY:")
print(f"{'Method':<20} {'Accuracy':<10} {'Time (s)':<10} {'Parameters':<15}")
print(f"{'-'*55}")
print(f"{'Baseline':<20} {baseline_acc:<10.3f} {'N/A':<10} {'0':<15}")
print(f"{'LoRA':<20} {final_accuracy:<10.3f} {lora_time:<10.1f} {len(lora_parameters)*2:,}{'':>5}")
print(f"{'Full Fine-tuning':<20} {full_finetuning_accuracy:<10.3f} {full_finetuning_time:<10.1f} {'66,955,010':<15}")

print(f"\nKey Takeaways:")
print(f"• LoRA achieved {final_accuracy:.1%} accuracy with {(len(lora_parameters)*2/66955010)*100:.2f}% of parameters")
print(f"• LoRA was {full_finetuning_time/lora_time:.1f}x faster to train")
print(f"• LoRA performance: {final_accuracy/full_finetuning_accuracy*100:.1f}% of full fine-tuning accuracy")

# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(losses, 'b-', label='LoRA Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(accuracies, 'g-', label='LoRA Accuracy')
plt.axhline(y=full_finetuning_accuracy, color='r', linestyle='--', label='Full Fine-tuning')
plt.axhline(y=baseline_acc, color='gray', linestyle=':', label='Baseline')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

print("\nCongratulations! You've successfully implemented and trained LoRA from scratch!")
