# Interactive Parameter Space Explorer

This notebook helps you:
1. Visualize the full distribution of CAPE, CIN, and shear in your Sobol samples
2. Interactively adjust filtering criteria to see how many samples pass
3. Identify which parameter bounds need adjustment
4. Understand correlations between parameters

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from pathlib import Path
from ipywidgets import interact, FloatSlider, IntSlider, Checkbox, VBox, HBox
import ipywidgets as widgets

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Imports successful!")

## Configuration

In [None]:
# --- CONFIG ---
REPO_ROOT = Path.cwd().parent if (Path.cwd().name == 'notebooks') else Path.cwd()
EXPERIMENT_DIR = REPO_ROOT / 'outputs' / 'sobol_exp_500'
DIAGNOSTICS_PKL = EXPERIMENT_DIR / 'soundings' / 'diagnostics.pkl'

print(f"Repository root: {REPO_ROOT}")
print(f"Experiment directory: {EXPERIMENT_DIR}")
print(f"Diagnostics file: {DIAGNOSTICS_PKL}")
print(f"\nFile exists: {DIAGNOSTICS_PKL.exists()}")

## Load Data

In [None]:
# Load diagnostics
with open(DIAGNOSTICS_PKL, 'rb') as f:
    diag = pickle.load(f)

# Load parameter values
with open(EXPERIMENT_DIR / 'problem.pkl', 'rb') as f:
    problem = pickle.load(f)
param_values = np.load(EXPERIMENT_DIR / 'param_values.npy')

# Create DataFrame
df = pd.DataFrame({
    'sample_id': diag['sample_id'],
    'MUCAPE': diag['mucape'],
    'MUCIN': diag['mucin'],
    'SBCAPE': diag['sbcape'],
    'SBCIN': diag['sbcin'],
    'SH01': diag['shear_0_1km'],
    'SH03': diag['shear_0_3km'],
    'SH06': diag['shear_0_6km'],
    'PWAT': diag['pwat'],
    'surf_theta': diag['surface_theta'],
    'surf_qv': diag['surface_qv']
})

# Add absolute CIN
df['CINabs'] = np.abs(df['MUCIN'])

# Add input parameters
for i, name in enumerate(problem['names']):
    df[f'param_{name}'] = param_values[df['sample_id'], i]

print(f"Loaded {len(df)} samples")
print(f"\nColumns: {list(df.columns)}")
print(f"\nFirst few rows:")
df.head()

## Summary Statistics

In [None]:
print("="*80)
print("SUMMARY STATISTICS")
print("="*80)

stats_cols = ['MUCAPE', 'CINabs', 'SBCAPE', 'SH01', 'SH03', 'SH06', 'PWAT']

summary = df[stats_cols].describe(percentiles=[.1, .25, .5, .75, .9])
print(summary.round(1))

print("\n" + "="*80)
print("PARAMETER RANGES (from Sobol samples)")
print("="*80)
for i, (name, bounds) in enumerate(zip(problem['names'], problem['bounds'])):
    actual_min = param_values[:, i].min()
    actual_max = param_values[:, i].max()
    print(f"{name:20s}: [{bounds[0]:8.2f}, {bounds[1]:8.2f}] "
          f"(actual: [{actual_min:8.2f}, {actual_max:8.2f}])")

## Distributions of Key Variables

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

plot_vars = [
    ('MUCAPE', 'MUCAPE (J/kg)', 50),
    ('CINabs', '|MUCIN| (J/kg)', 50),
    ('SBCAPE', 'SBCAPE (J/kg)', 50),
    ('SH01', '0-1 km Shear (m/s)', 30),
    ('SH06', '0-6 km Shear (m/s)', 30),
    ('PWAT', 'Precip Water (mm)', 30),
    ('surf_theta', 'Surface θ (K)', 30),
    ('surf_qv', 'Surface qv (g/kg)', 30)
]

for ax, (var, label, bins) in zip(axes, plot_vars):
    data = df[var].dropna()
    
    ax.hist(data, bins=bins, edgecolor='black', alpha=0.7, color='steelblue')
    ax.axvline(data.mean(), color='red', linestyle='--', linewidth=2, 
               label=f'Mean: {data.mean():.1f}')
    ax.axvline(data.median(), color='orange', linestyle='--', linewidth=2,
               label=f'Median: {data.median():.1f}')
    
    ax.set_xlabel(label, fontsize=10)
    ax.set_ylabel('Count', fontsize=10)
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# Print percentiles
print("\nKey Percentiles:")
print("="*80)
for var, label, _ in plot_vars[:6]:
    data = df[var].dropna()
    p10, p50, p90 = np.percentile(data, [10, 50, 90])
    print(f"{label:25s}: 10%={p10:7.1f}, 50%={p50:7.1f}, 90%={p90:7.1f}")

## Interactive Filter Explorer

**Adjust the sliders below to see how filtering criteria affect the number of viable samples.**

