# Lab 6 – Pruning an LLM with Unsloth (SST-2)

> **⚠️ IMPORTANT**: This lab requires **Google Colab with GPU enabled**
> - Go to Runtime → Change runtime type → GPU (T4 or better)
> - Unsloth requires CUDA and will not work on Mac/Windows locally
> - See `COLAB_SETUP.md` for detailed setup instructions

Pruning removes redundant neurons and weights from a neural network to reduce its size and inference time. In this lab, you'll experiment with both structured and unstructured pruning on a sentiment-classification task using the SST-2 dataset.

## Why Prune? The Trade-offs

**Benefits of Pruning:**
- 🚀 **Faster Inference**: Fewer parameters = faster computation
- 💾 **Memory Savings**: Smaller model size = less RAM/VRAM usage
- 📱 **Deployment**: Easier to deploy on edge devices
- ⚡ **Energy Efficiency**: Less computation = lower power consumption

**Trade-offs:**
- 📉 **Accuracy Loss**: Removing parameters can hurt performance
- 🔧 **Tuning Required**: Finding the right sparsity level is crucial
- ⚖️ **Balance**: More pruning = more speed, but potentially more accuracy loss

## Objectives

- Fine-tune a model for sentiment analysis on the SST-2 dataset.
- **Evaluate baseline performance** before pruning (accuracy, speed, memory)
- Apply pruning techniques to remove unnecessary parameters
- **Compare performance** after pruning (accuracy vs. speed trade-offs)
- Measure sparsity, model size reduction, and changes in inference speed and accuracy
- **Analyze the trade-offs**: How much accuracy do we lose for how much speed gain?

You can use Unsloth's API or PyTorch's pruning utilities (e.g., `torch.nn.utils.prune`) to perform pruning. Adjust hyperparameters to explore different sparsity levels.

In [None]:
# Install Unsloth using the official auto-install script
# This automatically detects your environment and installs the correct version
!wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -

# Alternative manual installation if auto-install fails:
# !pip install --upgrade pip
# !pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
# !pip install "unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git"

print("✅ Unsloth installation complete! Now restart runtime before proceeding.")
print("⚠️ IMPORTANT: Use GPU runtime, not TPU! Unsloth requires CUDA GPU.")

In [None]:
# 1️⃣ Load SST-2 dataset

from datasets import load_dataset
from transformers import AutoTokenizer

# Load subsets of the SST-2 dataset
train_data = load_dataset('glue', 'sst2', split='train[:5%]')
val_data = load_dataset('glue', 'sst2', split='validation[:5%]')

print(train_data[0])

# Initialize tokenizer
base_model_name = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

max_length = 128

def tokenize_function(examples):
    return tokenizer(examples['sentence'], padding='max_length', truncation=True, max_length=max_length)

train_dataset = train_data.map(tokenize_function, batched=True)
val_dataset = val_data.map(tokenize_function, batched=True)

print("Tokenized SST-2 dataset ready.")


In [None]:
# 2️⃣ Fine-tune a sentiment classifier on SST-2

# CRITICAL: Import unsloth FIRST to avoid weights/biases initialization errors
from unsloth import FastLanguageModel
import torch

# Load a base model for classification
model, _ = FastLanguageModel.from_pretrained(
    model_name=base_model_name,
    dtype=torch.float16,
    device_map="auto"
)

# Add a classification head
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import DataLoader
from tqdm import tqdm

# Prepare model for training
model = prepare_model_for_kbit_training(model)

