# 01 – Preselect & Diagnostics (CAPE / CIN / Shear)
This notebook **does not run WRF**. It reads soundings from your Sobol experiment, computes diagnostics (CAPE, CIN, 0–1 km and 0–6 km bulk shear, PW), and visualizes **ranges and coverage**.

**Modes:**
1. *Generate mode*: Generate soundings on-the-fly using Sobol samples
2. *Folder mode*: Read existing WRF-ready `input_sounding_*` files from step2 output
3. *Diagnostics mode*: Load pre-computed diagnostics from step2

The goal is to explore **full distributions** – not just a low‑CAPE/high‑shear subset.

In [None]:
import os, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import sys
import pickle

# Add src to path
REPO_ROOT = Path.cwd().parents[0] if (Path.cwd().name == 'notebooks') else Path.cwd()
SRC = REPO_ROOT / 'src'
sys.path.insert(0, str(SRC))

# Import your sounding generator
from sounding_generator import (
    generate_sounding,
    calculate_cape_cin,
    read_input_sounding as read_wrf_sounding
)

In [None]:
np.load()

In [None]:
# --- CONFIG ---
MODE = 'folder'   # 'generate', 'folder', or 'diagnostics'

# For 'generate' mode: path to your Sobol experiment
EXPERIMENT_DIR = str(REPO_ROOT / 'outputs' / 'sobol_exp_500')
BASE_SOUNDING = str(SRC / 'input_sounding')  # Your base sounding file
N_SAMPLES = 500  # How many samples to generate (for testing)

# For 'folder' mode: directory with input_sounding_* files
SOUNDINGS_DIR = str(REPO_ROOT / 'outputs' / 'sobol_exp_500' / 'soundings')

# For 'diagnostics' mode: path to diagnostics.pkl from step2
DIAGNOSTICS_PKL = str(REPO_ROOT / 'outputs' / 'sobol_exp_500' / 'soundings' / 'diagnostics.pkl')

# Output
OUT_CATALOG = str(REPO_ROOT / 'outputs' / 'env_catalog.csv')
os.makedirs(Path(OUT_CATALOG).parent, exist_ok=True)

In [None]:
def compute_diagnostics(sounding):
    """Compute diagnostics for a sounding dict."""
    
    # CAPE and CIN
    cape, cin = calculate_cape_cin(sounding)
    
    # Shear at different levels
    def get_shear(z_target):
        idx = np.argmin(np.abs(sounding['height'] - z_target))
        return np.sqrt((sounding['u'][idx] - sounding['u'][0])**2 +
                      (sounding['v'][idx] - sounding['v'][0])**2)
    
    sh01 = get_shear(1000)
    sh03 = get_shear(3000)
    sh06 = get_shear(6000)
    
    # Precipitable water
    pwat = 0
    for j in range(1, len(sounding['height'])):
        dz = sounding['height'][j] - sounding['height'][j-1]
        qv_avg = (sounding['qv'][j] + sounding['qv'][j-1]) / 2 / 1000  # Convert g/kg to kg/kg
        p_avg = (sounding['p'][j] + sounding['p'][j-1]) / 2 * 100  # Convert hPa to Pa
        t_avg = (sounding['t'][j] + sounding['t'][j-1]) / 2
        rho = p_avg / (287 * t_avg)
        pwat += rho * qv_avg * dz
    
    return {
        'MLCAPE': cape,
        'MLCIN': cin,
        'SH01': sh01,
        'SH03': sh03,
        'SH06': sh06,
        'PW': pwat
    }

In [None]:
# --- Build diagnostics catalog ---
rows = []

