# Illustration of Pruner-Zero

**Workflow**:
1. Define pruning utilities and search space
2. Load model and data
3. Run search-based pruning and simple pruning baseline
4. Compare results

In [1]:
import torch
import torch.nn as nn
import math
import random
import json
from statistics import mean
from datasets import load_dataset
import logging

# ============================================================
# Math Functions for Search Space
# ============================================================

def add(x, y):
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        x, y = torch.broadcast_tensors(x, y)
        return x + y
    return x + y

def sub(x, y):
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        x, y = torch.broadcast_tensors(x, y)
        return x - y
    return x - y

def mul(x, y):
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        x, y = torch.broadcast_tensors(x, y)
        return x * y
    return x * y

def div(x, y):
    if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
        x, y = torch.broadcast_tensors(x, y)
        return x / torch.norm(y)
    elif isinstance(x, (int, float)) and isinstance(y, (int, float)):
        return x / abs(y)
    raise TypeError('Input types not supported')

def sqr(x):
    return x * x if isinstance(x, torch.Tensor) or isinstance(x, (int, float)) else x

def neg(x):
    return -x

def abs_val(x):
    return torch.abs(x) if isinstance(x, torch.Tensor) else math.fabs(x)

def log(x):
    return torch.log(torch.abs(x) + 0.001) if isinstance(x, torch.Tensor) else math.log(abs(x) + 0.001)

def sqrt(x):
    return torch.sqrt(torch.abs(x)) if isinstance(x, torch.Tensor) else math.sqrt(abs(x))

def tanh(x):
    return torch.tanh(x) if isinstance(x, torch.Tensor) else math.tanh(x)

def pow_op(x):
    return torch.pow(x, 2) if isinstance(x, torch.Tensor) else x**2

def skp(x):
    return x

def mms(x):
    if isinstance(x, torch.Tensor):
        return (x - x.min()) / (x.max() - x.min())
    return (x - min(x)) / (max(x) - min(x))

def zsn(x):
    if isinstance(x, torch.Tensor):
        return (x - x.mean()) / x.std()
    return (x - mean(x)) / x.std()

def exp(x):
    if isinstance(x, torch.Tensor):
        return torch.exp(x.clamp(max=100))
    return math.exp(min(x, 100))

UNARY_FUNCTIONS = [sqr, neg, abs_val, log, exp, sqrt, tanh, pow_op, skp, mms, zsn]
BINARY_FUNCTIONS = [add, sub, mul, div]
FUNCTIONS = UNARY_FUNCTIONS + BINARY_FUNCTIONS
TERMINALS = ['W', 'G', 'X']

FUNCTION_MAP = {func.__name__: func for func in FUNCTIONS}

# Genetic Algorithm Parameters
PROB_MUTATION = 0.1
XO_RATE = 0.8
MIN_DEPTH = 2
MAX_DEPTH = 4

# ============================================================
# GPTree Class
# ============================================================

