# Full SWS Compression with Optimized Hyperpriors' parameters

Run complete soft weight-sharing retraining with hyperparameters selected from Pareto front exploration.

## Workflow
1. **Fill in hyperparameters** (from `BO.ipynb` Pareto front)
2. **Run full retraining** (100 epochs with diagnostics)
3. **Generate all plots** (GIF, training curves, mixture dynamics, etc.)
4. **Analyze compression** (layer-wise breakdown, sparsity, CR)

## Instructions
- **Step 1**: Run `BO.ipynb` to find Pareto-optimal hyperparameters
- **Step 2**: Select a solution from the Pareto front
- **Step 3**: Copy hyperparameters to the cell below
- **Step 4**: Run all cells

---
## 1. User Configuration

In [None]:
# ========================================
# USER: Fill in these values from Pareto front exploration (BO.ipynb)
# ========================================

# Model selection
MODEL = "lenet_300_100"  # Options: "lenet_300_100", "lenet5", "wrn_16_4"
DATASET = "mnist"        # Options: "mnist", "cifar10"

# Hyperparameters from Pareto front (REQUIRED - fill these in!)
TAU = None                # Example: 0.005
GAMMA_ALPHA = None        # Example: 250.0
GAMMA_BETA = None         # Example: 0.15
GAMMA_ALPHA_ZERO = None   # Example: 4500.0
GAMMA_BETA_ZERO = None    # Example: 2.5

# ========================================
# Validation (do not modify)
# ========================================

assert TAU is not None, "ERROR: Please set TAU (complexity regularization)"
assert GAMMA_ALPHA is not None, "ERROR: Please set GAMMA_ALPHA (Gamma prior shape for non-zero components)"
assert GAMMA_BETA is not None, "ERROR: Please set GAMMA_BETA (Gamma prior rate for non-zero components)"
assert GAMMA_ALPHA_ZERO is not None, "ERROR: Please set GAMMA_ALPHA_ZERO (Gamma prior shape for zero component)"
assert GAMMA_BETA_ZERO is not None, "ERROR: Please set GAMMA_BETA_ZERO (Gamma prior rate for zero component)"

print("✅ All hyperparameters validated")
print("\n" + "="*60)
print("CONFIGURATION")
print("="*60)
print(f"  Model:              {MODEL}")
print(f"  Dataset:            {DATASET}")
print(f"\n  Hyperparameters:")
print(f"    tau:              {TAU:.6g}")
print(f"    gamma_alpha:      {GAMMA_ALPHA:.6g}")
print(f"    gamma_beta:       {GAMMA_BETA:.6g}")
print(f"    gamma_alpha_zero: {GAMMA_ALPHA_ZERO:.6g}")
print(f"    gamma_beta_zero:  {GAMMA_BETA_ZERO:.6g}")
print("="*60)

### Map Model to Pretrained Checkpoint

In [None]:
# Auto-detect checkpoint path based on model and dataset
CHECKPOINT_MAP = {
    ("lenet_300_100", "mnist"): "checkpoints/mnist_lenet_300_100_pre.pt",
    ("lenet5", "mnist"): "checkpoints/mnist_lenet5_pre.pt",
    ("wrn_16_4", "cifar10"): "checkpoints/cifar10_wrn_16_4_pre.pt",
}

CHECKPOINT = CHECKPOINT_MAP.get((MODEL, DATASET))
assert CHECKPOINT is not None, f"Unknown model/dataset combination: {MODEL}/{DATASET}"

import os
if not os.path.exists(CHECKPOINT):
    print(f"⚠️  WARNING: Checkpoint not found: {CHECKPOINT}")
    print(f"   Please run training.ipynb first to create baseline checkpoints.")
else:
    print(f"✅ Loading pretrained checkpoint: {CHECKPOINT}")

---
## 2. Run SWS with Custom Hyperparameters