if MODE == 'generate':
    print(f"Generating {N_SAMPLES} soundings from Sobol samples...")
    
    # Load Sobol experiment
    with open(f'{EXPERIMENT_DIR}/problem.pkl', 'rb') as f:
        problem = pickle.load(f)
    param_values = np.load(f'{EXPERIMENT_DIR}/param_values.npy')
    
    # Generate subset of soundings
    n_process = min(N_SAMPLES, len(param_values))
    for i in range(n_process):
        param_dict = dict(zip(problem['names'], param_values[i]))
        
        try:
            sounding = generate_sounding(param_dict, base_sounding_file=BASE_SOUNDING)
            diag = compute_diagnostics(sounding)
            rows.append({'idx': i, **diag})
            
            if (i + 1) % 50 == 0:
                print(f"  Progress: {i+1}/{n_process}")
        except Exception as e:
            print(f"  ERROR: Sample {i} failed: {e}")
            rows.append({'idx': i, 'MLCAPE': np.nan, 'MLCIN': np.nan, 
                        'SH01': np.nan, 'SH03': np.nan, 'SH06': np.nan, 'PW': np.nan})

elif MODE == 'folder':
    print(f"Reading soundings from {SOUNDINGS_DIR}...")
    files = sorted(glob.glob(os.path.join(SOUNDINGS_DIR, 'input_sounding_*')))
    if not files:
        raise SystemExit(f'No input_sounding_* files found in {SOUNDINGS_DIR}')
    
    print(f"Found {len(files)} sounding files")
    for i, fp in enumerate(files):
        try:
            sounding = read_wrf_sounding(fp)
            diag = compute_diagnostics(sounding)
            rows.append({'idx': i, **diag})
            
            if (i + 1) % 100 == 0:
                print(f"  Progress: {i+1}/{len(files)}")
        except Exception as e:
            print(f"  ERROR: File {fp} failed: {e}")
            rows.append({'idx': i, 'MLCAPE': np.nan, 'MLCIN': np.nan,
                        'SH01': np.nan, 'SH03': np.nan, 'SH06': np.nan, 'PW': np.nan})

elif MODE == 'diagnostics':
    print(f"Loading pre-computed diagnostics from {DIAGNOSTICS_PKL}...")
    with open(DIAGNOSTICS_PKL, 'rb') as f:
        diag_dict = pickle.load(f)
    
    # Convert to format expected by notebook
    for i in range(len(diag_dict['sample_id'])):
        rows.append({
            'idx': diag_dict['sample_id'][i],
            'MLCAPE': diag_dict['cape'][i],
            'MLCIN': diag_dict['cin'][i],
            'SH01': diag_dict['shear_0_1km'][i],
            'SH03': diag_dict['shear_0_3km'][i],
            'SH06': diag_dict['shear_0_6km'][i],
            'PW': diag_dict['pwat'][i]
        })

else:
    raise ValueError(f'Unknown MODE: {MODE}')

df = pd.DataFrame(rows)
df.to_csv(OUT_CATALOG, index=False)
print(f"\nSaved catalog to {OUT_CATALOG}")
print(f"\nDataFrame shape: {df.shape}")
df.describe(include='all')

## Distributions – CAPE, |CIN|, and Shear
These give you the **ranges** covered by your design. CIN is plotted as absolute value.

In [None]:
for col in ['MLCAPE','MLCIN','SH01','SH03','SH06','PW']:
    if col not in df.columns: 
        continue
    vals = np.abs(df[col]) if col=='MLCIN' else df[col]
    vals_clean = vals.dropna()
    
    if len(vals_clean) == 0:
        print(f"No valid data for {col}")
        continue
    
    plt.figure(figsize=(7,4))
    plt.hist(vals_clean.values, bins=40, edgecolor='black', alpha=0.7)
    plt.xlabel(col if col!='MLCIN' else '|MLCIN| (J/kg)', fontsize=11)
    plt.ylabel('Count', fontsize=11)
    plt.title(f'Distribution of {col}', fontsize=12, fontweight='bold')
    plt.grid(alpha=0.3)
    
    # Add statistics
    mean_val = vals_clean.mean()
    median_val = vals_clean.median()
    plt.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.1f}')
    plt.axvline(median_val, color='orange', linestyle='--', linewidth=2, label=f'Median: {median_val:.1f}')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"{col}: min={vals_clean.min():.1f}, mean={mean_val:.1f}, max={vals_clean.max():.1f}")

