# All-in-One Experiment Runner
This notebook contains all necessary functions and classes to run the LoRA experiments.
The code below is automatically loaded from the project source files.

In [None]:
import yaml
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import json
import os
from datetime import datetime
from transformers import AutoTokenizer, BertForSequenceClassification, RobertaForSequenceClassification, DistilBertForSequenceClassification
from datasets import load_dataset

In [None]:
# Configuration
config = {
    'seed': 42,
    'batch_size': 32,
    'max_epochs': 5,
    'learning_rate': 2e-4,
    'lambda_reg': 0.01,
    'scale_factor': 1.0,
    'lora_rank': 4,
    'lora_alpha': 1.0,
    'dropout': 0.1,
    'warmup_steps': 100,
    'unfreeze_layers_after': 2
}

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
os.makedirs("results", exist_ok=True)

In [None]:
# Source: lora_utils/modeling.py
import torch
import torch.nn as nn

# -------------------------------
# 1️⃣ LoRA Weighted Function
# -------------------------------
class LoRAWeightedFunction(torch.autograd.Function):
    """
    Custom forward/backward function for LoRA layer.
    
    This function implements the core logic of label-wise regularization.
    In the backward pass, it scales the gradients based on the inverse of the 
    output norm. This means samples with lower confidence (smaller output norm)
    will have their gradients scaled up, while high-confidence samples will have
    smaller gradients.
    """
    @staticmethod
    def forward(ctx, x, A, B, scale_factor=1.0):
        """
        Forward pass: computes x @ A @ B
        """
        ctx.save_for_backward(x, A, B)
        ctx.scale_factor = scale_factor
        out = x @ A @ B
        ctx.out_forward = out.detach()
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: computes gradients with sample-wise scaling.
        """
        x, A, B = ctx.saved_tensors
        out = ctx.out_forward

        # Compute per-sample output norm
        out_norm = torch.norm(out, dim=-1, keepdim=True) + 1e-6
        weight = ctx.scale_factor / out_norm
        grad = grad_output * weight

        # Sample-level gradients
        grad_A_sample = x.unsqueeze(2) @ (grad @ B.T).unsqueeze(1)  # [B, D, r]
        grad_B_sample = (x @ A).unsqueeze(2) * grad.unsqueeze(1)    # [B, r, D]

        grad_A = grad_A_sample.sum(dim=0)
        grad_B = grad_B_sample.sum(dim=0)
        grad_x = grad @ B @ A.T

        # Save sample-level gradients for regularization
        # Note: This static storage is not thread-safe or multi-model safe. 
        # For production, consider attaching to the module instance or context.
        LoRAWeightedFunction.grad_A_sample = grad_A_sample
        LoRAWeightedFunction.grad_B_sample = grad_B_sample

        return grad_x, grad_A, grad_B, None

# -------------------------------
# 2️⃣ LoRA Linear Layer
# -------------------------------
class LoRABertLinear(nn.Module):
    """
    LoRA Linear Layer that replaces a standard nn.Linear layer.
    
    It freezes the original weights and adds trainable LoRA matrices A and B.
    It uses LoRAWeightedFunction for the forward pass of the LoRA path to 
    enable the gradient scaling logic.
    """
    def __init__(self, original_linear, r=4, alpha=1.0, scale_factor=1.0, dropout=0.1):
        super().__init__()
        self.in_features = original_linear.in_features
        self.out_features = original_linear.out_features
        self.r = r
        self.alpha = alpha
        self.scale_factor = scale_factor
        self.scaling = alpha / r
        
        # Freeze original weights
        self.weight = nn.Parameter(original_linear.weight.data.clone())
        self.weight.requires_grad = False
        
        # LoRA parameters
        self.lora_A = nn.Parameter(torch.randn(self.in_features, r) * 0.01)
        self.lora_B = nn.Parameter(torch.randn(r, self.out_features) * 0.01)
        
        self.dropout = nn.Dropout(p=dropout)
        
        # Buffers for gradients
        self.grad_A_sample = None
        self.grad_B_sample = None
        
        self.lora_A.register_hook(self._save_grad_A)
        self.lora_B.register_hook(self._save_grad_B)

    def _save_grad_A(self, grad):
        self.grad_A_sample = grad

    def _save_grad_B(self, grad):
        self.grad_B_sample = grad

    def forward(self, x):
        main = x @ self.weight.T
        lora = LoRAWeightedFunction.apply(x, self.lora_A, self.lora_B, self.scale_factor)
        return main + self.scaling * self.dropout(lora)

# -------------------------------
# 3️⃣ Injection Utility
# -------------------------------
def inject_lora_bert(model, r=4, alpha=1.0, scale_factor=1.0, dropout=0.1):
    """
    Inject LoRA layers into a BERT-based model.
    
    It targets the query, key, and value projection layers in the self-attention mechanism.
    
    Args:
        model (nn.Module): The model to modify.
        r (int): LoRA rank.
        alpha (float): LoRA alpha scaling.
        scale_factor (float): Factor for the weighted gradient scaling.
        dropout (float): Dropout probability.
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and \
           ('query' in name or 'key' in name or 'value' in name or \
            'q_lin' in name or 'k_lin' in name or 'v_lin' in name):
            # Handle both BERT and RoBERTa/DistilBERT naming conventions if possible
            # But strictly speaking, we need to find the parent module.
            # This simple string split works for standard Transformers models.
            parent_name = name.rsplit('.', 1)[0]
            child_name = name.rsplit('.', 1)[1]
            
            # Retrieve parent module
            parent = model
            for part in parent_name.split('.'):
                parent = getattr(parent, part)
            
            # Replace
            setattr(parent, child_name, LoRABertLinear(module, r, alpha, scale_factor, dropout))