class GPTree:
    def __init__(self, data=None, left=None, right=None):
        self.data = data
        self.left = left
        self.right = right

    def save_tree(self, filename):
        tree_data = self._serialize_tree()
        with open(filename, 'w') as file:
            json.dump(tree_data, file, indent=4)

    def _serialize_tree(self):
        data = {'data': self.node_label()}
        if self.left:
            data['left'] = self.left._serialize_tree()
        if self.right:
            data['right'] = self.right._serialize_tree()
        return data

    @staticmethod
    def load_tree(filename):
        with open(filename, 'r') as file:
            tree_data = json.load(file)
        return GPTree._deserialize_tree(tree_data)

    @staticmethod
    def _deserialize_tree(data):
        node = GPTree()
        node.data = GPTree._get_function_from_label(data['data'])
        if 'left' in data:
            node.left = GPTree._deserialize_tree(data['left'])
        if 'right' in data:
            node.right = GPTree._deserialize_tree(data['right'])
        return node

    def compute_tree(self, W, G, X):
        if self.data in FUNCTIONS:
            try:
                if self.data in UNARY_FUNCTIONS:
                    return self.data(self.left.compute_tree(W, G, X))
                else:
                    return self.data(
                        self.left.compute_tree(W, G, X),
                        self.right.compute_tree(W, G, X))
            except Exception as e:
                # print(f"Error computing tree: {e}")
                shape = W.shape if isinstance(W, torch.Tensor) else (1, 1)
                return torch.zeros(shape, dtype=torch.float32)
        elif self.data == 'W':
            return W
        elif self.data == 'G':
            return G
        elif self.data == 'X':
            return X
        else:
            shape = W.shape if isinstance(W, torch.Tensor) else G.shape
            return torch.full(shape, float(self.data), dtype=torch.float32)

    def forward(self, W, G, X):
        return self.compute_tree(W, G, X)

    @staticmethod
    def random_tree(method='grow', max_depth=4, depth=0):
        node = GPTree()
        if depth >= max_depth:
            node.data = TERMINALS[random.randint(0, len(TERMINALS) - 1)]
        elif method == 'full' or depth < MIN_DEPTH:
            node.data = FUNCTIONS[random.randint(0, len(FUNCTIONS) - 1)]
        else:
            if random.random() < 0.5:
                node.data = TERMINALS[random.randint(0, len(TERMINALS) - 1)]
            else:
                node.data = FUNCTIONS[random.randint(0, len(FUNCTIONS) - 1)]
        
        if node.data in UNARY_FUNCTIONS:
            node.left = GPTree.random_tree(method, max_depth, depth + 1)
        elif node.data in FUNCTIONS:
            node.left = GPTree.random_tree(method, max_depth, depth + 1)
            node.right = GPTree.random_tree(method, max_depth, depth + 1)
        return node

    def tree_to_string(self):
        if self.data in FUNCTIONS:
            if self.data in UNARY_FUNCTIONS:
                return f"{self.data.__name__}({self.left.tree_to_string()})"
            else:
                return f"({self.left.tree_to_string()} {self.data.__name__} {self.right.tree_to_string()})"
        return str(self.data)

    def size(self):
        if self.data in TERMINALS:
            return 1
        return 1 + (self.left.size() if self.left else 0) + (self.right.size() if self.right else 0)

    def depth(self):
        return max(self.left.depth() if self.left else 0, self.right.depth() if self.right else 0) + 1

    def copy(self):
        node = GPTree(self.data)
        if self.left: node.left = self.left.copy()
        if self.right: node.right = self.right.copy()
        return node

    def node_label(self):
        if self.data in TERMINALS: return str(self.data)
        if self.data in FUNCTIONS: return self.data.__name__
        return str(self.data)

    @staticmethod
    def _get_function_from_label(label):
        if label in TERMINALS: return label
        try: return float(label)
        except ValueError: pass
        for func in FUNCTIONS:
            if func.__name__ == label: return func
        raise ValueError(f"Unknown label: {label}")
        
    def scan_tree(self, count, replacement=None):
        count[0] -= 1
        if count[0] <= 0:
            if replacement is None:
                return self.copy() # Return copy to avoid ref issues
            else:
                self.data = replacement.data
                self.left = replacement.left
                self.right = replacement.right
        else:
            if self.left and count[0] > 0: self.left.scan_tree(count, replacement)
            if self.right and count[0] > 0: self.right.scan_tree(count, replacement)

    def mutation(self):
        if random.random() < PROB_MUTATION:
            new_tree = GPTree.random_tree(method='grow', max_depth=2)
            self.data = new_tree.data
            self.left = new_tree.left
            self.right = new_tree.right
        elif self.left:
            self.left.mutation()
        elif self.right:
            self.right.mutation()

    def crossover(self, other):
        if random.random() < XO_RATE:
            count = [random.randint(1, other.size())]
            second = other.scan_tree(count.copy(), None)
            if second:
                 count = [random.randint(1, self.size())]
                 self.scan_tree(count, second)
        elif self.left:
            self.left.crossover(other)
        elif self.right:
            self.right.crossover(other)

# ============================================================
# Helper Classes and Functions
# ============================================================

def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    try:
        traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    except:
        print("Failed to load wikitext-2 remote, trying local or skipping...")
        # Add fallback logic if needed
        raise
        
    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader

