In [None]:
import xarray as xr
import numpy as np
import glob
import os

import matplotlib.pyplot as plt

# --- Configuration ---
# Path to ensemble directories (each should contain globalStats.nc and regionalStats.nc)
ensemble_dirs = sorted(glob.glob('runs/*'))  # Adjust as needed

# --- Helper Functions ---
def load_stats(ensemble_dirs, stats_filename):
    """Load a list of xarray Datasets for a given stats file from each ensemble directory."""
    datasets = []
    for d in ensemble_dirs:
        path = os.path.join(d, stats_filename)
        if os.path.exists(path):
            ds = xr.open_dataset(path)
            datasets.append(ds)
        else:
            print(f"Warning: {path} not found.")
            datasets.append(None)
    return datasets

def plot_ensemble_time_series(datasets, varname, ylabel, title, ax=None):
    """Plot a variable from each dataset as a time series."""
    if ax is None:
        fig, ax = plt.subplots(figsize=(8,5))
    for i, ds in enumerate(datasets):
        if ds is not None and varname in ds:
            t = ds['daysSinceStart'].values / 365.0  # years
            y = ds[varname].values
            ax.plot(t, y, label=f'Run {i+1}')
    ax.set_xlabel('Years')
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.legend()
    return ax

# --- Load Data ---
global_stats = load_stats(ensemble_dirs, 'globalStats.nc')
regional_stats = load_stats(ensemble_dirs, 'regionalStats.nc')

# --- 1. Plot Floating Basal Mass Balance ---
plt.figure(figsize=(8,5))
plot_ensemble_time_series(global_stats, 'floatingBasalMassBalance', 
                         'Floating Basal Mass Balance (kg/s)', 
                         'Floating Basal Mass Balance Over Time')
plt.tight_layout()
plt.show()

# --- 2. Plot Total Ice Volume ---
plt.figure(figsize=(8,5))
plot_ensemble_time_series(global_stats, 'totalIceVolume', 
                         'Total Ice Volume (m³)', 
                         'Total Ice Volume Over Time')
plt.tight_layout()
plt.show()

# --- 3. Plot Volume Above Floatation ---
plt.figure(figsize=(8,5))
plot_ensemble_time_series(global_stats, 'volumeAboveFloatation', 
                         'Volume Above Floatation (m³)', 
                         'Volume Above Floatation Over Time')
plt.tight_layout()
plt.show()

# --- 4. Additional Global Variables (example: groundedArea) ---
plt.figure(figsize=(8,5))
plot_ensemble_time_series(global_stats, 'groundedArea', 
                         'Grounded Area (m²)', 
                         'Grounded Area Over Time')
plt.tight_layout()
plt.show()

# --- 5. Regional Stats Example: Plot regional grounded ice volume for each region ---
for i, ds in enumerate(regional_stats):
    if ds is not None and 'regionNames' in ds and 'groundedIceVolume' in ds:
        region_names = ds['regionNames'].values.astype(str)
        t = ds['daysSinceStart'].values / 365.0
        fig, ax = plt.subplots(figsize=(10,6))
        for j, region in enumerate(region_names):
            ax.plot(t, ds['groundedIceVolume'][:,j], label=region)
        ax.set_xlabel('Years')
        ax.set_ylabel('Grounded Ice Volume (m³)')
        ax.set_title(f'Run {i+1}: Regional Grounded Ice Volume')
        ax.legend(loc='best', fontsize='small')
        plt.tight_layout()
        plt.show()

# --- 6. Summary Statistics ---
# Example: Print final total ice volume for each run
print("Final total ice volume for each run:")
for i, ds in enumerate(global_stats):
    if ds is not None and 'totalIceVolume' in ds:
        final_vol = ds['totalIceVolume'].values[-1]
        print(f"Run {i+1}: {final_vol:.3e} m³")
    else:
        print(f"Run {i+1}: Data not available.")

# --- 7. (Optional) Add more plots/statistics as needed ---
# For example, plot mean/median across ensemble, or add more variables.