# Configure LoRA for efficient fine-tuning
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Prepare training dataloader
def collate_fn(batch):
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    attention_mask = torch.tensor([item['attention_mask'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

# Simple classification training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 2

print(f"🔄 Fine-tuning for sentiment classification...")
print(f"Epochs: {num_epochs}, Batch size: 8")

# CRITICAL: Configure model for proper training (prevents EmptyLogits)
model.config.use_cache = False  # Disable cache for training
model.gradient_checkpointing_enable()  # Enable gradient checkpointing

model.train()

for epoch in range(num_epochs):
    epoch_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        labels = batch['labels'].to(model.device)
        
        # Forward pass - use the model's hidden states for classification
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        # Simple classification loss (using last hidden state)
        # This is a simplified approach - in production you'd add a proper classification head
        logits = outputs.logits[:, -1, :2]  # Get last token, first 2 dims for binary classification
        loss = torch.nn.functional.cross_entropy(logits, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Limit batches for demo
        if batch_idx >= 50:
            break
    
    avg_loss = epoch_loss / min(len(train_dataloader), 51)
    print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")

print("✓ Fine-tuning complete!")

In [None]:
# 3️⃣ Apply pruning to the fine-tuned model

import torch.nn.utils.prune as prune
import time

print("🔪 Applying pruning to the fine-tuned model...")

# Apply unstructured pruning to linear layers
pruning_amount = 0.2  # Remove 20% of weights
pruned_modules = []

for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        # Apply L1 unstructured pruning (magnitude-based)
        prune.l1_unstructured(module, name='weight', amount=pruning_amount)
        pruned_modules.append(name)
        print(f"Pruned {name}: {pruning_amount*100:.1f}% of weights removed")

print(f"✅ Pruning complete! Pruned {len(pruned_modules)} modules")

# Make pruning permanent
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and hasattr(module, 'weight_mask'):
        prune.remove(module, 'weight')

print("✓ Pruning made permanent")

In [None]:
# 4️⃣ Evaluate pruned model performance

# Compute sparsity of the pruned model
def compute_sparsity(model):
    """Calculate the percentage of zero weights in the model"""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if param.requires_grad:  # Only count trainable parameters
            total_params += param.numel()
            zero_params += (param == 0).sum().item()
    
    sparsity = (zero_params / total_params * 100) if total_params > 0 else 0
    return sparsity, total_params, zero_params

print("\n📊 Evaluating pruned model...\n")

# Measure sparsity
sparsity_pct, total_params, zero_params = compute_sparsity(model)

print(f"🔍 Sparsity Analysis:")
print(f"  - Total trainable parameters: {total_params:,}")
print(f"  - Zero parameters: {zero_params:,}")
print(f"  - Sparsity: {sparsity_pct:.2f}%")

# Evaluate on validation set
model.eval()
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

correct = 0
total = 0
inference_times = []

print(f"\n🎯 Evaluating accuracy...")

with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(val_dataloader, desc="Validating")):
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['attention_mask'].to(model.device)
        labels = batch['labels'].to(model.device)
        
        # Measure inference time
        start_time = time.time()
        
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        
        elapsed = time.time() - start_time
        inference_times.append(elapsed)
        
        # Get predictions
        logits = outputs.logits[:, -1, :2]
        predictions = torch.argmax(logits, dim=-1)
        
        correct += (predictions == labels).sum().item()
        total += labels.numel()  # Count all elements, not just batch dimension
        
        # Limit evaluation for demo
        if batch_idx >= 20:
            break

accuracy = correct / total * 100
avg_inference_time = sum(inference_times) / len(inference_times)

print(f"\n📈 Performance Metrics:")
print(f"  - Accuracy: {accuracy:.2f}% ({correct}/{total} correct)")
print(f"  - Average inference time: {avg_inference_time*1000:.2f}ms per batch")
print(f"  - Samples/second: {8/avg_inference_time:.1f}")

print("\n✓ Evaluation complete!")
print("\n💡 Tip: Increase pruning amount (e.g., 0.3, 0.5) to see greater sparsity")
print("   but watch for accuracy degradation. Experiment with different pruning methods!")


## Reflection

- What sparsity levels did you achieve with different pruning configurations (e.g., 20%, 50%)?
- How did pruning affect accuracy and inference latency? Did structured pruning behave differently from unstructured pruning?
- Discuss how pruning, combined with quantization or distillation, could make LLMs more viable for deployment on resource-constrained devices.
