In [5]:
import os
import torch
import numpy as np

def diagnose_samples(samples_dir: str):
    """Comprehensive diagnosis of what state the data is actually in."""
    
    print("="*80)
    print("COMPREHENSIVE DATA DIAGNOSIS")
    print("="*80)
    
    name_map = {
        'combined': 'Combined',
        'vae_only': 'VAE only',
        '2d_traits': '2D traits',
        'unconditional': 'Unconditional'
    }
    
    real_data = None
    
    for filename in os.listdir(samples_dir):
        if filename.endswith('_synthetic.pt'):
            model_key = filename.replace('_synthetic.pt', '')
            model_name = name_map.get(model_key, model_key)
            
            print(f"\n{'='*80}")
            print(f"MODEL: {model_name}")
            print(f"{'='*80}")
            
            # Load synthetic
            synth_path = os.path.join(samples_dir, f"{model_key}_synthetic.pt")
            synth = torch.load(synth_path, map_location='cpu').float()
            
            # Load normalization
            norm_path = os.path.join(samples_dir, f"{model_key}_normalization.pt")
            norm_data = torch.load(norm_path, map_location='cpu')
            mu = norm_data['mean'].float()
            sd = norm_data['std'].float()
            
            print(f"\nSynthetic samples:")
            print(f"  Shape: {synth.shape}")
            synth_valid = synth[torch.isfinite(synth)]
            print(f"  Mean: {synth_valid.mean():.4f}")
            print(f"  Std: {synth_valid.std():.4f}")
            print(f"  Min: {synth_valid.min():.4f}")
            print(f"  Max: {synth_valid.max():.4f}")
            
            print(f"\nNormalization parameters:")
            print(f"  mu shape: {mu.shape}")
            print(f"  sd shape: {sd.shape}")
            for i in range(min(10, mu.shape[-1])):
                print(f"    Var {i}: mu={mu.flatten()[i]:.4f}, sd={sd.flatten()[i]:.4f}")
            
            # Test different hypotheses
            print(f"\nHypothesis testing:")
            
            # Hypothesis 1: Data is in standard normalized space (mean=0, std=1)
            if abs(synth_valid.mean()) < 0.5 and 0.5 < synth_valid.std() < 1.5:
                print(f"  ✓ Hypothesis 1: Data IS in standard normalized space")
            else:
                print(f"  ✗ Hypothesis 1: Data is NOT in standard normalized space")
            
            # Hypothesis 2: Data is in model's normalized space (using its mu/sd)
            synth_renorm = (synth - mu.view(1, 1, 1, -1)) / sd.view(1, 1, 1, -1)
            synth_renorm_valid = synth_renorm[torch.isfinite(synth_renorm)]
            print(f"  Hypothesis 2: If we normalize using (x-mu)/sd:")
            print(f"    Result mean: {synth_renorm_valid.mean():.4f}")
            print(f"    Result std: {synth_renorm_valid.std():.4f}")
            if abs(synth_renorm_valid.mean()) < 0.5 and 0.5 < synth_renorm_valid.std() < 1.5:
                print(f"    ✓ This produces standard normalized space")
                print(f"    → Conclusion: Data is ALREADY UNNORMALIZED (in physical space)")
            else:
                print(f"    ✗ This does NOT produce standard normalized space")
            
            # Hypothesis 3: Data is unnormalized (in physical space)
            synth_unnorm = synth * sd.view(1, 1, 1, -1) + mu.view(1, 1, 1, -1)
            synth_unnorm_valid = synth_unnorm[torch.isfinite(synth_unnorm)]
            print(f"  Hypothesis 3: If we unnormalize using x*sd + mu:")
            print(f"    Result mean: {synth_unnorm_valid.mean():.4f}")
            print(f"    Result std: {synth_unnorm_valid.std():.4f}")
            print(f"    Result range: [{synth_unnorm_valid.min():.4f}, {synth_unnorm_valid.max():.4f}]")
            
            # Load real data if available
            if real_data is None and model_key != 'unconditional':
                real_path = os.path.join(samples_dir, f"{model_key}_real.pt")
                if os.path.exists(real_path):
                    real_data = torch.load(real_path, map_location='cpu').float()
                    
                    print(f"\n{'='*80}")
                    print(f"REAL DATA (from {model_name})")
                    print(f"{'='*80}")
                    print(f"  Shape: {real_data.shape}")
                    real_valid = real_data[torch.isfinite(real_data)]
                    print(f"  Mean: {real_valid.mean():.4f}")
                    print(f"  Std: {real_valid.std():.4f}")
                    print(f"  Min: {real_valid.min():.4f}")
                    print(f"  Max: {real_valid.max():.4f}")
                    
                    print(f"\nHypothesis testing for real data:")
                    
                    # Is real data normalized?
                    if abs(real_valid.mean()) < 0.5 and 0.5 < real_valid.std() < 1.5:
                        print(f"  ✓ Real data appears to be in standard normalized space")
                    else:
                        print(f"  ✗ Real data is NOT in standard normalized space")
                    
                    # Try unnormalizing real data
                    real_unnorm = real_data * sd.view(1, 1, -1) + mu.view(1, 1, -1)
                    real_unnorm_valid = real_unnorm[torch.isfinite(real_unnorm)]
                    print(f"  If we unnormalize real using x*sd + mu:")
                    print(f"    Result mean: {real_unnorm_valid.mean():.4f}")
                    print(f"    Result std: {real_unnorm_valid.std():.4f}")
                    print(f"    Result range: [{real_unnorm_valid.min():.4f}, {real_unnorm_valid.max():.4f}]")
                    
                    # Try normalizing real data
                    real_norm = (real_data - mu.view(1, 1, -1)) / sd.view(1, 1, -1)
                    real_norm_valid = real_norm[torch.isfinite(real_norm)]
                    print(f"  If we normalize real using (x-mu)/sd:")
                    print(f"    Result mean: {real_norm_valid.mean():.4f}")
                    print(f"    Result std: {real_norm_valid.std():.4f}")
                    if abs(real_norm_valid.mean()) < 0.5 and 0.5 < real_norm_valid.std() < 1.5:
                        print(f"    ✓ This produces standard normalized space")
                        print(f"    → Conclusion: Real data is ALREADY UNNORMALIZED")
    
    print(f"\n{'='*80}")
    print("FINAL DIAGNOSIS")
    print(f"{'='*80}")
    print("\nBased on the analysis above, the data state is:")
    print("  [Check the hypothesis test results above]")
    print("\nRecommended fixes:")
    print("  1. If synthetic is already unnormalized: Don't unnormalize again")
    print("  2. If real is already unnormalized: Don't unnormalize again")
    print("  3. If both are normalized: Unnormalize both before comparison")
    print("  4. If normalizations differ between models: This is a training data bug")

if __name__ == "__main__":
    diagnose_samples("./synthetic_samples_fixed")

COMPREHENSIVE DATA DIAGNOSIS

MODEL: Combined

Synthetic samples:
  Shape: torch.Size([50, 5, 10, 50])
  Mean: -1.3334
  Std: 1.4504
  Min: -4.5656
  Max: 13.0277

Normalization parameters:
  mu shape: torch.Size([1, 1, 10])
  sd shape: torch.Size([1, 1, 10])
    Var 0: mu=-44.5473, sd=19.6076
    Var 1: mu=264.4723, sd=143.5289
    Var 2: mu=43.6924, sd=14.0836
    Var 3: mu=1.0400, sd=0.8400
    Var 4: mu=0.9334, sd=0.2054
    Var 5: mu=0.0014, sd=0.0004
    Var 6: mu=0.0016, sd=0.0006
    Var 7: mu=0.0085, sd=0.0061
    Var 8: mu=0.0015, sd=0.0005
    Var 9: mu=0.0030, sd=0.0016

Hypothesis testing:
  ✗ Hypothesis 1: Data is NOT in standard normalized space


RuntimeError: The size of tensor a (50) must match the size of tensor b (10) at non-singleton dimension 3

In [3]:
import torch
import numpy as np
import yaml
from main_model import CSDI_PM25


