# Pruning LLMs: Structured MLP Pruning and Evaluation

Welcome to the pruning basics notebook. This notebook walks you through a practical, end‑to‑end workflow for pruning a small LLM (TinyLlama) and measuring the impact on speed, memory, and quality.

## What You’ll Learn
- Pruning concepts: how removing neurons/weights can shrink and speed up models.
- Structured MLP pruning: zeroing entire rows/columns across `gate_proj`, `up_proj`, and `down_proj`.
- Rebuilding layers: replacing pruned modules with smaller `nn.Linear` layers to make sparsity dense and actually cheaper.
- Evaluation metrics:
  - Latency (s) and Throughput (tokens/s) for generation.
  - Peak GPU memory (MiB) during a forward pass.
  - Perplexity (lower is better) as a proxy for model quality.
- Practical trade‑offs: how pruning fraction affects performance vs. quality.

## Notebook Flow
1. Baseline: load model/tokenizer and measure latency, throughput, peak memory, and perplexity.
2. Prune: apply structured pruning to MLP layers on CPU to avoid OOM.
3. Rebuild: reconstruct smaller MLP layers to realize speed/memory benefits.
4. Re‑measure: run the same metrics to compare against baseline.
5. Save: write the rebuilt model to disk and report size.

## Key Functions in This Notebook
- `measure_latency_and_throughput(...)`: generation benchmarking.
- `measure_peak_mem_and_perplexity(...)`: forward pass memory + loss→perplexity.
- `prune_mlp_rows_and_cols(...)`: apply structured pruning masks on MLP.
- `rebuild_mlp_blocks(...)`: build smaller dense layers using kept indices.
- `start()`: ties everything together and prints results.

## Metrics at a Glance
- Latency: total time to produce `MAX_NEW_TOKENS` tokens.
- Throughput: tokens per second = generated_tokens / latency.
- Peak GPU Memory: maximum allocated memory during forward pass (MiB).
- Perplexity: `exp(loss)`; lower suggests better next‑token prediction.

## Configuration & Reproducibility
- Model: `MODEL_NAME` (TinyLlama by default for quick runs).
- Pruning strength: `MLP_PRUNE_FRAC` (e.g., 0.5 → keep ~50% of MLP neurons).
- Generation length: `MAX_NEW_TOKENS`.
- Device: auto‑selects CUDA if available, else CPU (FP16 requests on CPU are coerced to FP32 by libraries).

## Requirements & Notes
- Libraries: `torch`, `transformers`. A CUDA GPU is recommended but not required.
- Internet: Hugging Face Hub downloads the model on first use.
- Determinism: results can vary slightly across runs/hardware; consider seeding for tighter reproducibility.
- Trade‑off: higher pruning fractions tend to boost speed/memory savings but may degrade perplexity/quality.

## How to Use
- Run the notebook top‑to‑bottom, or execute `start()` to perform the full pipeline.
- Tweak `MLP_PRUNE_FRAC`, `MAX_NEW_TOKENS`, or `PROMPT`/`PERP_TEXT` to explore trade‑offs.

Let’s get started.

# Import Required Libraries
This cell:
- Imports standard libraries for file handling, timing, and mathematical operations.
- Imports PyTorch for deep learning operations and pruning utilities.
- Imports Hugging Face Transformers for model and tokenizer handling.

In [2]:
import os
import time
import math
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import AutoModelForCausalLM, AutoTokenizer

# Define Model and Evaluation Settings
This cell:
- Specifies the model name to be used for pruning and evaluation.
- Defines the fraction of neurons to prune in the MLP layers (`MLP_PRUNE_FRAC`).
- Sets the maximum number of tokens to generate during inference.
- Provides sample texts for benchmarking latency, throughput, and perplexity.

In [1]:
MODEL_NAME     = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MLP_PRUNE_FRAC = 0.2       # fraction of inner neurons to prune
MAX_NEW_TOKENS = 50
PROMPT = (
    "Over the next decade, sustainable energy solutions will revolutionize "
    "global power grids, reducing carbon footprints and fostering resilient "
    "communities through innovative storage and distribution technologies."
)
PERP_TEXT = (
    "Artificial intelligence (AI) is intelligence demonstrated by machines, in contrast "
    "to the natural intelligence displayed by humans and animals. Leading AI textbooks "
    "define the field as the study of intelligent agents: any system that perceives "
    "its environment and takes actions that maximize its chance of achieving its goals."
)


