In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

## Load Results

In [None]:
results_path = '../results'

# Load MNIST results
mnist_sparsity = None
mnist_noise = None

if os.path.exists(os.path.join(results_path, 'MNIST_sparsity_results.json')):
    with open(os.path.join(results_path, 'MNIST_sparsity_results.json'), 'r') as f:
        mnist_sparsity = json.load(f)

if os.path.exists(os.path.join(results_path, 'MNIST_noise_results.json')):
    with open(os.path.join(results_path, 'MNIST_noise_results.json'), 'r') as f:
        mnist_noise = json.load(f)

# Load FashionMNIST results
fashion_sparsity = None
fashion_noise = None

if os.path.exists(os.path.join(results_path, 'FashionMNIST_sparsity_results.json')):
    with open(os.path.join(results_path, 'FashionMNIST_sparsity_results.json'), 'r') as f:
        fashion_sparsity = json.load(f)

if os.path.exists(os.path.join(results_path, 'FashionMNIST_noise_results.json')):
    with open(os.path.join(results_path, 'FashionMNIST_noise_results.json'), 'r') as f:
        fashion_noise = json.load(f)

print("Loaded results:")
if mnist_sparsity:
    print(f"  MNIST Sparsity: {len(mnist_sparsity)} experiments")
if mnist_noise:
    print(f"  MNIST Noise: {len(mnist_noise)} experiments")
if fashion_sparsity:
    print(f"  FashionMNIST Sparsity: {len(fashion_sparsity)} experiments")
if fashion_noise:
    print(f"  FashionMNIST Noise: {len(fashion_noise)} experiments")

## Label Sparsity Analysis

In [None]:
if mnist_sparsity or fashion_sparsity:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    datasets = []
    if mnist_sparsity:
        datasets.append(('MNIST', mnist_sparsity))
    if fashion_sparsity:
        datasets.append(('FashionMNIST', fashion_sparsity))
    
    for dataset_name, results in datasets:
        label_fracs = [r['label_fraction'] * 100 for r in results]
        accuracies = [r['final_test_accuracy'] for r in results]
        beta_scores = [r['disentanglement_metrics']['beta_vae'] for r in results]
        mig_scores = [r['disentanglement_metrics']['mig'] for r in results]
        
        # Accuracy
        axes[0].plot(label_fracs, accuracies, 'o-', label=dataset_name, linewidth=2, markersize=8)
        axes[0].set_xlabel('Label Fraction (%)', fontsize=12)
        axes[0].set_ylabel('Test Accuracy', fontsize=12)
        axes[0].set_title('Classification Accuracy vs Label Sparsity', fontsize=14, fontweight='bold')
        axes[0].set_xscale('log')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend(fontsize=10)
        
        # Beta-VAE
        axes[1].plot(label_fracs, beta_scores, 's-', label=dataset_name, linewidth=2, markersize=8)
        axes[1].set_xlabel('Label Fraction (%)', fontsize=12)
        axes[1].set_ylabel('Beta-VAE Score', fontsize=12)
        axes[1].set_title('Disentanglement (Beta-VAE) vs Label Sparsity', fontsize=14, fontweight='bold')
        axes[1].set_xscale('log')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend(fontsize=10)
        
        # MIG
        axes[2].plot(label_fracs, mig_scores, '^-', label=dataset_name, linewidth=2, markersize=8)
        axes[2].set_xlabel('Label Fraction (%)', fontsize=12)
        axes[2].set_ylabel('MIG Score', fontsize=12)
        axes[2].set_title('Mutual Information Gap vs Label Sparsity', fontsize=14, fontweight='bold')
        axes[2].set_xscale('log')
        axes[2].grid(True, alpha=0.3)
        axes[2].legend(fontsize=10)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_path, 'sparsity_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No sparsity results found. Run the experiment first.")

## Label Noise Analysis