# -------------------------------
# 4️⃣ Regularization Loss
# -------------------------------
def grad_regularization_bert(model, logits, labels):
    """
    Compute the gradient regularization loss.
    
    This loss penalizes the magnitude of gradients for samples that are correctly classified.
    The idea is to stabilize the training by reducing updates from easy samples.
    
    Args:
        model (nn.Module): The model.
        logits (torch.Tensor): Output logits from the model [Batch, NumClasses].
        labels (torch.Tensor): Ground truth labels [Batch].
        
    Returns:
        torch.Tensor: The scalar regularization loss.
    """
    preds = logits.argmax(dim=-1)
    correct_mask = preds == labels
    reg_loss = 0.0
    count = correct_mask.sum().item()
    if count == 0:
        return torch.tensor(0., device=logits.device)
        
    for module in model.modules():
        if isinstance(module, LoRABertLinear) and module.grad_A_sample is not None:
            # We need to be careful about the batch dimension matching
            # Assuming grad_A_sample is [B, D, r]
            if module.grad_A_sample.shape[0] != correct_mask.shape[0]:
                continue # Skip if shapes don't match (e.g. last batch)
                
            mask = correct_mask.view(-1, 1, 1).expand_as(module.grad_A_sample)
            grad_A_correct = module.grad_A_sample[mask].view(-1, module.r)
            
            mask_B = correct_mask.view(-1, 1, 1).expand_as(module.grad_B_sample)
            grad_B_correct = module.grad_B_sample[mask_B].view(-1, module.lora_B.size(1))
            
            reg_loss += (grad_A_correct**2).sum() + (grad_B_correct**2).sum()
            
    return reg_loss / count

In [None]:
# Source: _datasets/mrpc.py
from datasets import load_dataset
import torch

def get_dataset_mrpc(split, tokenizer, max_length=128):
    """
    Load and preprocess the MRPC dataset.
    
    Args:
        split (str): One of 'train', 'validation', 'test'.
        tokenizer (PreTrainedTokenizer): Tokenizer to process the text.
        max_length (int): Maximum sequence length.
        
    Returns:
        Dataset: The tokenized dataset with 'input_ids', 'attention_mask', and 'label'.
    """
    dataset = load_dataset('glue', 'mrpc', split=split)
    
    def tokenize_function(examples):
        return tokenizer(examples['sentence1'], examples['sentence2'], 
                         padding='max_length', truncation=True, max_length=max_length)
    
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
    
    return tokenized_datasets


# Source: _datasets/sts_b.py
from datasets import load_dataset
import torch