# Load Model and Tokenizer
This function:
- Loads the tokenizer and model using Hugging Face Transformers.
- Configures the model to use FP16 precision for faster inference.
- Moves the model to the specified device (CPU or GPU).
- Sets the model to evaluation mode to disable gradient computations.

In [3]:
def load_model_and_tokenizer(model_name: str, device: torch.device):
    """
      - Load AutoTokenizer.from_pretrained(model_name, use_fast=True)
      - Load AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
      - Move model to `device` and set to .eval()
      - Return tokenizer, model
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    # load model in FP16 for faster inference
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16
    )
    # move to device and set to eval
    model = model.to(device)
    model.eval()
    return tokenizer, model

# Measure Baseline Performance
This function:
- Measures the baseline performance of the model before pruning.
- Evaluates:
  - **Latency**: Time taken to generate tokens for a given prompt.
  - **Throughput**: Tokens generated per second.
  - **Peak GPU Memory Usage**: Maximum memory used during inference.
  - **Perplexity**: A measure of how well the model predicts the given text.
- Prints the baseline metrics for comparison with the pruned model.

In [4]:
def measure_latency_and_throughput(model, tokenizer, prompt: str, device: torch.device, max_new_tokens=50, runs=3):
    """
    Measure latency and throughput for text generation.

    Args:
        model: The language model to evaluate.
        tokenizer: The tokenizer associated with the model.
        prompt (str): Input prompt for text generation.
        device (torch.device): Device to run the model on.
        max_new_tokens (int): Maximum number of new tokens to generate.
        runs (int): Number of runs for averaging metrics.

    Returns:
        tuple: Average latency (seconds) and throughput (tokens per second).
    """
    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(device)

    # Warmup
    with torch.inference_mode():
        _ = model.generate(**inputs, max_new_tokens=8)

    latencies = []
    throughputs = []
    for _ in range(runs):
        if device.type == "cuda":
            torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.inference_mode():
            outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
        if device.type == "cuda":
            torch.cuda.synchronize()
        t1 = time.perf_counter()

        gen_len = outputs.shape[1] - inputs["input_ids"].shape[1]
        latency = t1 - t0
        throughput = gen_len / latency if latency > 0 else float("nan")

        latencies.append(latency)
        throughputs.append(throughput)

    avg_latency = sum(latencies) / len(latencies)
    avg_throughput = sum(throughputs) / len(throughputs)
    return avg_latency, avg_throughput

In [5]:
def measure_peak_mem_and_perplexity(model, tokenizer, perp_text, device):
    """
    Measure peak GPU memory (MiB) during a forward pass and compute perplexity
    on the provided text. Returns (peak_mem_mib, perplexity).
    
    - Tokenizes `perp_text` and sets labels to input_ids to obtain loss.
    - Resets and reads CUDA peak memory stats if on GPU; returns NaN on CPU.
    """
    model.eval()
    
    # Tokenize and set labels for LM loss
    enc = tokenizer(perp_text, return_tensors="pt")
    enc["labels"] = enc["input_ids"].clone()
    enc = {k: v.to(device) for k, v in enc.items()}
    
    # Prepare CUDA memory measurement if available
    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats(device)
        torch.cuda.synchronize()
    
    # Forward pass to get loss (no gradients)
    with torch.inference_mode():
        outputs = model(**enc)
    
    # Read peak memory
    if device.type == "cuda":
        torch.cuda.synchronize()
        peak_bytes = torch.cuda.max_memory_allocated(device)
        peak_mib = peak_bytes / 1024**2
    else:
        peak_mib = float("nan")
    
    # Convert loss to perplexity
    loss = outputs.loss.item()
    perplexity = math.exp(loss)
    return peak_mib, perplexity

def measure_baseline(model: nn.Module, tokenizer, prompt: str, perp_text: str, device: torch.device):
    """
      - Warm up & measure generation latency & throughput on `prompt`
      - Measure peak GPU memory & perplexity on `perp_text`
      - Print or return these baseline metrics
    """
    # 1) Measure latency & throughput
    latency, throughput = measure_latency_and_throughput(model, tokenizer, prompt, device)
    # 2) Measure peak GPU memory & perplexity
    peak_mem, perplexity = measure_peak_mem_and_perplexity(model, tokenizer, perp_text, device)

    # 3) Print baseline metrics
    print(f"[Baseline] latency   = {latency:.3f}s")
    print(f"[Baseline] throughput= {throughput:.1f} tok/s")
    print(f"[Baseline] peak GPU  = {peak_mem:.1f} MiB")
    print(f"[Baseline] perplexity= {perplexity:.3f}")

    # Return in case caller wants to use them programmatically
    return {
        "latency": latency,
        "throughput": throughput,
        "peak_gpu_mem": peak_mem,
        "perplexity": perplexity
    }

# Prune MLP Rows and Columns
This function:
- Prunes the MLP layers in the model by:
  - Zeroing out a fraction of rows in the `gate_proj` and `up_proj` layers.
  - Zeroing out the corresponding columns in the `down_proj` layer.
- Uses structured pruning to remove entire rows or columns.
- Ensures pruning is performed on the CPU to avoid GPU memory issues.
- Removes the pruning reparameterizations after applying the masks.

In [6]:
def prune_mlp_rows_and_cols(model: nn.Module, prune_frac: float):
    """
      - Move model to CPU
      - For each layer in model.model.layers:
          • Zero out `prune_frac` of rows in gate_proj and up_proj
          • Zero out corresponding `prune_frac` of columns in down_proj
      - Remove pruning reparameterizations
    """
    # 1) Ensure we prune on CPU to avoid GPU OOM
    model.cpu()
    torch.cuda.empty_cache()

    # 2) Iterate through each decoder layer’s MLP
    for layer in model.model.layers:
        gate = layer.mlp.gate_proj   # [inner, hidden]
        up   = layer.mlp.up_proj     # [inner, hidden]
        down = layer.mlp.down_proj   # [hidden, inner]

        # 2a) Zero out rows in gate_proj and up_proj
        for proj in (gate, up):
            prune.ln_structured(
                proj,
                name="weight",
                amount=prune_frac,
                n=1,
                dim=0,           # prune entire rows
            )
            prune.remove(proj, "weight")

        # 2b) Zero out corresponding columns in down_proj
        prune.ln_structured(
            down,
            name="weight",
            amount=prune_frac,
            n=1,
            dim=1,               # prune entire columns
        )
        prune.remove(down, "weight")

    # 3) Return the model (now with zeros in place)
    return model

# Rebuild MLP Blocks
This function:
- Reconstructs the pruned MLP layers with reduced dimensions.
- Identifies the neurons that were not pruned in the `gate_proj` layer.
- Creates new `nn.Linear` modules for `gate_proj`, `up_proj`, and `down_proj` with updated dimensions.
- Copies the weights and biases from the original layers to the new layers.
- Replaces the old modules with the new ones in the model.

In [7]:
def rebuild_mlp_blocks(model: nn.Module):
    """
      - For each layer in model.model.layers:
          1) Identify kept neuron indices in gate_proj
          2) Construct new nn.Linear modules for gate_proj, up_proj, down_proj
             with reduced dimensions
          3) Copy over weights and biases
          4) Replace the old modules on the model
    """
    for layer in model.model.layers:
        # original modules (still on CPU, dtype=original)
        old_gate = layer.mlp.gate_proj
        old_up   = layer.mlp.up_proj
        old_down = layer.mlp.down_proj

        # discover surviving rows in gate_proj
        Wg = old_gate.weight.data     # [inner_orig, hidden], dtype say torch.half
        keep_idx = (Wg.abs().sum(dim=1) != 0).nonzero(as_tuple=False).view(-1)
        inner_new = keep_idx.numel()
        hidden    = Wg.size(1)
        dtype     = Wg.dtype
        device    = Wg.device

        # helper to build a new Linear with the same dtype/device
        def make_linear(in_f, out_f, bias, old_weight, old_bias=None):
            nl = nn.Linear(in_f, out_f, bias=bias)
            # init in correct dtype & device
            nl.weight.data = old_weight.clone().to(device=device, dtype=dtype)
            if bias and old_bias is not None:
                nl.bias.data = old_bias.clone().to(device=device, dtype=dtype)
            return nl

        # rebuild gate_proj: hidden -> inner_new
        new_gate = make_linear(
            hidden, inner_new, 
            bias=(old_gate.bias is not None),
            old_weight=old_gate.weight.data[keep_idx],
            old_bias=old_gate.bias.data[keep_idx] if old_gate.bias is not None else None
        )

        # rebuild up_proj: hidden -> inner_new
        new_up = make_linear(
            hidden, inner_new,
            bias=(old_up.bias is not None),
            old_weight=old_up.weight.data[keep_idx],
            old_bias=old_up.bias.data[keep_idx] if old_up.bias is not None else None
        )

        # rebuild down_proj: inner_new -> hidden
        new_down = make_linear(
            inner_new, hidden,
            bias=(old_down.bias is not None),
            old_weight=old_down.weight.data[:, keep_idx],
            old_bias=old_down.bias.data if old_down.bias is not None else None
        )

        # swap in-place
        layer.mlp.gate_proj = new_gate
        layer.mlp.up_proj   = new_up
        layer.mlp.down_proj = new_down

    return model

# Measure Performance After Rebuilding
This function:
- Evaluates the performance of the rebuilt model after pruning and reconstruction.
- Measures:
  - **Latency**: Time taken to generate tokens for a given prompt.
  - **Throughput**: Tokens generated per second.
  - **Peak GPU Memory Usage**: Maximum memory used during inference.
  - **Perplexity**: A measure of how well the model predicts the given text.
- Prints the metrics for comparison with the baseline model.

In [8]:
def measure_rebuilt(model: nn.Module, tokenizer, prompt: str, perp_text: str, device: torch.device):
    """
      - Move rebuilt model to `device` & .eval()
      - Re-measure latency, throughput, peak memory, perplexity
      - Print or return these metrics
    """
    # 1) Move to device and set to eval
    model.to(device)
    model.eval()

    # 2) Measure latency & throughput
    latency, throughput = measure_latency_and_throughput(
        model, tokenizer, prompt, device
    )

    # 3) Measure peak GPU memory & perplexity
    peak_mem, perplexity = measure_peak_mem_and_perplexity(
        model, tokenizer, perp_text, device
    )

    # 4) Print results
    print(f"[Rebuilt] latency   = {latency:.3f}s")
    print(f"[Rebuilt] throughput= {throughput:.1f} tok/s")
    print(f"[Rebuilt] peak GPU  = {peak_mem:.1f} MiB")
    print(f"[Rebuilt] perplexity= {perplexity:.3f}")

    # 5) Return for further use if needed
    return {
        "latency": latency,
        "throughput": throughput,
        "peak_gpu_mem": peak_mem,
        "perplexity": perplexity
    }

# Save Model and Report Size
This function:
- Saves the pruned and rebuilt model to the specified output directory.
- Calculates the total size of the saved model files on disk.
- Prints the on-disk size of the model for comparison with the original model.

In [9]:
def save_and_report_size(model: nn.Module, output_dir: str):
    """
      - model.save_pretrained(output_dir)
      - Walk `output_dir` to sum file sizes (in MiB)
      - Print the on-disk size
    """
    # 1) Save
    model.save_pretrained(output_dir)

    # 2) Sum file sizes
    total_bytes = 0
    for root, _, files in os.walk(output_dir):
        for fname in files:
            total_bytes += os.path.getsize(os.path.join(root, fname))

    # 3) Convert to MiB and print
    size_mb = total_bytes / 1024**2
    print(f"[Rebuilt] on-disk size = {size_mb:.1f} MiB")

    return size_mb

# Main Execution Flow
This function:
- Loads the model and tokenizer.
- Measures the baseline performance of the model.
- Applies structured pruning to the MLP layers.
- Rebuilds the pruned MLP layers with reduced dimensions.
- Measures the performance of the rebuilt model.
- Saves the pruned and rebuilt model to disk and reports its size.

In [10]:
def start():
    device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer, model = load_model_and_tokenizer(MODEL_NAME, device)

    # Baseline
    print("Measuring baseline performance...")
    measure_baseline(model, tokenizer, PROMPT, PERP_TEXT, device)

    # Prune on CPU
    print(f"Pruning {MLP_PRUNE_FRAC*100:.1f}% of MLP neurons...")
    prune_mlp_rows_and_cols(model, MLP_PRUNE_FRAC)

    # Measure after pruning (before rebuild)
    print("Measuring pruned (unrebuilt) model performance...")
    measure_rebuilt(model, tokenizer, PROMPT, PERP_TEXT, device)
    
    # Rebuild smaller MLPs
    print("Rebuilding smaller MLP blocks...")
    rebuild_mlp_blocks(model)

    # Re-benchmark rebuilt model
    print("Measuring rebuilt model performance...")
    measure_rebuilt(model, tokenizer, PROMPT, PERP_TEXT, device)

    # Save & report on-disk size
    print("Saving rebuilt model and reporting on-disk size...")
    save_and_report_size(model, "llama_pruned_rebuilt")

# Start the Pruning and Evaluation Process
This cell:
- Calls the `start` function to execute the entire pruning and evaluation pipeline.
- Outputs the baseline and post-pruning metrics, as well as the on-disk size of the pruned model.

In [11]:
start()

Measuring baseline performance...
[Baseline] latency   = 1.186s
[Baseline] throughput= 42.2 tok/s
[Baseline] peak GPU  = 2129.4 MiB
[Baseline] perplexity= 4.555
Pruning 20.0% of MLP neurons...
Measuring pruned (unrebuilt) model performance...
[Rebuilt] latency   = 1.453s
[Rebuilt] throughput= 34.6 tok/s
[Rebuilt] peak GPU  = 2129.4 MiB
[Rebuilt] perplexity= 66230.342
Rebuilding smaller MLP blocks...
Measuring rebuilt model performance...
[Rebuilt] latency   = 1.289s
[Rebuilt] throughput= 38.9 tok/s
[Rebuilt] peak GPU  = 1839.9 MiB
[Rebuilt] perplexity= 66363.558
Saving rebuilt model and reporting on-disk size...
[Rebuilt] on-disk size = 1807.9 MiB


A huge jump in perplexity after pruning **usually indicates damage**, but it does not automatically mean the model is “broken beyond usefulness”. It means the model’s probability distribution became extremely peaked in the wrong places, so the log likelihood collapses. Let me walk through the possibilities so you can pin down the cause without jumping to conclusions.

I’ll keep it grounded and avoid oversimplified answers.

---

# 1. What a jump from 4 point something to hundreds of thousands actually means

Perplexity is:

$$
\exp\left(\frac{1}{N}\sum \log p_\theta(\text{token}_i)\right)
$$

If perplexity becomes enormous, this means that **one or more token probabilities collapsed toward zero** during evaluation.

Not literally zero, but something so tiny that the log becomes a huge negative number.

A single very improbable prediction can explode the average.

So this does not imply the entire distribution is ruined, only that some part of it output a catastrophic likelihood for the evaluation sequence.

---

# 2. Usual reasons why pruning causes catastrophic perplexity

There are three main failure modes in manual pruning and rebuilding.

## 2.1 You pruned too aggressively

If you drop half or more of the inner MLP neurons in a compact model, the remaining layers cannot compensate.
LLaMA style models are sensitive to the width of the MLP.
Collapse of some activations then causes near zero logits for some tokens.

## 2.2 You rebuilt the layers incorrectly

This is very common in manual rebuild code.

Typical issues:

* misaligned indices in the layer slicing
* incorrect mapping between gate, up, and down projections
* mismatch in bias slicing
* wrong use of keep indices
* accidentally mixing rows and columns
* off by one errors in concatenated MLP structures
* using the wrong shape when creating the new Linear layer

Any structural mismatch will create a model that runs but no longer represents the original function.

The forward pass produces outputs, but the mapping is basically scrambled.

## 2.3 Pruned neurons that were not truly dead

Magnitude pruning is crude.
A small weight is not always an unimportant weight.

If your prune fraction is high, you cut out functional pathways.
Then the MLP no longer constructs the correct intermediate representation.
Perplexity explodes.

Magnitude pruning cannot distinguish useful low magnitude weights from noise.

This is why sparsity aware training or iterative gradual pruning exists.

---

# 3. Does this mean the model is broken?

A perplexity jump from 4 point something to over four hundred thousand is not mild damage.
It indicates either:

* a structural error in the rebuild
  or
* an extreme pruning fraction that destroyed semantic structure.

A healthy pruned transformer rarely jumps beyond a factor of two or three in perplexity after moderate pruning. A jump by five orders of magnitude is not normal at all.

So for practical purposes: **yes, the model in its current form is broken**.

But this does not mean the idea of pruning itself is broken.

It only means this particular pruning and rebuilding pass did not preserve the structure.

---

# 4. How to diagnose which of the three failure modes happened

Here is the easiest diagnostic approach.

### Step 1: measure perplexity right after zeroing weights

Before rebuilding, run perplexity evaluation.

If perplexity is still sane (for example 4.5 goes to 6 or 7), then the pruning thresholds were fine.

If perplexity explodes already here, pruning was too aggressive.

### Step 2: measure perplexity right after rebuilding the first layer

Rebuild only MLP block 0 and measure again.
Then rebuild 0 and 1, measure again.
Continue.

If the explosion happens only after a particular layer rebuild, you found a structural error in that layer’s MLP reconstruction.

This is the cleanest way to locate misalignment.

### Step 3: instrument shapes and index mappings

Print:

* keep_idx
* shape of old layers
* shape of rebuilt layers
* norms of pruned rows
* a few random input–output comparisons

Look for rows that were supposed to be all zero but are not.
Look for columns that do not match.
Look for mismatched shapes between gate, up and down.

---


# Outro: What We Achieved and What Actually Happened

In this notebook we walked through magnitude based pruning at a low level, without relying on PyTorch’s pruning framework or Hugging Face Optimum. Every step was done manually so that the computational mechanics become transparent rather than hidden behind abstractions.

The pruning process consisted of three essential parts:

**First**, we identified low importance neurons inside the MLP blocks by measuring weight magnitudes. These neurons were not “deleted” at first. Instead we explicitly zeroed their incoming rows and outgoing columns. This made their activations collapse to zero for every input, which means they no longer contributed anything to the forward pass.

**Second**, we extracted the structure that remained after pruning. Zero rows in the expansion layers and zero columns in the contraction layer mark dead hidden units. Since these units contribute nothing, the model is already functionally equivalent to a smaller model. At this stage the model still *looks* large on disk and in memory since the zeroed neurons remain as part of the tensors.

**Third**, we rebuilt the MLPs into truly smaller Linear layers. For each transformer block we created new projection layers whose shapes exactly matched the surviving neurons. We sliced the original dense weights to keep only the rows and columns associated with the active hidden units. This step removes the dead neurons entirely rather than just setting them to zero. Importantly, this does not change the function of the model because the removed units already produced zero outputs and had zero influence on downstream computations.

This manual reconstruction step is the key difference between cosmetic sparsity and actual model compression. Standard PyTorch pruning only applies masks and leaves tensor shapes unchanged. The notebook instead compresses the architecture itself, producing a smaller and more efficient MLP while preserving the model’s behavior.

In practice, this approach is useful for educational insight, research experiments or controlled model surgery. Production workflows often rely on pruning frameworks that integrate sparsity aware training, export pipelines and hardware aligned sparse kernels. However, the manual approach used here exposes the algebraic structure behind pruning in its clearest possible form. You can now see both the limitations of naïve masking and the precise mathematical reason why zero structured neurons can be safely removed.

At this point you have a functional, pruned and rebuilt model that is smaller in parameter count and computational cost, yet mathematically equivalent to the pruned dense model. This is the foundational idea behind structured MLP pruning and a stepping stone toward more advanced sparsity methods.