def diagnose_model_output():
    """
    Diagnose what the model is actually outputting during synthesis.
    """
    print("="*80)
    print("MODEL OUTPUT DIAGNOSTIC")
    print("="*80)
    
    # Load one model
    config_path = './config/base_conditional_combined.yaml'
    ckpt_path = './wandb/run-20260119_161726-85mtveaj/files/diffusion-combined_20260119_161727/model_best_val.pth'
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Device: {device}\n")
    
    model = CSDI_PM25(config, device=device)
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint)
    model = model.to(device)
    model.eval()
    
    print("Model loaded successfully\n")
    
    # Load normalization params that were used during training
    target_mean = np.load('target_mean.npy')
    target_std = np.load('target_std.npy')
    
    print("Training normalization parameters:")
    print(f"  Shape: {target_mean.shape}")
    print(f"  Mean (first 5): {target_mean.flatten()[:5]}")
    print(f"  Std (first 5): {target_std.flatten()[:5]}")
    
    # Generate a small batch
    print("\n" + "="*80)
    print("GENERATING TEST SAMPLES")
    print("="*80)
    
    batch_size = 10
    print(f"Generating {batch_size} samples...\n")
    
    with torch.no_grad():
        output = model.synthesize(batch_size=batch_size)
    
    output_np = output.cpu().numpy()
    
    print("Raw model output:")
    print(f"  Shape: {output_np.shape}")
    print(f"  Mean: {output_np.mean():.6f}")
    print(f"  Std: {output_np.std():.6f}")
    print(f"  Min: {output_np.min():.6f}")
    print(f"  Max: {output_np.max():.6f}")
    
    # Check per-variable statistics
    print(f"\nPer-variable statistics:")
    for d in range(min(output_np.shape[2], 5)):  # First 5 variables
        var_data = output_np[:, :, d]
        print(f"  Variable {d}: mean={var_data.mean():.4f}, std={var_data.std():.4f}, range=[{var_data.min():.4f}, {var_data.max():.4f}]")
    
    # Hypothesis 1: Is it normalized? (Should be mean~0, std~1)
    print(f"\n" + "="*80)
    print("HYPOTHESIS 1: Is output in normalized space (mean~0, std~1)?")
    print("="*80)
    
    is_normalized = abs(output_np.mean()) < 0.5 and 0.5 < output_np.std() < 1.5
    
    if is_normalized:
        print("✓ YES - Output appears to be in normalized space")
        print("  This is CORRECT - model should output normalized data")
    else:
        print("✗ NO - Output is NOT in normalized space")
        print(f"  Mean={output_np.mean():.4f} (should be ~0)")
        print(f"  Std={output_np.std():.4f} (should be ~1)")
        print("  This is a BUG - model should output normalized data")
    
    # Hypothesis 2: Try unnormalizing with saved params
    print(f"\n" + "="*80)
    print("HYPOTHESIS 2: What if we unnormalize using training params?")
    print("="*80)
    
    # Convert to tensor for easier manipulation
    output_tensor = torch.from_numpy(output_np)
    target_mean_tensor = torch.from_numpy(target_mean).float().view(1, 1, -1)
    target_std_tensor = torch.from_numpy(target_std).float().view(1, 1, -1)
    
    output_unnorm = output_tensor * target_std_tensor + target_mean_tensor
    output_unnorm_np = output_unnorm.numpy()
    
    print(f"After unnormalizing (x*std + mean):")
    print(f"  Mean: {output_unnorm_np.mean():.4f}")
    print(f"  Std: {output_unnorm_np.std():.4f}")
    print(f"  Range: [{output_unnorm_np.min():.4f}, {output_unnorm_np.max():.4f}]")
    
    # Check if these values are physically reasonable
    print(f"\nPhysical reasonability check:")
    print(f"  WRF_TEMP (var 0): mean={output_unnorm_np[:,:,0].mean():.1f}°C (should be ~-47°C)")
    print(f"  WRF_PRES (var 1): mean={output_unnorm_np[:,:,1].mean():.1f} hPa (should be ~250 hPa)")
    print(f"  WRF_RELH (var 2): mean={output_unnorm_np[:,:,2].mean():.1f}% (should be ~43%)")
    
    # Hypothesis 3: Is the diffusion process broken?
    print(f"\n" + "="*80)
    print("HYPOTHESIS 3: Check diffusion schedule parameters")
    print("="*80)
    
    print(f"Diffusion parameters:")
    print(f"  num_steps: {model.num_steps}")
    print(f"  beta range: [{model.beta.min():.6f}, {model.beta.max():.6f}]")
    print(f"  alpha range: [{model.alpha.min():.6f}, {model.alpha.max():.6f}]")
    print(f"  alpha[0]: {model.alpha[0]:.6f} (should be ~1.0)")
    print(f"  alpha[-1]: {model.alpha[-1]:.6f} (should be ~0.0)")
    
    # Final diagnosis
    print(f"\n" + "="*80)
    print("FINAL DIAGNOSIS")
    print("="*80)
    
    if is_normalized:
        print("\n✓ Model output is CORRECT")
        print("  Your generation script should save this as-is (normalized)")
        print("  Then unnormalize during analysis using the saved params")
    else:
        print("\n✗ Model output is BROKEN")
        print("  The diffusion sampling is not completing properly")
        print("  Possible issues:")
        print("    1. Bug in synthesize() method")
        print("    2. Wrong alpha/beta schedule")
        print("    3. Model wasn't trained properly")
        print("\n  FIX: Check the synthesize() method in main_model.py")
        print("       The final output should have mean~0, std~1")


if __name__ == "__main__":
    diagnose_model_output()

MODEL OUTPUT DIAGNOSTIC
Device: cuda

Model loaded successfully

Training normalization parameters:
  Shape: (1, 1, 10)
  Mean (first 5): [-44.547268   264.4723      43.692436     1.040037     0.93340003]
  Std (first 5): [ 19.607582   143.52885     14.083567     0.84001416   0.20537171]

GENERATING TEST SAMPLES
Generating 10 samples...

Raw model output:
  Shape: (10, 50, 10)
  Mean: -1.890928
  Std: 0.191723
  Min: -2.478394
  Max: -0.586174

Per-variable statistics:
  Variable 0: mean=-1.9094, std=0.1655, range=[-2.3727, -1.1870]
  Variable 1: mean=-1.8998, std=0.1783, range=[-2.4030, -0.9842]
  Variable 2: mean=-1.9301, std=0.1518, range=[-2.4339, -1.1244]
  Variable 3: mean=-1.9377, std=0.1466, range=[-2.3955, -1.2402]
  Variable 4: mean=-1.8898, std=0.1658, range=[-2.3737, -0.9920]

HYPOTHESIS 1: Is output in normalized space (mean~0, std~1)?
✗ NO - Output is NOT in normalized space
  Mean=-1.8909 (should be ~0)
  Std=0.1917 (should be ~1)
  This is a BUG - model should output no

In [1]:
"""
Better normalization check that properly excludes placeholders.
"""

import torch
import numpy as np
from dataset_crystaltraj import get_dataloader