In [None]:
# Build command with user-specified parameters
cmd = f"""python run_sws.py \\
    --preset {MODEL} \\
    --load-pretrained {CHECKPOINT} \\
    --pretrain-epochs 0 \\
    --retrain-epochs 100 \\
    --tau {TAU} \\
    --gamma-alpha {GAMMA_ALPHA} \\
    --gamma-beta {GAMMA_BETA} \\
    --gamma-alpha-zero {GAMMA_ALPHA_ZERO} \\
    --gamma-beta-zero {GAMMA_BETA_ZERO} \\
    --num-components 17 \\
    --merge-kl-thresh 1e-10 \\
    --quant-assign map \\
    --complexity-mode keras \\
    --tau-warmup-epochs 0 \\
    --quant-skip-last \\
    --batch-size 128 \\
    --num-workers 4 \\
    --eval-every 1 \\
    --log-mixture-every 1 \\
    --make-gif \\
    --gif-fps 3 \\
    --run-name {MODEL}_compression \\
    --save-dir compression_results \\
    --seed 42"""

print("Command to execute:")
print("="*80)
print(cmd)
print("="*80)

In [None]:
# Execute training
!{cmd}

### Set Run Directory for Plotting

In [None]:
RUN_DIR = f"compression_results/{MODEL}_compression"
print(f"Results directory: {RUN_DIR}")

# Verify it exists
if os.path.exists(RUN_DIR):
    print(f"✅ Run directory found")
else:
    print(f"⚠️  Run directory not found. Training may have failed.")

---
## 3. Generate All Diagnostic Plots

### Training Curves (Accuracy, Loss, Compression)

In [None]:
!python scripts/plot_curves.py --run-dir {RUN_DIR}

### Mixture Evolution Over Epochs

In [None]:
!python scripts/plot_mixture_dynamics.py --run-dir {RUN_DIR}

### Weight Movement (Pretrained → Retrained)

In [None]:
!python scripts/plot_weights_scatter.py --run-dir {RUN_DIR} --sample 20000

### Final Mixture Components

In [None]:
!python scripts/plot_mixture.py --run-dir {RUN_DIR} --checkpoint prequant

---
## 4. Display Results

### Training Evolution GIF

In [None]:
from IPython.display import Image, display

gif_path = f"{RUN_DIR}/figures/retraining.gif"
if os.path.exists(gif_path):
    print("Training Evolution Animation:")
    display(Image(filename=gif_path))
    print(f"\n💡 Animation shows weight evolution over 100 epochs")
    print(f"   Weights migrate from pretrained values toward mixture component means")
else:
    print(f"⚠️  GIF not found at {gif_path}")

### Compression Summary

In [None]:
import json

# Load summary metrics
summary_file = f"{RUN_DIR}/summary_paper_metrics.json"
if os.path.exists(summary_file):
    with open(summary_file) as f:
        summary = json.load(f)
    
    print("\n" + "="*80)
    print("FINAL RESULTS - COMPRESSION WITH CUSTOM HYPERPARAMETERS")
    print("="*80)
    print(f"\n📊 ACCURACY METRICS:")
    print(f"  Pretrain accuracy:    {summary['acc_pretrain']:.4f} ({summary['acc_pretrain']*100:.2f}%)")
    print(f"  Retrain accuracy:     {summary['acc_retrain']:.4f} ({summary['acc_retrain']*100:.2f}%)")
    print(f"  Quantized accuracy:   {summary['acc_quantized']:.4f} ({summary['acc_quantized']*100:.2f}%)")
    print(f"  Total accuracy drop:  {summary['Delta[%]']:.2f}%")
    
    print(f"\n💾 COMPRESSION METRICS:")
    print(f"  Compression Ratio:    {summary['CR']:.2f}x")
    print(f"  Total parameters:     {int(summary['|W|']):,}")
    print(f"  Non-zero params:      {summary['|W_nonzero|/|W|[%]']:.2f}%")
    print(f"  Sparsity:             {100 - summary['|W_nonzero|/|W|[%]']:.2f}%")
    
    print(f"\n🔧 HYPERPARAMETERS USED:")
    print(f"  tau:                  {TAU:.6g}")
    print(f"  gamma_alpha:          {GAMMA_ALPHA:.6g}")
    print(f"  gamma_beta:           {GAMMA_BETA:.6g}")
    print(f"  gamma_alpha_zero:     {GAMMA_ALPHA_ZERO:.6g}")
    print(f"  gamma_beta_zero:      {GAMMA_BETA_ZERO:.6g}")
    print("="*80)