In [None]:
if mnist_noise or fashion_noise:
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    datasets = []
    if mnist_noise:
        datasets.append(('MNIST', mnist_noise))
    if fashion_noise:
        datasets.append(('FashionMNIST', fashion_noise))
    
    for dataset_name, results in datasets:
        corruption_rates = [r['corruption_rate'] * 100 for r in results]
        accuracies = [r['final_test_accuracy'] for r in results]
        beta_scores = [r['disentanglement_metrics']['beta_vae'] for r in results]
        mig_scores = [r['disentanglement_metrics']['mig'] for r in results]
        
        # Accuracy
        axes[0].plot(corruption_rates, accuracies, 'o-', label=dataset_name, linewidth=2, markersize=8)
        axes[0].set_xlabel('Label Corruption Rate (%)', fontsize=12)
        axes[0].set_ylabel('Test Accuracy', fontsize=12)
        axes[0].set_title('Classification Accuracy vs Label Noise', fontsize=14, fontweight='bold')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend(fontsize=10)
        
        # Beta-VAE
        axes[1].plot(corruption_rates, beta_scores, 's-', label=dataset_name, linewidth=2, markersize=8)
        axes[1].set_xlabel('Label Corruption Rate (%)', fontsize=12)
        axes[1].set_ylabel('Beta-VAE Score', fontsize=12)
        axes[1].set_title('Disentanglement (Beta-VAE) vs Label Noise', fontsize=14, fontweight='bold')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend(fontsize=10)
        
        # MIG
        axes[2].plot(corruption_rates, mig_scores, '^-', label=dataset_name, linewidth=2, markersize=8)
        axes[2].set_xlabel('Label Corruption Rate (%)', fontsize=12)
        axes[2].set_ylabel('MIG Score', fontsize=12)
        axes[2].set_title('Mutual Information Gap vs Label Noise', fontsize=14, fontweight='bold')
        axes[2].grid(True, alpha=0.3)
        axes[2].legend(fontsize=10)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_path, 'noise_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No noise results found. Run the experiment first.")

## Summary Statistics

In [None]:
def print_summary(results, experiment_type):
    if not results:
        return
    
    print(f"\n{experiment_type} Summary:")
    print("=" * 80)
    
    for r in results:
        if 'label_fraction' in r:
            print(f"\nLabel Fraction: {r['label_fraction']*100:.2f}%")
        if 'corruption_rate' in r:
            print(f"Corruption Rate: {r['corruption_rate']*100:.2f}%")
        
        print(f"  Accuracy: {r['final_test_accuracy']:.4f}")
        print(f"  ELBO: {r['final_test_elbo']:.4e}")
        print(f"  Beta-VAE: {r['disentanglement_metrics']['beta_vae']:.4f}")
        print(f"  Factor-VAE: {r['disentanglement_metrics']['factor_vae']:.4f}")
        print(f"  MIG: {r['disentanglement_metrics']['mig']:.4f}")

# Print all summaries
print_summary(mnist_sparsity, "MNIST Label Sparsity")
print_summary(mnist_noise, "MNIST Label Noise")
print_summary(fashion_sparsity, "FashionMNIST Label Sparsity")
print_summary(fashion_noise, "FashionMNIST Label Noise")

## Training Curves

In [None]:
# Plot training curves for label sparsity experiments
if mnist_sparsity:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    for r in mnist_sparsity:
        label = f"{r['label_fraction']*100:.1f}%"
        epochs = range(1, len(r['test_accuracies']) + 1)
        
        axes[0].plot(epochs, r['test_accuracies'], label=label, alpha=0.7)
        axes[1].plot(epochs, r['test_elbos'], label=label, alpha=0.7)
    
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Test Accuracy', fontsize=12)
    axes[0].set_title('MNIST: Accuracy During Training', fontsize=14, fontweight='bold')
    axes[0].legend(title='Label %', fontsize=9)
    axes[0].grid(True, alpha=0.3)
    
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Test ELBO', fontsize=12)
    axes[1].set_title('MNIST: ELBO During Training', fontsize=14, fontweight='bold')
    axes[1].legend(title='Label %', fontsize=9)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_path, 'mnist_training_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()

## Key Findings

### Questions to Answer:

1. **At what label fraction does performance collapse?**
   - Examine the sparsity plots above
   - Look for sharp drops in accuracy and disentanglement

2. **How much label noise can the model tolerate?**
   - Check the noise plots
   - Identify corruption rates where metrics significantly degrade

3. **Is disentanglement more robust than accuracy?**
   - Compare the relative changes in accuracy vs Beta-VAE scores
   - Does disentanglement drop faster or slower than accuracy?

4. **How do MNIST and FashionMNIST compare?**
   - Is FashionMNIST more sensitive to label quality?
   - Which dataset maintains better disentanglement under stress?