def check_normalization_properly():
    """Check normalization while properly excluding placeholders."""
    
    print("\n" + "="*60)
    print("PROPER NORMALIZATION CHECK")
    print("="*60 + "\n")
    
    config = {
        'model': {
            'target_vars': ['WRF_TEMP', 'WRF_PRES', 'WRF_RELH', 'WRF_PHI', 'WRF_PHIS', 
                           'WRF_QICE', 'WRF_QSNOW', 'WRF_QVAPOR', 'WRF_QCLOUD', 'WRF_QRAIN'],
            'horizon': 50,
            'is_unconditional': False,
            'conditioning_type': 'combined'
        },
        'wandb_run': {
            'config': {
                'batch_size': 32
            }
        }
    }
    
    loaders = get_dataloader(config, batch_size=32, shuffle=False)
    
    # Collect all training data
    all_observed = []
    all_conditioning = []
    
    for batch in loaders['train']:
        all_observed.append(batch['observed_data'])
        all_conditioning.append(batch['conditioning_data'])
    
    all_observed = torch.cat(all_observed, dim=0)
    all_conditioning = torch.cat(all_conditioning, dim=0)
    
    print(f"Collected data shapes:")
    print(f"  Observed: {all_observed.shape}")
    print(f"  Conditioning: {all_conditioning.shape}")
    
    # Check observed data (target variables)
    print("\n" + "-"*60)
    print("OBSERVED DATA (Target Variables)")
    print("-"*60)
    
    placeholder = 0.0
    tolerance = 1e-3
    
    for i, var_name in enumerate(config['model']['target_vars']):
        # Get data for this variable across all samples and timesteps
        var_data = all_observed[:, i, :].flatten()
        
        # Exclude placeholders
        mask = torch.abs(var_data - placeholder) > tolerance
        valid_data = var_data[mask]
        
        if len(valid_data) > 0:
            mean = valid_data.mean().item()
            std = valid_data.std().item()
            
            print(f"{var_name:12s}: mean={mean:7.4f}, std={std:7.4f}", end="")
            
            # Check if normalized (mean ≈ 0, std ≈ 1)
            if abs(mean) < 0.15 and 0.7 < std < 1.3:
                print(" ✓")
            else:
                print(" ⚠️")
                
            print(f"              Valid: {len(valid_data)}/{len(var_data)} ({100*len(valid_data)/len(var_data):.1f}%)")
        else:
            print(f"{var_name:12s}: No valid data")
    
    # Check conditioning data
    print("\n" + "-"*60)
    print("CONDITIONING DATA")
    print("-"*60)
    
    # Check first 5 dimensions
    for i in range(min(10, all_conditioning.shape[1])):
        cond_data = all_conditioning[:, i, :].flatten()
        
        # Exclude placeholders
        mask = torch.abs(cond_data - placeholder) > tolerance
        valid_data = cond_data[mask]
        
        if len(valid_data) > 0:
            mean = valid_data.mean().item()
            std = valid_data.std().item()
            
            if i < 50:
                label = f"VAE_{i}"
            else:
                label = f"2D_{i-50}"
            
            print(f"{label:12s}: mean={mean:7.4f}, std={std:7.4f}", end="")
            
            if abs(mean) < 0.15 and 0.7 < std < 1.3:
                print(" ✓")
            else:
                print(" ⚠️")
    
    # Overall statistics
    print("\n" + "-"*60)
    print("OVERALL STATISTICS")
    print("-"*60)
    
    # Observed data (excluding placeholders)
    obs_flat = all_observed.flatten()
    obs_mask = torch.abs(obs_flat - placeholder) > tolerance
    obs_valid = obs_flat[obs_mask]
    
    if len(obs_valid) > 0:
        obs_mean = obs_valid.mean().item()
        obs_std = obs_valid.std().item()
        obs_pct = 100 * len(obs_valid) / len(obs_flat)
        
        print(f"\nObserved data (excluding placeholders):")
        print(f"  Valid values: {len(obs_valid)}/{len(obs_flat)} ({obs_pct:.1f}%)")
        print(f"  Mean: {obs_mean:.4f}")
        print(f"  Std:  {obs_std:.4f}")
        
        if abs(obs_mean) < 0.15 and 0.7 < obs_std < 1.3:
            print("  ✓ PROPERLY NORMALIZED")
        else:
            print("  ⚠️ May need adjustment")
    
    # Conditioning data (excluding placeholders)
    cond_flat = all_conditioning.flatten()
    cond_mask = torch.abs(cond_flat - placeholder) > tolerance
    cond_valid = cond_flat[cond_mask]
    
    if len(cond_valid) > 0:
        cond_mean = cond_valid.mean().item()
        cond_std = cond_valid.std().item()
        cond_pct = 100 * len(cond_valid) / len(cond_flat)
        
        print(f"\nConditioning data (excluding placeholders):")
        print(f"  Valid values: {len(cond_valid)}/{len(cond_flat)} ({cond_pct:.1f}%)")
        print(f"  Mean: {cond_mean:.4f}")
        print(f"  Std:  {cond_std:.4f}")
        
        if abs(cond_mean) < 0.15 and 0.7 < cond_std < 1.3:
            print("  ✓ PROPERLY NORMALIZED")
        else:
            print("  ⚠️ May need adjustment")
    
    # Load and check saved normalization parameters
    print("\n" + "-"*60)
    print("SAVED NORMALIZATION PARAMETERS")
    print("-"*60)
    
    target_mean = np.load("target_mean.npy")
    target_std = np.load("target_std.npy")
    data_mean = np.load("data_mean.npy")
    data_std = np.load("data_std.npy")
    
    print(f"\nTarget normalization (first 5 dims):")
    for i in range(min(5, target_mean.shape[2])):
        var_name = config['model']['target_vars'][i]
        print(f"  {var_name:12s}: mean={target_mean[0,0,i]:8.3f}, std={target_std[0,0,i]:8.3f}")
    
    print(f"\nConditioning normalization (first 5 dims):")
    for i in range(min(5, data_mean.shape[2])):
        print(f"  Dim {i:2d}: mean={data_mean[0,0,i]:8.3f}, std={data_std[0,0,i]:8.3f}")
    
    print("\n" + "="*60)
    print("NORMALIZATION CHECK COMPLETE")
    print("="*60 + "\n")


if __name__ == "__main__":
    check_normalization_properly()


PROPER NORMALIZATION CHECK


INITIALIZING DATALOADER
Config: ['WRF_TEMP', 'WRF_PRES', 'WRF_RELH', 'WRF_PHI', 'WRF_PHIS', 'WRF_QICE', 'WRF_QSNOW', 'WRF_QVAPOR', 'WRF_QCLOUD', 'WRF_QRAIN'], horizon=50, unconditional=False, type=combined
Batch size: 32

Loading target data...
  Raw shape: (24769, 71, 10), NaN%: 23.23%
  Trimmed to horizon: (24769, 50, 10)
  Placeholders: 6401947 (51.69%)

Creating data splits (before filtering)...
  Initial splits - Train: 19815, Val: 2477, Test: 2477

Loading conditioning data (type: combined)...
  Crystals: 24769

  Loading VAE embeddings to determine valid samples...
  Loading VAE embeddings...
    Matched: 21434/24769 (86.5%)
  Filtering all data to crystals with VAE embeddings...
    Keeping 21434/24769 crystals
    After filtering - Train: 17162, Val: 2138, Test: 2134
  Loading 2D traits...
    Shape: (21434, 14)

  Normalizing VAE embeddings...
  Normalization: all 1071700 values normalized

  Normalizing 2D traits...
  Normalization: all 300076 v

