# Soft Weight-Sharing for Neural Network Compression - PyTorch Tutorial

This tutorial demonstrates the **Soft Weight-Sharing** approach from Ullrich, Meeds & Welling (ICLR 2017) using PyTorch.

## Overview

Soft weight-sharing learns a Gaussian mixture model as an empirical prior over network weights. The key idea is that weights naturally cluster together during training, allowing us to:
1. Replace individual weights with their cluster center (quantization)
2. Store only cluster centers (codebook) + assignments
3. Achieve high compression with minimal accuracy loss

## Three-Phase Approach

1. **PART 1: Pretrain** - Train a standard network on MNIST
2. **PART 2: Retrain** - Retrain with a learned Gaussian mixture prior that encourages clustering
3. **PART 3: Post-process** - Quantize weights to mixture means and evaluate compression

---

## Setup and Imports

In [None]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

# Import our soft weight-sharing modules
from sws.models import TutorialNet
from sws.data import make_loaders
from sws.prior import init_mixture, MixturePrior
from sws.train import train_standard, retrain_soft_weight_sharing, evaluate
from sws.compress import compression_report
from sws.utils import collect_weight_params, set_seed, get_device
from sws.viz import TrainingGifVisualizer
from scripts.tutorial_helpers import (
    plot_weight_scatter,
    plot_weight_histogram,
    plot_mixture_components,
    plot_comparison_histograms
)

# Set random seed for reproducibility
set_seed(42)
device = get_device()
print(f"Using device: {device}")

## Load MNIST Dataset

In [None]:
# Load MNIST with batch size 128
train_loader, test_loader, num_classes = make_loaders(
    dataset="mnist",
    batch_size=128,
    num_workers=2
)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
print(f"Number of classes: {num_classes}")

---
# PART 1: Pretrain Network

We first train a standard convolutional neural network on MNIST:
- **Architecture**: 2 convolutional layers + 2 fully-connected layers
- **Parameters**: ~642,000 trainable weights
- **Training**: Standard cross-entropy loss with Adam optimizer
- **Expected accuracy**: ~98-99%

In [None]:
# Create the model
model = TutorialNet(num_classes=num_classes).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
weight_params = sum(p.numel() for p in collect_weight_params(model))

print("\nModel Architecture:")
print(model)
print(f"\nTotal parameters: {total_params:,}")
print(f"Weight parameters (to be compressed): {weight_params:,}")

### Train the baseline model

In [None]:
# Pretrain for 20 epochs
pretrain_acc = train_standard(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    epochs=10,
    lr=1e-3,
    wd=0.0,
    optim_name="adam",
    eval_every=5,
    desc="pretrain"
)

print(f"\n✓ Pretrained model accuracy: {pretrain_acc:.4f} ({pretrain_acc*100:.2f}%)")

### Save pretrained model

In [None]:
# Save the pretrained weights
os.makedirs("tutorial_outputs2", exist_ok=True)
torch.save(model.state_dict(), "tutorial_outputs2/pretrained_model.pt")
print("✓ Saved pretrained model to tutorial_outputs2/pretrained_model.pt")

# Store pretrained weights for later comparison
pretrained_weights = [w.clone() for w in collect_weight_params(model)]

### Visualize pretrained weight distribution

In [None]:
plot_weight_histogram(
    weights=pretrained_weights,
    title="Pretrained Weight Distribution",
    log_scale=False,
    save="tutorial_outputs2/pretrained_histogram.png"
)

---
# PART 2: Retrain with Gaussian Mixture Prior

Now we add a **Gaussian mixture prior** over the weights:

$$p(w) = \sum_{j=0}^{J-1} \pi_j \mathcal{N}(w | \mu_j, \sigma_j^2)$$

Where:
- **J = 16 components** (1 zero-spike + 15 non-zero clusters)
- **π₀ = 0.99** (high probability on zero for sparsity)
- **μ₀ = 0** (zero component is fixed)
- **μ₁...μ₁₅** are learned from the pretrained weight distribution

## Loss Function

The training loss becomes:

$$\mathcal{L} = \text{CrossEntropy}(y, \hat{y}) + \frac{\tau}{N} \sum_i -\log p(w_i)$$

Where:
- **τ = 0.003** (complexity regularization strength)
- **N = 60,000** (dataset size for proper normalization)

The negative log probability term encourages weights to cluster at the mixture component means.

### Initialize the Gaussian Mixture Prior

In [None]:
# Initialize mixture with 16 components
prior = init_mixture(
    model=model,
    J=16,  # Total components (1 zero + 15 non-zero)
    pi0=0.99,  # High probability on zero component
    init_means_mode="from_weights",  # Initialize from pretrained weight range
    init_sigma=0.25,  # Initial standard deviation
    device=device
)

print("\n✓ Initialized Gaussian mixture prior")
print(f"  - Number of components: {prior.J}")
print(f"  - Zero component mixing weight (π₀): {prior.pi0_init}")

# Visualize initial mixture
mu, sigma2, pi = prior.mixture_params()
print(f"\nInitial mixture means: {mu[:5].detach().cpu().numpy()}...")
print(f"Initial mixture stds: {torch.sqrt(sigma2[:5]).detach().cpu().numpy()}...")