## 2‑D Coverage (hexbin)
Quick maps to see where samples land: **CAPE vs shear**, and **CAPE vs |CIN|**.

In [None]:
pairs = [
    ('MLCAPE','SH06','CAPE vs 0–6 km shear'),
    ('MLCAPE','SH01','CAPE vs 0–1 km shear'),
    ('MLCAPE','MLCIN','CAPE vs |CIN|'),
    ('SH06','SH01','0–6 km shear vs 0–1 km shear'),
]

for x, y, title in pairs:
    if x not in df or y not in df: 
        continue
    
    xv = df[x].values
    yv = np.abs(df[y].values) if y=='MLCIN' else df[y].values
    
    # Remove NaN
    mask = ~(np.isnan(xv) | np.isnan(yv))
    xv = xv[mask]
    yv = yv[mask]
    
    if len(xv) == 0:
        print(f"No valid data for {title}")
        continue
    
    plt.figure(figsize=(7,6))
    hb = plt.hexbin(xv, yv, gridsize=40, mincnt=1, cmap='YlOrRd')
    plt.xlabel(x if x!='MLCIN' else '|MLCIN| (J/kg)', fontsize=11)
    plt.ylabel(y if y!='MLCIN' else '|MLCIN| (J/kg)', fontsize=11)
    plt.title(title, fontsize=12, fontweight='bold')
    cb = plt.colorbar(hb, label='Count')
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

## Binned coverage tables
Adjust bins as needed for your parameter space.

In [None]:
# Define bins
cape_bins = [0, 500, 1000, 2000, 3000, 4000, np.inf]
shear_bins = [0, 10, 15, 20, 25, 30, 40, np.inf]
cin_bins = [0, 25, 50, 75, 100, 150, np.inf]

# Create binned columns
df['CINabs'] = np.abs(df['MLCIN'])
df['CAPE_bin'] = pd.cut(df['MLCAPE'], bins=cape_bins)
df['SH06_bin'] = pd.cut(df['SH06'], bins=shear_bins)
df['CIN_bin'] = pd.cut(df['CINabs'], bins=cin_bins)

# Create pivot tables
print("\n" + "="*60)
print("CAPE vs 0-6 km Shear Coverage")
print("="*60)
pivot_cape_sh06 = df.pivot_table(
    index='CAPE_bin', 
    columns='SH06_bin', 
    values='idx', 
    aggfunc='count',
    observed=False
).fillna(0).astype(int)
print(pivot_cape_sh06)

print("\n" + "="*60)
print("CAPE vs |CIN| Coverage")
print("="*60)
pivot_cape_cin = df.pivot_table(
    index='CAPE_bin', 
    columns='CIN_bin', 
    values='idx', 
    aggfunc='count',
    observed=False
).fillna(0).astype(int)
print(pivot_cape_cin)

# Summary statistics
print("\n" + "="*60)
print("Summary Statistics")
print("="*60)
total_samples = len(df)
valid_samples = df['MLCAPE'].notna().sum()
print(f"Total samples: {total_samples}")
print(f"Valid samples: {valid_samples} ({valid_samples/total_samples*100:.1f}%)")
print(f"Failed samples: {total_samples - valid_samples}")

## 3D Scatter: CAPE vs Shear vs |CIN|
Visualize the full 3D parameter space coverage.

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# Clean data
mask = df['MLCAPE'].notna() & df['SH06'].notna() & df['MLCIN'].notna()
df_clean = df[mask].copy()