def get_dataset_stsb(split, tokenizer, max_length=128):
    """
    Load and preprocess the STS-B dataset.
    
    Args:
        split (str): One of 'train', 'validation', 'test'.
        tokenizer (PreTrainedTokenizer): Tokenizer to process the text.
        max_length (int): Maximum sequence length.
        
    Returns:
        Dataset: The tokenized dataset with 'input_ids', 'attention_mask', and 'label'.
                 Labels are converted to floats for regression.
    """
    dataset = load_dataset('glue', 'stsb', split=split)
    
    def tokenize_function(examples):
        return tokenizer(examples['sentence1'], examples['sentence2'], 
                         padding='max_length', truncation=True, max_length=max_length)
    
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    
    # STS-B is a regression task, label is float
    tokenized_datasets = tokenized_datasets.map(lambda x: {'label': float(x['label'])})
    tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
    
    return tokenized_datasets

In [None]:
# Source: models/bert_lora.py
from transformers import BertForSequenceClassification
from lora_utils.modeling import inject_lora_bert

def build_model_bert(model_name="bert-base-uncased", num_labels=1, r=4, alpha=1.0, scale_factor=1.0, dropout=0.1):
    """
    Build a BERT model with LoRA layers injected.
    
    Args:
        model_name (str): Name of the pre-trained BERT model.
        num_labels (int): Number of output labels.
        r (int): LoRA rank.
        alpha (float): LoRA alpha.
        scale_factor (float): Scale factor for weighted gradients.
        dropout (float): Dropout probability.
        
    Returns:
        nn.Module: The modified BERT model with LoRA layers.
    """
    model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    inject_lora_bert(model, r=r, alpha=alpha, scale_factor=scale_factor, dropout=dropout)
    
    # Freeze base model parameters
    for param in model.bert.parameters():
        param.requires_grad = False
        
    return model


# Source: models/roberta_lora.py
from transformers import RobertaForSequenceClassification
from lora_utils.modeling import inject_lora_bert

def build_model_roberta(model_name="roberta-base", num_labels=1, r=4, alpha=1.0, scale_factor=1.0, dropout=0.1):
    """
    Build a RoBERTa model with LoRA layers injected.
    
    Args:
        model_name (str): Name of the pre-trained RoBERTa model.
        num_labels (int): Number of output labels.
        r (int): LoRA rank.
        alpha (float): LoRA alpha.
        scale_factor (float): Scale factor for weighted gradients..
        dropout (float): Dropout probability.
        
    Returns:
        nn.Module: The modified RoBERTa model with LoRA layers.
    """
    model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    inject_lora_bert(model, r=r, alpha=alpha, scale_factor=scale_factor, dropout=dropout)
    
    # Freeze base model parameters
    for param in model.roberta.parameters():
        param.requires_grad = False
        
    return model


# Source: models/distilbert_lora.py
from transformers import DistilBertForSequenceClassification
from lora_utils.modeling import inject_lora_bert

def build_model_distilbert(model_name="distilbert-base-uncased", num_labels=1, r=4, alpha=1.0, scale_factor=1.0, dropout=0.1):
    """
    Build a DistilBERT model with LoRA layers injected.
    
    Args:
        model_name (str): Name of the pre-trained DistilBERT model.
        num_labels (int): Number of output labels.
        r (int): LoRA rank.
        alpha (float): LoRA alpha.
        scale_factor (float): Scale factor for weighted gradients.
        dropout (float): Dropout probability.
        
    Returns:
        nn.Module: The modified DistilBERT model with LoRA layers.
    """
    model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    inject_lora_bert(model, r=r, alpha=alpha, scale_factor=scale_factor, dropout=dropout)
    
    # Freeze base model parameters
    for param in model.distilbert.parameters():
        param.requires_grad = False
        
    return model