def find_layers(module, layers=[nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

def check_sparsity(model):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    count = 0
    total_params = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)
        for name in subset:
            W = subset[name].weight.data
            count += (W == 0).sum().item()
            total_params += W.numel()
    model.config.use_cache = use_cache
    return float(count) / total_params

class WrappedGPT:
    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        self.columns = layer.weight.data.shape[1]
        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()
        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp
        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples

def prepare_calibration_input(model, dataloader, device):
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    if hasattr(model, 'hf_device_map') and "model.embed_tokens" in model.hf_device_map:
        device = model.hf_device_map["model.embed_tokens"]

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros((len(dataloader), model.seqlen, model.config.hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs.get('attention_mask')
            cache['position_ids'] = kwargs.get('position_ids')
            raise ValueError
        def __getattr__(self, name):
            try:
                return super().__getattr__(name)
            except AttributeError:
                return getattr(self.module, name)

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(device))
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids

# ============================================================
# Main Pruning Function
# ============================================================

def apply_pruning(model, engine, sparsity_ratio, dataloader=None, nsamples=128, seed=0, tokenizer=None, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    if dataloader is None:
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided if dataloader is not.")
        # Default to wikitext2 if no dataloader provided
        dataloader = get_wikitext2(nsamples, seed, model.seqlen, tokenizer)

    use_cache = model.config.use_cache
    model.config.use_cache = False

    with torch.no_grad():
        inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device)

    layers = model.model.layers

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        if hasattr(model, 'hf_device_map') and f"model.layers.{i}" in model.hf_device_map:
            dev = model.hf_device_map[f"model.layers.{i}"]
            inps = inps.to(dev)
            outs = outs.to(dev)
            if attention_mask is not None:
                attention_mask = attention_mask.to(dev)
            if position_ids is not None:
                position_ids = position_ids.to(dev)

        wrapped_layers = {}
        for name in subset:
            wrapped_layers[name] = WrappedGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                wrapped_layers[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in wrapped_layers:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(len(dataloader)):
            with torch.no_grad():
                layer_args = {}

                if attention_mask is not None:
                    if attention_mask.size(0) == len(dataloader):
                        layer_args['attention_mask'] = attention_mask[j].unsqueeze(0)
                    else:
                        layer_args['attention_mask'] = attention_mask

                pos_ids = None
                if position_ids is not None:
                    if position_ids.size(0) == len(dataloader):
                        pos_ids = position_ids[j].unsqueeze(0)
                    else:
                        pos_ids = position_ids
                    layer_args['position_ids'] = pos_ids

                if hasattr(model.model, "rotary_emb") and pos_ids is not None:
                    try:
                        layer_args["position_embeddings"] = model.model.rotary_emb(
                            inps[j].unsqueeze(0), pos_ids
                        )
                    except TypeError:
                        layer_args["position_embeddings"] = model.model.rotary_emb(pos_ids)

                outs[j] = layer(inps[j].unsqueeze(0), **layer_args)[0]

        for h in handles:
            h.remove()

        for name in subset:
            # print(f"Pruning layer {i} name {name}")
            W = torch.abs(subset[name].weight.data)
            X = wrapped_layers[name].scaler_row.reshape((1, -1))
            G = torch.ones_like(W) # Dummy gradient

            W_metric = engine.forward(
                W.to(dtype=torch.float32),
                G.to(device=W.device, dtype=torch.float32),
                X.to(device=W.device, dtype=torch.float32),
            )

            sort_res = torch.sort(W_metric, dim=-1, stable=True)
            indices = sort_res[1][:, :int(W_metric.shape[1] * sparsity_ratio)]
            W_mask = torch.zeros_like(W_metric, dtype=torch.bool)
            W_mask.scatter_(1, indices, True)

            subset[name].weight.data[W_mask] = 0

        for j in range(len(dataloader)):
            with torch.no_grad():
                layer_args = {}
                if attention_mask is not None:
                    if attention_mask.size(0) == len(dataloader):
                        layer_args['attention_mask'] = attention_mask[j].unsqueeze(0)
                    else:
                        layer_args['attention_mask'] = attention_mask

                pos_ids = None
                if position_ids is not None:
                    if position_ids.size(0) == len(dataloader):
                        pos_ids = position_ids[j].unsqueeze(0)
                    else:
                        pos_ids = position_ids
                    layer_args['position_ids'] = pos_ids

                if hasattr(model.model, "rotary_emb") and pos_ids is not None:
                    try:
                        layer_args["position_embeddings"] = model.model.rotary_emb(
                            inps[j].unsqueeze(0), pos_ids
                        )
                    except TypeError:
                        layer_args["position_embeddings"] = model.model.rotary_emb(pos_ids)

                outs[j] = layer(inps[j].unsqueeze(0), **layer_args)[0]
        inps, outs = outs, inps

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    return model

In [2]:
#!/usr/bin/env python3
"""
Simplified LLM pruning test script
Tests pruning on the Qwen2.5-0.5B model.
"""

import torch
import torch.nn as nn
import sys
import pruner_utils

# ============================================================
# Environment check
# ============================================================
print("=" * 60)
print("Environment Check")
print("=" * 60)
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
print("=" * 60)

# Try importing required packages
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    print("✓ transformers is installed")
except ImportError as e:
    print("✗ transformers is NOT installed")
    print(f"  Error: {e}")
    print("  Please run: pip install transformers")

try:
    from datasets import load_dataset
    print("✓ datasets is installed")
except ImportError as e:
    print("✗ datasets is NOT installed")
    print(f"  Error: {e}")
    print("  Please run: pip install datasets")

print("=" * 60)
print()

# ============================================================
# Configuration
# ============================================================
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
SPARSITY_RATIO = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================================================
# Simplified data loading
# ============================================================

class TokenizerWrapper:
    def __init__(self, input_ids):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    """Load the WikiText-2 dataset."""
    print("Loading WikiText-2 dataset...")
    from datasets import load_dataset
    import random

    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", download_mode='force_redownload')
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test",  download_mode='force_redownload')

    # Encode datasets
    trainenc = tokenizer(" ".join(traindata["text"]), return_tensors="pt")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")

    # Generate samples
    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def find_layers(module, layers=[nn.Linear], name=""):
    """Recursively find layers of specific types."""
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            find_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res


def check_sparsity(model):
    """Check model sparsity (fraction of zero weights in Linear layers)."""
    use_cache = model.config.use_cache
    model.config.use_cache = False

    layers = model.model.layers
    zero_count = 0
    total_params = 0

    for i in range(len(layers)):
        subset = find_layers(layers[i])
        for name in subset:
            W = subset[name].weight.data
            zero_count += (W == 0).sum().item()
            total_params += W.numel()

    model.config.use_cache = use_cache
    return float(zero_count) / total_params


def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    """Evaluate perplexity on WikiText."""
    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    nlls = []
    print(f"Number of evaluation samples: {nsamples}")

    for i in range(0, nsamples, bs):
        if i % 50 == 0:
            print(f"  Sample {i}/{nsamples}")

        j = min(i + bs, nsamples)

        # Prepare inputs
        inputs = testenc[:, (i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j - i, model.seqlen)

        # Forward pass
        lm_logits = model(inputs).logits

        # Compute loss
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
        )

        # Negative log-likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j - i)
        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    return ppl.item()


def evaluate_model(model, tokenizer, device=None):
    """Evaluate model perplexity."""
    print("Evaluating model perplexity...")
    _, testloader = get_wikitext2(128, 0, model.seqlen, tokenizer)

    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, bs=1, device=device)

    return ppl_test


def apply_simple_pruning(model, sparsity_ratio=0.1):
    """Simple magnitude pruning (sets smallest-magnitude weights to zero)."""
    print(f"Applying pruning (target sparsity: {sparsity_ratio * 100}%)...")

    use_cache = model.config.use_cache
    model.config.use_cache = False

    layers = model.model.layers

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        for name in subset:
            print(f"  Pruning layer {i} - {name}")
            W = subset[name].weight.data

            # Magnitude as pruning metric
            W_metric = torch.abs(W)

            # Threshold by global fraction within this matrix
            sort_res = torch.sort(W_metric.flatten(), stable=True)
            threshold = sort_res[0][int(W_metric.numel() * sparsity_ratio)]
            W_mask = W_metric <= threshold

            # Zero out weights
            subset[name].weight.data[W_mask] = 0

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    return model


# ============================================================
# Main
# ============================================================

def main():
    print("=" * 60)
    print("Start Testing")
    print("=" * 60)

    # Step 1: Load model
    print(f"\nStep 1: Loading model {MODEL_NAME}")
    print("-" * 60)

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
            # force_download=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.float16,
            trust_remote_code=True,
            # force_download=True,
        )
        model = model.to(DEVICE)
        model.eval()

        model.seqlen = 512
        print("✓ Model loaded successfully")
        print(f"  Sequence length: {model.seqlen}")
        print(f"  Device: {DEVICE}")
    except Exception as e:
        print(f"✗ Model loading failed: {e}")
        return

    # Step 2: Evaluate original model
    print("\nStep 2: Evaluating original model")
    print("-" * 60)
    try:
        original_ppl = evaluate_model(model, tokenizer, device=DEVICE)
        print(f"✓ Original model perplexity: {original_ppl:.4f}")
    except Exception as e:
        print(f"✗ Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        return

    # Step 3: Apply pruning
    print("\nStep 3: Applying pruning")
    print("-" * 60)
    results = []
    try:
        import copy
        model_copy = copy.deepcopy(model)
        
        # Search-based pruning integration
        print("Creating pruning expression...")
        # Create a simple expression: mul(W, sqrt(X))
        simple_tree = pruner_utils.GPTree(pruner_utils.mul)
        simple_tree.left = pruner_utils.GPTree('W')
        simple_tree.right = pruner_utils.GPTree(pruner_utils.sqrt)
        simple_tree.right.left = pruner_utils.GPTree('X')
        print(f"Pruning expression: {simple_tree.tree_to_string()}")

        # Apply pruning using pruner_utils
        model_copy = pruner_utils.apply_pruning(
            model_copy,
            simple_tree,
            sparsity_ratio=SPARSITY_RATIO,
            tokenizer=tokenizer,
            device=DEVICE,
            nsamples=32  # Reduced samples for faster testing
        )
        
        # model_copy = apply_simple_pruning(model_copy, SPARSITY_RATIO)
        model_copy.eval()

        actual_sparsity = check_sparsity(model_copy)
        print("✓ Pruning completed")
        print(f"  Target sparsity: {SPARSITY_RATIO * 100:.1f}%")
        print(f"  Actual sparsity: {actual_sparsity * 100:.2f}%")
        results.append({
            "name": "Search-based",
            "model": model_copy,
            "sparsity": actual_sparsity,
        })
    except Exception as e:
        print(f"✗ Pruning failed: {e}")
        import traceback
        traceback.print_exc()
        return

    # Step 3b: Simple pruning baseline
    print("\nStep 3b: Applying simple pruning baseline")
    print("-" * 60)
    try:
        import copy
        simple_copy = copy.deepcopy(model)
        simple_copy = apply_simple_pruning(simple_copy, SPARSITY_RATIO)
        simple_copy.eval()

        simple_sparsity = check_sparsity(simple_copy)
        print("✓ Simple pruning completed")
        print(f"  Target sparsity: {SPARSITY_RATIO * 100:.1f}%")
        print(f"  Actual sparsity: {simple_sparsity * 100:.2f}%")
        results.append({
            "name": "Simple magnitude",
            "model": simple_copy,
            "sparsity": simple_sparsity,
        })
    except Exception as e:
        print(f"✗ Simple pruning failed: {e}")
        import traceback
        traceback.print_exc()
        return

    # Step 4: Evaluate pruned models
    print("\nStep 4: Evaluating pruned models")
    print("-" * 60)
    try:
        for item in results:
            item["ppl"] = evaluate_model(item["model"], tokenizer, device=DEVICE)
            print(f"✓ {item['name']} perplexity: {item['ppl']:.4f}")
    except Exception as e:
        print(f"✗ Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
        return

    # Step 5: Summarize results
    print("\n" + "=" * 60)
    print("Summary")
    print("=" * 60)
    print(f"Original model perplexity:   {original_ppl:.4f}")
    print("\nComparison Table")
    print("-" * 60)
    header = f"{'Method':<20}{'PPL':>12}{'ΔPPL':>12}{'Growth%':>12}{'Sparsity%':>12}"
    print(header)
    print("-" * 60)
    for item in results:
        delta = item["ppl"] - original_ppl
        growth = (item["ppl"] / original_ppl - 1) * 100
        row = f"{item['name']:<20}{item['ppl']:>12.4f}{delta:>12.4f}{growth:>11.2f}%{item['sparsity'] * 100:>11.2f}%"
        print(row)
    print("=" * 60)

    # Save model
    save_path = f"pruned_qwen2.5_0.5b_sparsity{SPARSITY_RATIO}.pt"
    print(f"\nSaving pruned model to: {save_path}")
    torch.save(model_copy.state_dict(), save_path)
    print("✓ Done!")


main()

Environment Check
Python version: 3.12.12 (main, Oct  9 2025, 11:07:00) [Clang 17.0.0 (clang-1700.0.13.3)]
PyTorch version: 2.10.0
CUDA available: False
✓ transformers is installed
✓ datasets is installed

Start Testing

Step 1: Loading model Qwen/Qwen2.5-0.5B
------------------------------------------------------------


`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

✓ Model loaded successfully
  Sequence length: 512
  Device: cpu

Step 2: Evaluating original model
------------------------------------------------------------
Evaluating model perplexity...
Loading WikiText-2 dataset...


test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (2541000 > 131072). Running this sequence through the model will result in indexing errors


Number of evaluation samples: 584
  Sample 0/584
  Sample 50/584
  Sample 100/584
  Sample 150/584
  Sample 200/584
  Sample 250/584
  Sample 300/584
  Sample 350/584
  Sample 400/584
  Sample 450/584
  Sample 500/584
  Sample 550/584
✓ Original model perplexity: 17.1919

Step 3: Applying pruning
------------------------------------------------------------
Creating pruning expression...
Pruning expression: (W mul sqrt(X))
✓ Pruning completed
  Target sparsity: 10.0%
  Actual sparsity: 9.95%

Step 3b: Applying simple pruning baseline
------------------------------------------------------------
Applying pruning (target sparsity: 10.0%)...
  Pruning layer 0 - self_attn.q_proj
  Pruning layer 0 - self_attn.k_proj
  Pruning layer 0 - self_attn.v_proj
  Pruning layer 0 - self_attn.o_proj
  Pruning layer 0 - mlp.gate_proj
  Pruning layer 0 - mlp.up_proj
  Pruning layer 0 - mlp.down_proj
  Pruning layer 1 - self_attn.q_proj
  Pruning layer 1 - self_attn.k_proj
  Pruning layer 1 - self_attn.v_p

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]



test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Number of evaluation samples: 584
  Sample 0/584
  Sample 50/584
  Sample 100/584
  Sample 150/584
  Sample 200/584
  Sample 250/584
  Sample 300/584
  Sample 350/584
  Sample 400/584
  Sample 450/584
  Sample 500/584
  Sample 550/584
✓ Search-based perplexity: 17.2372
Evaluating model perplexity...
Loading WikiText-2 dataset...


test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Number of evaluation samples: 584
  Sample 0/584
  Sample 50/584
  Sample 100/584
  Sample 150/584
  Sample 200/584
  Sample 250/584
  Sample 300/584
  Sample 350/584
  Sample 400/584
  Sample 450/584
  Sample 500/584
  Sample 550/584
✓ Simple magnitude perplexity: 17.4497

Summary
Original model perplexity:   17.1919

Comparison Table
------------------------------------------------------------
Method                       PPL        ΔPPL     Growth%   Sparsity%
------------------------------------------------------------
Search-based             17.2372      0.0453       0.26%       9.95%
Simple magnitude         17.4497      0.2578       1.50%      10.03%

Saving pruned model to: pruned_qwen2.5_0.5b_sparsity0.1.pt
✓ Done!