In [5]:
"""
KDE Analysis for Generated Samples

This script:
1. Loads all generated samples (normalized)
2. Unnormalizes them to original scale
3. Computes and plots KDEs for each variable
4. Compares synthetic vs real distributions
5. Provides statistical metrics
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import stats
from scipy.stats import wasserstein_distance, ks_2samp
import warnings
warnings.filterwarnings('ignore')


# Variable names for plotting
VAR_NAMES = [
    'WRF_TEMP', 'WRF_PRES', 'WRF_RELH', 'WRF_PHI', 'WRF_PHIS',
    'WRF_QICE', 'WRF_QSNOW', 'WRF_QVAPOR', 'WRF_QCLOUD', 'WRF_QRAIN'
]

# Nice names for plotting
VAR_NICE_NAMES = {
    'WRF_TEMP': 'Temperature (°C)',
    'WRF_PRES': 'Pressure (hPa)',
    'WRF_RELH': 'Relative Humidity (%)',
    'WRF_PHI': 'Geopotential Height (km)',
    'WRF_PHIS': 'Surface Geopotential (km)',
    'WRF_QICE': 'Ice Mixing Ratio (g/kg)',
    'WRF_QSNOW': 'Snow Mixing Ratio (g/kg)',
    'WRF_QVAPOR': 'Water Vapor (g/kg)',
    'WRF_QCLOUD': 'Cloud Mixing Ratio (g/kg)',
    'WRF_QRAIN': 'Rain Mixing Ratio (g/kg)'
}


def load_model_data(data_dir, model_name):
    """Load synthetic and real data for a model."""
    data_dir = Path(data_dir)
    
    # Load synthetic samples
    synthetic = torch.load(data_dir / f"{model_name}_synthetic.pt").numpy()
    
    # Load normalization parameters
    norm_params = torch.load(data_dir / f"{model_name}_normalization.pt")
    mean = norm_params['mean'].numpy()  # (1, 1, D)
    std = norm_params['std'].numpy()    # (1, 1, D)
    
    # Load real data if available (not for unconditional)
    real = None
    if (data_dir / f"{model_name}_real.pt").exists():
        real = torch.load(data_dir / f"{model_name}_real.pt").numpy()
    
    return synthetic, real, mean, std


def unnormalize(data, mean, std):
    """
    Unnormalize data.
    
    Args:
        data: Normalized data
              - Shape: (K, S, D, T) for synthetic conditional (from saved .pt files)
              - Shape: (K, T, D) for real (from saved .pt files)
        mean: (1, 1, D)
        std: (1, 1, D)
    
    Returns:
        Unnormalized data in original scale
    """
    # Check the actual shape
    print(f"    DEBUG: data shape={data.shape}, mean shape={mean.shape}, std shape={std.shape}")
    
    if data.ndim == 4:  # (K, S, D, T)
        # D is the 3rd dimension (index 2)
        K, S, D, T = data.shape
        # Reshape mean/std to (1, 1, D, 1) for broadcasting
        mean_broadcast = mean.reshape(1, 1, -1, 1)
        std_broadcast = std.reshape(1, 1, -1, 1)
    elif data.ndim == 3:  # (K, T, D)
        # D is the last dimension
        K, T, D = data.shape
        # Reshape to (1, 1, D) for broadcasting
        mean_broadcast = mean.reshape(1, 1, -1)
        std_broadcast = std.reshape(1, 1, -1)
    else:
        raise ValueError(f"Unexpected data shape: {data.shape}")
    
    print(f"    DEBUG: Broadcasting ({data.shape}) * ({std_broadcast.shape}) + ({mean_broadcast.shape})")
    return data * std_broadcast + mean_broadcast


def compute_kde_stats(synthetic, real, var_idx, var_name, bandwidth='scott'):
    """
    Compute KDE and distribution statistics.
    
    Args:
        synthetic: (K, S, T, D) or flattened
        real: (K, T, D) or flattened
        var_idx: Index of variable
        var_name: Name of variable
        
    Returns:
        Dictionary of statistics
    """
    # Flatten and extract variable
    if synthetic.ndim == 4:  # (K, S, T, D)
        synth_var = synthetic[:, :, :, var_idx].flatten()
    else:
        synth_var = synthetic.flatten()
    
    if real is not None:
        if real.ndim == 3:  # (K, T, D)
            real_var = real[:, :, var_idx].flatten()
        else:
            real_var = real.flatten()
    else:
        real_var = None
    
    # Remove any remaining NaNs or infs
    synth_var = synth_var[np.isfinite(synth_var)]
    if real_var is not None:
        real_var = real_var[np.isfinite(real_var)]
    
    stats_dict = {
        'var_name': var_name,
        'synthetic': {
            'mean': np.mean(synth_var),
            'std': np.std(synth_var),
            'median': np.median(synth_var),
            'q25': np.percentile(synth_var, 25),
            'q75': np.percentile(synth_var, 75),
            'min': np.min(synth_var),
            'max': np.max(synth_var),
            'n_samples': len(synth_var)
        }
    }
    
    if real_var is not None:
        stats_dict['real'] = {
            'mean': np.mean(real_var),
            'std': np.std(real_var),
            'median': np.median(real_var),
            'q25': np.percentile(real_var, 25),
            'q75': np.percentile(real_var, 75),
            'min': np.min(real_var),
            'max': np.max(real_var),
            'n_samples': len(real_var)
        }
        
        # Compute distribution distance metrics
        try:
            # Wasserstein distance (Earth Mover's Distance)
            wd = wasserstein_distance(synth_var, real_var)
            stats_dict['wasserstein_distance'] = wd
            
            # Kolmogorov-Smirnov test
            ks_stat, ks_pval = ks_2samp(synth_var, real_var)
            stats_dict['ks_statistic'] = ks_stat
            stats_dict['ks_pvalue'] = ks_pval
        except Exception as e:
            print(f"    Warning: Could not compute distance metrics for {var_name}: {e}")
    
    return stats_dict, synth_var, real_var


def plot_kde_comparison(models_data, var_idx, var_name, output_dir):
    """
    Plot KDE comparison for one variable across all models.
    
    Args:
        models_data: Dict mapping model_name -> (synthetic_unnorm, real_unnorm)
        var_idx: Variable index
        var_name: Variable name
        output_dir: Where to save plot
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    model_names = ['combined', 'vae_only', '2d_traits', 'unconditional']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    for idx, (model_name, ax) in enumerate(zip(model_names, axes)):
        if model_name not in models_data:
            ax.axis('off')
            continue
        
        synthetic_unnorm, real_unnorm = models_data[model_name]
        
        # Extract variable and flatten
        if synthetic_unnorm.ndim == 4:  # (K, S, T, D)
            synth_var = synthetic_unnorm[:, :, :, var_idx].flatten()
        else:
            synth_var = synthetic_unnorm[:, :, var_idx].flatten()
        
        synth_var = synth_var[np.isfinite(synth_var)]
        
        # Plot synthetic KDE
        try:
            kde_synth = stats.gaussian_kde(synth_var, bw_method='scott')
            x_range = np.linspace(synth_var.min(), synth_var.max(), 200)
            ax.plot(x_range, kde_synth(x_range), color=colors[idx], 
                   linewidth=2, label='Synthetic', alpha=0.8)
            ax.fill_between(x_range, kde_synth(x_range), alpha=0.3, color=colors[idx])
        except Exception as e:
            print(f"    Warning: Could not plot KDE for {model_name}/{var_name}: {e}")
        
        # Plot real KDE if available
        if real_unnorm is not None:
            if real_unnorm.ndim == 3:  # (K, T, D)
                real_var = real_unnorm[:, :, var_idx].flatten()
            else:
                real_var = real_unnorm[:, :, var_idx].flatten()
            
            real_var = real_var[np.isfinite(real_var)]
            
            try:
                kde_real = stats.gaussian_kde(real_var, bw_method='scott')
                x_range_real = np.linspace(real_var.min(), real_var.max(), 200)
                ax.plot(x_range_real, kde_real(x_range_real), 'k--', 
                       linewidth=2, label='Real', alpha=0.6)
            except Exception as e:
                print(f"    Warning: Could not plot real KDE for {model_name}/{var_name}: {e}")
        
        # Formatting
        ax.set_xlabel(VAR_NICE_NAMES.get(var_name, var_name), fontsize=11)
        ax.set_ylabel('Density', fontsize=11)
        ax.set_title(f'{model_name.replace("_", " ").title()}', fontsize=12, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.suptitle(f'KDE Comparison: {VAR_NICE_NAMES.get(var_name, var_name)}', 
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    output_path = Path(output_dir) / f'kde_{var_name}.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved: {output_path}")


def plot_statistics_comparison(all_stats, output_dir):
    """Plot comparison of statistics across all models."""
    output_dir = Path(output_dir)
    
    model_names = list(all_stats.keys())
    
    # 1. Mean comparison
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for var_idx, var_name in enumerate(VAR_NAMES):
        ax = axes[var_idx]
        
        synth_means = []
        real_means = []
        labels = []
        
        for model_name in model_names:
            if var_name in all_stats[model_name]:
                stats_dict = all_stats[model_name][var_name]
                synth_means.append(stats_dict['synthetic']['mean'])
                labels.append(model_name)
                
                if 'real' in stats_dict:
                    real_means.append(stats_dict['real']['mean'])
        
        x = np.arange(len(labels))
        width = 0.35
        
        ax.bar(x - width/2, synth_means, width, label='Synthetic', alpha=0.8)
        if real_means:
            ax.bar(x + width/2, [real_means[0]]*len(labels), width, 
                  label='Real', alpha=0.6, color='gray')
        
        ax.set_xlabel('Model', fontsize=9)
        ax.set_ylabel('Mean', fontsize=9)
        ax.set_title(VAR_NICE_NAMES.get(var_name, var_name), fontsize=10, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([l.replace('_', '\n') for l in labels], fontsize=8)
        if var_idx == 0:
            ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Mean Values Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_dir / 'statistics_means.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {output_dir / 'statistics_means.png'}")
    
    # 2. Standard deviation comparison
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for var_idx, var_name in enumerate(VAR_NAMES):
        ax = axes[var_idx]
        
        synth_stds = []
        real_stds = []
        labels = []
        
        for model_name in model_names:
            if var_name in all_stats[model_name]:
                stats_dict = all_stats[model_name][var_name]
                synth_stds.append(stats_dict['synthetic']['std'])
                labels.append(model_name)
                
                if 'real' in stats_dict:
                    real_stds.append(stats_dict['real']['std'])
        
        x = np.arange(len(labels))
        width = 0.35
        
        ax.bar(x - width/2, synth_stds, width, label='Synthetic', alpha=0.8)
        if real_stds:
            ax.bar(x + width/2, [real_stds[0]]*len(labels), width, 
                  label='Real', alpha=0.6, color='gray')
        
        ax.set_xlabel('Model', fontsize=9)
        ax.set_ylabel('Std Dev', fontsize=9)
        ax.set_title(VAR_NICE_NAMES.get(var_name, var_name), fontsize=10, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([l.replace('_', '\n') for l in labels], fontsize=8)
        if var_idx == 0:
            ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Standard Deviation Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_dir / 'statistics_stds.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {output_dir / 'statistics_stds.png'}")


def create_summary_table(all_stats, output_dir):
    """Create a text summary table of all statistics."""
    output_dir = Path(output_dir)
    
    with open(output_dir / 'statistics_summary.txt', 'w') as f:
        f.write("="*100 + "\n")
        f.write("DISTRIBUTION STATISTICS SUMMARY\n")
        f.write("="*100 + "\n\n")
        
        for model_name, model_stats in all_stats.items():
            f.write(f"\n{'='*100}\n")
            f.write(f"MODEL: {model_name.upper()}\n")
            f.write(f"{'='*100}\n\n")
            
            for var_name, stats_dict in model_stats.items():
                f.write(f"\n{'-'*100}\n")
                f.write(f"{VAR_NICE_NAMES.get(var_name, var_name)}\n")
                f.write(f"{'-'*100}\n")
                
                # Synthetic stats
                synth = stats_dict['synthetic']
                f.write(f"\nSynthetic:\n")
                f.write(f"  Mean:     {synth['mean']:12.4f}\n")
                f.write(f"  Std:      {synth['std']:12.4f}\n")
                f.write(f"  Median:   {synth['median']:12.4f}\n")
                f.write(f"  Q25-Q75:  [{synth['q25']:10.4f}, {synth['q75']:10.4f}]\n")
                f.write(f"  Range:    [{synth['min']:10.4f}, {synth['max']:10.4f}]\n")
                f.write(f"  N:        {synth['n_samples']:12,}\n")
                
                # Real stats if available
                if 'real' in stats_dict:
                    real = stats_dict['real']
                    f.write(f"\nReal:\n")
                    f.write(f"  Mean:     {real['mean']:12.4f}\n")
                    f.write(f"  Std:      {real['std']:12.4f}\n")
                    f.write(f"  Median:   {real['median']:12.4f}\n")
                    f.write(f"  Q25-Q75:  [{real['q25']:10.4f}, {real['q75']:10.4f}]\n")
                    f.write(f"  Range:    [{real['min']:10.4f}, {real['max']:10.4f}]\n")
                    f.write(f"  N:        {real['n_samples']:12,}\n")
                    
                    # Distance metrics
                    if 'wasserstein_distance' in stats_dict:
                        f.write(f"\nDistribution Metrics:\n")
                        f.write(f"  Wasserstein Distance: {stats_dict['wasserstein_distance']:12.6f}\n")
                        f.write(f"  KS Statistic:         {stats_dict['ks_statistic']:12.6f}\n")
                        f.write(f"  KS p-value:           {stats_dict['ks_pvalue']:12.6e}\n")
                        
                        # Interpretation
                        if stats_dict['ks_pvalue'] > 0.05:
                            f.write(f"  → Distributions are similar (p > 0.05) ✓\n")
                        else:
                            f.write(f"  → Distributions are different (p ≤ 0.05) ⚠️\n")
    
    print(f"  Saved: {output_dir / 'statistics_summary.txt'}")


def analyze_all_models(data_dir='./synthetic_samples_test', output_dir='./kde_analysis'):
    """Main analysis function."""
    data_dir = Path(data_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*80)
    print("KDE ANALYSIS OF GENERATED SAMPLES")
    print("="*80)
    
    model_names = ['combined', 'vae_only', '2d_traits', 'unconditional']
    
    # Load all data
    print("\nLoading data...")
    models_data_norm = {}
    models_data_unnorm = {}
    
    for model_name in model_names:
        print(f"  Loading {model_name}...")
        synthetic_norm, real_norm, mean, std = load_model_data(data_dir, model_name)
        
        print(f"    Synthetic shape: {synthetic_norm.shape}")
        if real_norm is not None:
            print(f"    Real shape: {real_norm.shape}")
        
        # Unnormalize
        synthetic_unnorm = unnormalize(synthetic_norm, mean, std)
        real_unnorm = unnormalize(real_norm, mean, std) if real_norm is not None else None
        
        models_data_norm[model_name] = (synthetic_norm, real_norm)
        models_data_unnorm[model_name] = (synthetic_unnorm, real_unnorm)
        
        print(f"    Unnormalized synthetic range: [{np.nanmin(synthetic_unnorm):.2f}, {np.nanmax(synthetic_unnorm):.2f}]")
    
    # Compute statistics for all models and variables
    print("\nComputing statistics...")
    all_stats = {}
    
    for model_name in model_names:
        print(f"  {model_name}...")
        synthetic_unnorm, real_unnorm = models_data_unnorm[model_name]
        
        model_stats = {}
        for var_idx, var_name in enumerate(VAR_NAMES):
            stats_dict, synth_var, real_var = compute_kde_stats(
                synthetic_unnorm, real_unnorm, var_idx, var_name
            )
            model_stats[var_name] = stats_dict
        
        all_stats[model_name] = model_stats
    
    # Generate plots
    print("\nGenerating KDE plots...")
    for var_idx, var_name in enumerate(VAR_NAMES):
        print(f"  {var_name}...")
        plot_kde_comparison(models_data_unnorm, var_idx, var_name, output_dir)
    
    # Generate comparison plots
    print("\nGenerating comparison plots...")
    plot_statistics_comparison(all_stats, output_dir)
    
    # Create summary table
    print("\nCreating summary table...")
    create_summary_table(all_stats, output_dir)
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETE")
    print("="*80)
    print(f"\nResults saved to: {output_dir}/")
    print("\nGenerated files:")
    print(f"  - kde_*.png: KDE plots for each variable")
    print(f"  - statistics_means.png: Mean comparison across models")
    print(f"  - statistics_stds.png: Std deviation comparison")
    print(f"  - statistics_summary.txt: Detailed statistics table")


if __name__ == "__main__":
    analyze_all_models(
        data_dir='./synthetic_samples_test',
        output_dir='./kde_analysis'
    )


KDE ANALYSIS OF GENERATED SAMPLES

Loading data...
  Loading combined...
    Synthetic shape: (50, 5, 10, 50)
    Real shape: (50, 50, 10)
    DEBUG: data shape=(50, 5, 10, 50), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 5, 10, 50)) * ((1, 1, 10, 1)) + ((1, 1, 10, 1))
    DEBUG: data shape=(50, 50, 10), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 50, 10)) * ((1, 1, 10)) + ((1, 1, 10))
    Unnormalized synthetic range: [-68.54, 1004.39]
  Loading vae_only...
    Synthetic shape: (50, 5, 10, 50)
    Real shape: (50, 50, 10)
    DEBUG: data shape=(50, 5, 10, 50), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 5, 10, 50)) * ((1, 1, 10, 1)) + ((1, 1, 10, 1))
    DEBUG: data shape=(50, 50, 10), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 50, 10)) * ((1, 1, 10)) + ((1, 1, 10))
    Unnormalized synthetic range: [-69.98, 994.06]
  Loading 2d_traits...
    Synthetic shape: (50, 

In [6]:
"""
KDE Analysis for Generated Samples

This script:
1. Loads all generated samples (normalized)
2. Unnormalizes them to original scale
3. Computes and plots KDEs for each variable
4. Compares synthetic vs real distributions
5. Provides statistical metrics
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import stats
from scipy.stats import wasserstein_distance, ks_2samp
import warnings
warnings.filterwarnings('ignore')


# Variable names for plotting
VAR_NAMES = [
    'WRF_TEMP', 'WRF_PRES', 'WRF_RELH', 'WRF_PHI', 'WRF_PHIS',
    'WRF_QICE', 'WRF_QSNOW', 'WRF_QVAPOR', 'WRF_QCLOUD', 'WRF_QRAIN'
]

# Nice names for plotting
VAR_NICE_NAMES = {
    'WRF_TEMP': 'Temperature (°C)',
    'WRF_PRES': 'Pressure (hPa)',
    'WRF_RELH': 'Relative Humidity (%)',
    'WRF_PHI': 'Geopotential Height (km)',
    'WRF_PHIS': 'Surface Geopotential (km)',
    'WRF_QICE': 'Ice Mixing Ratio (g/kg)',
    'WRF_QSNOW': 'Snow Mixing Ratio (g/kg)',
    'WRF_QVAPOR': 'Water Vapor (g/kg)',
    'WRF_QCLOUD': 'Cloud Mixing Ratio (g/kg)',
    'WRF_QRAIN': 'Rain Mixing Ratio (g/kg)'
}


def load_model_data(data_dir, model_name):
    """Load synthetic and real data for a model."""
    data_dir = Path(data_dir)
    
    # Load synthetic samples
    synthetic = torch.load(data_dir / f"{model_name}_synthetic.pt").numpy()
    
    # Load normalization parameters
    norm_params = torch.load(data_dir / f"{model_name}_normalization.pt")
    mean = norm_params['mean'].numpy()  # (1, 1, D)
    std = norm_params['std'].numpy()    # (1, 1, D)
    
    # Load real data if available (not for unconditional)
    real = None
    if (data_dir / f"{model_name}_real.pt").exists():
        real = torch.load(data_dir / f"{model_name}_real.pt").numpy()
    
    return synthetic, real, mean, std


def unnormalize(data, mean, std):
    """
    Unnormalize data.
    
    Args:
        data: Normalized data
              - Shape: (K, S, D, T) for synthetic conditional (from saved .pt files)
              - Shape: (K, T, D) for real (from saved .pt files)
        mean: (1, 1, D)
        std: (1, 1, D)
    
    Returns:
        Unnormalized data in original scale
    """
    # Check the actual shape
    print(f"    DEBUG: data shape={data.shape}, mean shape={mean.shape}, std shape={std.shape}")
    
    if data.ndim == 4:  # (K, S, D, T)
        # D is the 3rd dimension (index 2)
        K, S, D, T = data.shape
        # Reshape mean/std to (1, 1, D, 1) for broadcasting
        mean_broadcast = mean.reshape(1, 1, -1, 1)
        std_broadcast = std.reshape(1, 1, -1, 1)
    elif data.ndim == 3:  # (K, T, D)
        # D is the last dimension
        K, T, D = data.shape
        # Reshape to (1, 1, D) for broadcasting
        mean_broadcast = mean.reshape(1, 1, -1)
        std_broadcast = std.reshape(1, 1, -1)
    else:
        raise ValueError(f"Unexpected data shape: {data.shape}")
    
    print(f"    DEBUG: Broadcasting ({data.shape}) * ({std_broadcast.shape}) + ({mean_broadcast.shape})")
    return data * std_broadcast + mean_broadcast


def compute_kde_stats(synthetic, real, var_idx, var_name, bandwidth='scott'):
    """
    Compute KDE and distribution statistics.
    
    Args:
        synthetic: (K, S, T, D) or flattened
        real: (K, T, D) or flattened
        var_idx: Index of variable
        var_name: Name of variable
        
    Returns:
        Dictionary of statistics
    """
    # Flatten and extract variable
    # Synthetic shape: (K, S, D, T) where D is variables
    # Real shape: (K, T, D) where D is variables
    if synthetic.ndim == 4:  # (K, S, D, T)
        synth_var = synthetic[:, :, var_idx, :].flatten()
    else:
        synth_var = synthetic.flatten()
    
    if real is not None:
        if real.ndim == 3:  # (K, T, D)
            real_var = real[:, :, var_idx].flatten()
        else:
            real_var = real.flatten()
    else:
        real_var = None
    
    # Remove any remaining NaNs or infs
    synth_var = synth_var[np.isfinite(synth_var)]
    if real_var is not None:
        real_var = real_var[np.isfinite(real_var)]
    
    stats_dict = {
        'var_name': var_name,
        'synthetic': {
            'mean': np.mean(synth_var),
            'std': np.std(synth_var),
            'median': np.median(synth_var),
            'q25': np.percentile(synth_var, 25),
            'q75': np.percentile(synth_var, 75),
            'min': np.min(synth_var),
            'max': np.max(synth_var),
            'n_samples': len(synth_var)
        }
    }
    
    if real_var is not None:
        stats_dict['real'] = {
            'mean': np.mean(real_var),
            'std': np.std(real_var),
            'median': np.median(real_var),
            'q25': np.percentile(real_var, 25),
            'q75': np.percentile(real_var, 75),
            'min': np.min(real_var),
            'max': np.max(real_var),
            'n_samples': len(real_var)
        }
        
        # Compute distribution distance metrics
        try:
            # Wasserstein distance (Earth Mover's Distance)
            wd = wasserstein_distance(synth_var, real_var)
            stats_dict['wasserstein_distance'] = wd
            
            # Kolmogorov-Smirnov test
            ks_stat, ks_pval = ks_2samp(synth_var, real_var)
            stats_dict['ks_statistic'] = ks_stat
            stats_dict['ks_pvalue'] = ks_pval
        except Exception as e:
            print(f"    Warning: Could not compute distance metrics for {var_name}: {e}")
    
    return stats_dict, synth_var, real_var


def plot_kde_comparison(models_data, var_idx, var_name, output_dir):
    """
    Plot KDE comparison for one variable across all models.
    
    Args:
        models_data: Dict mapping model_name -> (synthetic_unnorm, real_unnorm)
        var_idx: Variable index
        var_name: Variable name
        output_dir: Where to save plot
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    model_names = ['combined', 'vae_only', '2d_traits', 'unconditional']
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
    
    for idx, (model_name, ax) in enumerate(zip(model_names, axes)):
        if model_name not in models_data:
            ax.axis('off')
            continue
        
        synthetic_unnorm, real_unnorm = models_data[model_name]
        
        # Extract variable and flatten
        # Data shape is (K, S, D, T) where D is variables, T is time
        if synthetic_unnorm.ndim == 4:  # (K, S, D, T)
            synth_var = synthetic_unnorm[:, :, var_idx, :].flatten()
        else:
            synth_var = synthetic_unnorm[:, :, var_idx].flatten()
        
        synth_var = synth_var[np.isfinite(synth_var)]
        
        # Plot synthetic KDE
        try:
            kde_synth = stats.gaussian_kde(synth_var, bw_method='scott')
            x_range = np.linspace(synth_var.min(), synth_var.max(), 200)
            ax.plot(x_range, kde_synth(x_range), color=colors[idx], 
                   linewidth=2, label='Synthetic', alpha=0.8)
            ax.fill_between(x_range, kde_synth(x_range), alpha=0.3, color=colors[idx])
        except Exception as e:
            print(f"    Warning: Could not plot KDE for {model_name}/{var_name}: {e}")
        
        # Plot real KDE if available
        if real_unnorm is not None:
            # Real shape is (K, T, D) where D is variables, T is time
            if real_unnorm.ndim == 3:  # (K, T, D)
                real_var = real_unnorm[:, :, var_idx].flatten()
            else:
                real_var = real_unnorm[:, :, var_idx].flatten()
            
            real_var = real_var[np.isfinite(real_var)]
            
            try:
                kde_real = stats.gaussian_kde(real_var, bw_method='scott')
                x_range_real = np.linspace(real_var.min(), real_var.max(), 200)
                ax.plot(x_range_real, kde_real(x_range_real), 'k--', 
                       linewidth=2, label='Real', alpha=0.6)
            except Exception as e:
                print(f"    Warning: Could not plot real KDE for {model_name}/{var_name}: {e}")
        
        # Formatting
        ax.set_xlabel(VAR_NICE_NAMES.get(var_name, var_name), fontsize=11)
        ax.set_ylabel('Density', fontsize=11)
        ax.set_title(f'{model_name.replace("_", " ").title()}', fontsize=12, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
    plt.suptitle(f'KDE Comparison: {VAR_NICE_NAMES.get(var_name, var_name)}', 
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    output_path = Path(output_dir) / f'kde_{var_name}.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved: {output_path}")


def plot_statistics_comparison(all_stats, output_dir):
    """Plot comparison of statistics across all models."""
    output_dir = Path(output_dir)
    
    model_names = list(all_stats.keys())
    
    # 1. Mean comparison
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for var_idx, var_name in enumerate(VAR_NAMES):
        ax = axes[var_idx]
        
        synth_means = []
        real_means = []
        labels = []
        
        for model_name in model_names:
            if var_name in all_stats[model_name]:
                stats_dict = all_stats[model_name][var_name]
                synth_means.append(stats_dict['synthetic']['mean'])
                labels.append(model_name)
                
                if 'real' in stats_dict:
                    real_means.append(stats_dict['real']['mean'])
        
        x = np.arange(len(labels))
        width = 0.35
        
        ax.bar(x - width/2, synth_means, width, label='Synthetic', alpha=0.8)
        if real_means:
            ax.bar(x + width/2, [real_means[0]]*len(labels), width, 
                  label='Real', alpha=0.6, color='gray')
        
        ax.set_xlabel('Model', fontsize=9)
        ax.set_ylabel('Mean', fontsize=9)
        ax.set_title(VAR_NICE_NAMES.get(var_name, var_name), fontsize=10, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([l.replace('_', '\n') for l in labels], fontsize=8)
        if var_idx == 0:
            ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Mean Values Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_dir / 'statistics_means.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {output_dir / 'statistics_means.png'}")
    
    # 2. Standard deviation comparison
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))
    axes = axes.flatten()
    
    for var_idx, var_name in enumerate(VAR_NAMES):
        ax = axes[var_idx]
        
        synth_stds = []
        real_stds = []
        labels = []
        
        for model_name in model_names:
            if var_name in all_stats[model_name]:
                stats_dict = all_stats[model_name][var_name]
                synth_stds.append(stats_dict['synthetic']['std'])
                labels.append(model_name)
                
                if 'real' in stats_dict:
                    real_stds.append(stats_dict['real']['std'])
        
        x = np.arange(len(labels))
        width = 0.35
        
        ax.bar(x - width/2, synth_stds, width, label='Synthetic', alpha=0.8)
        if real_stds:
            ax.bar(x + width/2, [real_stds[0]]*len(labels), width, 
                  label='Real', alpha=0.6, color='gray')
        
        ax.set_xlabel('Model', fontsize=9)
        ax.set_ylabel('Std Dev', fontsize=9)
        ax.set_title(VAR_NICE_NAMES.get(var_name, var_name), fontsize=10, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels([l.replace('_', '\n') for l in labels], fontsize=8)
        if var_idx == 0:
            ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Standard Deviation Comparison', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_dir / 'statistics_stds.png', dpi=150, bbox_inches='tight')
    plt.close()
    print(f"  Saved: {output_dir / 'statistics_stds.png'}")


def create_summary_table(all_stats, output_dir):
    """Create a text summary table of all statistics."""
    output_dir = Path(output_dir)
    
    with open(output_dir / 'statistics_summary.txt', 'w') as f:
        f.write("="*100 + "\n")
        f.write("DISTRIBUTION STATISTICS SUMMARY\n")
        f.write("="*100 + "\n\n")
        
        for model_name, model_stats in all_stats.items():
            f.write(f"\n{'='*100}\n")
            f.write(f"MODEL: {model_name.upper()}\n")
            f.write(f"{'='*100}\n\n")
            
            for var_name, stats_dict in model_stats.items():
                f.write(f"\n{'-'*100}\n")
                f.write(f"{VAR_NICE_NAMES.get(var_name, var_name)}\n")
                f.write(f"{'-'*100}\n")
                
                # Synthetic stats
                synth = stats_dict['synthetic']
                f.write(f"\nSynthetic:\n")
                f.write(f"  Mean:     {synth['mean']:12.4f}\n")
                f.write(f"  Std:      {synth['std']:12.4f}\n")
                f.write(f"  Median:   {synth['median']:12.4f}\n")
                f.write(f"  Q25-Q75:  [{synth['q25']:10.4f}, {synth['q75']:10.4f}]\n")
                f.write(f"  Range:    [{synth['min']:10.4f}, {synth['max']:10.4f}]\n")
                f.write(f"  N:        {synth['n_samples']:12,}\n")
                
                # Real stats if available
                if 'real' in stats_dict:
                    real = stats_dict['real']
                    f.write(f"\nReal:\n")
                    f.write(f"  Mean:     {real['mean']:12.4f}\n")
                    f.write(f"  Std:      {real['std']:12.4f}\n")
                    f.write(f"  Median:   {real['median']:12.4f}\n")
                    f.write(f"  Q25-Q75:  [{real['q25']:10.4f}, {real['q75']:10.4f}]\n")
                    f.write(f"  Range:    [{real['min']:10.4f}, {real['max']:10.4f}]\n")
                    f.write(f"  N:        {real['n_samples']:12,}\n")
                    
                    # Distance metrics
                    if 'wasserstein_distance' in stats_dict:
                        f.write(f"\nDistribution Metrics:\n")
                        f.write(f"  Wasserstein Distance: {stats_dict['wasserstein_distance']:12.6f}\n")
                        f.write(f"  KS Statistic:         {stats_dict['ks_statistic']:12.6f}\n")
                        f.write(f"  KS p-value:           {stats_dict['ks_pvalue']:12.6e}\n")
                        
                        # Interpretation
                        if stats_dict['ks_pvalue'] > 0.05:
                            f.write(f"  → Distributions are similar (p > 0.05) ✓\n")
                        else:
                            f.write(f"  → Distributions are different (p ≤ 0.05) ⚠️\n")
    
    print(f"  Saved: {output_dir / 'statistics_summary.txt'}")


def analyze_all_models(data_dir='./synthetic_samples_test', output_dir='./kde_analysis'):
    """Main analysis function."""
    data_dir = Path(data_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("\n" + "="*80)
    print("KDE ANALYSIS OF GENERATED SAMPLES")
    print("="*80)
    
    model_names = ['combined', 'vae_only', '2d_traits', 'unconditional']
    
    # Load all data
    print("\nLoading data...")
    models_data_norm = {}
    models_data_unnorm = {}
    
    for model_name in model_names:
        print(f"  Loading {model_name}...")
        synthetic_norm, real_norm, mean, std = load_model_data(data_dir, model_name)
        
        print(f"    Synthetic shape: {synthetic_norm.shape}")
        if real_norm is not None:
            print(f"    Real shape: {real_norm.shape}")
        
        # Unnormalize
        synthetic_unnorm = unnormalize(synthetic_norm, mean, std)
        real_unnorm = unnormalize(real_norm, mean, std) if real_norm is not None else None
        
        models_data_norm[model_name] = (synthetic_norm, real_norm)
        models_data_unnorm[model_name] = (synthetic_unnorm, real_unnorm)
        
        print(f"    Unnormalized synthetic range: [{np.nanmin(synthetic_unnorm):.2f}, {np.nanmax(synthetic_unnorm):.2f}]")
    
    # Compute statistics for all models and variables
    print("\nComputing statistics...")
    all_stats = {}
    
    for model_name in model_names:
        print(f"  {model_name}...")
        synthetic_unnorm, real_unnorm = models_data_unnorm[model_name]
        
        model_stats = {}
        for var_idx, var_name in enumerate(VAR_NAMES):
            stats_dict, synth_var, real_var = compute_kde_stats(
                synthetic_unnorm, real_unnorm, var_idx, var_name
            )
            model_stats[var_name] = stats_dict
        
        all_stats[model_name] = model_stats
    
    # Generate plots
    print("\nGenerating KDE plots...")
    for var_idx, var_name in enumerate(VAR_NAMES):
        print(f"  {var_name}...")
        plot_kde_comparison(models_data_unnorm, var_idx, var_name, output_dir)
    
    # Generate comparison plots
    print("\nGenerating comparison plots...")
    plot_statistics_comparison(all_stats, output_dir)
    
    # Create summary table
    print("\nCreating summary table...")
    create_summary_table(all_stats, output_dir)
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETE")
    print("="*80)
    print(f"\nResults saved to: {output_dir}/")
    print("\nGenerated files:")
    print(f"  - kde_*.png: KDE plots for each variable")
    print(f"  - statistics_means.png: Mean comparison across models")
    print(f"  - statistics_stds.png: Std deviation comparison")
    print(f"  - statistics_summary.txt: Detailed statistics table")


if __name__ == "__main__":
    analyze_all_models(
        data_dir='./synthetic_samples_test',
        output_dir='./kde_analysis'
    )


KDE ANALYSIS OF GENERATED SAMPLES

Loading data...
  Loading combined...
    Synthetic shape: (50, 5, 10, 50)
    Real shape: (50, 50, 10)
    DEBUG: data shape=(50, 5, 10, 50), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 5, 10, 50)) * ((1, 1, 10, 1)) + ((1, 1, 10, 1))
    DEBUG: data shape=(50, 50, 10), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 50, 10)) * ((1, 1, 10)) + ((1, 1, 10))
    Unnormalized synthetic range: [-68.54, 1004.39]
  Loading vae_only...
    Synthetic shape: (50, 5, 10, 50)
    Real shape: (50, 50, 10)
    DEBUG: data shape=(50, 5, 10, 50), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 5, 10, 50)) * ((1, 1, 10, 1)) + ((1, 1, 10, 1))
    DEBUG: data shape=(50, 50, 10), mean shape=(1, 1, 10), std shape=(1, 1, 10)
    DEBUG: Broadcasting ((50, 50, 10)) * ((1, 1, 10)) + ((1, 1, 10))
    Unnormalized synthetic range: [-69.98, 994.06]
  Loading 2d_traits...
    Synthetic shape: (50, 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# ---- paste your results dict here ----
results = {
    "crps": {
        "normalized": {
            "Unconditional": {
                "global": {
                    "mean": 0.455703431609233,
                    "variance": 0.3438329177007727
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 0.5096260824362288,
                        "WRF_PRES": 0.4395695622774767,
                        "WRF_RELH": 0.49019994824486923,
                        "WRF_PHI": 0.22738192745169078,
                        "WRF_PHIS": 0.9351946365837736,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 0.1906905688142736,
                        "WRF_PRES": 0.5395076839447306,
                        "WRF_RELH": 0.20997768019665505,
                        "WRF_PHI": 0.3124788846628543,
                        "WRF_PHIS": 8.227896343368474,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            },
            "Combined": {
                "global": {
                    "mean": 0.33918000741170284,
                    "variance": 0.3318952141059212
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 0.2972459089736422,
                        "WRF_PRES": 0.303127531954899,
                        "WRF_RELH": 0.45649683224786713,
                        "WRF_PHI": 0.20656920967678336,
                        "WRF_PHIS": 0.8880345891937982,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 0.20316804931292368,
                        "WRF_PRES": 0.4954027079631665,
                        "WRF_RELH": 0.20458508362613612,
                        "WRF_PHI": 0.29265388750953686,
                        "WRF_PHIS": 8.458615798517402,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            }
        },
        "unnormalized": {
            "Unconditional": {
                "global": {
                    "mean": 24.307392536854348,
                    "variance": 3980.6346510209146
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 10.586785691423838,
                        "WRF_PRES": 62.85041876850196,
                        "WRF_RELH": 7.81457515932265,
                        "WRF_PHI": 0.17148069797649954,
                        "WRF_PHIS": 0.2499163029303237,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 82.29134561445036,
                        "WRF_PRES": 11029.574569532699,
                        "WRF_RELH": 53.362646028294314,
                        "WRF_PHI": 0.177721215058749,
                        "WRF_PHIS": 0.5875894263679595,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            },
            "Combined": {
                "global": {
                    "mean": 16.99427657687965,
                    "variance": 3370.282298559872
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 6.174877708207189,
                        "WRF_PRES": 43.3417005147448,
                        "WRF_RELH": 7.277293313404588,
                        "WRF_PHI": 0.15578473035573392,
                        "WRF_PHIS": 0.23731350964145767,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 87.67592580893309,
                        "WRF_PRES": 10127.902293209856,
                        "WRF_RELH": 51.99219931368879,
                        "WRF_PHI": 0.1664458209263574,
                        "WRF_PHIS": 0.6040660938714525,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            }
        }
    },
    "variance": {
        "normalized": {
            "Unconditional": {
                "global": {
                    "mean": 1.302576270455835,
                    "variance": 66795.67286804893
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 0.8381569862695308,
                        "WRF_PRES": 0.7429375265225866,
                        "WRF_RELH": 0.7683658694691656,
                        "WRF_PHI": 2.868489754062307,
                        "WRF_PHIS": 0.4900988649012914,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 0.14837369097661918,
                        "WRF_PRES": 0.20069064636267162,
                        "WRF_RELH": 0.1225233211961902,
                        "WRF_PHI": 267807.46165457397,
                        "WRF_PHIS": 0.37525484555233135,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            },
            "Combined": {
                "global": {
                    "mean": 0.4181735781956309,
                    "variance": 103.5704875001463
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 0.3525519493908465,
                        "WRF_PRES": 0.48805906916023234,
                        "WRF_RELH": 0.6070444682986504,
                        "WRF_PHI": 0.09186156503102837,
                        "WRF_PHIS": 0.5671214404670987,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 0.08096426058862939,
                        "WRF_PRES": 0.469784958608979,
                        "WRF_RELH": 0.14247130163740468,
                        "WRF_PHI": 4.276453717188019,
                        "WRF_PHIS": 577.6046809265014,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            }
        },
        "unnormalized": {
            "Unconditional": {
                "global": {
                    "mean": 3927.5237278310947,
                    "variance": 63104421.64892162
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 361.7014959104242,
                        "WRF_PRES": 15188.44937254412,
                        "WRF_RELH": 195.26853877345937,
                        "WRF_PHI": 1.6314430155436663,
                        "WRF_PHIS": 0.035000067029186316,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 27631.646969330897,
                        "WRF_PRES": 83878260.70430753,
                        "WRF_RELH": 7913.109628165262,
                        "WRF_PHI": 86628.31861429926,
                        "WRF_PHIS": 0.0019138005175506226,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            },
            "Combined": {
                "global": {
                    "mean": 2123.015678328688,
                    "variance": 56591451.73670605
                },
                "per_variable": {
                    "mean": {
                        "WRF_TEMP": 152.14162689064378,
                        "WRF_PRES": 9977.770940509672,
                        "WRF_RELH": 154.27114998884772,
                        "WRF_PHI": 0.05224592782823873,
                        "WRF_PHIS": 0.040500580294212496,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    },
                    "variance": {
                        "WRF_TEMP": 15077.982161072297,
                        "WRF_PRES": 196345698.95179492,
                        "WRF_RELH": 9201.440327584261,
                        "WRF_PHI": 1.383314687585893,
                        "WRF_PHIS": 2.945785112167575,
                        "WRF_QICE": NaN,
                        "WRF_QSNOW": NaN,
                        "WRF_QVAPOR": NaN,
                        "WRF_QCLOUD": NaN,
                        "WRF_QRAIN": NaN
                    }
                }
            }
        }
    },
    "jsd": {
        "Unconditional": {
            "WRF_TEMP": 0.08576902278743079,
            "WRF_PRES": 0.2034697367766451,
            "WRF_RELH": 0.015190060911810868,
            "WRF_PHI": 0.4299411585190164,
            "WRF_PHIS": 0.5649119930937401,
            "WRF_QICE": NaN,
            "WRF_QSNOW": NaN,
            "WRF_QVAPOR": NaN,
            "WRF_QCLOUD": NaN,
            "WRF_QRAIN": NaN
        },
        "Combined": {
            "WRF_TEMP": 0.03227201271672949,
            "WRF_PRES": 0.13908606747024363,
            "WRF_RELH": 0.006626244447357493,
            "WRF_PHI": 0.031922824254363005,
            "WRF_PHIS": 0.08828215123430269,
            "WRF_QICE": NaN,
            "WRF_QSNOW": NaN,
            "WRF_QVAPOR": NaN,
            "WRF_QCLOUD": NaN,
            "WRF_QRAIN": NaN
        }
    }
}

# Choose which CRPS space to compare: "normalized" or "unnormalized"
CRPS_SPACE = "normalized"

# Variables you want in the plot (edit as needed)
vars_to_plot = ["WRF_TEMP", "WRF_PRES", "WRF_RELH"]

label_map = {
    "WRF_TEMP": "Temperature",
    "WRF_PRES": "Pressure",
    "WRF_RELH": "Relative Humidity",
}

uncond = results["crps"][CRPS_SPACE]["Unconditional"]["per_variable"]["mean"]
cond   = results["crps"][CRPS_SPACE]["Combined"]["per_variable"]["mean"]

labels = []
pct_improvement = []

for v in vars_to_plot:
    u = uncond.get(v, np.nan)
    c = cond.get(v, np.nan)

    # Skip NaNs or degenerate cases
    if not np.isfinite(u) or not np.isfinite(c) or u == 0:
        continue

    pct = (u - c) / u * 100.0
    pct_improvement.append(pct)
    labels.append(label_map.get(v, v))

# Sort bars from largest improvement to smallest
order = np.argsort(pct_improvement)[::-1]
pct_improvement = np.array(pct_improvement)[order]
labels = np.array(labels)[order]

plt.figure(figsize=(8, 4.5))
plt.bar(labels, pct_improvement)
plt.axhline(0, linewidth=1)
plt.ylabel("Percent improvement in CRPS (%)")
plt.title(f"Percent improvement: Conditional (Combined) vs Unconditional\n({CRPS_SPACE.capitalize()} CRPS; higher is better)")
plt.xticks(rotation=20, ha="right")

# Annotate each bar with the numeric value
for i, val in enumerate(pct_improvement):
    plt.text(i, val + (0.8 if val >= 0 else -0.8), f"{val:.1f}%", ha="center",
             va="bottom" if val >= 0 else "top")

plt.tight_layout()
plt.show()