In [None]:
def evaluate(model, dataloader, device, is_regression=False):
    """
    Evaluate the model on a given dataset.
    
    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): DataLoader for the evaluation dataset.
        device (str): Device to run the evaluation on ('cuda' or 'cpu').
        is_regression (bool): Whether the task is regression (True) or classification (False).
        
    Returns:
        float: The evaluation metric (MSE for regression, Accuracy for classification).
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    criterion = nn.MSELoss() if is_regression else nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            labels = batch['label'].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits.squeeze() if is_regression else outputs.logits
            
            loss = criterion(logits, labels)
            total_loss += loss.item() * len(labels)
            
            if not is_regression:
                preds = logits.argmax(dim=-1)
                correct += (preds == labels).sum().item()
            
            total += len(labels)
            
    avg_loss = total_loss / total
    metric = avg_loss if is_regression else correct / total
    return metric

def train(model, train_loader, val_loader, config, device, is_regression=False):
    """
    Train the model.
    
    Args:
        model (nn.Module): The model to train.
        train_loader (DataLoader): DataLoader for the training dataset.
        val_loader (DataLoader): DataLoader for the validation dataset.
        config (dict): Configuration dictionary.
        device (str): Device to run the training on.
        is_regression (bool): Whether the task is regression.
        
    Returns:
        float: The best metric achieved on the validation set.
    """
    model.to(device)
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=float(config['learning_rate']))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['max_epochs'])
    scaler = torch.cuda.amp.GradScaler()
    criterion = nn.MSELoss() if is_regression else nn.CrossEntropyLoss()
    
    best_metric = float('inf') if is_regression else 0.0
    
    for epoch in range(config['max_epochs']):
        model.train()
        
        # Optional: Unfreeze layers
        if epoch == config.get('unfreeze_layers_after', 999):
            print(f"Unfreezing last {config.get('unfreeze_layers_count', 2)} layers...")
            # Logic to unfreeze would go here (simplified for now)
            
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['max_epochs']}")
        for batch in pbar:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                outputs = model(**inputs)
                logits = outputs.logits.squeeze() if is_regression else outputs.logits
                
                loss_task = criterion(logits, labels)
                
                # Gradient regularization (only for classification for now)
                loss_grad = torch.tensor(0., device=device)
                if not is_regression:
                    loss_grad = grad_regularization_bert(model, outputs.logits, labels)
                
                loss = loss_task + float(config['lambda_reg']) * loss_grad
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            pbar.set_postfix({'loss': loss.item()})
        
        scheduler.step()
        
        # Validation
        metric = evaluate(model, val_loader, device, is_regression)
        print(f"Validation {'MSE' if is_regression else 'Acc'}: {metric:.4f}")
        
        # Save best model if metric improves
        if is_regression:
            if metric < best_metric:
                best_metric = metric
                torch.save(model.state_dict(), f"results/best_model_{config['model_name']}_{config['dataset_name']}.pt")
        else:
            if metric > best_metric:
                best_metric = metric
                torch.save(model.state_dict(), f"results/best_model_{config['model_name']}_{config['dataset_name']}.pt")
                
    return best_metric

In [None]:
# Define experiments mapping
experiments_map = [
    ('bert-base-uncased', 'sts_b', build_model_bert),
    ('bert-base-uncased', 'mrpc', build_model_bert),
    ('roberta-base', 'sts_b', build_model_roberta),
    ('roberta-base', 'mrpc', build_model_roberta),
    ('distilbert-base-uncased', 'sts_b', build_model_distilbert),
    ('distilbert-base-uncased', 'mrpc', build_model_distilbert),
]

results = []

print("Starting Experiments...")

for model_name, dataset_name, build_fn in experiments_map:
    print(f"\n🚀 Running {model_name} on {dataset_name}")
    
    # Update config for current run (optional, logging purposes)
    config['model_name'] = model_name
    config['dataset_name'] = dataset_name
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load Data
    if dataset_name == 'sts_b':
        is_regression = True
        num_labels = 1
        train_data = get_dataset_stsb('train', tokenizer)
        val_data = get_dataset_stsb('validation', tokenizer)
    else:
        is_regression = False
        num_labels = 2
        train_data = get_dataset_mrpc('train', tokenizer)
        val_data = get_dataset_mrpc('validation', tokenizer)
        
    train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_data, batch_size=config['batch_size'])
    
    # Build Model
    model = build_fn(
        model_name=model_name,
        num_labels=num_labels,
        r=config['lora_rank'],
        alpha=config['lora_alpha'],
        scale_factor=config['scale_factor'],
        dropout=config['dropout']
    )
    
    # Train
    metric = train(model, train_loader, val_loader, config, device, is_regression)
    
    results.append({
        'model': model_name,
        'dataset': dataset_name,
        'metric': metric,
        'type': 'MSE' if is_regression else 'Accuracy'
    })
    
    # Clean up to save memory
    del model
    torch.cuda.empty_cache()
    
# Save results
with open('results/experiment_results_notebook.json', 'w') as f:
    json.dump(results, f, indent=4)
    
print("\n✅ All experiments completed!")