In [None]:
def explore_filters(mucape_min, mucin_max, sbcape_min, use_sbcape):
    """
    Interactive function to explore filtering criteria.
    
    mucin_max is the most negative CIN allowed (e.g., -200 means |CIN| <= 200)
    """
    # Apply filters
    mask_cape = df['MUCAPE'] >= mucape_min
    mask_cin = df['MUCIN'] >= mucin_max  # More negative = stronger cap
    mask_sbcape = df['SBCAPE'] >= sbcape_min if use_sbcape else True
    mask_finite = df['MUCAPE'].notna() & df['MUCIN'].notna()
    
    if use_sbcape:
        viable_mask = mask_cape & mask_cin & mask_sbcape & mask_finite
    else:
        viable_mask = mask_cape & mask_cin & mask_finite
    
    n_viable = viable_mask.sum()
    n_total = len(df)
    pct_viable = n_viable / n_total * 100
    
    # Print summary
    print("="*80)
    print(f"FILTERING RESULTS")
    print("="*80)
    print(f"Total samples:   {n_total}")
    print(f"Viable samples:  {n_viable} ({pct_viable:.1f}%)")
    print(f"Filtered out:    {n_total - n_viable} ({100-pct_viable:.1f}%)")
    print("\nCriteria:")
    print(f"  MUCAPE >= {mucape_min} J/kg")
    print(f"  MUCIN >= {mucin_max} J/kg  (i.e., |CIN| <= {-mucin_max} J/kg)")
    if use_sbcape:
        print(f"  SBCAPE >= {sbcape_min} J/kg")
    
    # Failure reasons
    print("\nReason for filtering:")
    print(f"  Insufficient MUCAPE: {(~mask_cape).sum():4d} ({(~mask_cape).sum()/n_total*100:5.1f}%)")
    print(f"  Excessive CIN:       {(~mask_cin).sum():4d} ({(~mask_cin).sum()/n_total*100:5.1f}%)")
    if use_sbcape:
        print(f"  Insufficient SBCAPE: {(~mask_sbcape).sum():4d} ({(~mask_sbcape).sum()/n_total*100:5.1f}%)")
    
    # Plot distributions with thresholds
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # MUCAPE
    axes[0].hist(df['MUCAPE'].dropna(), bins=50, alpha=0.6, label='All', color='gray')
    axes[0].hist(df.loc[viable_mask, 'MUCAPE'], bins=50, alpha=0.8, 
                 label='Viable', color='green')
    axes[0].axvline(mucape_min, color='red', linestyle='--', linewidth=2, 
                    label=f'Threshold: {mucape_min}')
    axes[0].set_xlabel('MUCAPE (J/kg)')
    axes[0].set_ylabel('Count')
    axes[0].legend()
    axes[0].set_title('MUCAPE Distribution')
    axes[0].grid(alpha=0.3)
    
    # |MUCIN|
    axes[1].hist(df['CINabs'].dropna(), bins=50, alpha=0.6, label='All', color='gray')
    axes[1].hist(df.loc[viable_mask, 'CINabs'], bins=50, alpha=0.8,
                 label='Viable', color='green')
    axes[1].axvline(-mucin_max, color='red', linestyle='--', linewidth=2,
                    label=f'Threshold: {-mucin_max}')
    axes[1].set_xlabel('|MUCIN| (J/kg)')
    axes[1].set_ylabel('Count')
    axes[1].legend()
    axes[1].set_title('|MUCIN| Distribution')
    axes[1].grid(alpha=0.3)
    
    # SBCAPE (if used)
    if use_sbcape:
        axes[2].hist(df['SBCAPE'].dropna(), bins=50, alpha=0.6, label='All', color='gray')
        axes[2].hist(df.loc[viable_mask, 'SBCAPE'], bins=50, alpha=0.8,
                     label='Viable', color='green')
        axes[2].axvline(sbcape_min, color='red', linestyle='--', linewidth=2,
                        label=f'Threshold: {sbcape_min}')
        axes[2].set_xlabel('SBCAPE (J/kg)')
        axes[2].set_title('SBCAPE Distribution')
    else:
        axes[2].text(0.5, 0.5, 'SBCAPE filter\nnot used', 
                     ha='center', va='center', fontsize=14, transform=axes[2].transAxes)
        axes[2].set_xticks([])
        axes[2].set_yticks([])
    
    axes[2].set_ylabel('Count')
    axes[2].legend()
    axes[2].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Scatter plot: CAPE vs Shear (colored by viability)
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))
    
    # Plot rejected samples
    ax.scatter(df.loc[~viable_mask, 'MUCAPE'], df.loc[~viable_mask, 'SH06'],
               c='lightgray', alpha=0.5, s=30, label='Filtered out')
    
    # Plot viable samples
    scatter = ax.scatter(df.loc[viable_mask, 'MUCAPE'], df.loc[viable_mask, 'SH06'],
                        c=df.loc[viable_mask, 'CINabs'], cmap='viridis', 
                        s=50, alpha=0.8, edgecolor='black', linewidth=0.5,
                        label='Viable')
    
    ax.axvline(mucape_min, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
    ax.set_xlabel('MUCAPE (J/kg)', fontsize=12)
    ax.set_ylabel('0-6 km Shear (m/s)', fontsize=12)
    ax.set_title('CAPE vs Shear (colored by |CIN|)', fontsize=13, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)
    
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('|MUCIN| (J/kg)', fontsize=11)
    
    plt.tight_layout()
    plt.show()

# Create interactive widgets
interact(
    explore_filters,
    mucape_min=IntSlider(min=0, max=2000, step=100, value=500, 
                         description='Min CAPE:', continuous_update=False),
    mucin_max=IntSlider(min=-500, max=0, step=25, value=-200,
                        description='Max CIN:', continuous_update=False),
    sbcape_min=IntSlider(min=0, max=500, step=50, value=100,
                         description='Min SBCAPE:', continuous_update=False),
    use_sbcape=Checkbox(value=False, description='Use SBCAPE filter')
);

## Parameter Correlations

Understanding which input parameters correlate with CAPE/CIN/shear.

In [None]:
# Select columns for correlation
param_cols = [c for c in df.columns if c.startswith('param_')]
output_cols = ['MUCAPE', 'CINabs', 'SBCAPE', 'SH01', 'SH06']

corr_df = df[param_cols + output_cols].copy()
corr_df.columns = [c.replace('param_', '') for c in corr_df.columns]

# Compute correlation matrix
corr = corr_df.corr()

# Plot
fig, ax = plt.subplots(figsize=(14, 12))
sns.heatmap(corr, annot=True, fmt='.2f', cmap='RdBu_r', center=0, 
            vmin=-1, vmax=1, square=True, ax=ax, cbar_kws={'shrink': 0.8})
ax.set_title('Correlation Matrix: Input Parameters vs Outputs', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Print strongest correlations with CAPE
print("\nStrongest correlations with MUCAPE:")
print("="*60)
cape_corr = corr['MUCAPE'].drop('MUCAPE').sort_values(ascending=False)
for param, val in cape_corr.items():
    print(f"{param:25s}: {val:6.3f}")

## Pairplot: Key Parameters vs Outputs

In [None]:
# Select a subset of key variables
pairplot_cols = ['param_low_level_lapse', 'param_mid_level_lapse', 
                 'param_low_level_rh', 'param_surface_theta',
                 'MUCAPE', 'CINabs', 'SH06']

pairplot_df = df[pairplot_cols].copy()
pairplot_df.columns = [c.replace('param_', '') for c in pairplot_df.columns]

# Sample if too many points
if len(pairplot_df) > 500:
    pairplot_df = pairplot_df.sample(500, random_state=42)
    print(f"Sampled 500 points for pairplot (out of {len(df)})")

g = sns.pairplot(pairplot_df, diag_kind='hist', plot_kws={'alpha': 0.6})
g.fig.suptitle('Pairplot: Key Input Parameters vs Outputs', y=1.01, fontsize=14, fontweight='bold')
plt.show()

## Recommendations

Based on the analysis above, here are some recommendations:

In [None]:
print("="*80)
print("RECOMMENDATIONS FOR PARAMETER BOUNDS")
print("="*80)

# Analyze current CAPE distribution
cape_data = df['MUCAPE'].dropna()
cape_p10, cape_p50, cape_p90 = np.percentile(cape_data, [10, 50, 90])

print(f"\nCurrent MUCAPE distribution:")
print(f"  10th percentile: {cape_p10:.0f} J/kg")
print(f"  50th percentile: {cape_p50:.0f} J/kg")
print(f"  90th percentile: {cape_p90:.0f} J/kg")

if cape_p90 < 2000:
    print("\n⚠️  WARNING: 90% of samples have CAPE < 2000 J/kg")
    print("   Recommendation: Increase lapse rate bounds")
    print("   Suggested ranges:")
    print("     - low_level_lapse: [7.5, 9.5] K/km (current: check your bounds)")
    print("     - mid_level_lapse: [6.0, 8.0] K/km (current: check your bounds)")

# Analyze CIN
cin_data = df['CINabs'].dropna()
cin_p50, cin_p90 = np.percentile(cin_data, [50, 90])

print(f"\nCurrent |MUCIN| distribution:")
print(f"  50th percentile: {cin_p50:.0f} J/kg")
print(f"  90th percentile: {cin_p90:.0f} J/kg")

# Filtering recommendations
print("\n" + "="*80)
print("RECOMMENDED FILTERING CRITERIA")
print("="*80)

# Test different thresholds
thresholds = [(300, -250), (500, -200), (700, -150)]
print("\nViable samples with different criteria:")
for cape_min, cin_max in thresholds:
    mask = (df['MUCAPE'] >= cape_min) & (df['MUCIN'] >= cin_max) & df['MUCAPE'].notna()
    n_viable = mask.sum()
    pct = n_viable / len(df) * 100
    print(f"  CAPE>={cape_min:4d}, |CIN|<={-cin_max:3d}: {n_viable:3d} samples ({pct:5.1f}%)")

print("\n💡 Recommendation:")
print("   Start with relaxed criteria (e.g., CAPE>=300, |CIN|<=250)")
print("   Then tighten based on WRF results")
print("   Drop SBCAPE requirement unless specifically needed")