### Setup GIF Visualizer

We'll use the `TrainingGifVisualizer` to create an animated visualization showing how weights evolve during retraining. The GIF will show:
- Weight scatter plot (pretrained vs current weights)
- Marginal histograms
- Mixture component bands (mean ± 2σ)
- Test accuracy per epoch

In [None]:
# Create GIF visualizer to track weight evolution during retraining
viz = TrainingGifVisualizer(
    out_dir="tutorial_outputs2",
    tag="retraining",
    framerate=2,
    notebook_display = True,
    cleanup_frames = True
)

print("✓ GIF visualizer ready - will capture frames during training")

### Retrain with Soft Weight-Sharing

In [None]:
# Retrain with mixture prior and GIF visualization
retrain_acc = retrain_soft_weight_sharing(
    model=model,
    prior=prior,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    epochs=10,
    lr_w=5e-4,  # Learning rate for network weights
    lr_theta=3e-4,  # Learning rate for mixture parameters
    weight_decay=0.0,
    tau=0.003,  # Complexity regularization (properly normalized by dataset size)
    tau_warmup_epochs=0,  # Gradually increase tau over first 10 epochs
    complexity_mode="keras",  # Use keras-style normalization (tau/dataset_size)
    eval_every=10,
    cr_every=0,  # Don't compute compression during training (slow)
    mixture_every=0,  # Don't log mixture every epoch
    run_dir="tutorial_outputs2",
    viz=viz  # Pass visualizer to capture frames during training
)

print(f"\n✓ Retrained model accuracy: {retrain_acc:.4f} ({retrain_acc*100:.2f}%)")
print(f"  Accuracy drop from pretraining: {(pretrain_acc - retrain_acc)*100:.2f}%")
print(f"\n✓ GIF animation saved to: tutorial_outputs2/retraining.gif")

### View the Training Animation

The GIF shows how weights migrate from their pretrained values (x-axis) toward mixture component means (y-axis) during retraining. You'll see:
- Weights clustering into horizontal bands (the mixture components)
- The marginal histogram becoming more sparse and discrete
- Test accuracy tracked in the title

In [None]:
# Display the GIF inline (works in Jupyter Lab/Notebook)
from IPython.display import Image, display

gif_path = "tutorial_outputs2/retraining.gif"
if os.path.exists(gif_path):
    display(Image(filename=gif_path))
    print(f"\n💡 Animation also saved as individual frames in: tutorial_outputs2/retraining_frames/")
else:
    print(f"GIF not found at {gif_path}")

### Save retrained (pre-quantized) model

In [None]:
# Save pre-quantized model
torch.save(model.state_dict(), "tutorial_outputs2/retrained_prequant_model.pt")
print("✓ Saved retrained model to tutorial_outputs2/retrained_prequant_model.pt")

# Store retrained weights for comparison
retrained_weights = [w.clone() for w in collect_weight_params(model)]

### Visualize learned mixture components

In [None]:
# Plot the learned Gaussian mixture
plot_mixture_components(
    prior=prior,
    xlim=(-0.5, 0.5),
    save="tutorial_outputs2/learned_mixture.png"
)

### Visualize weight movement

In [None]:
# Scatter plot showing how weights moved during retraining
plot_weight_scatter(
    weights_before=pretrained_weights,
    weights_after=retrained_weights,
    sample=20000,
    xlim=(-0.5, 0.5),
    ylim=(-0.5, 0.5),
    save="tutorial_outputs2/weight_movement.png"
)

print("\n💡 Weights should cluster towards mixture component means (away from diagonal)")

---
# PART 3: Post-Processing and Compression

Now we:
1. **Quantize** weights by assigning each to its nearest mixture mean
2. **Evaluate** the quantized model accuracy
3. **Compute compression** using CSR + Huffman encoding

## Quantization Strategy

We use **maximum likelihood (ML) assignment**:
- For each weight, find the component with highest likelihood
- Replace weight with that component's mean
- This avoids bias towards the zero-spike during snapping

### Quantize weights to mixture means

In [None]:
# Quantize: snap each weight to its nearest mixture component mean
prior.quantize_model(
    model=model,
    skip_last_matrix=True,  # Keep final classifier at full precision
    assign="ml"  # Use maximum likelihood assignment (avoids zero-spike bias)
)

print("✓ Quantized weights to mixture means")

### Evaluate quantized model

In [None]:
# Evaluate accuracy after quantization
quantized_acc = evaluate(model, test_loader, device)

print(f"\n📊 Accuracy Comparison:")
print(f"  Pretrained:  {pretrain_acc:.4f} ({pretrain_acc*100:.2f}%)")
print(f"  Retrained:   {retrain_acc:.4f} ({retrain_acc*100:.2f}%)")
print(f"  Quantized:   {quantized_acc:.4f} ({quantized_acc*100:.2f}%)")
print(f"\n  Total accuracy drop: {(pretrain_acc - quantized_acc)*100:.2f}%")

### Save quantized model