else:
    print(f"⚠️  Summary file not found: {summary_file}")

### Detailed Compression Report

In [None]:
report_file = f"{RUN_DIR}/report.json"
if os.path.exists(report_file):
    with open(report_file) as f:
        report = json.load(f)
    
    print("\n📋 DETAILED COMPRESSION REPORT")
    print("="*80)
    print(f"Original bits:        {report['orig_bits']:,}")
    print(f"Compressed bits:      {report['compressed_bits']:,}")
    print(f"Compression Ratio:    {report['CR']:.2f}x")
    print(f"Non-zero weights:     {report['nnz']:,}")
    print(f"Dataset:              {report.get('dataset', 'N/A')}")
    print(f"Huffman encoding:     {report.get('use_huffman', False)}")
    print("="*80)
else:
    print(f"⚠️  Report file not found: {report_file}")

### Layer-wise Compression Breakdown

In [None]:
if os.path.exists(report_file):
    import numpy as np
    
    print("\n📊 LAYER-WISE COMPRESSION:")
    print(f"{'Layer':<15} {'Shape':<20} {'Original (bits)':<18} {'Compressed (bits)':<20} {'CR':<8} {'Sparsity':<10}")
    print("-" * 100)
    
    for layer_info in report['layers']:
        if layer_info.get('passthrough', False):
            cr_str = "N/A"
            sparsity = 0.0
        else:
            compressed = (layer_info['bits_IR'] + layer_info['bits_IC'] + 
                         layer_info['bits_A'] + layer_info['bits_codebook'])
            cr = layer_info['orig_bits'] / max(compressed, 1)
            cr_str = f"{cr:.2f}x"
            total_weights = np.prod(layer_info['shape'])
            sparsity = 100 * (1 - layer_info['nnz'] / total_weights)
        
        shape_str = 'x'.join(map(str, layer_info['shape']))
        orig_str = f"{layer_info['orig_bits']:,}"
        comp_str = f"{layer_info['bits_IR'] + layer_info['bits_IC'] + layer_info['bits_A'] + layer_info['bits_codebook']:,}"
        
        print(f"{layer_info['layer']:<15} {shape_str:<20} {orig_str:<18} {comp_str:<20} {cr_str:<8} {sparsity:.1f}%")
    print("-" * 100)

---
## 5. List All Generated Files

In [None]:
print(f"\n📁 Generated files in {RUN_DIR}:")
print("="*80)
!ls -lh {RUN_DIR}

print(f"\n📁 Diagnostic plots in {RUN_DIR}/figures:")
print("="*80)
!ls -lh {RUN_DIR}/figures/

---
## Summary

Successfully completed SWS compression with custom hyperparameters!

### Outputs:
- **Models**: 
  - Pretrained: `{RUN_DIR}/{DATASET}_{MODEL}_pre.pt`
  - Pre-quantized: `{RUN_DIR}/{DATASET}_{MODEL}_prequant.pt`
  - Quantized: `{RUN_DIR}/{DATASET}_{MODEL}_quantized.pt`

- **Diagnostics**:
  - Training GIF: `{RUN_DIR}/figures/retraining.gif`
  - Training curves: `{RUN_DIR}/figures/plot_curves.png`
  - Mixture dynamics: `{RUN_DIR}/figures/plot_mixture_dynamics.png`
  - Weight scatter: `{RUN_DIR}/figures/plot_weights_scatter.png`
  - Final mixture: `{RUN_DIR}/figures/plot_mixture_prequant.png`

- **Reports**:
  - Compression report: `{RUN_DIR}/report.json`
  - Summary metrics: `{RUN_DIR}/summary_paper_metrics.json`
  - Layer pruning: `{RUN_DIR}/layer_pruning.json`
  - Training log: `{RUN_DIR}/metrics.csv`

### Next Steps:
- Compare results with `no_hyperpriors.ipynb` baseline
- Try different hyperparameters from the Pareto front
- Apply to other models (LeNet5, WRN-16-4)