In [2]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Using cached pyarrow-19.0.1-cp39-cp39-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp39-cp39-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py39-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting aiohttp (from datasets)
  Using cached aiohttp-3.11.14-cp39-cp39-macosx_11_0_arm64.whl.metadata (7.7 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)
  Using cached aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets)
 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel
import torch.nn.utils.prune as prune
import time
import gc
from tqdm import tqdm
import pandas as pd
from datasets import load_dataset
import os
from torch.ao.quantization import quantize_dynamic

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configuration
MODELS = {
    "distilgpt2": "distilgpt2",
    "gpt2": "gpt2",
    # "gpt2-medium": "gpt2-medium"
}
# PRUNING_RATES = [0.0, 0.3, 0.5, 0.7, 0.9]

PRUNING_RATES = [0.0, 0.4, 0.7]
EVAL_SAMPLES = 100  # Number of samples to evaluate on
WARMUP_RUNS = 5     # Number of warmup runs for inference timing
MEASURE_RUNS = 10   # Number of runs to average for measurements
MAX_SEQ_LENGTH = 128
RESULTS_DIR = "compression_results"

# Create results directory
os.makedirs(RESULTS_DIR, exist_ok=True)

# Helper functions
def get_model_size_mb(model):
    """Calculate model size in MB (parameters only)"""
    return sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024)

def get_model_total_size_mb(model):
    """Calculate total model size including buffers"""
    total_size = 0
    for param in model.parameters():
        total_size += param.numel() * param.element_size()
    for buffer in model.buffers():
        total_size += buffer.numel() * buffer.element_size()
    return total_size / (1024 * 1024)

def count_non_zero_params(model):
    """Count non-zero parameters in the model"""
    return sum(torch.count_nonzero(p) for p in model.parameters())

def calculate_sparsity(model):
    """Calculate the actual sparsity percentage of the model"""
    total_params = sum(p.numel() for p in model.parameters())
    nonzero_params = count_non_zero_params(model)
    return 100 * (1 - nonzero_params / total_params)

def calculate_perplexity(model, eval_dataloader, device):
    """Calculate perplexity on evaluation dataset"""
    model.eval()
    total_loss = 0
    total_length = 0

    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Calculating perplexity"):
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)

            # Calculate loss only on non-padded tokens
            loss = outputs.loss
            total_loss += loss.item() * inputs["input_ids"].size(0)
            total_length += inputs["input_ids"].size(0)

    avg_loss = total_loss / total_length
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity

def measure_inference_time(model, inputs, n_runs=10):
    """Measure average inference time over multiple runs"""
    model.eval()
    # Warmup runs
    with torch.no_grad():
        for _ in range(WARMUP_RUNS):
            _ = model(**inputs)

    # Actual measurement
    latencies = []
    with torch.no_grad():
        for _ in range(n_runs):
            start_time = time.time()
            _ = model(**inputs)
            latencies.append(time.time() - start_time)

    return {
        "mean": np.mean(latencies),
        "std": np.std(latencies),
        "min": np.min(latencies),
        "max": np.max(latencies)
    }

def prune_model(model, amount=0.5):
    """Apply L1 unstructured pruning to all linear layers"""
    for name, module in tqdm(model.named_modules(), desc=f"Pruning Model ({amount*100:.1f}%)"):
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)

    # Make pruning permanent to save memory
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.remove(module, 'weight')

    return model

def prepare_eval_dataset(tokenizer, num_samples=100):
    """Prepare dataset for evaluation"""
    # Load WikiText for perplexity evaluation
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    # Take a subset for faster evaluation
    dataset = dataset.select(range(min(num_samples, len(dataset))))

    # Tokenize the dataset
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length",
                         truncation=True, max_length=MAX_SEQ_LENGTH,
                         return_tensors="pt")

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    # Create labels for causal language modeling (shift input_ids)
    def prepare_clm_inputs(examples):
        examples["labels"] = examples["input_ids"].clone()
        return examples

    tokenized_dataset = tokenized_dataset.map(prepare_clm_inputs)
    return tokenized_dataset