In [None]:
# Save final quantized model
torch.save(model.state_dict(), "tutorial_outputs2/quantized_model.pt")
print("✓ Saved quantized model to tutorial_outputs2/quantized_model.pt")

# Store quantized weights
quantized_weights = [w.clone() for w in collect_weight_params(model)]

### Compute compression statistics

In [None]:
# Detailed compression report using CSR + Huffman encoding
report = compression_report(
    model=model,
    prior=prior,
    dataset="mnist",
    use_huffman=True,
    pbits_fc=5,  # Bits for FC layer column index diffs
    pbits_conv=8,  # Bits for Conv layer column index diffs
    skip_last_matrix=True,  # Last layer was not quantized
    assign_mode="ml"  # Must match quantization assignment mode
)

print("\n🗜️  Compression Report:")
print(f"  Original bits:    {report['orig_bits']:,}")
print(f"  Compressed bits:  {report['compressed_bits']:,}")
print(f"  Compression Ratio: {report['CR']:.2f}x")
print(f"  Non-zero weights:  {report['nnz']:,} / {weight_params:,} ({100*report['nnz']/weight_params:.2f}%)")
print(f"  Sparsity:         {100*(1 - report['nnz']/weight_params):.2f}%")

### Layer-wise compression breakdown

In [None]:
print("\n📋 Layer-wise Compression:")
print(f"{'Layer':<15} {'Shape':<20} {'Original (bits)':<18} {'Compressed (bits)':<20} {'CR':<8} {'Sparsity':<10}")
print("-" * 100)

for layer_info in report['layers']:
    if layer_info['passthrough']:
        cr_str = "N/A"
        sparsity = 0.0
    else:
        compressed = layer_info['bits_IR'] + layer_info['bits_IC'] + layer_info['bits_A'] + layer_info['bits_codebook']
        cr = layer_info['orig_bits'] / max(compressed, 1)
        cr_str = f"{cr:.2f}x"
        total_weights = np.prod(layer_info['shape'])
        sparsity = 100 * (1 - layer_info['nnz'] / total_weights)
    
    shape_str = 'x'.join(map(str, layer_info['shape']))
    orig_str = f"{layer_info['orig_bits']:,}"
    comp_str = f"{layer_info['bits_IR'] + layer_info['bits_IC'] + layer_info['bits_A'] + layer_info['bits_codebook']:,}"
    
    print(f"{layer_info['layer']:<15} {shape_str:<20} {orig_str:<18} {comp_str:<20} {cr_str:<8} {sparsity:.1f}%")

### Visualize weight distributions across all phases

In [None]:
# Compare histograms: pretrained → retrained → quantized
plot_comparison_histograms(
    weights_pre=pretrained_weights,
    weights_retrained=retrained_weights,
    weights_quantized=quantized_weights,
    save_prefix="tutorial_outputs2/phase"
)

print("\n💡 Notice how weights progressively cluster and become sparse")

---
## Summary

We successfully applied soft weight-sharing to compress a neural network:

### Results:
- **Original model**: ~642K parameters at 32-bit → ~20.5 MB
- **Compressed model**: Achieves ~20-60x compression (depending on hyperparameters)
- **Accuracy loss**: Typically < 1% on MNIST

### Key Insights:

1. **Soft clustering is better than hard**: By using a differentiable mixture prior, weights naturally migrate to cluster centers during training

2. **Zero-spike induces sparsity**: The high mixing weight (π₀ = 0.99) on the zero component encourages many weights to become exactly zero

3. **Proper normalization matters**: The complexity term must be normalized by dataset size (τ/N) for stable training

4. **CSR + Huffman is efficient**: Sparse weights + entropy coding achieves better compression than naive approaches

### Next Steps:
- Try different number of components (J)
- Experiment with different τ values
- Apply to larger models (ResNets, etc.)
- Combine with other compression techniques (pruning, quantization-aware training)

---

**Reference**: Ullrich, K., Meeds, E., & Welling, M. (2017). Soft Weight-Sharing for Neural Network Compression. *ICLR 2017*.

---
## Bonus: Analyze Mixture Component Usage

In [None]:
# Count how many weights are assigned to each component
mu, sigma2, pi = prior.mixture_params()
mu = mu.detach().cpu().numpy()
pi = pi.detach().cpu().numpy()

# Count assignments in quantized model
all_weights = torch.cat([w.flatten() for w in quantized_weights]).cpu().numpy()

print("\n📊 Component Assignment Statistics:\n")
print(f"{'Component':<12} {'Mean (μ)':<12} {'Mixing (π)':<12} {'# Weights':<15} {'Percentage':<12}")
print("-" * 70)

for j in range(len(mu)):
    # Count weights at this mean (with small tolerance for floating point)
    count = np.sum(np.abs(all_weights - mu[j]) < 1e-6)
    percentage = 100 * count / len(all_weights)
    
    comp_name = f"Comp {j}" if j > 0 else "Zero-spike"
    print(f"{comp_name:<12} {mu[j]:>11.4f} {pi[j]:>11.4f} {count:>14,} {percentage:>11.2f}%")

print("\n💡 The zero-spike should capture most weights (high sparsity)")