if len(df_clean) > 0:
    fig = plt.figure(figsize=(14,10))
    ax = fig.add_subplot(111, projection='3d')
    
    # Main 3D scatter
    scatter = ax.scatter(
        df_clean['MLCAPE'], 
        df_clean['SH06'], 
        df_clean['CINabs'],
        c=df_clean['CINabs'], 
        cmap='viridis',
        marker='o', 
        s=20, 
        alpha=0.6
    )
    
    ax.set_xlabel('MLCAPE (J/kg)', fontsize=11)
    ax.set_ylabel('0-6 km Shear (m/s)', fontsize=11)
    ax.set_zlabel('|MLCIN| (J/kg)', fontsize=11)
    ax.set_title('3D Parameter Space: CAPE vs Shear vs |CIN|', fontsize=13, fontweight='bold')
    
    # Set reasonable axis limits
    ax.set_xlim(df_clean['MLCAPE'].min() - 200, df_clean['MLCAPE'].max() + 200)
    ax.set_ylim(0, df_clean['SH06'].max() + 5)
    ax.set_zlim(0, df_clean['CINabs'].max() + 10)
    
    # Projections on walls
    ax.scatter(
        df_clean['MLCAPE'], 
        df_clean['SH06'], 
        zs=0, 
        zdir='z', 
        c='lightcoral', 
        marker='.', 
        s=5, 
        alpha=0.3
    )
    ax.scatter(
        df_clean['MLCAPE'], 
        zs=df_clean['SH06'].max() + 5, 
        zdir='y', 
        ys=df_clean['CINabs'], 
        c='lightgreen', 
        marker='.', 
        s=5, 
        alpha=0.3
    )
    ax.scatter(
        zs=df_clean['MLCAPE'].min() - 200, 
        zdir='x', 
        xs=df_clean['SH06'], 
        ys=df_clean['CINabs'], 
        c='lightblue', 
        marker='.', 
        s=5, 
        alpha=0.3
    )
    
    plt.colorbar(scatter, label='|CIN| (J/kg)', shrink=0.6)
    plt.tight_layout()
    plt.show()
else:
    print("No valid data for 3D plot")

## Analysis Summary
Key metrics and coverage assessment.

In [None]:
print("="*70)
print("PARAMETER SPACE COVERAGE SUMMARY")
print("="*70)

valid_df = df.dropna(subset=['MLCAPE', 'MLCIN', 'SH06'])

if len(valid_df) > 0:
    print(f"\nValid samples: {len(valid_df)}")
    print("\nParameter Ranges:")
    print("-"*70)
    
    params = [
        ('MLCAPE', 'J/kg'),
        ('|MLCIN|', 'J/kg'),
        ('0-1 km Shear', 'm/s'),
        ('0-6 km Shear', 'm/s'),
        ('Precip Water', 'mm')
    ]
    
    cols = ['MLCAPE', 'CINabs', 'SH01', 'SH06', 'PW']
    
    for (name, unit), col in zip(params, cols):
        if col in valid_df.columns:
            vals = valid_df[col]
            print(f"{name:20s}: {vals.min():7.1f} to {vals.max():7.1f} {unit:6s} "
                  f"(mean: {vals.mean():7.1f})")
    
    # Identify potentially interesting cases
    print("\n" + "-"*70)
    print("Potentially Interesting Cases:")
    print("-"*70)
    
    high_cape_high_shear = valid_df[(valid_df['MLCAPE'] > 2000) & (valid_df['SH06'] > 20)]
    low_cin = valid_df[valid_df['CINabs'] < 25]
    extreme_shear = valid_df[valid_df['SH06'] > 30]
    
    print(f"High CAPE + High Shear (>2000 J/kg, >20 m/s): {len(high_cape_high_shear)} samples")
    print(f"Low CIN (<25 J/kg): {len(low_cin)} samples")
    print(f"Extreme shear (>30 m/s): {len(extreme_shear)} samples")
    
    print("\n" + "="*70)
else:
    print("No valid samples to analyze!")