def evaluate_model(model_name, pruning_rates, use_quantization=False):
    """Evaluate model with different pruning rates and optional quantization"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # Prepare evaluation dataset
    eval_dataset = prepare_eval_dataset(tokenizer, EVAL_SAMPLES)
    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=4, shuffle=False
    )

    # Single input for latency testing
    sample_text = "The future of artificial intelligence is"
    sample_inputs = tokenizer(sample_text, return_tensors="pt").to(device)

    results = []

    for pruning_rate in pruning_rates:
        # Load a fresh model for each pruning rate
        print(f"\n--- Evaluating {model_name} with pruning rate {pruning_rate:.2f} ---")
        model = AutoModelForCausalLM.from_pretrained(model_name)
        model.to(device)

        # Measure baseline
        baseline_size = get_model_size_mb(model)
        baseline_time = measure_inference_time(model, sample_inputs, MEASURE_RUNS)
        baseline_perplexity = calculate_perplexity(model, eval_dataloader, device)
        baseline_sparsity = calculate_sparsity(model)

        # Apply pruning if rate > 0
        if pruning_rate > 0:
            model = prune_model(model, pruning_rate)

        # Measure pruned model (before quantization)
        pruned_size = get_model_size_mb(model)
        pruned_time = measure_inference_time(model, sample_inputs, MEASURE_RUNS)
        pruned_perplexity = calculate_perplexity(model, eval_dataloader, device)
        pruned_sparsity = calculate_sparsity(model)

        # Apply quantization if requested
        if use_quantization:
            print("Applying quantization...")
            quantized_model = quantize_dynamic(
                model, {torch.nn.Linear}, dtype=torch.qint8
            )
            quantized_model.to(device)

            # Measure quantized model
            quantized_size = get_model_size_mb(quantized_model)
            quantized_time = measure_inference_time(quantized_model, sample_inputs, MEASURE_RUNS)
            quantized_perplexity = calculate_perplexity(quantized_model, eval_dataloader, device)
            quantized_sparsity = calculate_sparsity(quantized_model)

            # Record results
            results.append({
                "model": model_name,
                "technique": "pruning+quantization",
                "pruning_rate": pruning_rate,
                "perplexity": quantized_perplexity,
                "size_mb": quantized_size,
                "latency_ms": quantized_time["mean"] * 1000,
                "latency_std_ms": quantized_time["std"] * 1000,
                "sparsity": quantized_sparsity
            })
        else:
            # Record results for pruning only
            results.append({
                "model": model_name,
                "technique": "pruning",
                "pruning_rate": pruning_rate,
                "perplexity": pruned_perplexity,
                "size_mb": pruned_size,
                "latency_ms": pruned_time["mean"] * 1000,
                "latency_std_ms": pruned_time["std"] * 1000,
                "sparsity": pruned_sparsity
            })

        # Free memory
        del model
        if use_quantization:
            del quantized_model
        gc.collect()
        torch.cuda.empty_cache()

    return pd.DataFrame(results)

def create_visualizations(results_df):
    """Create and save visualizations from results"""
    sns.set(style="whitegrid")

    # Ensure pruning_rate is numeric
    results_df["pruning_rate"] = results_df["pruning_rate"].astype(float)

    # 1. Size Reduction vs Pruning Rate (by Model and Technique)
    plt.figure(figsize=(12, 8))
    sns.lineplot(data=results_df, x="pruning_rate", y="size_mb",
                 hue="model", style="technique", markers=True, dashes=False)
    plt.title("Model Size vs Pruning Rate", fontsize=16)
    plt.xlabel("Pruning Rate", fontsize=14)
    plt.ylabel("Model Size (MB)", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.savefig(f"{RESULTS_DIR}/size_vs_pruning.png", dpi=300, bbox_inches='tight')

    # 2. Perplexity vs Pruning Rate (by Model and Technique)
    plt.figure(figsize=(12, 8))
    sns.lineplot(data=results_df, x="pruning_rate", y="perplexity",
                 hue="model", style="technique", markers=True, dashes=False)
    plt.title("Perplexity vs Pruning Rate", fontsize=16)
    plt.xlabel("Pruning Rate", fontsize=14)
    plt.ylabel("Perplexity (lower is better)", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.savefig(f"{RESULTS_DIR}/perplexity_vs_pruning.png", dpi=300, bbox_inches='tight')

    # 3. Latency vs Pruning Rate (by Model and Technique)
    plt.figure(figsize=(12, 8))
    sns.lineplot(data=results_df, x="pruning_rate", y="latency_ms",
                 hue="model", style="technique", markers=True, dashes=False)
    plt.title("Inference Latency vs Pruning Rate", fontsize=16)
    plt.xlabel("Pruning Rate", fontsize=14)
    plt.ylabel("Latency (ms)", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.savefig(f"{RESULTS_DIR}/latency_vs_pruning.png", dpi=300, bbox_inches='tight')

    # 4. Compression Efficiency: Perplexity vs Size
    plt.figure(figsize=(12, 8))
    for model in results_df["model"].unique():
        model_data = results_df[results_df["model"] == model]
        techniques = model_data["technique"].unique()

        for technique in techniques:
            data = model_data[model_data["technique"] == technique]
            plt.plot(data["size_mb"], data["perplexity"],
                     marker='o', label=f"{model} - {technique}")

            # Annotate points with pruning rate
            for i, row in data.iterrows():
                plt.annotate(f"{row['pruning_rate']:.1f}",
                             (row["size_mb"], row["perplexity"]),
                             textcoords="offset points",
                             xytext=(0,10),
                             ha='center')

    plt.title("Compression Efficiency: Perplexity vs Model Size", fontsize=16)
    plt.xlabel("Model Size (MB)", fontsize=14)
    plt.ylabel("Perplexity (lower is better)", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.savefig(f"{RESULTS_DIR}/perplexity_vs_size.png", dpi=300, bbox_inches='tight')

    # 5. Memory-Latency Tradeoff
    plt.figure(figsize=(12, 8))
    for model in results_df["model"].unique():
        model_data = results_df[results_df["model"] == model]
        techniques = model_data["technique"].unique()

        for technique in techniques:
            data = model_data[model_data["technique"] == technique]
            plt.plot(data["size_mb"], data["latency_ms"],
                     marker='o', label=f"{model} - {technique}")

            # Annotate points with pruning rate
            for i, row in data.iterrows():
                plt.annotate(f"{row['pruning_rate']:.1f}",
                             (row["size_mb"], row["latency_ms"]),
                             textcoords="offset points",
                             xytext=(0,10),
                             ha='center')

    plt.title("Memory-Latency Tradeoff", fontsize=16)
    plt.xlabel("Model Size (MB)", fontsize=14)
    plt.ylabel("Latency (ms)", fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.savefig(f"{RESULTS_DIR}/latency_vs_size.png", dpi=300, bbox_inches='tight')

    # 6. Sparsity Bar Chart - FIXED
    plt.figure(figsize=(14, 8))
    # Create a categorical version of pruning_rate for the bar chart
    results_df["pruning_rate_str"] = results_df["pruning_rate"].astype(str)

    sns.barplot(data=results_df, x="pruning_rate_str", y="sparsity", hue="technique",
                errorbar=None, palette="viridis")
    plt.title("Achieved Sparsity by Model and Technique", fontsize=16)
    plt.xlabel("Target Pruning Rate", fontsize=14)
    plt.ylabel("Actual Sparsity (%)", fontsize=14)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.legend(title="Technique", fontsize=12)
    plt.savefig(f"{RESULTS_DIR}/sparsity_by_technique.png", dpi=300, bbox_inches='tight')

    # 7. Scatterplot matrix for all metrics
    metrics = ["pruning_rate", "size_mb", "latency_ms", "perplexity", "sparsity"]
    plt.figure(figsize=(20, 20))
    g = sns.pairplot(results_df, vars=metrics, hue="model", palette="colorblind",
                      markers=["o", "s", "D"], height=3, aspect=1.2)
    g.fig.suptitle("Relationships Between Compression Metrics", fontsize=20, y=1.02)
    plt.savefig(f"{RESULTS_DIR}/metrics_relationships.png", dpi=300, bbox_inches='tight')

# Main execution
def main():
    all_results = []

    # Evaluate pruning only
    for model_name in MODELS.values():
        results_df = evaluate_model(model_name, PRUNING_RATES, use_quantization=False)
        all_results.append(results_df)

    # Evaluate pruning + quantization
    for model_name in MODELS.values():
        results_df = evaluate_model(model_name, PRUNING_RATES, use_quantization=True)
        all_results.append(results_df)

    # Combine all results
    combined_results = pd.concat(all_results, ignore_index=True)

    # Save results to CSV
    combined_results.to_csv(f"{RESULTS_DIR}/compression_results.csv", index=False)

    # Create visualizations
    create_visualizations(combined_results)

    # Print summary
    print("\n--- Summary of Best Configurations ---")

    # Best size reduction with acceptable perplexity
    acceptable_perplexity = combined_results[combined_results["pruning_rate"] == 0]["perplexity"].mean() * 1.5
    acceptable_results = combined_results[combined_results["perplexity"] <= acceptable_perplexity]

    if not acceptable_results.empty:
        best_compression = acceptable_results.loc[acceptable_results["size_mb"].idxmin()]
        print(f"\nBest Size Reduction (with perplexity < {acceptable_perplexity:.2f}):")
        print(f"  Model: {best_compression['model']}")
        print(f"  Technique: {best_compression['technique']}")
        print(f"  Pruning Rate: {best_compression['pruning_rate']:.2f}")
        print(f"  Size: {best_compression['size_mb']:.2f} MB")
        print(f"  Perplexity: {best_compression['perplexity']:.2f}")
        print(f"  Latency: {best_compression['latency_ms']:.2f} ms")

    # Best perplexity with significant compression
    min_compression = 0.5  # At least 50% size reduction
    baseline_sizes = {model: combined_results[(combined_results["model"] == model) &
                                             (combined_results["pruning_rate"] == 0)]["size_mb"].values[0]
                     for model in combined_results["model"].unique()}

    compressed_results = combined_results.copy()
    for i, row in compressed_results.iterrows():
        baseline = baseline_sizes[row["model"]]
        compressed_results.at[i, "compression_ratio"] = baseline / row["size_mb"]

    good_compression = compressed_results[compressed_results["compression_ratio"] >= min_compression]

    if not good_compression.empty:
        best_quality = good_compression.loc[good_compression["perplexity"].idxmin()]
        print(f"\nBest Quality (with compression ratio >= {min_compression:.1f}):")
        print(f"  Model: {best_quality['model']}")
        print(f"  Technique: {best_quality['technique']}")
        print(f"  Pruning Rate: {best_quality['pruning_rate']:.2f}")
        print(f"  Size: {best_quality['size_mb']:.2f} MB")
        print(f"  Compression Ratio: {best_quality['compression_ratio']:.2f}x")
        print(f"  Perplexity: {best_quality['perplexity']:.2f}")
        print(f"  Latency: {best_quality['latency_ms']:.2f} ms")

if __name__ == "__main__":
    main()