fair_facts_v2_viz_total:

- This notebook visualizes Totaled outputs from the fair_fact_v2_total.ipynb notebook.
- Ensemble means and aggregates of sealevel changes are compared and plotted.
- These selected visualizations may be heavy-handed for the current workflow, but are intended to spur ideas for future illustrations also.


In [None]:
import os
from pathlib import Path

def list_output_files():
    """List all output files from the workflow"""
    output_dirs = [
        './data/output/fair',
        './data/output/lws',
        './data/output/sterodynamics',
        # './data/output/emulandice',
        # './data/output/glacier'
    ]
    
    for dir_path in output_dirs:
        print(f"\n{dir_path}:")
        if Path(dir_path).exists():
            files = list(Path(dir_path).glob('*.nc'))
            if files:
                for f in files:
                    print(f"  - {f.name} ({f.stat().st_size / 1024:.1f} KB)")
            else:
                print("  (no .nc files found)")
        else:
            print("  (directory doesn't exist)")

list_output_files()


./data/output/fair:
  - climate.nc (151.5 KB)
  - oceantemp.nc (61.9 KB)
  - gsat.nc (61.3 KB)
  - ohc.nc (62.0 KB)

./data/output/lws:
  - gslr.nc (12.9 KB)
  - lslr.nc (13.6 KB)

./data/output/sterodynamics:
  - gslr.nc (14.9 KB)
  - lslr.nc (13.0 KB)


In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import warnings

# CRITICAL: Disable numbagg BEFORE any xarray operations
xr.set_options(use_numbagg=False)

# Suppress overflow warnings for large value calculations
warnings.filterwarnings('ignore', category=RuntimeWarning, message='overflow encountered')

class FACTSeaLevelVisualizer:
    """Visualizer for FACT Sea Level Framework outputs"""
    
    def __init__(self, output_base_dir='./data/output'):
        self.base_dir = Path(output_base_dir)
        self.datasets = {}
        self.scenario = 'ssp585'
        
    def diagnose_datasets(self):
        """Print detailed information about loaded datasets"""
        print("\n" + "="*60)
        print("DATASET DIAGNOSTICS")
        print("="*60)
        
        for key, ds in self.datasets.items():
            print(f"\n{key}:")
            print(f"  Variables: {list(ds.data_vars)}")
            print(f"  Coordinates: {list(ds.coords)}")
            print(f"  Dimensions: {dict(ds.dims)}")
            
            for var in ds.data_vars:
                var_data = ds[var]
                print(f"\n  {var}:")
                print(f"    Shape: {var_data.shape}")
                print(f"    Dims: {var_data.dims}")
                print(f"    Dtype: {var_data.dtype}")
                
                # Check for NaN values
                if np.issubdtype(var_data.dtype, np.floating):
                    values = var_data.values
                    n_nan = np.sum(np.isnan(values))
                    n_total = values.size
                    print(f"    NaN values: {n_nan}/{n_total} ({100*n_nan/n_total:.1f}%)")
                    
                    if n_nan < n_total:
                        valid_values = values[~np.isnan(values)]
                        print(f"    Range: [{np.min(valid_values):.3f}, {np.max(valid_values):.3f}]")
                        print(f"    Mean: {np.mean(valid_values):.3f}")
                    
    def load_outputs(self):
        """Load all available output files"""
        print("Loading output files...")
        
        output_files = {
            'climate': self.base_dir / 'fair/climate.nc',
            'gsat': self.base_dir / 'fair/gsat.nc',
            'oceantemp': self.base_dir / 'fair/oceantemp.nc',
            'ohc': self.base_dir / 'fair/ohc.nc',
            'lws_gslr': self.base_dir / 'lws/gslr.nc',
            'lws_lslr': self.base_dir / 'lws/lslr.nc',
            'stereo_gslr': self.base_dir / 'sterodynamics/gslr.nc',
            'stereo_lslr': self.base_dir / 'sterodynamics/lslr.nc',
            # 'ice_gslr': self.base_dir / 'emulandice/gslr.nc',
            # 'ice_lslr': self.base_dir / 'emulandice/lslr.nc'
        }
        
        for key, filepath in output_files.items():
            if filepath.exists():
                try:
                    if key == 'climate':
                        self.datasets[key] = xr.open_dataset(filepath, group=self.scenario)
                        print(f"  ✓ Loaded {key}: {filepath.name} (group: {self.scenario})")
                    else:
                        self.datasets[key] = xr.open_dataset(filepath)
                        print(f"  ✓ Loaded {key}: {filepath.name}")
                except Exception as e:
                    print(f"  ✗ Error loading {key}: {e}")
            else:
                print(f"  - Not found: {filepath}")
        
        return self.datasets
    
    def plot_climate_latest_snapshot(self, scenario='ssp585', figsize=(16, 10)):
        """Plot climate variables at the latest timestamp"""
        
        if 'climate' not in self.datasets:
            print("Climate data not available")
            return
        
        ds = self.datasets['climate']
        
        years = ds['years'].values
        latest_year = years[-1]
        latest_idx = -1
        
        print(f"\nPlotting climate snapshot at year {latest_year}")
        
        fig, axes = plt.subplots(2, 3, figsize=figsize)
        
        samples_dim = 'samples'
        
        variables = {
            'surface_temperature': ('Surface Temperature', '°C', 'Reds'),
            'deep_ocean_temperature': ('Deep Ocean Temperature', '°C', 'Blues'),
            'ocean_heat_content': ('Ocean Heat Content', 'YJ (10²⁴ J)', 'Greens')
        }
        
        for idx, (var_name, (title, unit, cmap)) in enumerate(variables.items()):
            if var_name not in ds.data_vars:
                print(f"⚠ Variable '{var_name}' not found")
                continue
            
            var_data = ds[var_name]
            
            # Get data at latest timestamp
            data_latest = var_data.isel(years=latest_idx).values
            
            # Scale ocean heat content
            if var_name == 'ocean_heat_content':
                data_latest = data_latest / 1e24
            
            print(f"\n{var_name}:")
            print(f"  Shape at latest time: {data_latest.shape}")
            print(f"  Range: [{np.min(data_latest):.3f}, {np.max(data_latest):.3f}] {unit}")
            print(f"  Mean: {np.mean(data_latest):.3f} {unit}")
            print(f"  Std: {np.std(data_latest):.3f} {unit}")
            
            # Top row: Histogram
            ax = axes[0, idx]
            ax.hist(data_latest, bins=30, alpha=0.7, color=plt.cm.get_cmap(cmap)(0.6), edgecolor='black')
            ax.axvline(np.mean(data_latest), color='red', linestyle='--', linewidth=2, 
                       label=f'Mean: {np.mean(data_latest):.2f} {unit}')
            ax.axvline(np.median(data_latest), color='darkred', linestyle=':', linewidth=2,
                       label=f'Median: {np.median(data_latest):.2f} {unit}')
            ax.set_title(f'{title}\nat {latest_year}', fontsize=12, fontweight='bold')
            ax.set_xlabel(f'{title} ({unit})')
            ax.set_ylabel('Frequency (Ensemble Members)')
            ax.legend(fontsize=9)
            ax.grid(True, alpha=0.3, axis='y')
            
            # Bottom row: Box plot with individual points
            ax = axes[1, idx]
            
            # Box plot
            bp = ax.boxplot([data_latest], positions=[1], widths=0.6, patch_artist=True,
                            showmeans=True, meanline=True,
                            boxprops=dict(facecolor=plt.cm.get_cmap(cmap)(0.4), alpha=0.7),
                            medianprops=dict(color='darkred', linewidth=2),
                            meanprops=dict(color='red', linewidth=2, linestyle='--'),
                            whiskerprops=dict(linewidth=1.5),
                            capprops=dict(linewidth=1.5))
            
            # Scatter individual points
            y_scatter = data_latest
            x_scatter = np.random.normal(1, 0.04, size=len(data_latest))  # Add jitter
            ax.scatter(x_scatter, y_scatter, alpha=0.5, s=50, 
                      c=plt.cm.get_cmap(cmap)(0.8), edgecolors='black', linewidth=0.5)
            
            # Add statistics text
            stats_text = f'n = {len(data_latest)}\n'
            stats_text += f'μ = {np.mean(data_latest):.2f}\n'
            stats_text += f'σ = {np.std(data_latest):.2f}\n'
            stats_text += f'5th = {np.percentile(data_latest, 5):.2f}\n'
            stats_text += f'95th = {np.percentile(data_latest, 95):.2f}'
            
            ax.text(1.5, np.mean(data_latest), stats_text,
                   fontsize=9, verticalalignment='center',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
            
            ax.set_xlim([0.5, 2])
            ax.set_xticks([1])
            ax.set_xticklabels([f'{latest_year}'])
            ax.set_ylabel(f'{title} ({unit})')
            ax.set_title(f'{title} - Ensemble Spread', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'Climate Variables at Latest Timestamp ({latest_year}) - {scenario.upper()}',
                     fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_climate_evolution_to_latest(self, scenario='ssp585', figsize=(16, 12)):
        """Plot evolution of climate variables from start to latest timestamp"""
        
        if 'climate' not in self.datasets:
            print("Climate data not available")
            return
        
        ds = self.datasets['climate']
        
        years = ds['years'].values
        latest_year = years[-1]
        
        fig, axes = plt.subplots(3, 2, figsize=figsize)
        
        samples_dim = 'samples'
        
        variables = {
            'surface_temperature': ('Surface Temperature', '°C', 'red'),
            'deep_ocean_temperature': ('Deep Ocean Temperature', '°C', 'blue'),
            'ocean_heat_content': ('Ocean Heat Content', 'YJ (10²⁴ J)', 'green')
        }
        
        for idx, (var_name, (title, unit, color)) in enumerate(variables.items()):
            if var_name not in ds.data_vars:
                print(f"⚠ Variable '{var_name}' not found")
                continue
            
            var_data = ds[var_name]
            
            # Scale ocean heat content
            if var_name == 'ocean_heat_content':
                var_data = var_data / 1e24
            
            # Left column: Time series with all ensemble members highlighted at end
            ax = axes[idx, 0]
            
            # Plot all ensemble members
            for i in range(var_data.sizes[samples_dim]):
                member_data = var_data.isel(samples=i).values
                ax.plot(years, member_data, alpha=0.3, color=color, linewidth=0.5)
                
                # Highlight endpoint
                ax.scatter(years[-1], member_data[-1], s=50, color=color, 
                          alpha=0.6, edgecolors='black', linewidth=0.5, zorder=5)
            
            # Plot mean
            mean = var_data.mean(dim=samples_dim)
            ax.plot(years, mean, color='darkred', linewidth=2.5, label='Ensemble Mean')
            ax.scatter(years[-1], mean.values[-1], s=100, color='darkred',
                      marker='*', edgecolors='black', linewidth=1, zorder=10,
                      label=f'Latest: {mean.values[-1]:.2f} {unit}')
            
            ax.set_title(f'{title} - Evolution', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel(f'{title} ({unit})')
            ax.legend(fontsize=9)
            ax.grid(True, alpha=0.3)
            ax.axvline(x=latest_year, color='gray', linestyle='--', alpha=0.5, linewidth=1)
            
            # Right column: Latest value distribution with trend indication
            ax = axes[idx, 1]
            
            data_latest = var_data.isel(years=-1).values
            data_previous = var_data.isel(years=-2).values if len(years) > 1 else data_latest
            
            # Violin plot
            parts = ax.violinplot([data_latest], positions=[1], widths=0.8,
                                  showmeans=True, showmedians=True)
            
            # Color the violin plot
            for pc in parts['bodies']:
                pc.set_facecolor(color)
                pc.set_alpha(0.7)
            
            # Scatter points with color indicating change from previous
            changes = data_latest - data_previous
            colors_scatter = ['green' if c > 0 else 'red' if c < 0 else 'gray' for c in changes]
            
            x_scatter = np.random.normal(1, 0.04, size=len(data_latest))
            ax.scatter(x_scatter, data_latest, c=colors_scatter, alpha=0.6, s=60,
                      edgecolors='black', linewidth=0.5)
            
            # Statistics
            mean_val = np.mean(data_latest)
            std_val = np.std(data_latest)
            
            ax.axhline(y=mean_val, color='darkred', linestyle='--', linewidth=2, alpha=0.7)
            
            # Add change indicator
            mean_change = np.mean(changes)
            change_symbol = '↑' if mean_change > 0 else '↓' if mean_change < 0 else '→'
            change_color = 'green' if mean_change > 0 else 'red' if mean_change < 0 else 'gray'
            
            stats_text = f'{change_symbol} Δ = {mean_change:.3f} {unit}\n'
            stats_text += f'μ = {mean_val:.2f} {unit}\n'
            stats_text += f'σ = {std_val:.2f} {unit}\n'
            stats_text += f'Range: [{np.min(data_latest):.2f}, {np.max(data_latest):.2f}]'
            
            ax.text(1.5, mean_val, stats_text, fontsize=9, verticalalignment='center',
                   bbox=dict(boxstyle='round', facecolor=change_color, alpha=0.3))
            
            ax.set_xlim([0.5, 2.5])
            ax.set_xticks([1])
            ax.set_xticklabels([f'{latest_year}'])
            ax.set_ylabel(f'{title} ({unit})')
            ax.set_title(f'{title} at {latest_year}', fontsize=12, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'Climate Variables Evolution to Latest ({latest_year}) - {scenario.upper()}',
                     fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_climate_heatmap_over_time(self, scenario='ssp585', figsize=(16, 10)):
        """Create heatmap showing ensemble spread over time"""
        
        if 'climate' not in self.datasets:
            print("Climate data not available")
            return
        
        ds = self.datasets['climate']
        
        years = ds['years'].values
        samples_dim = 'samples'
        
        fig, axes = plt.subplots(3, 1, figsize=figsize)
        
        variables = {
            'surface_temperature': ('Surface Temperature', '°C', 'RdYlBu_r'),
            'deep_ocean_temperature': ('Deep Ocean Temperature', '°C', 'RdYlBu_r'),
            'ocean_heat_content': ('Ocean Heat Content', 'YJ (10²⁴ J)', 'YlOrRd')
        }
        
        for idx, (var_name, (title, unit, cmap)) in enumerate(variables.items()):
            if var_name not in ds.data_vars:
                continue
            
            var_data = ds[var_name]
            
            # Scale ocean heat content
            if var_name == 'ocean_heat_content':
                var_data = var_data / 1e24
            
            ax = axes[idx]
            
            # Create 2D array: samples x years
            data_2d = var_data.values.T  # Transpose to get (samples, years)
            
            # Create heatmap
            im = ax.imshow(data_2d, aspect='auto', cmap=cmap, interpolation='nearest')
            
            # Set ticks
            n_ticks = min(10, len(years))
            tick_indices = np.linspace(0, len(years)-1, n_ticks, dtype=int)
            ax.set_xticks(tick_indices)
            ax.set_xticklabels([f'{years[i]:.0f}' for i in tick_indices])
            
            ax.set_yticks(range(data_2d.shape[0]))
            ax.set_yticklabels([f'M{i+1}' for i in range(data_2d.shape[0])], fontsize=8)
            
            ax.set_xlabel('Year')
            ax.set_ylabel('Ensemble Member')
            ax.set_title(f'{title} - Ensemble Heatmap', fontsize=12, fontweight='bold')
            
            # Colorbar
            cbar = plt.colorbar(im, ax=ax, orientation='vertical', pad=0.02)
            cbar.set_label(unit, rotation=270, labelpad=20)
            
            # Mark latest year
            ax.axvline(x=len(years)-1, color='white', linestyle='--', linewidth=2, alpha=0.8)
        
        plt.suptitle(f'Climate Variables Ensemble Heatmap - {scenario.upper()}',
                     fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def rolling_mean_numpy(self, data, window, axis=0):
        """Manual rolling mean calculation using numpy to avoid numbagg"""
        pad_width = [(0, 0)] * data.ndim
        pad_width[axis] = (window // 2, window // 2)
        padded = np.pad(data, pad_width, mode='edge')
        
        cumsum = np.cumsum(padded, axis=axis)
        cumsum_shifted = np.roll(cumsum, window, axis=axis)
        cumsum_shifted[..., :window] = 0
        
        result = (cumsum - cumsum_shifted) / window
        
        slices = [slice(None)] * data.ndim
        slices[axis] = slice(window // 2, -(window // 2) if window // 2 > 0 else None)
        return result[tuple(slices)]
    
    def plot_climate_variables(self, scenario='ssp585', figsize=(16, 12)):
        """Plot FAIR climate model outputs with ensemble uncertainty"""
        if 'climate' not in self.datasets:
            print("Climate data not available")
            return
        
        ds = self.datasets['climate']
        fig, axes = plt.subplots(3, 2, figsize=figsize)
        
        years = ds['years'].values
        samples_dim = 'samples'
        
        # 1. Surface Temperature - Individual trajectories
        ax = axes[0, 0]
        surf_temp = ds['surface_temperature']
        for i in range(surf_temp.sizes[samples_dim]):
            ax.plot(years, surf_temp.isel(samples=i), alpha=0.3, color='red', linewidth=0.5)
        
        mean_temp = surf_temp.mean(dim=samples_dim)
        ax.plot(years, mean_temp, color='darkred', linewidth=2.5, label='Ensemble Mean')
        ax.set_title('Surface Temperature - Individual Ensemble Members', fontsize=12, fontweight='bold')
        ax.set_xlabel('Year')
        ax.set_ylabel('Temperature Anomaly (°C)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 2. Surface Temperature - Percentile bands
        ax = axes[0, 1]
        mean = surf_temp.mean(dim=samples_dim)
        p5 = surf_temp.quantile(0.05, dim=samples_dim)
        p95 = surf_temp.quantile(0.95, dim=samples_dim)
        p25 = surf_temp.quantile(0.25, dim=samples_dim)
        p75 = surf_temp.quantile(0.75, dim=samples_dim)
        
        ax.fill_between(years, p5, p95, alpha=0.2, color='red', label='5-95th percentile')
        ax.fill_between(years, p25, p75, alpha=0.4, color='red', label='25-75th percentile')
        ax.plot(years, mean, color='darkred', linewidth=2.5, label='Mean')
        ax.set_title('Surface Temperature - Uncertainty Ranges', fontsize=12, fontweight='bold')
        ax.set_xlabel('Year')
        ax.set_ylabel('Temperature Anomaly (°C)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 3. Deep Ocean Temperature
        ax = axes[1, 0]
        deep_temp = ds['deep_ocean_temperature']
        mean_deep = deep_temp.mean(dim=samples_dim)
        p5_deep = deep_temp.quantile(0.05, dim=samples_dim)
        p95_deep = deep_temp.quantile(0.95, dim=samples_dim)
        
        ax.fill_between(years, p5_deep, p95_deep, alpha=0.3, color='blue')
        ax.plot(years, mean_deep, color='darkblue', linewidth=2.5)
        ax.set_title('Deep Ocean Temperature', fontsize=12, fontweight='bold')
        ax.set_xlabel('Year')
        ax.set_ylabel('Temperature Anomaly (°C)')
        ax.grid(True, alpha=0.3)
        
        # 4. Ocean Heat Content
        ax = axes[1, 1]
        ohc = ds['ocean_heat_content']
        mean_ohc = ohc.mean(dim=samples_dim) / 1e24
        p5_ohc = ohc.quantile(0.05, dim=samples_dim) / 1e24
        p95_ohc = ohc.quantile(0.95, dim=samples_dim) / 1e24
        
        ax.fill_between(years, p5_ohc, p95_ohc, alpha=0.3, color='green')
        ax.plot(years, mean_ohc, color='darkgreen', linewidth=2.5)
        ax.set_title('Ocean Heat Content', fontsize=12, fontweight='bold')
        ax.set_xlabel('Year')
        ax.set_ylabel('Heat Content (YJ, 10²⁴ J)')
        ax.grid(True, alpha=0.3)
        
        # 5. Temperature distribution at 2100
        ax = axes[2, 0]
        year_2100_idx = np.argmin(np.abs(years - 2100))
        temp_2100 = surf_temp.isel(years=year_2100_idx).values
        
        ax.hist(temp_2100, bins=15, alpha=0.7, color='red', edgecolor='black')
        ax.axvline(np.mean(temp_2100), color='darkred', linestyle='--', linewidth=2, 
                   label=f'Mean: {np.mean(temp_2100):.2f}°C')
        ax.axvline(np.median(temp_2100), color='red', linestyle=':', linewidth=2,
                   label=f'Median: {np.median(temp_2100):.2f}°C')
        ax.set_title(f'Surface Temperature Distribution at 2100', fontsize=12, fontweight='bold')
        ax.set_xlabel('Temperature Anomaly (°C)')
        ax.set_ylabel('Frequency')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        # 6. Warming rate - using manual calculation
        ax = axes[2, 1]
        window = 10
        
        surf_temp_values = surf_temp.values
        temp_diff = np.diff(surf_temp_values, axis=0)
        
        if temp_diff.shape[0] >= window:
            smoothed_diff = self.rolling_mean_numpy(temp_diff, window, axis=0)
            warming_rate_values = smoothed_diff * 10
            
            mean_rate = np.mean(warming_rate_values, axis=1)
            p5_rate = np.percentile(warming_rate_values, 5, axis=1)
            p95_rate = np.percentile(warming_rate_values, 95, axis=1)
            
            years_rate = years[1:len(mean_rate)+1]
            
            ax.fill_between(years_rate, p5_rate, p95_rate, alpha=0.3, color='orange')
            ax.plot(years_rate, mean_rate, color='darkorange', linewidth=2.5)
            ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
            ax.set_title(f'Decadal Warming Rate ({window}-year smoothed)', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Warming Rate (°C/decade)')
            ax.grid(True, alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'Insufficient data for warming rate', 
                   ha='center', va='center', transform=ax.transAxes)
        
        plt.suptitle(f'FAIR Climate Model Outputs - {scenario.upper()}', 
                     fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_sea_level_components_comparison(self, year=2100, figsize=(16, 10)):
        """Compare sea level rise contributions from different components"""
        
        components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Sheets & Glaciers': 'ice_gslr'
        }
        
        available = {name: key for name, key in components.items() if key in self.datasets}
        
        if not available:
            print("No sea level rise data available")
            return
        
        n_components = len(available)
        fig, axes = plt.subplots(2, n_components + 1, figsize=figsize)
        
        all_time_series = []
        all_time_axes = []
        all_distributions = []
        colors = ['blue', 'green', 'orange', 'purple']
        
        for idx, (name, key) in enumerate(available.items()):
            ds = self.datasets[key]
            
            slr_vars = [v for v in ds.data_vars if 'slr' in v.lower() or 'sea' in v.lower()]
            if not slr_vars:
                print(f"No SLR variable found in {key}")
                continue
            
            slr_var = slr_vars[0]
            slr_data = ds[slr_var]
            
            time_dim = 'years' if 'years' in slr_data.dims else 'year' if 'year' in slr_data.dims else 'time'
            sample_dim = 'samples' if 'samples' in slr_data.dims else 'sample' if 'sample' in slr_data.dims else 'ensemble'
            
            if time_dim not in slr_data.dims:
                print(f"Could not find time dimension in {key}")
                continue
            
            years_data = ds[time_dim].values
            
            # Time series plot
            ax = axes[0, idx]
            if sample_dim in slr_data.dims:
                mean = slr_data.mean(dim=sample_dim).values.squeeze()
                p5 = slr_data.quantile(0.05, dim=sample_dim).values.squeeze()
                p95 = slr_data.quantile(0.95, dim=sample_dim).values.squeeze()
                
                # Ensure 1D
                if mean.ndim > 1:
                    mean = mean.flatten()
                if p5.ndim > 1:
                    p5 = p5.flatten()
                if p95.ndim > 1:
                    p95 = p95.flatten()
                
                ax.fill_between(years_data, p5, p95, alpha=0.3, color=colors[idx])
                ax.plot(years_data, mean, color=colors[idx], linewidth=2.5)
                
                all_time_series.append(mean)
                all_time_axes.append(years_data)
            else:
                slr_values = slr_data.values.squeeze()
                if slr_values.ndim > 1:
                    slr_values = slr_values.flatten()
                ax.plot(years_data, slr_values, color=colors[idx], linewidth=2.5)
                all_time_series.append(slr_values)
                all_time_axes.append(years_data)
            
            ax.set_title(name, fontsize=11, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('SLR (mm)')
            ax.grid(True, alpha=0.3)
            
            # Distribution at target year
            ax = axes[1, idx]
            year_idx = np.argmin(np.abs(years_data - year))
            
            if sample_dim in slr_data.dims:
                data_at_year = slr_data.isel({time_dim: year_idx}).values.flatten()
                
                ax.hist(data_at_year, bins=20, alpha=0.7, color=colors[idx], edgecolor='black')
                ax.axvline(np.mean(data_at_year), color='darkred', linestyle='--', linewidth=2,
                          label=f'Mean: {np.mean(data_at_year):.1f} mm')
                ax.axvline(np.percentile(data_at_year, 5), color='gray', linestyle=':', linewidth=1.5,
                          label=f'5th: {np.percentile(data_at_year, 5):.1f} mm')
                ax.axvline(np.percentile(data_at_year, 95), color='gray', linestyle=':', linewidth=1.5,
                          label=f'95th: {np.percentile(data_at_year, 95):.1f} mm')
                
                all_distributions.append(data_at_year)
            else:
                data_at_year = slr_data.isel({time_dim: year_idx}).values.flatten()[0]
                ax.axvline(data_at_year, color=colors[idx], linewidth=3)
                all_distributions.append([data_at_year])
            
            ax.set_title(f'{name} at {year}', fontsize=11, fontweight='bold')
            ax.set_xlabel('SLR (mm)')
            ax.set_ylabel('Frequency')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3, axis='y')
        
        # Plot total - CHECK FOR CONSISTENT TIME AXES
        if len(all_time_series) > 1:
            # Check if all time axes are the same
            time_lengths = [len(t) for t in all_time_axes]
            
            if len(set(time_lengths)) == 1:
                # All same length, can sum directly
                ax = axes[0, -1]
                years_common = all_time_axes[0]
                total_ts = np.sum(all_time_series, axis=0)
                ax.plot(years_common, total_ts, color='red', linewidth=3)
                ax.fill_between(years_common, 0, total_ts, alpha=0.2, color='red')
                ax.set_title('Total SLR', fontsize=11, fontweight='bold')
                ax.set_xlabel('Year')
                ax.set_ylabel('Total SLR (mm)')
                ax.grid(True, alpha=0.3)
                
                # Distribution total
                ax = axes[1, -1]
                if all(len(d) > 1 for d in all_distributions):
                    total_dist = np.sum(all_distributions, axis=0)
                    ax.hist(total_dist, bins=25, alpha=0.7, color='red', edgecolor='black')
                    ax.axvline(np.mean(total_dist), color='darkred', linestyle='--', linewidth=2,
                              label=f'Mean: {np.mean(total_dist):.1f} mm')
                    ax.axvline(np.percentile(total_dist, 5), color='gray', linestyle=':', linewidth=1.5,
                              label=f'5th: {np.percentile(total_dist, 5):.1f} mm')
                    ax.axvline(np.percentile(total_dist, 95), color='gray', linestyle=':', linewidth=1.5,
                              label=f'95th: {np.percentile(total_dist, 95):.1f} mm')
                    ax.set_title(f'Total SLR at {year}', fontsize=11, fontweight='bold')
                    ax.set_xlabel('Total SLR (mm)')
                    ax.set_ylabel('Frequency')
                    ax.legend(fontsize=8)
                    ax.grid(True, alpha=0.3, axis='y')
            else:
                # Different lengths - need to interpolate to common time axis
                print(f"⚠ Time axes differ: {time_lengths}. Interpolating to common axis.")
                
                # Find the common time range
                min_year = max([t[0] for t in all_time_axes])
                max_year = min([t[-1] for t in all_time_axes])
                
                # Create common time axis
                n_points = min(time_lengths)
                years_common = np.linspace(min_year, max_year, n_points)
                
                # Interpolate all series to common axis
                interpolated_series = []
                for years_orig, data_orig in zip(all_time_axes, all_time_series):
                    data_interp = np.interp(years_common, years_orig, data_orig)
                    interpolated_series.append(data_interp)
                
                # Now sum
                ax = axes[0, -1]
                total_ts = np.sum(interpolated_series, axis=0)
                ax.plot(years_common, total_ts, color='red', linewidth=3)
                ax.fill_between(years_common, 0, total_ts, alpha=0.2, color='red')
                ax.set_title('Total SLR (interpolated)', fontsize=11, fontweight='bold')
                ax.set_xlabel('Year')
                ax.set_ylabel('Total SLR (mm)')
                ax.grid(True, alpha=0.3)
                
                # Distribution total
                ax = axes[1, -1]
                if all(len(d) > 1 for d in all_distributions):
                    total_dist = np.sum(all_distributions, axis=0)
                    ax.hist(total_dist, bins=25, alpha=0.7, color='red', edgecolor='black')
                    ax.axvline(np.mean(total_dist), color='darkred', linestyle='--', linewidth=2,
                              label=f'Mean: {np.mean(total_dist):.1f} mm')
                    ax.axvline(np.percentile(total_dist, 5), color='gray', linestyle=':', linewidth=1.5,
                              label=f'5th: {np.percentile(total_dist, 5):.1f} mm')
                    ax.axvline(np.percentile(total_dist, 95), color='gray', linestyle=':', linewidth=1.5,
                              label=f'95th: {np.percentile(total_dist, 95):.1f} mm')
                    ax.set_title(f'Total SLR at {year}', fontsize=11, fontweight='bold')
                    ax.set_xlabel('Total SLR (mm)')
                    ax.set_ylabel('Frequency')
                    ax.legend(fontsize=8)
                    ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'Sea Level Rise Components - {self.scenario.upper()}',
                     fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_slr_stacked_contributions(self, figsize=(14, 8)):
        """Create stacked area plot showing relative contributions"""
        
        components = {
            'Land Water Storage': ('lws_gslr', 'blue'),
            'Sterodynamics': ('stereo_gslr', 'green'),
            'Ice Sheets & Glaciers': ('ice_gslr', 'orange')
        }
        
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
        
        data_dict = {}
        years_list = []
        
        for name, (key, color) in components.items():
            if key not in self.datasets:
                continue
            
            ds = self.datasets[key]
            slr_vars = [v for v in ds.data_vars if 'slr' in v.lower()]
            if not slr_vars:
                continue
            
            slr_var = slr_vars[0]
            slr_data = ds[slr_var]
            
            time_dim = 'years' if 'years' in slr_data.dims else 'year' if 'year' in slr_data.dims else 'time'
            sample_dim = 'samples' if 'samples' in slr_data.dims else 'sample' if 'sample' in slr_data.dims else None
            
            years_data = ds[time_dim].values
            years_list.append(years_data)
            
            if sample_dim:
                mean = slr_data.mean(dim=sample_dim).values.squeeze()
                if mean.ndim > 1:
                    mean = mean.flatten()
            else:
                mean = slr_data.values.squeeze()
                if mean.ndim > 1:
                    mean = mean.flatten()
            
            data_dict[name] = (years_data, mean, color)
        
        if not data_dict:
            print("Not enough data for stacked plot")
            return
        
        # Check if all time axes are the same
        time_lengths = [len(y) for y, _, _ in data_dict.values()]
        
        if len(set(time_lengths)) == 1:
            # All same length
            years_common = list(data_dict.values())[0][0]
            
            # Stacked area plot
            cumsum = np.zeros_like(years_common, dtype=float)
            for name, (_, data, color) in data_dict.items():
                ax1.fill_between(years_common, cumsum, cumsum + data, alpha=0.7, color=color, label=name)
                cumsum += data
            
            ax1.plot(years_common, cumsum, 'k-', linewidth=2, label='Total')
            ax1.set_title('Stacked Sea Level Rise Contributions', fontsize=13, fontweight='bold')
            ax1.set_xlabel('Year')
            ax1.set_ylabel('Cumulative SLR (mm)')
            ax1.legend(loc='upper left')
            ax1.grid(True, alpha=0.3)
            
            # Relative contributions
            total = cumsum
            # Avoid division by zero
            total_safe = np.where(total != 0, total, 1)
            for name, (_, data, color) in data_dict.items():
                percentage = (data / total_safe) * 100
                ax2.plot(years_common, percentage, linewidth=2.5, color=color, label=name)
            
            ax2.set_title('Relative Contributions to Total SLR', fontsize=13, fontweight='bold')
            ax2.set_xlabel('Year')
            ax2.set_ylabel('Contribution (%)')
            ax2.set_ylim([0, 100])
            ax2.legend(loc='best')
            ax2.grid(True, alpha=0.3)
        else:
            # Need to interpolate
            print(f"⚠ Time axes differ: {time_lengths}. Interpolating to common axis.")
            
            # Find common time range
            min_year = max([y[0] for y, _, _ in data_dict.values()])
            max_year = min([y[-1] for y, _, _ in data_dict.values()])
            n_points = min(time_lengths)
            years_common = np.linspace(min_year, max_year, n_points)
            
            # Interpolate all series
            interpolated_dict = {}
            for name, (years_orig, data_orig, color) in data_dict.items():
                data_interp = np.interp(years_common, years_orig, data_orig)
                interpolated_dict[name] = (data_interp, color)
            
            # Stacked area plot
            cumsum = np.zeros_like(years_common, dtype=float)
            for name, (data, color) in interpolated_dict.items():
                ax1.fill_between(years_common, cumsum, cumsum + data, alpha=0.7, color=color, label=name)
                cumsum += data
            
            ax1.plot(years_common, cumsum, 'k-', linewidth=2, label='Total')
            ax1.set_title('Stacked Sea Level Rise Contributions', fontsize=13, fontweight='bold')
            ax1.set_xlabel('Year')
            ax1.set_ylabel('Cumulative SLR (mm)')
            ax1.legend(loc='upper left')
            ax1.grid(True, alpha=0.3)
            
            # Relative contributions
            total = cumsum
            total_safe = np.where(total != 0, total, 1)
            for name, (data, color) in interpolated_dict.items():
                percentage = (data / total_safe) * 100
                ax2.plot(years_common, percentage, linewidth=2.5, color=color, label=name)
            
            ax2.set_title('Relative Contributions to Total SLR', fontsize=13, fontweight='bold')
            ax2.set_xlabel('Year')
            ax2.set_ylabel('Contribution (%)')
            ax2.set_ylim([0, 100])
            ax2.legend(loc='best')
            ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        return fig
    
    def create_summary_dashboard(self, year=2100, scenario='ssp585'):
        """Comprehensive summary dashboard"""
        
        fig = plt.figure(figsize=(20, 14))
        gs = fig.add_gridspec(4, 3, hspace=0.35, wspace=0.3)
        
        # 1. Temperature trajectory
        if 'climate' in self.datasets:
            ax = fig.add_subplot(gs[0, :])
            ds = self.datasets['climate']
            
            years = ds['years'].values
            surf_temp = ds['surface_temperature']
            
            mean = surf_temp.mean(dim='samples').values
            p5 = surf_temp.quantile(0.05, dim='samples').values
            p95 = surf_temp.quantile(0.95, dim='samples').values
            
            ax.fill_between(years, p5, p95, alpha=0.3, color='red', label='5-95% range')
            ax.plot(years, mean, 'r-', linewidth=3, label='Mean')
            
            for temp_level in [1.5, 2.0, 3.0]:
                idx = np.where(mean >= temp_level)[0]
                if len(idx) > 0:
                    ax.axhline(y=temp_level, color='gray', linestyle='--', alpha=0.5, linewidth=1)
                    ax.text(years[-1], temp_level, f'{temp_level}°C', va='center', ha='right', fontsize=9)
            
            ax.set_title('Global Surface Temperature Projection', fontsize=14, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Temperature Anomaly (°C)')
            ax.legend(loc='upper left')
            ax.grid(True, alpha=0.3)
        
        # 2-4. Individual SLR components
        components = [
            ('Land Water\nStorage', 'lws_gslr', gs[1, 0], 'blue'),
            ('Sterodynamics', 'stereo_gslr', gs[1, 1], 'green'),
            ('Ice Components', 'ice_gslr', gs[1, 2], 'orange')
        ]
        
        all_means = []
        for name, key, grid_spec, color in components:
            if key in self.datasets:
                ax = fig.add_subplot(grid_spec)
                ds = self.datasets[key]
                
                slr_vars = [v for v in ds.data_vars if 'slr' in v.lower()]
                if slr_vars:
                    slr_data = ds[slr_vars[0]]
                    time_dim = 'years' if 'years' in slr_data.dims else 'year' if 'year' in slr_data.dims else 'time'
                    sample_dim = 'samples' if 'samples' in slr_data.dims else 'sample' if 'sample' in slr_data.dims else None
                    
                    years_slr = ds[time_dim].values
                    
                    if sample_dim:
                        mean = slr_data.mean(dim=sample_dim).values.squeeze()
                        p5 = slr_data.quantile(0.05, dim=sample_dim).values.squeeze()
                        p95 = slr_data.quantile(0.95, dim=sample_dim).values.squeeze()
                        
                        # Ensure 1D
                        if mean.ndim > 1:
                            mean = mean.flatten()
                        if p5.ndim > 1:
                            p5 = p5.flatten()
                        if p95.ndim > 1:
                            p95 = p95.flatten()
                        
                        ax.fill_between(years_slr, p5, p95, alpha=0.3, color=color)
                        ax.plot(years_slr, mean, color=color, linewidth=2.5)
                        
                        all_means.append((years_slr, mean))
                    
                    ax.set_title(name, fontsize=12, fontweight='bold')
                    ax.set_xlabel('Year', fontsize=10)
                    ax.set_ylabel('SLR (mm)', fontsize=10)
                    ax.grid(True, alpha=0.3)
        
        # 5. Total SLR over time
        if len(all_means) > 0:
            ax = fig.add_subplot(gs[2, :])
            
            # Check if interpolation needed
            time_lengths = [len(y) for y, _ in all_means]
            
            if len(set(time_lengths)) == 1:
                # All same length
                years_common = all_means[0][0]
                total = np.sum([data for _, data in all_means], axis=0)
            else:
                # Interpolate
                min_year = max([y[0] for y, _ in all_means])
                max_year = min([y[-1] for y, _ in all_means])
                n_points = min(time_lengths)
                years_common = np.linspace(min_year, max_year, n_points)
                
                interpolated = []
                for years_orig, data_orig in all_means:
                    data_interp = np.interp(years_common, years_orig, data_orig)
                    interpolated.append(data_interp)
                
                total = np.sum(interpolated, axis=0)
            
            ax.plot(years_common, total, 'r-', linewidth=3)
            ax.fill_between(years_common, 0, total, alpha=0.2, color='red')
            
            year_idx = np.argmin(np.abs(years_common - year))
            total_at_year = total[year_idx]
            ax.plot(year, total_at_year, 'ro', markersize=10, zorder=5)
            ax.text(year, total_at_year, f'  {total_at_year:.0f} mm\n  in {year}',
                   fontsize=11, fontweight='bold', va='center')
            
            ax.set_title('Total Sea Level Rise', fontsize=14, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Total SLR (mm)')
            ax.grid(True, alpha=0.3)
        
        # 6-8. Distributions at target year
        for idx, (name, key, _, color) in enumerate(components):
            if key in self.datasets:
                ax = fig.add_subplot(gs[3, idx])
                ds = self.datasets[key]
                
                slr_vars = [v for v in ds.data_vars if 'slr' in v.lower()]
                if slr_vars:
                    slr_data = ds[slr_vars[0]]
                    time_dim = 'years' if 'years' in slr_data.dims else 'year' if 'year' in slr_data.dims else 'time'
                    sample_dim = 'samples' if 'samples' in slr_data.dims else 'sample' if 'sample' in slr_data.dims else None
                    
                    years_slr = ds[time_dim].values
                    year_idx = np.argmin(np.abs(years_slr - year))
                    
                    if sample_dim:
                        data_at_year = slr_data.isel({time_dim: year_idx}).values.flatten()
                        
                        ax.hist(data_at_year, bins=15, alpha=0.7, color=color, edgecolor='black')
                        ax.axvline(np.mean(data_at_year), color='darkred', linestyle='--', linewidth=2)
                        
                        ax.set_title(f'{name.replace(chr(10), " ")} - {year}', fontsize=11, fontweight='bold')
                        ax.set_xlabel('SLR (mm)', fontsize=10)
                        ax.set_ylabel('Count', fontsize=10)
                        ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'FACT Sea Level Framework - Summary Dashboard ({scenario.upper()})',
                     fontsize=18, fontweight='bold', y=0.998)
        
        return fig
    
    def export_statistics(self, year=2100, scenario='ssp585', output_file='slr_statistics.csv'):
        """Export detailed statistics to CSV"""
        
        results = []
        
        # Climate statistics
        if 'climate' in self.datasets:
            ds = self.datasets['climate']
            years = ds['years'].values
            year_idx = np.argmin(np.abs(years - year))
            
            for var in ['surface_temperature', 'deep_ocean_temperature', 'ocean_heat_content']:
                if var in ds.data_vars:
                    data = ds[var].isel(years=year_idx).values
                    
                    # For ocean_heat_content, scale down to avoid overflow in std calculation
                    if var == 'ocean_heat_content':
                        data_scaled = data / 1e24  # Convert to YJ
                        results.append({
                            'Component': 'FAIR',
                            'Variable': var,
                            'Unit': 'YJ (10²⁴ J)',
                            'Year': year,
                            'Mean': np.mean(data_scaled),
                            'Median': np.median(data_scaled),
                            'Std': np.std(data_scaled),
                            'P5': np.percentile(data_scaled, 5),
                            'P25': np.percentile(data_scaled, 25),
                            'P75': np.percentile(data_scaled, 75),
                            'P95': np.percentile(data_scaled, 95),
                            'Min': np.min(data_scaled),
                            'Max': np.max(data_scaled)
                        })
                    else:
                        results.append({
                            'Component': 'FAIR',
                            'Variable': var,
                            'Unit': '°C',
                            'Year': year,
                            'Mean': np.mean(data),
                            'Median': np.median(data),
                            'Std': np.std(data),
                            'P5': np.percentile(data, 5),
                            'P25': np.percentile(data, 25),
                            'P75': np.percentile(data, 75),
                            'P95': np.percentile(data, 95),
                            'Min': np.min(data),
                            'Max': np.max(data)
                        })
        
        # Sea level statistics
        slr_components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Components': 'ice_gslr'
        }
        
        for comp_name, key in slr_components.items():
            if key in self.datasets:
                ds = self.datasets[key]
                
                slr_vars = [v for v in ds.data_vars if 'slr' in v.lower()]
                for var in slr_vars:
                    slr_data = ds[var]
                    time_dim = 'years' if 'years' in slr_data.dims else 'year' if 'year' in slr_data.dims else 'time'
                    sample_dim = 'samples' if 'samples' in slr_data.dims else 'sample' if 'sample' in slr_data.dims else None
                    
                    years_slr = ds[time_dim].values
                    year_idx = np.argmin(np.abs(years_slr - year))
                    
                    if sample_dim:
                        data = slr_data.isel({time_dim: year_idx}).values
                        
                        results.append({
                            'Component': comp_name,
                            'Variable': var,
                            'Unit': 'mm',
                            'Year': year,
                            'Mean': np.mean(data),
                            'Median': np.median(data),
                            'Std': np.std(data),
                            'P5': np.percentile(data, 5),
                            'P25': np.percentile(data, 25),
                            'P75': np.percentile(data, 75),
                            'P95': np.percentile(data, 95),
                            'Min': np.min(data),
                            'Max': np.max(data)
                        })
        
        df = pd.DataFrame(results)
        df.to_csv(output_file, index=False)
        print(f"\nStatistics exported to {output_file}")
        print(f"\nSummary:")
        print(df[['Component', 'Variable', 'Unit', 'Mean', 'P5', 'P95']].to_string(index=False))
        
        return df

    def plot_individual_gslr_components(self, year=2100, figsize=(18, 12)):
        """Plot each individual GSLR component in detail with ensemble uncertainty"""
        
        gslr_components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Sheets & Glaciers': 'ice_gslr'
        }
        
        available = {name: key for name, key in gslr_components.items() if key in self.datasets}
        
        if not available:
            print("No GSLR data available")
            return
        
        n_components = len(available)
        fig, axes = plt.subplots(n_components, 3, figsize=figsize)
        
        if n_components == 1:
            axes = axes.reshape(1, -1)
        
        colors = ['blue', 'green', 'orange']
        
        plot_idx = 0
        for idx, (name, key) in enumerate(available.items()):
            ds = self.datasets[key]
            
            # Find GSLR variable - check all possible names
            gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            if not gslr_vars:
                print(f"⚠ No SLR variable found in {key}. Available variables: {list(ds.data_vars)}")
                # Mark plots as unavailable
                for col in range(3):
                    axes[idx, col].text(0.5, 0.5, f'No data available\nfor {name}',
                                       ha='center', va='center', transform=axes[idx, col].transAxes)
                    axes[idx, col].set_title(f'{name} - No Data', fontsize=12, fontweight='bold')
                continue
            
            gslr_var = gslr_vars[0]
            print(f"  Using variable '{gslr_var}' from {key}")
            gslr_data = ds[gslr_var]
            
            print(f"    Shape: {gslr_data.shape}, Dims: {gslr_data.dims}")
            
            # Determine dimensions
            time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
            sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else 'ensemble'
            
            if time_dim not in gslr_data.dims:
                print(f"⚠ Could not find time dimension in {key}. Available dims: {gslr_data.dims}")
                for col in range(3):
                    axes[idx, col].text(0.5, 0.5, f'No time dimension\nfor {name}',
                                       ha='center', va='center', transform=axes[idx, col].transAxes)
                    axes[idx, col].set_title(f'{name} - No Time Data', fontsize=12, fontweight='bold')
                continue
            
            years_data = ds[time_dim].values
            print(f"    Years: {years_data[0]} to {years_data[-1]} ({len(years_data)} points)")
            
            # Check if sample dimension exists
            has_samples = sample_dim in gslr_data.dims
            if has_samples:
                print(f"    Samples: {gslr_data.sizes[sample_dim]}")
            
            # Plot 1: Time series with uncertainty
            ax = axes[idx, 0]
            
            if has_samples:
                # Individual ensemble members (plot a few)
                n_plot = min(10, gslr_data.sizes[sample_dim])
                for i in range(n_plot):
                    sample_data = gslr_data.isel({sample_dim: i}).values.squeeze()
                    if sample_data.ndim > 1:
                        sample_data = sample_data.flatten()
                    ax.plot(years_data, sample_data, alpha=0.2, color=colors[idx], linewidth=0.5)
                
                # Statistics
                mean = gslr_data.mean(dim=sample_dim).values.squeeze()
                p5 = gslr_data.quantile(0.05, dim=sample_dim).values.squeeze()
                p95 = gslr_data.quantile(0.95, dim=sample_dim).values.squeeze()
                p25 = gslr_data.quantile(0.25, dim=sample_dim).values.squeeze()
                p75 = gslr_data.quantile(0.75, dim=sample_dim).values.squeeze()
                
                if mean.ndim > 1:
                    mean = mean.flatten()
                if p5.ndim > 1:
                    p5 = p5.flatten()
                if p95.ndim > 1:
                    p95 = p95.flatten()
                if p25.ndim > 1:
                    p25 = p25.flatten()
                if p75.ndim > 1:
                    p75 = p75.flatten()
                
                print(f"    Mean range: {np.min(mean):.2f} to {np.max(mean):.2f}")
                
                ax.fill_between(years_data, p5, p95, alpha=0.2, color=colors[idx], label='5-95th percentile')
                ax.fill_between(years_data, p25, p75, alpha=0.3, color=colors[idx], label='25-75th percentile')
                ax.plot(years_data, mean, color=colors[idx], linewidth=2.5, label='Mean')
            else:
                gslr_values = gslr_data.values.squeeze()
                if gslr_values.ndim > 1:
                    gslr_values = gslr_values.flatten()
                print(f"    Value range: {np.min(gslr_values):.2f} to {np.max(gslr_values):.2f}")
                ax.plot(years_data, gslr_values, color=colors[idx], linewidth=2.5)
            
            ax.set_title(f'{name} - Time Series', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Global SLR (mm)')
            if has_samples:
                ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)
            
            # Plot 2: Distribution at target year
            ax = axes[idx, 1]
            
            try:
                year_idx = np.argmin(np.abs(years_data - year))
                actual_year = years_data[year_idx]
                print(f"    Target year {year}, using data from {actual_year}")
                
                if has_samples:
                    data_at_year = gslr_data.isel({time_dim: year_idx}).values.flatten()
                    
                    if len(data_at_year) > 0:
                        ax.hist(data_at_year, bins=20, alpha=0.7, color=colors[idx], edgecolor='black')
                        ax.axvline(np.mean(data_at_year), color='darkred', linestyle='--', linewidth=2,
                                  label=f'Mean: {np.mean(data_at_year):.1f} mm')
                        ax.axvline(np.median(data_at_year), color='red', linestyle=':', linewidth=2,
                                  label=f'Median: {np.median(data_at_year):.1f} mm')
                        ax.axvline(np.percentile(data_at_year, 5), color='gray', linestyle=':', linewidth=1.5,
                                  label=f'5th: {np.percentile(data_at_year, 5):.1f} mm')
                        ax.axvline(np.percentile(data_at_year, 95), color='gray', linestyle=':', linewidth=1.5,
                                  label=f'95th: {np.percentile(data_at_year, 95):.1f} mm')
                    else:
                        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
                else:
                    data_at_year = gslr_data.isel({time_dim: year_idx}).values.flatten()
                    if len(data_at_year) > 0:
                        ax.axvline(data_at_year[0], color=colors[idx], linewidth=3)
                        ax.text(0.5, 0.5, f'Value: {data_at_year[0]:.1f} mm',
                               ha='center', va='center', transform=ax.transAxes)
                    else:
                        ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
                
                ax.set_title(f'{name} - Distribution at {actual_year}', fontsize=12, fontweight='bold')
                ax.set_xlabel('SLR (mm)')
                ax.set_ylabel('Frequency')
                if has_samples:
                    ax.legend(fontsize=8)
                ax.grid(True, alpha=0.3, axis='y')
                
            except Exception as e:
                print(f"    ⚠ Error plotting distribution: {e}")
                ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', transform=ax.transAxes)
            
            # Plot 3: Rate of change (decadal trend)
            ax = axes[idx, 2]
            
            try:
                if has_samples:
                    mean = gslr_data.mean(dim=sample_dim).values.squeeze()
                    if mean.ndim > 1:
                        mean = mean.flatten()
                else:
                    mean = gslr_data.values.squeeze()
                    if mean.ndim > 1:
                        mean = mean.flatten()
                
                # Calculate decadal rate
                if len(mean) > 1 and len(years_data) > 1:
                    # Calculate rate
                    year_diff = np.diff(years_data)
                    mean_diff = np.diff(mean)
                    rate = mean_diff / year_diff * 10  # mm per decade
                    years_rate = years_data[:-1] + year_diff / 2  # Use midpoint
                    
                    print(f"    Rate range: {np.min(rate):.2f} to {np.max(rate):.2f} mm/decade")
                    
                    # Smooth the rate if we have enough points
                    if len(rate) >= 10:
                        window = 10
                        # Use valid mode
                        rate_smoothed = np.convolve(rate, np.ones(window)/window, mode='valid')
                        # Calculate corresponding years
                        years_smoothed = years_rate[window//2:window//2+len(rate_smoothed)]
                        
                        ax.plot(years_smoothed, rate_smoothed, color=colors[idx], linewidth=2.5, label='Smoothed (10yr)')
                        ax.plot(years_rate, rate, color=colors[idx], linewidth=1, alpha=0.3, label='Annual')
                        ax.legend(fontsize=8)
                    else:
                        ax.plot(years_rate, rate, color=colors[idx], linewidth=2.5)
                    
                    ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
                    ax.set_title(f'{name} - Rate of Change', fontsize=12, fontweight='bold')
                    ax.set_xlabel('Year')
                    ax.set_ylabel('SLR Rate (mm/decade)')
                    ax.grid(True, alpha=0.3)
                else:
                    ax.text(0.5, 0.5, 'Insufficient data\nfor rate calculation',
                           ha='center', va='center', transform=ax.transAxes)
                    ax.set_title(f'{name} - Rate of Change', fontsize=12, fontweight='bold')
                    
            except Exception as e:
                print(f"    ⚠ Error plotting rate: {e}")
                ax.text(0.5, 0.5, f'Error: {str(e)}', ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'{name} - Rate of Change', fontsize=12, fontweight='bold')
            
            plot_idx += 1
        
        plt.suptitle('Individual GSLR Component Analysis', fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
        
    def _plot_individual_gslr_components(self, year=2100, figsize=(18, 12)):
        """Plot each individual GSLR component in detail with ensemble uncertainty"""
        
        gslr_components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Sheets & Glaciers': 'ice_gslr'
        }
        
        available = {name: key for name, key in gslr_components.items() if key in self.datasets}
        
        if not available:
            print("No GSLR data available")
            return
        
        n_components = len(available)
        fig, axes = plt.subplots(n_components, 3, figsize=figsize)
        
        if n_components == 1:
            axes = axes.reshape(1, -1)
        
        colors = ['blue', 'green', 'orange']
        
        for idx, (name, key) in enumerate(available.items()):
            ds = self.datasets[key]
            
            # Find GSLR variable - check all possible names
            gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            if not gslr_vars:
                print(f"No SLR variable found in {key}. Available variables: {list(ds.data_vars)}")
                continue
            
            gslr_var = gslr_vars[0]
            print(f"  Using variable '{gslr_var}' from {key}")
            gslr_data = ds[gslr_var]
            
            # Determine dimensions
            time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
            sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else 'ensemble'
            
            if time_dim not in gslr_data.dims:
                print(f"Could not find time dimension in {key}. Available dims: {gslr_data.dims}")
                continue
            
            years_data = ds[time_dim].values
            
            # Plot 1: Time series with uncertainty
            ax = axes[idx, 0]
            
            if sample_dim in gslr_data.dims:
                # Individual ensemble members
                for i in range(min(10, gslr_data.sizes[sample_dim])):  # Plot max 10 members
                    sample_data = gslr_data.isel({sample_dim: i}).values.squeeze()
                    if sample_data.ndim > 1:
                        sample_data = sample_data.flatten()
                    ax.plot(years_data, sample_data, alpha=0.2, color=colors[idx], linewidth=0.5)
                
                # Statistics
                mean = gslr_data.mean(dim=sample_dim).values.squeeze()
                p5 = gslr_data.quantile(0.05, dim=sample_dim).values.squeeze()
                p95 = gslr_data.quantile(0.95, dim=sample_dim).values.squeeze()
                p25 = gslr_data.quantile(0.25, dim=sample_dim).values.squeeze()
                p75 = gslr_data.quantile(0.75, dim=sample_dim).values.squeeze()
                
                if mean.ndim > 1:
                    mean = mean.flatten()
                if p5.ndim > 1:
                    p5 = p5.flatten()
                if p95.ndim > 1:
                    p95 = p95.flatten()
                if p25.ndim > 1:
                    p25 = p25.flatten()
                if p75.ndim > 1:
                    p75 = p75.flatten()
                
                ax.fill_between(years_data, p5, p95, alpha=0.2, color=colors[idx], label='5-95th percentile')
                ax.fill_between(years_data, p25, p75, alpha=0.3, color=colors[idx], label='25-75th percentile')
                ax.plot(years_data, mean, color=colors[idx], linewidth=2.5, label='Mean')
            else:
                gslr_values = gslr_data.values.squeeze()
                if gslr_values.ndim > 1:
                    gslr_values = gslr_values.flatten()
                ax.plot(years_data, gslr_values, color=colors[idx], linewidth=2.5)
            
            ax.set_title(f'{name} - Time Series', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Global SLR (mm)')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3)
            
            # Plot 2: Distribution at target year
            ax = axes[idx, 1]
            year_idx = np.argmin(np.abs(years_data - year))
            
            if sample_dim in gslr_data.dims:
                data_at_year = gslr_data.isel({time_dim: year_idx}).values.flatten()
                
                ax.hist(data_at_year, bins=20, alpha=0.7, color=colors[idx], edgecolor='black')
                ax.axvline(np.mean(data_at_year), color='darkred', linestyle='--', linewidth=2,
                          label=f'Mean: {np.mean(data_at_year):.1f} mm')
                ax.axvline(np.median(data_at_year), color='red', linestyle=':', linewidth=2,
                          label=f'Median: {np.median(data_at_year):.1f} mm')
                ax.axvline(np.percentile(data_at_year, 5), color='gray', linestyle=':', linewidth=1.5,
                          label=f'5th: {np.percentile(data_at_year, 5):.1f} mm')
                ax.axvline(np.percentile(data_at_year, 95), color='gray', linestyle=':', linewidth=1.5,
                          label=f'95th: {np.percentile(data_at_year, 95):.1f} mm')
            else:
                data_at_year = gslr_data.isel({time_dim: year_idx}).values.flatten()[0]
                ax.axvline(data_at_year, color=colors[idx], linewidth=3)
            
            ax.set_title(f'{name} - Distribution at {year}', fontsize=12, fontweight='bold')
            ax.set_xlabel('SLR (mm)')
            ax.set_ylabel('Frequency')
            ax.legend(fontsize=8)
            ax.grid(True, alpha=0.3, axis='y')
            
            # Plot 3: Rate of change (decadal trend)
            ax = axes[idx, 2]
            
            if sample_dim in gslr_data.dims:
                mean = gslr_data.mean(dim=sample_dim).values.squeeze()
                if mean.ndim > 1:
                    mean = mean.flatten()
                
            # Calculate decadal rate
            if len(mean) > 1 and len(years_data) > 1:
                rate = np.diff(mean) / np.diff(years_data) * 10  # mm per decade
                years_rate = years_data[:-1]
                
                # Smooth the rate
                if len(rate) >= 10:
                    window = 10
                    # Use valid mode which returns len(rate) - window + 1 points
                    rate_smoothed = np.convolve(rate, np.ones(window)/window, mode='valid')
                    # Adjust years to match: take the middle of the window
                    years_smoothed = years_rate[(window-1)//2:len(rate_smoothed)+(window-1)//2]
                    
                    # Ensure they match
                    if len(years_smoothed) != len(rate_smoothed):
                        years_smoothed = years_smoothed[:len(rate_smoothed)]
                    
                    ax.plot(years_smoothed, rate_smoothed, color=colors[idx], linewidth=2.5, label='Smoothed (10yr)')
                    ax.plot(years_rate, rate, color=colors[idx], linewidth=1, alpha=0.3, label='Annual')
                else:
                    ax.plot(years_rate, rate, color=colors[idx], linewidth=2.5)
                
                ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
                ax.set_title(f'{name} - Rate of Change', fontsize=12, fontweight='bold')
                ax.set_xlabel('Year')
                ax.set_ylabel('SLR Rate (mm/decade)')
                if len(rate) >= 10:
                    ax.legend(fontsize=8)
                ax.grid(True, alpha=0.3)
            else:
                ax.text(0.5, 0.5, 'Insufficient data\nfor rate calculation',
                       ha='center', va='center', transform=ax.transAxes)

        plt.suptitle('Individual GSLR Component Analysis', fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_gslr_differences(self, year=2100, figsize=(18, 10)):
        """Plot differences between GSLR components"""
        
        gslr_components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Sheets & Glaciers': 'ice_gslr'
        }
        
        available = {name: key for name, key in gslr_components.items() if key in self.datasets}
        
        if len(available) < 2:
            print("Need at least 2 GSLR components for difference plots")
            return
        
        # Validate that datasets have SLR variables
        valid_components = {}
        for name, key in available.items():
            ds = self.datasets[key]
            slr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            if slr_vars:
                valid_components[name] = key
            else:
                print(f"⚠ Skipping {name}: no SLR variable found. Available: {list(ds.data_vars)}")
        
        if len(valid_components) < 2:
            print("Need at least 2 valid GSLR components for difference plots")
            return
        
        # Get all pairwise combinations
        from itertools import combinations
        pairs = list(combinations(valid_components.items(), 2))
        
        n_pairs = len(pairs)
        fig, axes = plt.subplots(n_pairs, 3, figsize=(figsize[0], figsize[1] * n_pairs / 3))
        
        if n_pairs == 1:
            axes = axes.reshape(1, -1)
        
        for idx, ((name1, key1), (name2, key2)) in enumerate(pairs):
            ds1 = self.datasets[key1]
            ds2 = self.datasets[key2]
            
            # Get GSLR data with better error handling
            gslr_vars1 = [v for v in ds1.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            gslr_vars2 = [v for v in ds2.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            
            if not gslr_vars1 or not gslr_vars2:
                print(f"⚠ Skipping {name1} vs {name2}: missing variables")
                continue
            
            gslr_var1 = gslr_vars1[0]
            gslr_var2 = gslr_vars2[0]
            
            print(f"  Comparing {name1} ({gslr_var1}) vs {name2} ({gslr_var2})")
            
            gslr_data1 = ds1[gslr_var1]
            gslr_data2 = ds2[gslr_var2]
            
            # Determine dimensions
            time_dim1 = 'years' if 'years' in gslr_data1.dims else 'year' if 'year' in gslr_data1.dims else 'time'
            time_dim2 = 'years' if 'years' in gslr_data2.dims else 'year' if 'year' in gslr_data2.dims else 'time'
            sample_dim1 = 'samples' if 'samples' in gslr_data1.dims else 'sample' if 'sample' in gslr_data1.dims else 'ensemble'
            sample_dim2 = 'samples' if 'samples' in gslr_data2.dims else 'sample' if 'sample' in gslr_data2.dims else 'ensemble'
            
            years_data1 = ds1[time_dim1].values
            years_data2 = ds2[time_dim2].values
            
            # Get mean values
            if sample_dim1 in gslr_data1.dims:
                mean1 = gslr_data1.mean(dim=sample_dim1).values.squeeze()
                if mean1.ndim > 1:
                    mean1 = mean1.flatten()
            else:
                mean1 = gslr_data1.values.squeeze()
                if mean1.ndim > 1:
                    mean1 = mean1.flatten()
            
            if sample_dim2 in gslr_data2.dims:
                mean2 = gslr_data2.mean(dim=sample_dim2).values.squeeze()
                if mean2.ndim > 1:
                    mean2 = mean2.flatten()
            else:
                mean2 = gslr_data2.values.squeeze()
                if mean2.ndim > 1:
                    mean2 = mean2.flatten()
            
            # Check if need to interpolate
            if len(years_data1) != len(years_data2) or not np.allclose(years_data1, years_data2):
                print(f"  ⚠ Interpolating {name1} and {name2} to common time axis")
                min_year = max(years_data1[0], years_data2[0])
                max_year = min(years_data1[-1], years_data2[-1])
                n_points = min(len(years_data1), len(years_data2))
                years_common = np.linspace(min_year, max_year, n_points)
                
                mean1_interp = np.interp(years_common, years_data1, mean1)
                mean2_interp = np.interp(years_common, years_data2, mean2)
                
                years_data = years_common
                mean1 = mean1_interp
                mean2 = mean2_interp
            else:
                years_data = years_data1
            
            # Calculate difference
            diff = mean1 - mean2
            
            # Plot 1: Both time series
            ax = axes[idx, 0]
            ax.plot(years_data, mean1, linewidth=2.5, label=name1, color='blue')
            ax.plot(years_data, mean2, linewidth=2.5, label=name2, color='red')
            ax.set_title(f'{name1} vs {name2}', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Global SLR (mm)')
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Plot 2: Difference time series
            ax = axes[idx, 1]
            ax.plot(years_data, diff, linewidth=2.5, color='purple')
            ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
            ax.fill_between(years_data, 0, diff, alpha=0.3, color='purple')
            ax.set_title(f'Difference: {name1} - {name2}', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Difference (mm)')
            ax.grid(True, alpha=0.3)
            
            # Add statistics annotation
            mean_diff = np.mean(diff)
            std_diff = np.std(diff)
            max_diff = np.max(np.abs(diff))
            ax.text(0.05, 0.95, f'Mean: {mean_diff:.1f} mm\nStd: {std_diff:.1f} mm\nMax: {max_diff:.1f} mm',
                   transform=ax.transAxes, fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            # Plot 3: Relative difference (percentage)
            ax = axes[idx, 2]
            # Avoid division by zero
            mean2_safe = np.where(np.abs(mean2) > 1e-6, mean2, 1)
            rel_diff = (diff / mean2_safe) * 100
            
            # Remove infinities
            rel_diff = np.where(np.isinf(rel_diff), np.nan, rel_diff)
            
            ax.plot(years_data, rel_diff, linewidth=2.5, color='green')
            ax.axhline(y=0, color='black', linestyle='--', linewidth=0.5)
            ax.fill_between(years_data, 0, rel_diff, alpha=0.3, color='green')
            ax.set_title(f'Relative Difference (%)', fontsize=12, fontweight='bold')
            ax.set_xlabel('Year')
            ax.set_ylabel('Relative Difference (%)')
            ax.grid(True, alpha=0.3)
            
            # Add statistics
            mean_rel = np.nanmean(rel_diff)
            ax.text(0.05, 0.95, f'Mean: {mean_rel:.1f}%',
                   transform=ax.transAxes, fontsize=9, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.suptitle('GSLR Component Differences', fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
    
    def plot_gslr_correlation_matrix(self, year=2100, figsize=(12, 10)):
        """Create correlation matrix between GSLR components at specific year"""
        
        gslr_components = {
            'Land Water Storage': 'lws_gslr',
            'Sterodynamics': 'stereo_gslr',
            'Ice Sheets & Glaciers': 'ice_gslr'
        }
        
        available = {name: key for name, key in gslr_components.items() if key in self.datasets}
        
        if len(available) < 2:
            print("Need at least 2 GSLR components for correlation analysis")
            return
        
        # Collect ensemble data at target year
        ensemble_data = {}
        
        for name, key in available.items():
            ds = self.datasets[key]
            
            # Find SLR variable
            gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
            if not gslr_vars:
                print(f"⚠ Skipping {name}: no SLR variable found")
                continue
            
            gslr_var = gslr_vars[0]
            gslr_data = ds[gslr_var]
            
            time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
            sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else 'ensemble'
            
            if time_dim not in gslr_data.dims:
                print(f"⚠ Skipping {name}: no time dimension found")
                continue
            
            years_data = ds[time_dim].values
            year_idx = np.argmin(np.abs(years_data - year))
            
            if sample_dim in gslr_data.dims:
                data_at_year = gslr_data.isel({time_dim: year_idx}).values.flatten()
                ensemble_data[name] = data_at_year
            else:
                print(f"⚠ Skipping {name}: no ensemble dimension for correlation")
        
        if len(ensemble_data) < 2:
            print("Need at least 2 components with ensemble data for correlation analysis")
            return
        
        # Create DataFrame
        df = pd.DataFrame(ensemble_data)
        
        # Calculate correlation
        corr_matrix = df.corr()
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot 1: Correlation heatmap
        try:
            import seaborn as sns
            sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm', center=0,
                        square=True, linewidths=1, cbar_kws={"shrink": 0.8}, ax=ax1,
                        vmin=-1, vmax=1)
        except ImportError:
            # Fallback if seaborn not available
            im = ax1.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
            ax1.set_xticks(range(len(corr_matrix.columns)))
            ax1.set_yticks(range(len(corr_matrix.columns)))
            ax1.set_xticklabels(corr_matrix.columns, rotation=45, ha='right')
            ax1.set_yticklabels(corr_matrix.columns)
            
            # Add text annotations
            for i in range(len(corr_matrix)):
                for j in range(len(corr_matrix)):
                    text = ax1.text(j, i, f'{corr_matrix.iloc[i, j]:.3f}',
                                   ha="center", va="center", color="black", fontsize=10)
            
            plt.colorbar(im, ax=ax1)
        
        ax1.set_title(f'Correlation Matrix at {year}', fontsize=14, fontweight='bold')
        
        # Plot 2: Scatter matrix (for first pair)
        if len(ensemble_data) >= 2:
            names = list(ensemble_data.keys())
            data1 = ensemble_data[names[0]]
            data2 = ensemble_data[names[1]]
            
            ax2.scatter(data1, data2, alpha=0.6, s=100, edgecolors='black')
            ax2.set_xlabel(f'{names[0]} (mm)', fontsize=11)
            ax2.set_ylabel(f'{names[1]} (mm)', fontsize=11)
            ax2.set_title(f'{names[0]} vs {names[1]} at {year}', fontsize=12, fontweight='bold')
            ax2.grid(True, alpha=0.3)
            
            # Add correlation coefficient
            corr_val = corr_matrix.loc[names[0], names[1]]
            ax2.text(0.05, 0.95, f'r = {corr_val:.3f}',
                    transform=ax2.transAxes, fontsize=11, verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            
            # Add best fit line
            z = np.polyfit(data1, data2, 1)
            p = np.poly1d(z)
            ax2.plot(data1, p(data1), "r--", alpha=0.8, linewidth=2, label=f'y={z[0]:.2f}x+{z[1]:.2f}')
            ax2.legend()
        
        plt.suptitle('GSLR Component Correlations', fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        
        return fig
        
    def plot_lslr_location_points(self, component='lws', year=2100, sample_idx=None, figsize=(16, 12)):
        """Plot LSLR for specific location points (when no spatial grid available)"""
        
        import cartopy.crs as ccrs
        import cartopy.feature as cfeature
        
        lslr_key = f'{component}_lslr'
        
        if lslr_key not in self.datasets:
            print(f"Dataset '{lslr_key}' not found")
            return
        
        ds = self.datasets[lslr_key]
        
        print(f"\nAnalyzing location-based {lslr_key}:")
        print(f"  Variables: {list(ds.data_vars)}")
        print(f"  Dimensions: {dict(ds.dims)}")
        
        # Check if we have location dimension
        loc_dim = None
        for name in ['location', 'locations', 'site', 'sites', 'point', 'points']:
            if name in ds.dims:
                loc_dim = name
                break
        
        if loc_dim is None:
            print("⚠ No location dimension found")
            return
        
        print(f"  Location dimension: {loc_dim}")
        n_locations = ds.dims[loc_dim]
        print(f"  Number of locations: {n_locations}")
        
        # Try to find lat/lon coordinates
        lat_coord = None
        lon_coord = None
        
        for name in ['lat', 'latitude']:
            if name in ds.coords or name in ds.data_vars:
                lat_coord = name
                break
        
        for name in ['lon', 'longitude']:
            if name in ds.coords or name in ds.data_vars:
                lon_coord = name
                break
        
        if lat_coord is None or lon_coord is None:
            print(f"⚠ No lat/lon coordinates found")
            print(f"  Available coords: {list(ds.coords)}")
            return
        
        lats = ds[lat_coord].values
        lons = ds[lon_coord].values
        
        print(f"  Coordinates found: {len(lats)} locations")
        print(f"  Lat range: [{np.min(lats):.2f}, {np.max(lats):.2f}]")
        print(f"  Lon range: [{np.min(lons):.2f}, {np.max(lons):.2f}]")
        
        # Get LSLR data
        lslr_vars = [v for v in ds.data_vars if 'lslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
        if not lslr_vars:
            print("No LSLR variable found")
            return
        
        lslr_var = lslr_vars[0]
        lslr_data = ds[lslr_var]
        
        # Get time and sample dimensions
        time_dim = 'years' if 'years' in lslr_data.dims else 'year' if 'year' in lslr_data.dims else 'time'
        sample_dim = 'samples' if 'samples' in lslr_data.dims else 'sample' if 'sample' in lslr_data.dims else None
        
        # Select data for target year
        if time_dim in lslr_data.dims:
            years = ds[time_dim].values
            year_idx = np.argmin(np.abs(years - year))
            actual_year = years[year_idx]
            print(f"  Using year: {actual_year}")
            lslr_at_year = lslr_data.isel({time_dim: year_idx})
        else:
            actual_year = year
            lslr_at_year = lslr_data
        
        # Handle samples
        if sample_dim and sample_dim in lslr_at_year.dims:
            if sample_idx is None:
                lslr_values = lslr_at_year.mean(dim=sample_dim).values
                title_suffix = "Mean across samples"
            else:
                lslr_values = lslr_at_year.isel({sample_dim: sample_idx}).values
                title_suffix = f"Sample {sample_idx}"
        else:
            lslr_values = lslr_at_year.values
            title_suffix = "Single value"
        
        # Flatten if needed
        if lslr_values.ndim > 1:
            lslr_values = lslr_values.flatten()
        
        print(f"  SLR range: [{np.min(lslr_values):.2f}, {np.max(lslr_values):.2f}] mm")
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
        
        # Main map
        ax_map = fig.add_subplot(gs[0, :], projection=ccrs.PlateCarree())
        
        # Add geographic features
        ax_map.coastlines(resolution='50m', linewidth=1)
        ax_map.add_feature(cfeature.BORDERS, linewidth=0.5, alpha=0.5)
        ax_map.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3)
        ax_map.add_feature(cfeature.OCEAN, facecolor='lightblue', alpha=0.3)
        
        # Plot points colored by SLR value
        scatter = ax_map.scatter(lons, lats, c=lslr_values, s=100, 
                                cmap='RdYlBu_r', edgecolors='black', linewidth=0.5,
                                transform=ccrs.PlateCarree(), zorder=5)
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax_map, orientation='horizontal', pad=0.05, shrink=0.8)
        cbar.set_label('Local Sea Level Rise (mm)', fontsize=12)
        
        # Label some key points
        for i in range(min(10, len(lats))):  # Label first 10 points
            ax_map.text(lons[i], lats[i]+1, f'{i+1}', transform=ccrs.PlateCarree(),
                       fontsize=8, ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
        
        # Grid lines
        gl = ax_map.gridlines(draw_labels=True, linewidth=0.5, alpha=0.5, linestyle='--')
        gl.top_labels = False
        gl.right_labels = False
        
        # Set extent with padding
        extent = [np.min(lons)-5, np.max(lons)+5, np.min(lats)-5, np.max(lats)+5]
        ax_map.set_extent(extent, crs=ccrs.PlateCarree())
        
        component_name = component.upper()
        ax_map.set_title(f'{component_name} Local SLR at {actual_year} - {title_suffix}',
                         fontsize=14, fontweight='bold')
        
        # Histogram
        ax_hist = fig.add_subplot(gs[1, 0])
        ax_hist.hist(lslr_values, bins=30, alpha=0.7, color='steelblue', edgecolor='black')
        ax_hist.axvline(np.mean(lslr_values), color='red', linestyle='--', linewidth=2,
                        label=f'Mean: {np.mean(lslr_values):.1f} mm')
        ax_hist.set_xlabel('Local SLR (mm)')
        ax_hist.set_ylabel('Frequency')
        ax_hist.set_title('Distribution Across Locations', fontsize=12, fontweight='bold')
        ax_hist.legend()
        ax_hist.grid(True, alpha=0.3, axis='y')
        
        # Top locations table
        ax_table = fig.add_subplot(gs[1, 1])
        ax_table.axis('off')
        
        # Sort and get top/bottom locations
        sorted_idx = np.argsort(lslr_values)
        
        table_text = "Top 5 Highest SLR Locations:\n\n"
        for i in range(min(5, len(sorted_idx))):
            idx = sorted_idx[-(i+1)]
            table_text += f"{i+1}. Loc {idx+1}: {lslr_values[idx]:.1f} mm\n"
            table_text += f"   ({lats[idx]:.2f}°, {lons[idx]:.2f}°)\n"
        
        table_text += "\nBottom 5 Lowest SLR Locations:\n\n"
        for i in range(min(5, len(sorted_idx))):
            idx = sorted_idx[i]
            table_text += f"{i+1}. Loc {idx+1}: {lslr_values[idx]:.1f} mm\n"
            table_text += f"   ({lats[idx]:.2f}°, {lons[idx]:.2f}°)\n"
        
        ax_table.text(0.1, 0.9, table_text, transform=ax_table.transAxes,
                      fontsize=9, verticalalignment='top', family='monospace',
                      bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5))
        
        plt.suptitle(f'Local Sea Level Rise - {component_name} ({n_locations} locations)',
                     fontsize=16, fontweight='bold')
        
        return fig

    def plot_lslr_spatial_map(self, component='lws', year=2100, sample_idx=None, figsize=(18, 12)):
        """Plot spatial map of Local SLR with sample locations highlighted"""
        
        import cartopy.crs as ccrs
        import cartopy.feature as cfeature
        
        # Get the LSLR dataset
        lslr_key = f'{component}_lslr'
        
        if lslr_key not in self.datasets:
            print(f"Dataset '{lslr_key}' not found. Available: {list(self.datasets.keys())}")
            return
        
        ds = self.datasets[lslr_key]
        
        print(f"\nAnalyzing {lslr_key}:")
        print(f"  Variables: {list(ds.data_vars)}")
        print(f"  Coordinates: {list(ds.coords)}")
        print(f"  Dimensions: {dict(ds.dims)}")
        
        # Find LSLR variable
        lslr_vars = [v for v in ds.data_vars if 'lslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
        if not lslr_vars:
            print(f"No LSLR variable found. Available variables: {list(ds.data_vars)}")
            return
        
        lslr_var = lslr_vars[0]
        print(f"  Using variable: {lslr_var}")
        lslr_data = ds[lslr_var]
        print(f"  Shape: {lslr_data.shape}")
        print(f"  Dims: {lslr_data.dims}")
        
        # Check for spatial dimensions
        lat_names = ['lat', 'latitude', 'y']
        lon_names = ['lon', 'longitude', 'x']
        
        lat_dim = None
        lon_dim = None
        
        for name in lat_names:
            if name in lslr_data.dims or name in ds.coords:
                lat_dim = name
                break
        
        for name in lon_names:
            if name in lslr_data.dims or name in ds.coords:
                lon_dim = name
                break
        
        if lat_dim is None or lon_dim is None:
            print(f"⚠ No spatial dimensions found (lat/lon)")
            print(f"  This dataset appears to be for specific locations only")
            
            # Try to read location file if referenced
            if 'location' in ds.dims or 'locations' in ds.dims:
                return self.plot_lslr_location_points(component, year, sample_idx, figsize)
            
            return
        
        print(f"  Spatial dims: lat={lat_dim}, lon={lon_dim}")
        
        # Get coordinates
        lats = ds[lat_dim].values
        lons = ds[lon_dim].values
        
        print(f"  Lat range: [{np.min(lats):.2f}, {np.max(lats):.2f}]")
        print(f"  Lon range: [{np.min(lons):.2f}, {np.max(lons):.2f}]")
        
        # Get time dimension
        time_dim = 'years' if 'years' in lslr_data.dims else 'year' if 'year' in lslr_data.dims else 'time'
        sample_dim = 'samples' if 'samples' in lslr_data.dims else 'sample' if 'sample' in lslr_data.dims else 'ensemble'
        
        # Find target year
        if time_dim in lslr_data.dims:
            years = ds[time_dim].values
            year_idx = np.argmin(np.abs(years - year))
            actual_year = years[year_idx]
            print(f"  Using year: {actual_year} (requested: {year})")
        else:
            year_idx = 0
            actual_year = year
        
        # Determine which sample to plot
        if sample_dim in lslr_data.dims:
            n_samples = lslr_data.sizes[sample_dim]
            print(f"  Number of samples: {n_samples}")
            
            if sample_idx is None:
                # Use mean across samples
                if time_dim in lslr_data.dims:
                    data_to_plot = lslr_data.isel({time_dim: year_idx}).mean(dim=sample_dim)
                else:
                    data_to_plot = lslr_data.mean(dim=sample_dim)
                title_suffix = "Mean"
            else:
                # Use specific sample
                if time_dim in lslr_data.dims:
                    data_to_plot = lslr_data.isel({time_dim: year_idx, sample_dim: sample_idx})
                else:
                    data_to_plot = lslr_data.isel({sample_dim: sample_idx})
                title_suffix = f"Sample {sample_idx}"
        else:
            if time_dim in lslr_data.dims:
                data_to_plot = lslr_data.isel({time_dim: year_idx})
            else:
                data_to_plot = lslr_data
            title_suffix = "Single Value"
        
        # Create figure with subplots
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3, height_ratios=[2, 1])
        
        # Main map
        ax_map = fig.add_subplot(gs[0, :], projection=ccrs.PlateCarree())
        
        # Plot the data
        data_values = data_to_plot.values
        
        # Handle different data shapes
        if data_values.ndim > 2:
            data_values = data_values.squeeze()
        
        print(f"  Plotting data shape: {data_values.shape}")
        print(f"  Data range: [{np.nanmin(data_values):.2f}, {np.nanmax(data_values):.2f}] mm")
        
        # Create mesh for plotting
        if len(lats.shape) == 1 and len(lons.shape) == 1:
            # 1D coordinates - create meshgrid
            lon_grid, lat_grid = np.meshgrid(lons, lats)
        else:
            # Already 2D
            lon_grid, lat_grid = lons, lats
        
        # Plot filled contours
        levels = np.linspace(np.nanmin(data_values), np.nanmax(data_values), 20)
        cf = ax_map.contourf(lon_grid, lat_grid, data_values, levels=levels,
                             cmap='RdYlBu_r', transform=ccrs.PlateCarree(),
                             extend='both')
        
        # Add contour lines
        cs = ax_map.contour(lon_grid, lat_grid, data_values, levels=10,
                            colors='black', linewidths=0.5, alpha=0.3,
                            transform=ccrs.PlateCarree())
        ax_map.clabel(cs, inline=True, fontsize=8, fmt='%1.0f')
        
        # Add geographic features
        ax_map.coastlines(resolution='50m', linewidth=1)
        ax_map.add_feature(cfeature.BORDERS, linewidth=0.5, alpha=0.5)
        ax_map.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.3, zorder=0)
        
        # Grid lines
        gl = ax_map.gridlines(draw_labels=True, linewidth=0.5, alpha=0.5, linestyle='--')
        gl.top_labels = False
        gl.right_labels = False
        
        # Colorbar
        cbar = plt.colorbar(cf, ax=ax_map, orientation='horizontal', pad=0.05, shrink=0.8)
        cbar.set_label('Local Sea Level Rise (mm)', fontsize=12)
        
        # Set extent
        extent = [np.min(lons)-5, np.max(lons)+5, np.min(lats)-5, np.max(lats)+5]
        ax_map.set_extent(extent, crs=ccrs.PlateCarree())
        
        component_name = component.upper()
        ax_map.set_title(f'{component_name} Local SLR at {actual_year} ({title_suffix})',
                         fontsize=14, fontweight='bold')
        
        # Histogram of values
        ax_hist = fig.add_subplot(gs[1, 0])
        valid_data = data_values[~np.isnan(data_values)].flatten()
        ax_hist.hist(valid_data, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
        ax_hist.axvline(np.mean(valid_data), color='red', linestyle='--', linewidth=2,
                        label=f'Mean: {np.mean(valid_data):.1f} mm')
        ax_hist.axvline(np.median(valid_data), color='darkred', linestyle=':', linewidth=2,
                        label=f'Median: {np.median(valid_data):.1f} mm')
        ax_hist.set_xlabel('Local SLR (mm)')
        ax_hist.set_ylabel('Frequency')
        ax_hist.set_title('Distribution of SLR Values', fontsize=12, fontweight='bold')
        ax_hist.legend()
        ax_hist.grid(True, alpha=0.3, axis='y')
        
        # Statistics box
        ax_stats = fig.add_subplot(gs[1, 1])
        ax_stats.axis('off')
        
        stats_text = f"""
        Statistics for {actual_year}:
        
        Mean:     {np.mean(valid_data):.2f} mm
        Median:   {np.median(valid_data):.2f} mm
        Std Dev:  {np.std(valid_data):.2f} mm
        Min:      {np.min(valid_data):.2f} mm
        Max:      {np.max(valid_data):.2f} mm
        
        5th %ile: {np.percentile(valid_data, 5):.2f} mm
        95th %ile:{np.percentile(valid_data, 95):.2f} mm
        
        Grid points: {len(valid_data)}
        Lat range: [{np.min(lats):.1f}°, {np.max(lats):.1f}°]
        Lon range: [{np.min(lons):.1f}°, {np.max(lons):.1f}°]
        """
        
        ax_stats.text(0.1, 0.9, stats_text, transform=ax_stats.transAxes,
                      fontsize=10, verticalalignment='top', family='monospace',
                      bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
        
        plt.suptitle(f'Local Sea Level Rise - {component_name}',
                     fontsize=16, fontweight='bold')
        
        return fig

    
    def plot_lslr_all_components_map(self, year=2100, figsize=(18, 14)):
        """Plot maps for all available LSLR components"""
        
        components = ['lws', 'stereo', 'ice']
        available = []
        
        for comp in components:
            key = f'{comp}_lslr'
            if key in self.datasets:
                available.append(comp)
        
        if not available:
            print("No LSLR components available")
            return
        
        n_comp = len(available)
        fig, axes = plt.subplots(n_comp, 1, figsize=(figsize[0], figsize[1]*n_comp/3),
                                subplot_kw={'projection': ccrs.PlateCarree()})
        
        if n_comp == 1:
            axes = [axes]
        
        for idx, comp in enumerate(available):
            print(f"\nPlotting {comp}...")
            # This is a simplified version - you'd need to adapt based on your data structure
            ax = axes[idx]
            ax.text(0.5, 0.5, f'{comp.upper()} LSLR - See individual plots',
                   ha='center', va='center', transform=ax.transAxes, fontsize=14)
            ax.set_title(f'{comp.upper()} Component', fontsize=14, fontweight='bold')
        
        plt.tight_layout()
        return fig
    def plot_gslr_map(self, component='lws', year=2100, sample_idx=None, figsize=(16, 12)):
        """Plot GSLR data - if it has spatial structure, otherwise show global value"""
        
        import cartopy.crs as ccrs
        import cartopy.feature as cfeature
        
        gslr_key = f'{component}_gslr'
        
        if gslr_key not in self.datasets:
            print(f"Dataset '{gslr_key}' not found. Available: {list(self.datasets.keys())}")
            return
        
        ds = self.datasets[gslr_key]
        
        print(f"\nAnalyzing {gslr_key}:")
        print(f"  Variables: {list(ds.data_vars)}")
        print(f"  Coordinates: {list(ds.coords)}")
        print(f"  Dimensions: {dict(ds.dims)}")
        
        # Find GSLR variable
        gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
        if not gslr_vars:
            print(f"No GSLR variable found. Available variables: {list(ds.data_vars)}")
            return
        
        gslr_var = gslr_vars[0]
        print(f"  Using variable: {gslr_var}")
        gslr_data = ds[gslr_var]
        print(f"  Shape: {gslr_data.shape}")
        print(f"  Dims: {gslr_data.dims}")
        
        # Check for spatial dimensions
        lat_names = ['lat', 'latitude', 'y']
        lon_names = ['lon', 'longitude', 'x']
        
        lat_dim = None
        lon_dim = None
        
        for name in lat_names:
            if name in gslr_data.dims or name in ds.coords:
                lat_dim = name
                break
        
        for name in lon_names:
            if name in gslr_data.dims or name in ds.coords:
                lon_dim = name
                break
        
        # GSLR is typically global (no spatial variation)
        if lat_dim is None and lon_dim is None:
            print(f"  ✓ This is Global SLR (uniform value worldwide)")
            return self.plot_gslr_global_value(component, year, sample_idx, figsize)
        else:
            print(f"  ✓ This GSLR has spatial structure")
            print(f"    Spatial dims: lat={lat_dim}, lon={lon_dim}")
            # Continue with spatial plotting below
            
        # Get coordinates
        lats = ds[lat_dim].values
        lons = ds[lon_dim].values
        
        print(f"  Lat range: [{np.min(lats):.2f}, {np.max(lats):.2f}]")
        print(f"  Lon range: [{np.min(lons):.2f}, {np.max(lons):.2f}]")
        
        # Get time and sample dimensions
        time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
        sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else 'ensemble'
        
        # Find target year
        if time_dim in gslr_data.dims:
            years = ds[time_dim].values
            year_idx = np.argmin(np.abs(years - year))
            actual_year = years[year_idx]
            print(f"  Using year: {actual_year} (requested: {year})")
        else:
            year_idx = 0
            actual_year = year
        
        # Determine which sample to plot
        if sample_dim in gslr_data.dims:
            n_samples = gslr_data.sizes[sample_dim]
            print(f"  Number of samples: {n_samples}")
            
            if sample_idx is None:
                # Use mean across samples
                if time_dim in gslr_data.dims:
                    data_to_plot = gslr_data.isel({time_dim: year_idx}).mean(dim=sample_dim)
                else:
                    data_to_plot = gslr_data.mean(dim=sample_dim)
                title_suffix = "Mean"
            else:
                # Use specific sample
                if time_dim in gslr_data.dims:
                    data_to_plot = gslr_data.isel({time_dim: year_idx, sample_dim: sample_idx})
                else:
                    data_to_plot = gslr_data.isel({sample_dim: sample_idx})
                title_suffix = f"Sample {sample_idx}"
        else:
            if time_dim in gslr_data.dims:
                data_to_plot = gslr_data.isel({time_dim: year_idx})
            else:
                data_to_plot = gslr_data
            title_suffix = "Single Value"
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3, height_ratios=[2, 1])
        
        # Main map
        ax_map = fig.add_subplot(gs[0, :], projection=ccrs.PlateCarree())
        
        # Plot the data
        data_values = data_to_plot.values
        
        if data_values.ndim > 2:
            data_values = data_values.squeeze()
        
        print(f"  Plotting data shape: {data_values.shape}")
        print(f"  Data range: [{np.nanmin(data_values):.2f}, {np.nanmax(data_values):.2f}] mm")
        
        # Create mesh for plotting
        if len(lats.shape) == 1 and len(lons.shape) == 1:
            lon_grid, lat_grid = np.meshgrid(lons, lats)
        else:
            lon_grid, lat_grid = lons, lats
        
        # Plot filled contours
        levels = np.linspace(np.nanmin(data_values), np.nanmax(data_values), 20)
        cf = ax_map.contourf(lon_grid, lat_grid, data_values, levels=levels,
                             cmap='RdYlBu_r', transform=ccrs.PlateCarree(),
                             extend='both')
        
        # Add contour lines
        cs = ax_map.contour(lon_grid, lat_grid, data_values, levels=10,
                            colors='black', linewidths=0.5, alpha=0.3,
                            transform=ccrs.PlateCarree())
        ax_map.clabel(cs, inline=True, fontsize=8, fmt='%1.0f')
        
        # Add geographic features
        ax_map.coastlines(resolution='50m', linewidth=1)
        ax_map.add_feature(cfeature.BORDERS, linewidth=0.5, alpha=0.5)
        ax_map.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2, zorder=0)
        
        # Grid lines
        gl = ax_map.gridlines(draw_labels=True, linewidth=0.5, alpha=0.5, linestyle='--')
        gl.top_labels = False
        gl.right_labels = False
        
        # Colorbar
        cbar = plt.colorbar(cf, ax=ax_map, orientation='horizontal', pad=0.05, shrink=0.8)
        cbar.set_label('Global Sea Level Rise (mm)', fontsize=12)
        
        # Set extent
        ax_map.set_global()
        
        component_name = component.upper()
        ax_map.set_title(f'{component_name} Global SLR at {actual_year} ({title_suffix})',
                         fontsize=14, fontweight='bold')
        
        # Histogram
        ax_hist = fig.add_subplot(gs[1, 0])
        valid_data = data_values[~np.isnan(data_values)].flatten()
        ax_hist.hist(valid_data, bins=50, alpha=0.7, color='steelblue', edgecolor='black')
        ax_hist.axvline(np.mean(valid_data), color='red', linestyle='--', linewidth=2,
                        label=f'Mean: {np.mean(valid_data):.1f} mm')
        ax_hist.set_xlabel('Global SLR (mm)')
        ax_hist.set_ylabel('Frequency')
        ax_hist.set_title('Distribution', fontsize=12, fontweight='bold')
        ax_hist.legend()
        ax_hist.grid(True, alpha=0.3, axis='y')
        
        # Statistics
        ax_stats = fig.add_subplot(gs[1, 1])
        ax_stats.axis('off')
        
        stats_text = f"""
        Global SLR Statistics ({actual_year}):
        
        Mean:     {np.mean(valid_data):.2f} mm
        Std Dev:  {np.std(valid_data):.2f} mm
        Min:      {np.min(valid_data):.2f} mm
        Max:      {np.max(valid_data):.2f} mm
        
        5th %ile: {np.percentile(valid_data, 5):.2f} mm
        95th %ile:{np.percentile(valid_data, 95):.2f} mm
        
        Grid points: {len(valid_data)}
        """
        
        ax_stats.text(0.1, 0.9, stats_text, transform=ax_stats.transAxes,
                      fontsize=10, verticalalignment='top', family='monospace',
                      bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.3))
        
        plt.suptitle(f'Global Sea Level Rise - {component_name}',
                     fontsize=16, fontweight='bold')
        
        return fig

    def plot_gslr_global_value(self, component='lws', year=2100, sample_idx=None, figsize=(16, 10)):
        """Plot GSLR as a global uniform value (typical case)"""
        
        import cartopy.crs as ccrs
        import cartopy.feature as cfeature
        
        gslr_key = f'{component}_gslr'
        ds = self.datasets[gslr_key]
        
        # Find GSLR variable
        gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
        gslr_var = gslr_vars[0]
        gslr_data = ds[gslr_var]
        
        # Get time and sample dimensions
        time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
        sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else 'ensemble'
        
        # Find target year
        if time_dim in gslr_data.dims:
            years = ds[time_dim].values
            year_idx = np.argmin(np.abs(years - year))
            actual_year = years[year_idx]
            gslr_at_year = gslr_data.isel({time_dim: year_idx})
        else:
            actual_year = year
            gslr_at_year = gslr_data
        
        # Handle samples
        if sample_dim in gslr_at_year.dims:
            n_samples = gslr_at_year.sizes[sample_dim]
            
            if sample_idx is None:
                gslr_value = gslr_at_year.mean(dim=sample_dim).values.item()
                print ("\n Raw mean gslr_value = ", gslr_value, "\n")
                gslr_ensemble = gslr_at_year.values.flatten()
                title_suffix = "Ensemble Mean"
            else:
                gslr_value = gslr_at_year.isel({sample_dim: sample_idx}).values.item()
                gslr_ensemble = gslr_at_year.values.flatten()
                title_suffix = f"Sample {sample_idx}"
        else:
            gslr_value = gslr_at_year.values.item()
            gslr_ensemble = None
            title_suffix = "Single Value"
        
        print(f"  GSLR value at {actual_year}: {gslr_value:.2f} mm")
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(2, 2, hspace=0.35, wspace=0.3)
        
        # World map showing uniform value
        ax_map = fig.add_subplot(gs[0, :], projection=ccrs.Robinson())
        ax_map.set_global()
        
        # Create a uniform field for visualization
        lons = np.linspace(-180, 180, 360)
        lats = np.linspace(-90, 90, 180)
        lon_grid, lat_grid = np.meshgrid(lons, lats)
        uniform_field = np.ones_like(lon_grid) * gslr_value
        
        # Plot uniform value
        cf = ax_map.contourf(lon_grid, lat_grid, uniform_field,
                             levels=np.linspace(gslr_value*0.95, gslr_value*1.05, 10),
                             cmap='RdYlBu_r', transform=ccrs.PlateCarree())
        
        # Add geographic features
        ax_map.coastlines(resolution='110m', linewidth=1)
        ax_map.add_feature(cfeature.BORDERS, linewidth=0.5, alpha=0.5)
        ax_map.add_feature(cfeature.LAND, edgecolor='black', facecolor='none', linewidth=0.5)
        
        # Colorbar
        cbar = plt.colorbar(cf, ax=ax_map, orientation='horizontal', pad=0.05, shrink=0.6)
        cbar.set_label('Global Sea Level Rise (mm)', fontsize=12)
        
        # Add text annotation
        ax_map.text(0.5, 0.5, f'GSLR: {gslr_value:.1f} mm', transform=ax_map.transAxes,
                   fontsize=24, fontweight='bold', ha='center', va='center',
                   bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.8, edgecolor='black', linewidth=2))
        
        component_name = component.upper()
        ax_map.set_title(f'{component_name} Global SLR at {actual_year} ({title_suffix})\n(Uniform worldwide value)',
                         fontsize=14, fontweight='bold')
        
        # Time series
        ax_ts = fig.add_subplot(gs[1, 0])
        
        if time_dim in gslr_data.dims:
            years_all = ds[time_dim].values
            
            if sample_dim in gslr_data.dims:
                # Plot ensemble members
                for i in range(min(10, gslr_data.sizes[sample_dim])):
                    sample_data = gslr_data.isel({sample_dim: i}).values.squeeze()
                    ax_ts.plot(years_all, sample_data, alpha=0.3, color='blue', linewidth=0.5)
                
                mean = gslr_data.mean(dim=sample_dim).values.squeeze()
                ax_ts.plot(years_all, mean, color='darkblue', linewidth=2.5, label='Ensemble Mean')
            else:
                values = gslr_data.values.squeeze()
                ax_ts.plot(years_all, values, color='darkblue', linewidth=2.5)
            
            # Mark current year
            ax_ts.axvline(x=actual_year, color='red', linestyle='--', linewidth=2, alpha=0.7)
            ax_ts.plot(actual_year, gslr_value, 'r*', markersize=20, markeredgecolor='black', 
                      markeredgewidth=1.5, zorder=5)
            
            ax_ts.set_xlabel('Year')
            ax_ts.set_ylabel('Global SLR (mm)')
            ax_ts.set_title('Time Series Evolution', fontsize=12, fontweight='bold')
            ax_ts.legend()
            ax_ts.grid(True, alpha=0.3)
        else:
            ax_ts.text(0.5, 0.5, 'No time series available', ha='center', va='center',
                      transform=ax_ts.transAxes)
        
        # Ensemble distribution (if available)
        ax_dist = fig.add_subplot(gs[1, 1])
        
        if gslr_ensemble is not None and len(gslr_ensemble) > 1:
            ax_dist.hist(gslr_ensemble, bins=20, alpha=0.7, color='steelblue', edgecolor='black')
            ax_dist.axvline(gslr_value, color='red', linestyle='--', linewidth=2,
                           label=f'Mean: {gslr_value:.1f} mm')
            ax_dist.axvline(np.median(gslr_ensemble), color='darkred', linestyle=':', linewidth=2,
                           label=f'Median: {np.median(gslr_ensemble):.1f} mm')
            
            ax_dist.set_xlabel('Global SLR (mm)')
            ax_dist.set_ylabel('Frequency')
            ax_dist.set_title(f'Ensemble Distribution at {actual_year}', fontsize=12, fontweight='bold')
            ax_dist.legend()
            ax_dist.grid(True, alpha=0.3, axis='y')
            
            # Add statistics text
            stats_text = f"""
            Statistics:
            
            Mean:   {np.mean(gslr_ensemble):.2f} mm
            Median: {np.median(gslr_ensemble):.2f} mm
            Std:    {np.std(gslr_ensemble):.2f} mm
            Min:    {np.min(gslr_ensemble):.2f} mm
            Max:    {np.max(gslr_ensemble):.2f} mm
            
            5th:    {np.percentile(gslr_ensemble, 5):.2f} mm
            95th:   {np.percentile(gslr_ensemble, 95):.2f} mm
            
            N samples: {len(gslr_ensemble)}
            """
            
            ax_dist.text(1.05, 0.5, stats_text, transform=ax_dist.transAxes,
                        fontsize=9, verticalalignment='center', family='monospace',
                        bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5))
        else:
            ax_dist.text(0.5, 0.5, f'Single value:\n{gslr_value:.2f} mm',
                        ha='center', va='center', transform=ax_dist.transAxes,
                        fontsize=16, fontweight='bold',
                        bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
            ax_dist.set_title('Value at Target Year', fontsize=12, fontweight='bold')
        
        plt.suptitle(f'Global Sea Level Rise - {component_name}\n(Note: GSLR is globally uniform - same value everywhere)',
                     fontsize=16, fontweight='bold')
        
        return fig
    
    def plot_gslr_all_components(self, year=2100, figsize=(18, 12)):
        """Plot GSLR for all available components on one figure"""
        
        import cartopy.crs as ccrs
        import cartopy.feature as cfeature
        
        components = ['lws', 'stereo', 'ice']
        available = []
        gslr_values = {}
        
        for comp in components:
            key = f'{comp}_gslr'
            if key in self.datasets:
                ds = self.datasets[key]
                gslr_vars = [v for v in ds.data_vars if 'gslr' in v.lower() or 'slr' in v.lower() or 'sea_level' in v.lower()]
                if gslr_vars:
                    available.append(comp)
                    
                    # Get value at target year
                    gslr_var = gslr_vars[0]
                    gslr_data = ds[gslr_var]
                    
                    time_dim = 'years' if 'years' in gslr_data.dims else 'year' if 'year' in gslr_data.dims else 'time'
                    sample_dim = 'samples' if 'samples' in gslr_data.dims else 'sample' if 'sample' in gslr_data.dims else None
                    
                    if time_dim in gslr_data.dims:
                        years = ds[time_dim].values
                        year_idx = np.argmin(np.abs(years - year))
                        gslr_at_year = gslr_data.isel({time_dim: year_idx})
                    else:
                        gslr_at_year = gslr_data
                    
                    if sample_dim and sample_dim in gslr_at_year.dims:
                        value = gslr_at_year.mean(dim=sample_dim).values.item()
                        ensemble = gslr_at_year.values.flatten()
                    else:
                        value = gslr_at_year.values.item()
                        ensemble = None
                    
                    gslr_values[comp] = (value, ensemble)
        
        if not available:
            print("No GSLR data available")
            return
        
        fig = plt.figure(figsize=figsize)
        gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
        
        # World map with all components
        ax_map = fig.add_subplot(gs[0, :], projection=ccrs.Robinson())
        ax_map.set_global()
        ax_map.coastlines(resolution='110m', linewidth=1)
        ax_map.add_feature(cfeature.BORDERS, linewidth=0.5, alpha=0.5)
        ax_map.add_feature(cfeature.LAND, facecolor='lightgray', alpha=0.2)
        
        # Display values
        colors = ['blue', 'green', 'orange']
        y_pos = 0.7
        
        for idx, comp in enumerate(available):
            value, _ = gslr_values[comp]
            comp_name = comp.upper()
            ax_map.text(0.5, y_pos, f'{comp_name}: {value:.1f} mm',
                       transform=ax_map.transAxes, fontsize=16, fontweight='bold',
                       ha='center', color=colors[idx],
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor=colors[idx], linewidth=2))
            y_pos -= 0.15
        
        # Total
        total_value = sum([v[0] for v in gslr_values.values()])
        ax_map.text(0.5, y_pos, f'TOTAL: {total_value:.1f} mm',
                   transform=ax_map.transAxes, fontsize=18, fontweight='bold',
                   ha='center', color='red',
                   bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.9, edgecolor='red', linewidth=3))
        
        ax_map.set_title(f'Global Sea Level Rise Components at {year}',
                         fontsize=16, fontweight='bold')
        
        # Bar chart comparison
        ax_bar = fig.add_subplot(gs[1, 0])
        
        comp_names = [c.upper() for c in available] + ['TOTAL']
        values = [gslr_values[c][0] for c in available] + [total_value]
        colors_bar = colors[:len(available)] + ['red']
        
        bars = ax_bar.bar(comp_names, values, color=colors_bar, alpha=0.7, edgecolor='black', linewidth=1.5)
        
        # Add value labels on bars
        for bar, val in zip(bars, values):
            height = bar.get_height()
            ax_bar.text(bar.get_x() + bar.get_width()/2., height,
                       f'{val:.1f} mm', ha='center', va='bottom', fontweight='bold', fontsize=11)
        
        ax_bar.set_ylabel('Sea Level Rise (mm)', fontsize=12)
        ax_bar.set_title('Component Comparison', fontsize=12, fontweight='bold')
        ax_bar.grid(True, alpha=0.3, axis='y')
        
        # Pie chart of contributions
        ax_pie = fig.add_subplot(gs[1, 1])
        
        contrib_values = [gslr_values[c][0] for c in available]
        contrib_labels = [f'{c.upper()}\n{v:.1f} mm\n({v/total_value*100:.1f}%)' 
                         for c, v in zip(available, contrib_values)]
        
        ax_pie.pie(contrib_values, labels=contrib_labels, colors=colors[:len(available)],
                  autopct='', startangle=90, textprops={'fontsize': 10, 'fontweight': 'bold'})
        ax_pie.set_title(f'Relative Contributions\n(Total: {total_value:.1f} mm)', 
                        fontsize=12, fontweight='bold')
        
        plt.suptitle(f'Global Sea Level Rise - All Components',
                     fontsize=18, fontweight='bold')
        
        return fig    
# ==================== USAGE ====================

# Initialize visualizer
viz = FACTSeaLevelVisualizer('./data/output')

# Load all datasets
datasets = viz.load_outputs()

if datasets:
    viz.diagnose_datasets()
    print("\n" + "="*60)
    print("Generating visualizations...")
    print("="*60)
    
    # 12. Plot GSLR maps
    print("\n12. Plotting GSLR maps...")
    
    # Individual components
    for component in ['lws', 'stereo', 'ice']:
        key = f'{component}_gslr'
        if key in viz.datasets:
            print(f"\n  Plotting {component} GSLR...")
            fig12 = viz.plot_gslr_map(component=component, year=2100, sample_idx=None)
            if fig12:
                plt.savefig(f'gslr_map_{component}.png', dpi=300, bbox_inches='tight')
                print(f"   ✓ Saved: gslr_map_{component}.png")
    
    # All components together
    print("\n  Plotting all GSLR components together...")
    fig12_all = viz.plot_gslr_all_components(year=2100)
    if fig12_all:
        plt.savefig('gslr_all_components.png', dpi=300, bbox_inches='tight')
        print("   ✓ Saved: gslr_all_components.png")
    
    # 11. Plot LSLR spatial maps
    print("\n11. Plotting LSLR spatial maps...")
    
    # Try each component
    for component in ['lws', 'stereo', 'ice']:
        print(f"\n  Plotting {component} LSLR...")
        fig11 = viz.plot_lslr_spatial_map(component=component, year=2100, sample_idx=None)
        if fig11:
            plt.savefig(f'lslr_map_{component}.png', dpi=300, bbox_inches='tight')
            print(f"   ✓ Saved: lslr_map_{component}.png")
    # exit()
    
    # # 1. Climate variables
    # print("\n1. Plotting climate variables...")
    # fig1 = viz.plot_climate_variables(scenario='ssp585')
    # if fig1:
    #     plt.savefig('climate_outputs.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: climate_outputs.png")
    
    # # 1b. Latest climate snapshot
    # print("\n1b. Plotting climate latest snapshot...")
    # fig1b = viz.plot_climate_latest_snapshot(scenario='ssp585')
    # if fig1b:
    #     plt.savefig('climate_latest_snapshot.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: climate_latest_snapshot.png")
    
    # # 1c. Climate evolution to latest
    # print("\n1c. Plotting climate evolution to latest...")
    # fig1c = viz.plot_climate_evolution_to_latest(scenario='ssp585')
    # if fig1c:
    #     plt.savefig('climate_evolution_to_latest.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: climate_evolution_to_latest.png")
    
    # # 1d. Climate heatmap
    # print("\n1d. Plotting climate heatmap...")
    # fig1d = viz.plot_climate_heatmap_over_time(scenario='ssp585')
    # if fig1d:
    #     plt.savefig('climate_heatmap.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: climate_heatmap.png")

    # # 2. Sea level components comparison
    # print("\n2. Plotting sea level components...")
    # fig2 = viz.plot_sea_level_components_comparison(year=2100)
    # if fig2:
    #     plt.savefig('slr_components_2100.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: slr_components_2100.png")
    
    # # 3. Stacked contributions
    # print("\n3. Plotting stacked contributions...")
    # fig3 = viz.plot_slr_stacked_contributions()
    # if fig3:
    #     plt.savefig('slr_stacked.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: slr_stacked.png")
    
    # # 4. Summary dashboard
    # print("\n4. Creating summary dashboard...")
    # fig4 = viz.create_summary_dashboard(year=2100, scenario='ssp585')
    # if fig4:
    #     plt.savefig('summary_dashboard.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: summary_dashboard.png")
    
    # # 5. Individual GSLR components
    # print("\n5. Plotting individual GSLR components...")
    # fig5 = viz.plot_individual_gslr_components(year=2100)
    # if fig5:
    #     plt.savefig('individual_gslr_components.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: individual_gslr_components.png")
    
    # # 6. GSLR differences
    # print("\n6. Plotting GSLR differences...")
    # fig6 = viz.plot_gslr_differences(year=2100)
    # if fig6:
    #     plt.savefig('gslr_differences.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: gslr_differences.png")
    
    # # 7. GSLR correlation matrix
    # print("\n7. Plotting GSLR correlation matrix...")
    # fig7 = viz.plot_gslr_correlation_matrix(year=2100)
    # if fig7:
    #     plt.savefig('gslr_correlation.png', dpi=300, bbox_inches='tight')
    #     print("   ✓ Saved: gslr_correlation.png")
    
    # # 8. Export statistics
    # print("\n8. Exporting statistics...")
    # stats_df = viz.export_statistics(year=2100, scenario='ssp585')
    
    print("\n" + "="*60)
    print("All visualizations complete!")
    print("="*60)
    
    plt.show()
else:
    print("\n⚠ No datasets loaded. Please run the workflow first to generate outputs.")    

Loading output files...
  ✓ Loaded climate: climate.nc (group: ssp585)
  ✓ Loaded gsat: gsat.nc
  ✓ Loaded oceantemp: oceantemp.nc
  ✓ Loaded ohc: ohc.nc
  ✓ Loaded lws_gslr: gslr.nc
  ✓ Loaded lws_lslr: lslr.nc
  ✓ Loaded stereo_gslr: gslr.nc
  ✓ Loaded stereo_lslr: lslr.nc

DATASET DIAGNOSTICS

climate:
  Variables: ['surface_temperature', 'deep_ocean_temperature', 'ocean_heat_content']
  Coordinates: ['years', 'samples']
  Dimensions: {'years': 751, 'samples': 20}

  surface_temperature:
    Shape: (751, 20)
    Dims: ('years', 'samples')
    Dtype: float32
    NaN values: 0/15020 (0.0%)
    Range: [-0.576, 14.080]
    Mean: 4.010

  deep_ocean_temperature:
    Shape: (751, 20)
    Dims: ('years', 'samples')
    Dtype: float32
    NaN values: 0/15020 (0.0%)
    Range: [-0.040, 12.841]
    Mean: 2.617

  ocean_heat_content:
    Shape: (751, 20)
    Dims: ('years', 'samples')
    Dtype: float32
    NaN values: 0/15020 (0.0%)
    Range: [-79375934425205737259008.000, 217418536857941598

  print(f"  Dimensions: {dict(ds.dims)}")
  print(f"  Dimensions: {dict(ds.dims)}")


   ✓ Saved: gslr_map_lws.png

  Plotting stereo GSLR...

Analyzing stereo_gslr:
  Variables: ['lat', 'lon', 'sea_level_change']
  Coordinates: ['years', 'samples', 'locations']
  Dimensions: {'years': 29, 'samples': 20, 'locations': 1}
  Using variable: sea_level_change
  Shape: (20, 29, 1)
  Dims: ('samples', 'years', 'locations')
  ✓ This is Global SLR (uniform value worldwide)

 Raw mean gslr_value =  286.08575439453125 

  GSLR value at 2100: 286.09 mm


  print(f"  Dimensions: {dict(ds.dims)}")


In [None]:
##############

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Disable numbagg to avoid NumPy compatibility issues
xr.set_options(use_bottleneck=False, use_numbagg=False)

def load_sea_level_data(filepath):
    """Load sea level data from NetCDF file, handling different structures"""
    ds = xr.open_dataset(filepath)
    
    # Find the sea level variable (different files may use different names)
    possible_names = ['sea_level_change', 'sealevel_change', 'slr', 'sea_level']
    var_name = None
    
    for name in possible_names:
        if name in ds.data_vars:
            var_name = name
            break
    
    if var_name is None:
        # Just take the first non-coordinate variable
        data_vars = [v for v in ds.data_vars if v not in ['lat', 'lon']]
        if data_vars:
            var_name = data_vars[0]
        else:
            raise ValueError(f"Could not find sea level variable in {filepath}")
    
    data = ds[var_name]
    years = ds['years'].values if 'years' in ds else None
    
    return data, years, ds

def align_to_common_years(data_dict, total_years):
    """
    Align all datasets to common years using interpolation or subset
    
    Parameters:
    -----------
    data_dict : dict
        Dictionary with component data and years
    total_years : array
        Years from the total file
        
    Returns:
    --------
    aligned_dict : dict
        Dictionary with aligned data
    """
    aligned = {}
    
    for comp_name, comp_info in data_dict.items():
        comp_data = comp_info['data']
        comp_years = comp_info['years']
        
        print(f"  Aligning {comp_name}: {len(comp_years)} years -> {len(total_years)} years")
        print(f"    Component years: {comp_years[0]} to {comp_years[-1]}")
        print(f"    Total years: {total_years[0]} to {total_years[-1]}")
        
        # Check if years match exactly
        if len(comp_years) == len(total_years) and np.all(comp_years == total_years):
            aligned[comp_name] = comp_info
            print(f"    ✓ Years already match")
            continue
        
        # Find common years (intersection)
        common_years = np.intersect1d(comp_years, total_years)
        
        if len(common_years) > 0:
            # Subset both datasets to common years
            comp_year_indices = [np.where(comp_years == y)[0][0] for y in common_years]
            
            # Get the year dimension name
            year_dim = None
            for dim in comp_data.dims:
                if 'year' in dim.lower():
                    year_dim = dim
                    break
            
            if year_dim:
                aligned_data = comp_data.isel({year_dim: comp_year_indices})
                aligned[comp_name] = {
                    'data': aligned_data,
                    'years': common_years,
                    'ds': comp_info['ds']
                }
                print(f"    ✓ Aligned to {len(common_years)} common years")
            else:
                print(f"    ✗ Could not find year dimension")
                aligned[comp_name] = comp_info
        else:
            print(f"    ⚠ No common years found, using interpolation")
            # Use xarray interpolation if no common years
            aligned[comp_name] = comp_info
    
    return aligned

def verify_totals(output_dir='./data/output', file_type='lslr'):
    """
    Verify that totaled_output_all_{file_type}.nc = sum of individual component files
    
    Parameters:
    -----------
    output_dir : str
        Directory containing the output files
    file_type : str
        Either 'lslr' or 'gslr'
    """
    
    print(f"\n{'='*60}")
    print(f"VERIFYING {file_type.upper()} TOTALS")
    print(f"{'='*60}\n")
    
    # Load the totaled file
    total_file = Path(output_dir) / f'totaled_output_all_{file_type}.nc'
    total_data, total_years, total_ds = load_sea_level_data(total_file)
    
    print(f"Loaded total file: {total_file.name}")
    print(f"  Shape: {total_data.shape}")
    print(f"  Dimensions: {total_data.dims}")
    print(f"  Years: {len(total_years)} points from {total_years[0]} to {total_years[-1]}")
    
    # Load individual component files
    components = ['lws', 'sterodynamics']  # Add 'ice' if you have it
    component_data = {}
    
    for comp in components:
        comp_file = Path(output_dir) / comp / f'{file_type}.nc'
        if comp_file.exists():
            data, years, ds = load_sea_level_data(comp_file)
            component_data[comp] = {'data': data, 'years': years, 'ds': ds}
            print(f"Loaded {comp}: {comp_file.name}")
            print(f"  Shape: {data.shape}")
            print(f"  Dimensions: {data.dims}")
            print(f"  Years: {len(years)} points from {years[0]} to {years[-1]}")
        else:
            print(f"WARNING: Component file not found: {comp_file}")
    
    # Align all data to common years
    print("\nAligning data to common years...")
    component_data = align_to_common_years(component_data, total_years)
    
    # Find actual common years across all datasets
    all_years = [total_years]
    for comp_info in component_data.values():
        all_years.append(comp_info['years'])
    
    common_years = all_years[0]
    for years in all_years[1:]:
        common_years = np.intersect1d(common_years, years)
    
    print(f"\nCommon years across all files: {len(common_years)} years")
    print(f"  Range: {common_years[0]} to {common_years[-1]}")
    
    # Subset total data to common years
    total_year_indices = [np.where(total_years == y)[0][0] for y in common_years]
    
    # Squeeze out location dimension if present
    if 'locations' in total_data.dims:
        total_data_subset = total_data.squeeze('locations')
    else:
        total_data_subset = total_data
    
    # Find year dimension in total data
    year_dim = None
    for dim in total_data_subset.dims:
        if 'year' in dim.lower():
            year_dim = dim
            break
    
    if year_dim:
        total_data_subset = total_data_subset.isel({year_dim: total_year_indices})
    
    print(f"Total data subset shape: {total_data_subset.shape}")
    
    # Sum components, handling dimension alignment
    print("\nCalculating sum of components...")
    manual_sum = None
    
    for comp_name, comp_dict in component_data.items():
        comp_data = comp_dict['data']
        comp_years = comp_dict['years']
        
        # Subset to common years
        comp_year_indices = [np.where(comp_years == y)[0][0] for y in common_years]
        
        # Squeeze locations if present
        if 'locations' in comp_data.dims:
            comp_data = comp_data.squeeze('locations')
        
        # Find year dimension
        year_dim_comp = None
        for dim in comp_data.dims:
            if 'year' in dim.lower():
                year_dim_comp = dim
                break
        
        if year_dim_comp:
            comp_data = comp_data.isel({year_dim_comp: comp_year_indices})
        
        # Ensure dimensions match total
        if comp_data.dims != total_data_subset.dims:
            try:
                comp_data = comp_data.transpose(*total_data_subset.dims)
            except Exception as e:
                print(f"  Warning: Could not transpose {comp_name}: {e}")
        
        if manual_sum is None:
            manual_sum = comp_data.copy()
        else:
            manual_sum = manual_sum + comp_data
        
        print(f"  Added {comp_name}: shape {comp_data.shape}")
    
    print(f"Manual sum shape: {manual_sum.shape}")
    print(f"Total data shape: {total_data_subset.shape}")
    
    # Calculate difference
    difference = total_data_subset - manual_sum
    max_diff = float(np.nanmax(np.abs(difference.values)))
    mean_diff = float(np.nanmean(np.abs(difference.values)))
    
    print(f"\nDifference Statistics:")
    print(f"  Max absolute difference: {max_diff:.6e} mm")
    print(f"  Mean absolute difference: {mean_diff:.6e} mm")
    
    if max_diff < 1e-3:  # Less than 0.001 mm difference
        print("  ✅ VERIFICATION PASSED: Total matches sum of components")
    else:
        print("  ⚠️  WARNING: Significant differences detected")
    
    # Create visualization
    fig = plt.figure(figsize=(16, 12))
    
    # Get sample indices for plotting
    n_samples = total_data_subset.shape[0] if 'samples' in total_data_subset.dims else total_data_subset.shape[1]
    sample_indices = [0, n_samples//2, n_samples-1]  # First, middle, last
    
    # Determine dimension order
    samples_first = 'samples' == list(total_data_subset.dims)[0]
    
    # Plot 1: Time series for selected samples
    ax1 = plt.subplot(3, 3, 1)
    for idx in sample_indices:
        if samples_first:
            ax1.plot(common_years, total_data_subset[idx, :], 
                    label=f'Total (sample {idx})', linewidth=2)
        else:
            ax1.plot(common_years, total_data_subset[:, idx], 
                    label=f'Total (sample {idx})', linewidth=2)
    ax1.set_xlabel('Year')
    ax1.set_ylabel('Sea Level Change (mm)')
    ax1.set_title(f'Total {file_type.upper()}: Selected Samples')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Individual components for one sample
    ax2 = plt.subplot(3, 3, 2)
    sample_to_plot = 0
    for comp_name, comp_dict in component_data.items():
        comp_data = comp_dict['data']
        comp_years_local = comp_dict['years']
        
        # Subset to common years
        comp_year_indices = [np.where(comp_years_local == y)[0][0] for y in common_years]
        
        if 'locations' in comp_data.dims:
            comp_data = comp_data.squeeze('locations')
        
        year_dim_comp = None
        for dim in comp_data.dims:
            if 'year' in dim.lower():
                year_dim_comp = dim
                break
        
        if year_dim_comp:
            comp_data = comp_data.isel({year_dim_comp: comp_year_indices})
        
        if comp_data.dims != total_data_subset.dims:
            try:
                comp_data = comp_data.transpose(*total_data_subset.dims)
            except:
                pass
        
        if samples_first:
            y_data = comp_data[sample_to_plot, :].values
        else:
            y_data = comp_data[:, sample_to_plot].values
            
        ax2.plot(common_years, y_data, 
                label=comp_name, linewidth=2, marker='o', markersize=4)
    ax2.set_xlabel('Year')
    ax2.set_ylabel('Sea Level Change (mm)')
    ax2.set_title(f'Individual Components (sample {sample_to_plot})')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Manual sum vs Total for one sample
    ax3 = plt.subplot(3, 3, 3)
    if samples_first:
        total_y = total_data_subset[sample_to_plot, :].values
        sum_y = manual_sum[sample_to_plot, :].values
    else:
        total_y = total_data_subset[:, sample_to_plot].values
        sum_y = manual_sum[:, sample_to_plot].values
        
    ax3.plot(common_years, total_y, 
            label='Total (from file)', linewidth=2, marker='o')
    ax3.plot(common_years, sum_y, 
            label='Manual sum', linewidth=2, marker='s', linestyle='--')
    ax3.set_xlabel('Year')
    ax3.set_ylabel('Sea Level Change (mm)')
    ax3.set_title(f'Verification: Total vs Sum (sample {sample_to_plot})')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Ensemble mean comparison
    ax4 = plt.subplot(3, 3, 4)
    total_mean = total_data_subset.mean(dim='samples')
    manual_sum_mean = manual_sum.mean(dim='samples')
    ax4.plot(common_years, total_mean, label='Total mean', linewidth=3)
    ax4.plot(common_years, manual_sum_mean, label='Sum mean', 
            linewidth=3, linestyle='--')
    
    for comp_name, comp_dict in component_data.items():
        comp_data = comp_dict['data']
        comp_years_local = comp_dict['years']
        
        comp_year_indices = [np.where(comp_years_local == y)[0][0] for y in common_years]
        
        if 'locations' in comp_data.dims:
            comp_data = comp_data.squeeze('locations')
        
        year_dim_comp = None
        for dim in comp_data.dims:
            if 'year' in dim.lower():
                year_dim_comp = dim
                break
        
        if year_dim_comp:
            comp_data = comp_data.isel({year_dim_comp: comp_year_indices})
        
        if comp_data.dims != total_data_subset.dims:
            try:
                comp_data = comp_data.transpose(*total_data_subset.dims)
            except:
                pass
        
        comp_mean = comp_data.mean(dim='samples')
        ax4.plot(common_years, comp_mean, label=f'{comp_name} mean', 
                linewidth=2, alpha=0.7)
    
    ax4.set_xlabel('Year')
    ax4.set_ylabel('Sea Level Change (mm)')
    ax4.set_title('Ensemble Means')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # Plot 5: Ensemble spread (percentiles)
    ax5 = plt.subplot(3, 3, 5)
    total_p50 = total_data_subset.quantile(0.5, dim='samples').values.flatten()
    total_p05 = total_data_subset.quantile(0.05, dim='samples').values.flatten()
    total_p95 = total_data_subset.quantile(0.95, dim='samples').values.flatten()
    ax5.fill_between(common_years, total_p05, total_p95, alpha=0.3, label='Total 5-95%')
    ax5.plot(common_years, total_p50, label='Total median', linewidth=2)
    
    sum_p50 = manual_sum.quantile(0.5, dim='samples').values.flatten()
    sum_p05 = manual_sum.quantile(0.05, dim='samples').values.flatten()
    sum_p95 = manual_sum.quantile(0.95, dim='samples').values.flatten()
    ax5.plot(common_years, sum_p50, label='Sum median', 
            linewidth=2, linestyle='--')
    
    ax5.set_xlabel('Year')
    ax5.set_ylabel('Sea Level Change (mm)')
    ax5.set_title('Ensemble Spread (5th-95th percentile)')
    ax5.legend()
    ax5.grid(True, alpha=0.3)
    
    # Plot 6: Absolute difference
    ax6 = plt.subplot(3, 3, 6)
    abs_diff = np.abs(difference.values)
    for idx in sample_indices:
        if samples_first:
            y_diff = abs_diff[idx, :]
        else:
            y_diff = abs_diff[:, idx]
        ax6.plot(common_years, y_diff, 
                label=f'Sample {idx}', linewidth=2)
    ax6.set_xlabel('Year')
    ax6.set_ylabel('Absolute Difference (mm)')
    ax6.set_title('|Total - Sum| by Sample')
    ax6.legend()
    ax6.grid(True, alpha=0.3)
    # Only set log scale if max difference is significant
    if max_diff > 1e-10:
        ax6.set_yscale('log')
    
    # Plot 7: Contribution by component (stacked area)
    ax7 = plt.subplot(3, 3, 7)
    stack_data = []
    labels = []
    for comp_name, comp_dict in component_data.items():
        comp_data = comp_dict['data']
        comp_years_local = comp_dict['years']
        
        comp_year_indices = [np.where(comp_years_local == y)[0][0] for y in common_years]
        
        if 'locations' in comp_data.dims:
            comp_data = comp_data.squeeze('locations')
        
        year_dim_comp = None
        for dim in comp_data.dims:
            if 'year' in dim.lower():
                year_dim_comp = dim
                break
        
        if year_dim_comp:
            comp_data = comp_data.isel({year_dim_comp: comp_year_indices})
        
        if comp_data.dims != total_data_subset.dims:
            try:
                comp_data = comp_data.transpose(*total_data_subset.dims)
            except:
                pass
        
        comp_mean = comp_data.mean(dim='samples').values
        stack_data.append(comp_mean)
        labels.append(comp_name)
    
    ax7.stackplot(common_years, *stack_data, labels=labels, alpha=0.7)
    ax7.set_xlabel('Year')
    ax7.set_ylabel('Sea Level Change (mm)')
    ax7.set_title('Component Contributions (Stacked, Ensemble Mean)')
    ax7.legend(loc='upper left')
    ax7.grid(True, alpha=0.3)
    
    # Plot 8: Relative contribution by component
    ax8 = plt.subplot(3, 3, 8)
    total_abs_mean = np.abs(total_mean.values)
    for comp_name, comp_dict in component_data.items():
        comp_data = comp_dict['data']
        comp_years_local = comp_dict['years']
        
        comp_year_indices = [np.where(comp_years_local == y)[0][0] for y in common_years]
        
        if 'locations' in comp_data.dims:
            comp_data = comp_data.squeeze('locations')
        
        year_dim_comp = None
        for dim in comp_data.dims:
            if 'year' in dim.lower():
                year_dim_comp = dim
                break
        
        if year_dim_comp:
            comp_data = comp_data.isel({year_dim_comp: comp_year_indices})
        
        if comp_data.dims != total_data_subset.dims:
            try:
                comp_data = comp_data.transpose(*total_data_subset.dims)
            except:
                pass
        
        comp_mean = comp_data.mean(dim='samples').values
        # Avoid division by zero
        with np.errstate(divide='ignore', invalid='ignore'):
            rel_contrib = 100 * comp_mean / total_abs_mean
            rel_contrib[~np.isfinite(rel_contrib)] = 0
        ax8.plot(common_years, rel_contrib, label=comp_name, 
                linewidth=2, marker='o')
    ax8.set_xlabel('Year')
    ax8.set_ylabel('Relative Contribution (%)')
    ax8.set_title('Relative Component Contributions')
    ax8.legend()
    ax8.grid(True, alpha=0.3)
    ax8.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
    
    # Plot 9: Heatmap of differences across all samples
    ax9 = plt.subplot(3, 3, 9)
    # Ensure proper orientation for heatmap
    if not samples_first:
        abs_diff = abs_diff.T
    im = ax9.imshow(abs_diff, aspect='auto', cmap='viridis', 
                    interpolation='nearest')
    ax9.set_xlabel('Year Index')
    ax9.set_ylabel('Sample Index')
    ax9.set_title('Heatmap: |Total - Sum| (all samples)')
    plt.colorbar(im, ax=ax9, label='Absolute Difference (mm)')
    
    plt.suptitle(f'{file_type.upper()} Verification: Total vs Sum of Components\n' + 
                 f'Max diff: {max_diff:.2e} mm, Mean diff: {mean_diff:.2e} mm',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    # Save figure
    output_file = Path(output_dir) / f'verification_{file_type}_totals.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved verification plot to: {output_file}")
    
    plt.show()
    
    return {
        'max_difference': max_diff,
        'mean_difference': mean_diff,
        'total_data': total_data_subset,
        'manual_sum': manual_sum,
        'difference': difference,
        'common_years': common_years
    }

# Run verification for both LSLR and GSLR
print("="*80)
print("VERIFYING SEA LEVEL TOTALS")
print("="*80)

# Verify LSLR
lslr_results = verify_totals(output_dir='./data/output', file_type='lslr')

# Verify GSLR
gslr_results = verify_totals(output_dir='./data/output', file_type='gslr')

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"LSLR Max Difference: {lslr_results['max_difference']:.6e} mm")
print(f"GSLR Max Difference: {gslr_results['max_difference']:.6e} mm")
print("="*80)

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Disable numbagg to avoid NumPy compatibility issues
xr.set_options(use_bottleneck=False, use_numbagg=False)

def load_gslr_file(filepath):
    """Load GSLR data from NetCDF file"""
    ds = xr.open_dataset(filepath)
    
    # Find the sea level variable
    possible_names = ['sea_level_change', 'sealevel_change', 'slr', 'sea_level']
    var_name = None
    
    for name in possible_names:
        if name in ds.data_vars:
            var_name = name
            break
    
    if var_name is None:
        data_vars = [v for v in ds.data_vars if v not in ['lat', 'lon']]
        if data_vars:
            var_name = data_vars[0]
        else:
            raise ValueError(f"Could not find sea level variable in {filepath}")
    
    data = ds[var_name]
    years = ds['years'].values if 'years' in ds else None
    
    # Squeeze out locations dimension if present
    if 'locations' in data.dims:
        data = data.squeeze('locations')
    
    return data, years, ds, var_name

def compare_gslr_files_barchart(output_dir='./data/output'):
    """
    Create bar chart comparisons of totaled GSLR files
    """
    
    print("="*80)
    print("LOADING GSLR FILES FOR BAR CHART COMPARISON")
    print("="*80)
    
    # Define files to compare
    files_to_load = {
        'All (Total)': 'totaled_output_all_gslr.nc',
        'LWS': 'totaled_output_lws_gslr.nc',
        'Sterodynamics': 'totaled_output_sterodynamics_gslr.nc'
    }
    
    # Load all files
    data_dict = {}
    for label, filename in files_to_load.items():
        filepath = Path(output_dir) / filename
        if filepath.exists():
            data, years, ds, var_name = load_gslr_file(filepath)
            data_dict[label] = {
                'data': data,
                'years': years,
                'ds': ds,
                'var_name': var_name
            }
            print(f"\n{label}: {filename}")
            print(f"  Shape: {data.shape}")
            print(f"  Dimensions: {data.dims}")
            print(f"  Years: {len(years)} points from {years[0]} to {years[-1]}")
        else:
            print(f"\nWARNING: File not found: {filepath}")
    
    if len(data_dict) == 0:
        print("ERROR: No files could be loaded!")
        return
    
    # Find common years across all datasets
    all_years = [info['years'] for info in data_dict.values()]
    common_years = all_years[0]
    for years in all_years[1:]:
        common_years = np.intersect1d(common_years, years)
    
    print(f"\nCommon years: {len(common_years)} years from {common_years[0]} to {common_years[-1]}")
    
    # Calculate statistics for each file at common years
    stats_dict = {}
    for label, info in data_dict.items():
        data = info['data']
        years = info['years']
        
        # Subset to common years
        year_indices = [np.where(years == y)[0][0] for y in common_years]
        
        # Find year dimension
        year_dim = None
        for dim in data.dims:
            if 'year' in dim.lower():
                year_dim = dim
                break
        
        if year_dim:
            data_subset = data.isel({year_dim: year_indices})
        else:
            data_subset = data
        
        # Calculate statistics across samples
        mean_timeseries = data_subset.mean(dim='samples').values
        std_timeseries = data_subset.std(dim='samples').values
        p05_timeseries = data_subset.quantile(0.05, dim='samples').values
        p95_timeseries = data_subset.quantile(0.95, dim='samples').values
        median_timeseries = data_subset.quantile(0.5, dim='samples').values
        
        stats_dict[label] = {
            'mean': mean_timeseries,
            'std': std_timeseries,
            'p05': p05_timeseries,
            'p95': p95_timeseries,
            'median': median_timeseries,
            'data_subset': data_subset
        }
    
    # Create comprehensive bar chart visualization
    fig = plt.figure(figsize=(18, 12))
    
    # Define colors for each dataset
    colors = {
        'All (Total)': '#2E86AB',
        'LWS': '#A23B72',
        'Sterodynamics': '#F18F01'
    }
    
    labels_list = list(stats_dict.keys())
    n_labels = len(labels_list)
    
    # Plot 1: Mean values by year (grouped bar chart)
    ax1 = plt.subplot(3, 3, 1)
    x = np.arange(len(common_years))
    width = 0.25
    
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax1.bar(x + offset, stats_dict[label]['mean'], width, 
                label=label, color=colors[label], alpha=0.8)
    
    ax1.set_xlabel('Year')
    ax1.set_ylabel('Mean GSLR (mm)')
    ax1.set_title('Mean Global Sea Level Rise by Year')
    ax1.set_xticks(x)
    ax1.set_xticklabels(common_years, rotation=45)
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Final year comparison (mean ± std)
    ax2 = plt.subplot(3, 3, 2)
    final_means = [stats_dict[label]['mean'][-1] for label in labels_list]
    final_stds = [stats_dict[label]['std'][-1] for label in labels_list]
    
    x_pos = np.arange(n_labels)
    bars = ax2.bar(x_pos, final_means, yerr=final_stds, 
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax2.set_ylabel('GSLR (mm)')
    ax2.set_title(f'Final Year ({common_years[-1]}) Comparison\n(Mean ± Std Dev)')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(labels_list, rotation=15, ha='right')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, (bar, mean, std) in enumerate(zip(bars, final_means, final_stds)):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{mean:.1f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 3: First year comparison
    ax3 = plt.subplot(3, 3, 3)
    first_means = [stats_dict[label]['mean'][0] for label in labels_list]
    first_stds = [stats_dict[label]['std'][0] for label in labels_list]
    
    bars = ax3.bar(x_pos, first_means, yerr=first_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax3.set_ylabel('GSLR (mm)')
    ax3.set_title(f'First Year ({common_years[0]}) Comparison\n(Mean ± Std Dev)')
    ax3.set_xticks(x_pos)
    ax3.set_xticklabels(labels_list, rotation=15, ha='right')
    ax3.grid(True, alpha=0.3, axis='y')
    
    for i, (bar, mean, std) in enumerate(zip(bars, first_means, first_stds)):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{mean:.1f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 4: Total change (final - first year)
    ax4 = plt.subplot(3, 3, 4)
    total_changes = [stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0] 
                     for label in labels_list]
    
    bars = ax4.bar(x_pos, total_changes,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, edgecolor='black', linewidth=1.5)
    
    ax4.set_ylabel('Total Change (mm)')
    ax4.set_title(f'Total GSLR Change\n({common_years[0]} to {common_years[-1]})')
    ax4.set_xticks(x_pos)
    ax4.set_xticklabels(labels_list, rotation=15, ha='right')
    ax4.grid(True, alpha=0.3, axis='y')
    
    for bar, change in zip(bars, total_changes):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{change:.1f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 5: Standard deviation by year
    ax5 = plt.subplot(3, 3, 5)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax5.bar(x + offset, stats_dict[label]['std'], width,
                label=label, color=colors[label], alpha=0.8)
    
    ax5.set_xlabel('Year')
    ax5.set_ylabel('Std Dev (mm)')
    ax5.set_title('Ensemble Spread (Std Dev) by Year')
    ax5.set_xticks(x)
    ax5.set_xticklabels(common_years, rotation=45)
    ax5.legend()
    ax5.grid(True, alpha=0.3, axis='y')
    
    # Plot 6: Uncertainty range (P95 - P05) by year
    ax6 = plt.subplot(3, 3, 6)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        uncertainty = stats_dict[label]['p95'] - stats_dict[label]['p05']
        ax6.bar(x + offset, uncertainty, width,
                label=label, color=colors[label], alpha=0.8)
    
    ax6.set_xlabel('Year')
    ax6.set_ylabel('Uncertainty Range (mm)')
    ax6.set_title('90% Confidence Interval (P95-P05) by Year')
    ax6.set_xticks(x)
    ax6.set_xticklabels(common_years, rotation=45)
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')
    
    # Plot 7: Average across all years
    ax7 = plt.subplot(3, 3, 7)
    avg_means = [np.mean(stats_dict[label]['mean']) for label in labels_list]
    avg_stds = [np.mean(stats_dict[label]['std']) for label in labels_list]
    
    bars = ax7.bar(x_pos, avg_means, yerr=avg_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax7.set_ylabel('Average GSLR (mm)')
    ax7.set_title(f'Average Across All Years\n({common_years[0]}-{common_years[-1]})')
    ax7.set_xticks(x_pos)
    ax7.set_xticklabels(labels_list, rotation=15, ha='right')
    ax7.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, avg_means):
        height = bar.get_height()
        ax7.text(bar.get_x() + bar.get_width()/2., height,
                f'{mean:.1f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 8: Contribution percentages (relative to total at final year)
    ax8 = plt.subplot(3, 3, 8)
    if 'All (Total)' in stats_dict:
        total_final = stats_dict['All (Total)']['mean'][-1]
        contributions = []
        contrib_labels = []
        
        for label in labels_list:
            if label != 'All (Total)':
                final_val = stats_dict[label]['mean'][-1]
                pct = (final_val / total_final) * 100 if total_final != 0 else 0
                contributions.append(pct)
                contrib_labels.append(label)
        
        bars = ax8.bar(range(len(contributions)), contributions,
                       color=[colors[label] for label in contrib_labels],
                       alpha=0.8, edgecolor='black', linewidth=1.5)
        
        ax8.set_ylabel('Contribution (%)')
        ax8.set_title(f'Component Contribution to Total\n(Final Year: {common_years[-1]})')
        ax8.set_xticks(range(len(contributions)))
        ax8.set_xticklabels(contrib_labels, rotation=15, ha='right')
        ax8.grid(True, alpha=0.3, axis='y')
        ax8.axhline(y=100, color='red', linestyle='--', linewidth=2, alpha=0.5, label='100%')
        
        for bar, pct in zip(bars, contributions):
            height = bar.get_height()
            ax8.text(bar.get_x() + bar.get_width()/2., height,
                    f'{pct:.1f}%',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 9: Rate of change (difference between consecutive years)
    ax9 = plt.subplot(3, 3, 9)
    for label in labels_list:
        mean_vals = stats_dict[label]['mean']
        rates = np.diff(mean_vals) / np.diff(common_years)
        ax9.plot(common_years[:-1], rates, marker='o', linewidth=2,
                label=label, color=colors[label])
    
    ax9.set_xlabel('Year')
    ax9.set_ylabel('Rate of Change (mm/year)')
    ax9.set_title('Rate of GSLR Change')
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    ax9.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    plt.suptitle('GSLR Comparison: All vs Components\n' +
                 f'({len(common_years)} years from {common_years[0]} to {common_years[-1]})',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save figure
    output_file = Path(output_dir) / 'gslr_barchart_comparison.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved bar chart comparison to: {output_file}")
    
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("SUMMARY STATISTICS")
    print("="*80)
    
    for label in labels_list:
        print(f"\n{label}:")
        print(f"  First year ({common_years[0]}): {stats_dict[label]['mean'][0]:.2f} ± {stats_dict[label]['std'][0]:.2f} mm")
        print(f"  Final year ({common_years[-1]}): {stats_dict[label]['mean'][-1]:.2f} ± {stats_dict[label]['std'][-1]:.2f} mm")
        print(f"  Total change: {stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0]:.2f} mm")
        print(f"  Average: {np.mean(stats_dict[label]['mean']):.2f} mm")
        print(f"  Average rate: {(stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0]) / (common_years[-1] - common_years[0]):.2f} mm/year")
    
    # Verification: Check if All = LWS + Sterodynamics
    if 'All (Total)' in stats_dict and 'LWS' in stats_dict and 'Sterodynamics' in stats_dict:
        print("\n" + "="*80)
        print("VERIFICATION: All = LWS + Sterodynamics")
        print("="*80)
        
        all_mean = stats_dict['All (Total)']['mean']
        lws_mean = stats_dict['LWS']['mean']
        stereo_mean = stats_dict['Sterodynamics']['mean']
        sum_mean = lws_mean + stereo_mean
        
        diff = all_mean - sum_mean
        max_diff = np.max(np.abs(diff))
        mean_diff = np.mean(np.abs(diff))
        
        print(f"Max absolute difference: {max_diff:.6e} mm")
        print(f"Mean absolute difference: {mean_diff:.6e} mm")
        
        if max_diff < 1e-3:
            print("✅ VERIFICATION PASSED: All = LWS + Sterodynamics")
        else:
            print("⚠️  WARNING: Significant differences detected")
    
    return stats_dict

# Run the comparison
compare_gslr_files_barchart(output_dir='./data/output')

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Disable numbagg to avoid NumPy compatibility issues
xr.set_options(use_bottleneck=False, use_numbagg=False)

def load_lslr_file(filepath):
    """Load LSLR data from NetCDF file"""
    ds = xr.open_dataset(filepath)
    
    # Find the sea level variable
    possible_names = ['sea_level_change', 'sealevel_change', 'slr', 'sea_level']
    var_name = None
    
    for name in possible_names:
        if name in ds.data_vars:
            var_name = name
            break
    
    if var_name is None:
        data_vars = [v for v in ds.data_vars if v not in ['lat', 'lon']]
        if data_vars:
            var_name = data_vars[0]
        else:
            raise ValueError(f"Could not find sea level variable in {filepath}")
    
    data = ds[var_name]
    years = ds['years'].values if 'years' in ds else None
    
    # Squeeze out locations dimension if present
    if 'locations' in data.dims:
        data = data.squeeze('locations')
    
    return data, years, ds, var_name

def compare_lslr_files_barchart(output_dir='./data/output'):
    """
    Create bar chart comparisons of totaled LSLR files
    """
    
    print("="*80)
    print("LOADING LSLR FILES FOR BAR CHART COMPARISON")
    print("="*80)
    
    # Define files to compare
    files_to_load = {
        'All (Total)': 'totaled_output_all_lslr.nc',
        'LWS': 'totaled_output_lws_lslr.nc',
        'Sterodynamics': 'totaled_output_sterodynamics_lslr.nc'
    }
    
    # Load all files
    data_dict = {}
    for label, filename in files_to_load.items():
        filepath = Path(output_dir) / filename
        if filepath.exists():
            data, years, ds, var_name = load_lslr_file(filepath)
            data_dict[label] = {
                'data': data,
                'years': years,
                'ds': ds,
                'var_name': var_name
            }
            print(f"\n{label}: {filename}")
            print(f"  Shape: {data.shape}")
            print(f"  Dimensions: {data.dims}")
            print(f"  Years: {len(years)} points from {years[0]} to {years[-1]}")
        else:
            print(f"\nWARNING: File not found: {filepath}")
    
    if len(data_dict) == 0:
        print("ERROR: No files could be loaded!")
        return
    
    # Find common years across all datasets
    all_years = [info['years'] for info in data_dict.values()]
    common_years = all_years[0]
    for years in all_years[1:]:
        common_years = np.intersect1d(common_years, years)
    
    print(f"\nCommon years: {len(common_years)} years from {common_years[0]} to {common_years[-1]}")
    
    # Calculate statistics for each file at common years
    stats_dict = {}
    for label, info in data_dict.items():
        data = info['data']
        years = info['years']
        
        # Subset to common years
        year_indices = [np.where(years == y)[0][0] for y in common_years]
        
        # Find year dimension
        year_dim = None
        for dim in data.dims:
            if 'year' in dim.lower():
                year_dim = dim
                break
        
        if year_dim:
            data_subset = data.isel({year_dim: year_indices})
        else:
            data_subset = data
        
        # Calculate statistics across samples
        mean_timeseries = data_subset.mean(dim='samples').values
        std_timeseries = data_subset.std(dim='samples').values
        p05_timeseries = data_subset.quantile(0.05, dim='samples').values
        p95_timeseries = data_subset.quantile(0.95, dim='samples').values
        median_timeseries = data_subset.quantile(0.5, dim='samples').values
        
        stats_dict[label] = {
            'mean': mean_timeseries,
            'std': std_timeseries,
            'p05': p05_timeseries,
            'p95': p95_timeseries,
            'median': median_timeseries,
            'data_subset': data_subset
        }
    
    # Create comprehensive bar chart visualization
    fig = plt.figure(figsize=(18, 12))
    
    # Define colors for each dataset
    colors = {
        'All (Total)': '#2E86AB',
        'LWS': '#A23B72',
        'Sterodynamics': '#F18F01'
    }
    
    labels_list = list(stats_dict.keys())
    n_labels = len(labels_list)
    
    # Plot 1: Mean values by year (grouped bar chart)
    ax1 = plt.subplot(3, 3, 1)
    x = np.arange(len(common_years))
    width = 0.25
    
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax1.bar(x + offset, stats_dict[label]['mean'], width, 
                label=label, color=colors[label], alpha=0.8)
    
    ax1.set_xlabel('Year')
    ax1.set_ylabel('Mean LSLR (mm)')
    ax1.set_title('Mean Local Sea Level Rise by Year')
    ax1.set_xticks(x)
    ax1.set_xticklabels(common_years, rotation=45)
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Final year comparison (mean ± std)
    ax2 = plt.subplot(3, 3, 2)
    final_means = [stats_dict[label]['mean'][-1] for label in labels_list]
    final_stds = [stats_dict[label]['std'][-1] for label in labels_list]
    
    x_pos = np.arange(n_labels)
    bars = ax2.bar(x_pos, final_means, yerr=final_stds, 
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax2.set_ylabel('LSLR (mm)')
    ax2.set_title(f'Final Year ({common_years[-1]}) Comparison\n(Mean ± Std Dev)')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(labels_list, rotation=15, ha='right')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, (bar, mean, std) in enumerate(zip(bars, final_means, final_stds)):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{mean:.1f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 3: First year comparison
    ax3 = plt.subplot(3, 3, 3)
    first_means = [stats_dict[label]['mean'][0] for label in labels_list]
    first_stds = [stats_dict[label]['std'][0] for label in labels_list]
    
    bars = ax3.bar(x_pos, first_means, yerr=first_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax3.set_ylabel('LSLR (mm)')
    ax3.set_title(f'First Year ({common_years[0]}) Comparison\n(Mean ± Std Dev)')
    ax3.set_xticks(x_pos)
    ax3.set_xticklabels(labels_list, rotation=15, ha='right')
    ax3.grid(True, alpha=0.3, axis='y')
    
    for i, (bar, mean, std) in enumerate(zip(bars, first_means, first_stds)):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{mean:.1f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 4: Total change (final - first year)
    ax4 = plt.subplot(3, 3, 4)
    total_changes = [stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0] 
                     for label in labels_list]
    
    bars = ax4.bar(x_pos, total_changes,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, edgecolor='black', linewidth=1.5)
    
    ax4.set_ylabel('Total Change (mm)')
    ax4.set_title(f'Total LSLR Change\n({common_years[0]} to {common_years[-1]})')
    ax4.set_xticks(x_pos)
    ax4.set_xticklabels(labels_list, rotation=15, ha='right')
    ax4.grid(True, alpha=0.3, axis='y')
    
    for bar, change in zip(bars, total_changes):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{change:.1f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 5: Standard deviation by year
    ax5 = plt.subplot(3, 3, 5)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax5.bar(x + offset, stats_dict[label]['std'], width,
                label=label, color=colors[label], alpha=0.8)
    
    ax5.set_xlabel('Year')
    ax5.set_ylabel('Std Dev (mm)')
    ax5.set_title('Ensemble Spread (Std Dev) by Year')
    ax5.set_xticks(x)
    ax5.set_xticklabels(common_years, rotation=45)
    ax5.legend()
    ax5.grid(True, alpha=0.3, axis='y')
    
    # Plot 6: Uncertainty range (P95 - P05) by year
    ax6 = plt.subplot(3, 3, 6)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        uncertainty = stats_dict[label]['p95'] - stats_dict[label]['p05']
        ax6.bar(x + offset, uncertainty, width,
                label=label, color=colors[label], alpha=0.8)
    
    ax6.set_xlabel('Year')
    ax6.set_ylabel('Uncertainty Range (mm)')
    ax6.set_title('90% Confidence Interval (P95-P05) by Year')
    ax6.set_xticks(x)
    ax6.set_xticklabels(common_years, rotation=45)
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')
    
    # Plot 7: Average across all years
    ax7 = plt.subplot(3, 3, 7)
    avg_means = [np.mean(stats_dict[label]['mean']) for label in labels_list]
    avg_stds = [np.mean(stats_dict[label]['std']) for label in labels_list]
    
    bars = ax7.bar(x_pos, avg_means, yerr=avg_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax7.set_ylabel('Average LSLR (mm)')
    ax7.set_title(f'Average Across All Years\n({common_years[0]}-{common_years[-1]})')
    ax7.set_xticks(x_pos)
    ax7.set_xticklabels(labels_list, rotation=15, ha='right')
    ax7.grid(True, alpha=0.3, axis='y')
    
    for bar, mean in zip(bars, avg_means):
        height = bar.get_height()
        ax7.text(bar.get_x() + bar.get_width()/2., height,
                f'{mean:.1f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 8: Contribution percentages (relative to total at final year)
    ax8 = plt.subplot(3, 3, 8)
    if 'All (Total)' in stats_dict:
        total_final = stats_dict['All (Total)']['mean'][-1]
        contributions = []
        contrib_labels = []
        
        for label in labels_list:
            if label != 'All (Total)':
                final_val = stats_dict[label]['mean'][-1]
                pct = (final_val / total_final) * 100 if total_final != 0 else 0
                contributions.append(pct)
                contrib_labels.append(label)
        
        bars = ax8.bar(range(len(contributions)), contributions,
                       color=[colors[label] for label in contrib_labels],
                       alpha=0.8, edgecolor='black', linewidth=1.5)
        
        ax8.set_ylabel('Contribution (%)')
        ax8.set_title(f'Component Contribution to Total\n(Final Year: {common_years[-1]})')
        ax8.set_xticks(range(len(contributions)))
        ax8.set_xticklabels(contrib_labels, rotation=15, ha='right')
        ax8.grid(True, alpha=0.3, axis='y')
        ax8.axhline(y=100, color='red', linestyle='--', linewidth=2, alpha=0.5, label='100%')
        
        for bar, pct in zip(bars, contributions):
            height = bar.get_height()
            ax8.text(bar.get_x() + bar.get_width()/2., height,
                    f'{pct:.1f}%',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 9: Rate of change (difference between consecutive years)
    ax9 = plt.subplot(3, 3, 9)
    for label in labels_list:
        mean_vals = stats_dict[label]['mean']
        rates = np.diff(mean_vals) / np.diff(common_years)
        ax9.plot(common_years[:-1], rates, marker='o', linewidth=2,
                label=label, color=colors[label])
    
    ax9.set_xlabel('Year')
    ax9.set_ylabel('Rate of Change (mm/year)')
    ax9.set_title('Rate of LSLR Change')
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    ax9.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    plt.suptitle('LSLR Comparison: All vs Components\n' +
                 f'({len(common_years)} years from {common_years[0]} to {common_years[-1]})',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save figure
    output_file = Path(output_dir) / 'lslr_barchart_comparison.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved bar chart comparison to: {output_file}")
    
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("SUMMARY STATISTICS")
    print("="*80)
    
    for label in labels_list:
        print(f"\n{label}:")
        print(f"  First year ({common_years[0]}): {stats_dict[label]['mean'][0]:.2f} ± {stats_dict[label]['std'][0]:.2f} mm")
        print(f"  Final year ({common_years[-1]}): {stats_dict[label]['mean'][-1]:.2f} ± {stats_dict[label]['std'][-1]:.2f} mm")
        print(f"  Total change: {stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0]:.2f} mm")
        print(f"  Average: {np.mean(stats_dict[label]['mean']):.2f} mm")
        print(f"  Average rate: {(stats_dict[label]['mean'][-1] - stats_dict[label]['mean'][0]) / (common_years[-1] - common_years[0]):.2f} mm/year")
    
    # Verification: Check if All = LWS + Sterodynamics
    if 'All (Total)' in stats_dict and 'LWS' in stats_dict and 'Sterodynamics' in stats_dict:
        print("\n" + "="*80)
        print("VERIFICATION: All = LWS + Sterodynamics")
        print("="*80)
        
        all_mean = stats_dict['All (Total)']['mean']
        lws_mean = stats_dict['LWS']['mean']
        stereo_mean = stats_dict['Sterodynamics']['mean']
        sum_mean = lws_mean + stereo_mean
        
        diff = all_mean - sum_mean
        max_diff = np.max(np.abs(diff))
        mean_diff = np.mean(np.abs(diff))
        
        print(f"Max absolute difference: {max_diff:.6e} mm")
        print(f"Mean absolute difference: {mean_diff:.6e} mm")
        
        if max_diff < 1e-3:
            print("✅ VERIFICATION PASSED: All = LWS + Sterodynamics")
        else:
            print("⚠️  WARNING: Significant differences detected")
    
    return stats_dict

# Run the comparison
compare_lslr_files_barchart(output_dir='./data/output')

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Disable numbagg to avoid NumPy compatibility issues
xr.set_options(use_bottleneck=False, use_numbagg=False)

def load_sea_level_file(filepath):
    """Load sea level data from NetCDF file"""
    ds = xr.open_dataset(filepath)
    
    # Find the sea level variable
    possible_names = ['sea_level_change', 'sealevel_change', 'slr', 'sea_level']
    var_name = None
    
    for name in possible_names:
        if name in ds.data_vars:
            var_name = name
            break
    
    if var_name is None:
        data_vars = [v for v in ds.data_vars if v not in ['lat', 'lon']]
        if data_vars:
            var_name = data_vars[0]
        else:
            raise ValueError(f"Could not find sea level variable in {filepath}")
    
    data = ds[var_name]
    years = ds['years'].values if 'years' in ds else None
    
    # Squeeze out locations dimension if present
    if 'locations' in data.dims:
        data = data.squeeze('locations')
    
    return data, years, ds, var_name

def compare_files_barchart_sum(output_dir='./data/output', file_type='gslr'):
    """
    Create bar chart comparisons of totaled files using SUM statistics
    
    Parameters:
    -----------
    output_dir : str
        Directory containing the output files
    file_type : str
        Either 'lslr' or 'gslr'
    """
    
    print("="*80)
    print(f"LOADING {file_type.upper()} FILES FOR BAR CHART COMPARISON (SUM)")
    print("="*80)
    
    # Define files to compare
    files_to_load = {
        'All (Total)': f'totaled_output_all_{file_type}.nc',
        'LWS': f'totaled_output_lws_{file_type}.nc',
        'Sterodynamics': f'totaled_output_sterodynamics_{file_type}.nc'
    }
    
    # Load all files
    data_dict = {}
    for label, filename in files_to_load.items():
        filepath = Path(output_dir) / filename
        if filepath.exists():
            data, years, ds, var_name = load_sea_level_file(filepath)
            data_dict[label] = {
                'data': data,
                'years': years,
                'ds': ds,
                'var_name': var_name
            }
            print(f"\n{label}: {filename}")
            print(f"  Shape: {data.shape}")
            print(f"  Dimensions: {data.dims}")
            print(f"  Years: {len(years)} points from {years[0]} to {years[-1]}")
        else:
            print(f"\nWARNING: File not found: {filepath}")
    
    if len(data_dict) == 0:
        print("ERROR: No files could be loaded!")
        return
    
    # Find common years across all datasets
    all_years = [info['years'] for info in data_dict.values()]
    common_years = all_years[0]
    for years in all_years[1:]:
        common_years = np.intersect1d(common_years, years)
    
    print(f"\nCommon years: {len(common_years)} years from {common_years[0]} to {common_years[-1]}")
    
    # Calculate statistics for each file at common years
    stats_dict = {}
    for label, info in data_dict.items():
        data = info['data']
        years = info['years']
        
        # Subset to common years
        year_indices = [np.where(years == y)[0][0] for y in common_years]
        
        # Find year dimension
        year_dim = None
        for dim in data.dims:
            if 'year' in dim.lower():
                year_dim = dim
                break
        
        if year_dim:
            data_subset = data.isel({year_dim: year_indices})
        else:
            data_subset = data
        
        # Calculate SUM statistics across samples (instead of mean)
        sum_timeseries = data_subset.sum(dim='samples').values
        std_timeseries = data_subset.std(dim='samples').values
        p05_timeseries = data_subset.quantile(0.05, dim='samples').values
        p95_timeseries = data_subset.quantile(0.95, dim='samples').values
        median_timeseries = data_subset.quantile(0.5, dim='samples').values
        
        stats_dict[label] = {
            'sum': sum_timeseries,
            'std': std_timeseries,
            'p05': p05_timeseries,
            'p95': p95_timeseries,
            'median': median_timeseries,
            'data_subset': data_subset
        }
    
    # Create comprehensive bar chart visualization
    fig = plt.figure(figsize=(18, 12))
    
    # Define colors for each dataset
    colors = {
        'All (Total)': '#2E86AB',
        'LWS': '#A23B72',
        'Sterodynamics': '#F18F01'
    }
    
    labels_list = list(stats_dict.keys())
    n_labels = len(labels_list)
    
    # Plot 1: Sum values by year (grouped bar chart)
    ax1 = plt.subplot(3, 3, 1)
    x = np.arange(len(common_years))
    width = 0.25
    
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax1.bar(x + offset, stats_dict[label]['sum'], width, 
                label=label, color=colors[label], alpha=0.8)
    
    ax1.set_xlabel('Year')
    ax1.set_ylabel(f'Total {file_type.upper()} (mm × samples)')
    ax1.set_title(f'Total (Sum) {file_type.upper()} by Year')
    ax1.set_xticks(x)
    ax1.set_xticklabels(common_years, rotation=45)
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Final year comparison (sum ± std)
    ax2 = plt.subplot(3, 3, 2)
    final_sums = [stats_dict[label]['sum'][-1] for label in labels_list]
    final_stds = [stats_dict[label]['std'][-1] for label in labels_list]
    
    x_pos = np.arange(n_labels)
    bars = ax2.bar(x_pos, final_sums, yerr=final_stds, 
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax2.set_ylabel(f'{file_type.upper()} (mm × samples)')
    ax2.set_title(f'Final Year ({common_years[-1]}) Comparison\n(Sum ± Std Dev)')
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels(labels_list, rotation=15, ha='right')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, (bar, sum_val, std) in enumerate(zip(bars, final_sums, final_stds)):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{sum_val:.0f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 3: First year comparison
    ax3 = plt.subplot(3, 3, 3)
    first_sums = [stats_dict[label]['sum'][0] for label in labels_list]
    first_stds = [stats_dict[label]['std'][0] for label in labels_list]
    
    bars = ax3.bar(x_pos, first_sums, yerr=first_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax3.set_ylabel(f'{file_type.upper()} (mm × samples)')
    ax3.set_title(f'First Year ({common_years[0]}) Comparison\n(Sum ± Std Dev)')
    ax3.set_xticks(x_pos)
    ax3.set_xticklabels(labels_list, rotation=15, ha='right')
    ax3.grid(True, alpha=0.3, axis='y')
    
    for i, (bar, sum_val, std) in enumerate(zip(bars, first_sums, first_stds)):
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height + std,
                f'{sum_val:.0f}±{std:.1f}',
                ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Plot 4: Total change (final - first year)
    ax4 = plt.subplot(3, 3, 4)
    total_changes = [stats_dict[label]['sum'][-1] - stats_dict[label]['sum'][0] 
                     for label in labels_list]
    
    bars = ax4.bar(x_pos, total_changes,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, edgecolor='black', linewidth=1.5)
    
    ax4.set_ylabel('Total Change (mm × samples)')
    ax4.set_title(f'Total {file_type.upper()} Change\n({common_years[0]} to {common_years[-1]})')
    ax4.set_xticks(x_pos)
    ax4.set_xticklabels(labels_list, rotation=15, ha='right')
    ax4.grid(True, alpha=0.3, axis='y')
    
    for bar, change in zip(bars, total_changes):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{change:.0f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 5: Standard deviation by year
    ax5 = plt.subplot(3, 3, 5)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        ax5.bar(x + offset, stats_dict[label]['std'], width,
                label=label, color=colors[label], alpha=0.8)
    
    ax5.set_xlabel('Year')
    ax5.set_ylabel('Std Dev (mm)')
    ax5.set_title('Ensemble Spread (Std Dev) by Year')
    ax5.set_xticks(x)
    ax5.set_xticklabels(common_years, rotation=45)
    ax5.legend()
    ax5.grid(True, alpha=0.3, axis='y')
    
    # Plot 6: Uncertainty range (P95 - P05) by year
    ax6 = plt.subplot(3, 3, 6)
    for i, label in enumerate(labels_list):
        offset = (i - n_labels/2 + 0.5) * width
        uncertainty = stats_dict[label]['p95'] - stats_dict[label]['p05']
        ax6.bar(x + offset, uncertainty, width,
                label=label, color=colors[label], alpha=0.8)
    
    ax6.set_xlabel('Year')
    ax6.set_ylabel('Uncertainty Range (mm)')
    ax6.set_title('90% Confidence Interval (P95-P05) by Year')
    ax6.set_xticks(x)
    ax6.set_xticklabels(common_years, rotation=45)
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')
    
    # Plot 7: Average across all years (of sum values)
    ax7 = plt.subplot(3, 3, 7)
    avg_sums = [np.mean(stats_dict[label]['sum']) for label in labels_list]
    avg_stds = [np.mean(stats_dict[label]['std']) for label in labels_list]
    
    bars = ax7.bar(x_pos, avg_sums, yerr=avg_stds,
                   color=[colors[label] for label in labels_list],
                   alpha=0.8, capsize=10, edgecolor='black', linewidth=1.5)
    
    ax7.set_ylabel(f'Average {file_type.upper()} (mm × samples)')
    ax7.set_title(f'Average Sum Across All Years\n({common_years[0]}-{common_years[-1]})')
    ax7.set_xticks(x_pos)
    ax7.set_xticklabels(labels_list, rotation=15, ha='right')
    ax7.grid(True, alpha=0.3, axis='y')
    
    for bar, avg in zip(bars, avg_sums):
        height = bar.get_height()
        ax7.text(bar.get_x() + bar.get_width()/2., height,
                f'{avg:.0f}',
                ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 8: Contribution percentages (relative to total at final year)
    ax8 = plt.subplot(3, 3, 8)
    if 'All (Total)' in stats_dict:
        total_final = stats_dict['All (Total)']['sum'][-1]
        contributions = []
        contrib_labels = []
        
        for label in labels_list:
            if label != 'All (Total)':
                final_val = stats_dict[label]['sum'][-1]
                pct = (final_val / total_final) * 100 if total_final != 0 else 0
                contributions.append(pct)
                contrib_labels.append(label)
        
        bars = ax8.bar(range(len(contributions)), contributions,
                       color=[colors[label] for label in contrib_labels],
                       alpha=0.8, edgecolor='black', linewidth=1.5)
        
        ax8.set_ylabel('Contribution (%)')
        ax8.set_title(f'Component Contribution to Total\n(Final Year: {common_years[-1]})')
        ax8.set_xticks(range(len(contributions)))
        ax8.set_xticklabels(contrib_labels, rotation=15, ha='right')
        ax8.grid(True, alpha=0.3, axis='y')
        ax8.axhline(y=100, color='red', linestyle='--', linewidth=2, alpha=0.5, label='100%')
        
        for bar, pct in zip(bars, contributions):
            height = bar.get_height()
            ax8.text(bar.get_x() + bar.get_width()/2., height,
                    f'{pct:.1f}%',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    # Plot 9: Rate of change (difference between consecutive years)
    ax9 = plt.subplot(3, 3, 9)
    for label in labels_list:
        sum_vals = stats_dict[label]['sum']
        rates = np.diff(sum_vals) / np.diff(common_years)
        ax9.plot(common_years[:-1], rates, marker='o', linewidth=2,
                label=label, color=colors[label])
    
    ax9.set_xlabel('Year')
    ax9.set_ylabel('Rate of Change ((mm × samples)/year)')
    ax9.set_title(f'Rate of {file_type.upper()} Change')
    ax9.legend()
    ax9.grid(True, alpha=0.3)
    ax9.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    plt.suptitle(f'{file_type.upper()} Comparison (SUM): All vs Components\n' +
                 f'({len(common_years)} years from {common_years[0]} to {common_years[-1]})',
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    # Save figure
    output_file = Path(output_dir) / f'{file_type}_barchart_comparison_sum.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved bar chart comparison to: {output_file}")
    
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*80)
    print("SUMMARY STATISTICS (SUM)")
    print("="*80)
    
    for label in labels_list:
        print(f"\n{label}:")
        print(f"  First year ({common_years[0]}): {stats_dict[label]['sum'][0]:.2f} ± {stats_dict[label]['std'][0]:.2f} mm×samples")
        print(f"  Final year ({common_years[-1]}): {stats_dict[label]['sum'][-1]:.2f} ± {stats_dict[label]['std'][-1]:.2f} mm×samples")
        print(f"  Total change: {stats_dict[label]['sum'][-1] - stats_dict[label]['sum'][0]:.2f} mm×samples")
        print(f"  Average: {np.mean(stats_dict[label]['sum']):.2f} mm×samples")
        print(f"  Average rate: {(stats_dict[label]['sum'][-1] - stats_dict[label]['sum'][0]) / (common_years[-1] - common_years[0]):.2f} (mm×samples)/year")
    
    # Verification: Check if All = LWS + Sterodynamics
    if 'All (Total)' in stats_dict and 'LWS' in stats_dict and 'Sterodynamics' in stats_dict:
        print("\n" + "="*80)
        print("VERIFICATION: All = LWS + Sterodynamics")
        print("="*80)
        
        all_sum = stats_dict['All (Total)']['sum']
        lws_sum = stats_dict['LWS']['sum']
        stereo_sum = stats_dict['Sterodynamics']['sum']
        sum_sum = lws_sum + stereo_sum
        
        diff = all_sum - sum_sum
        max_diff = np.max(np.abs(diff))
        mean_diff = np.mean(np.abs(diff))
        
        print(f"Max absolute difference: {max_diff:.6e} mm×samples")
        print(f"Mean absolute difference: {mean_diff:.6e} mm×samples")
        
        if max_diff < 1e-3:
            print("✅ VERIFICATION PASSED: All = LWS + Sterodynamics")
        else:
            print("⚠️  WARNING: Significant differences detected")
    
    return stats_dict

# Run comparisons for both GSLR and LSLR using SUM
print("="*80)
print("COMPARING SEA LEVEL TOTALS USING SUM")
print("="*80)

# Compare GSLR files with SUM
gslr_sum_results = compare_files_barchart_sum(output_dir='./data/output', file_type='gslr')

# Compare LSLR files with SUM
lslr_sum_results = compare_files_barchart_sum(output_dir='./data/output', file_type='lslr')

print("\n" + "="*80)
print("SUMMARY (SUM METHOD)")
print("="*80)
if gslr_sum_results:
    print(f"GSLR Final Year Total (All): {gslr_sum_results['All (Total)']['sum'][-1]:.2f} mm×samples")
if lslr_sum_results:
    print(f"LSLR Final Year Total (All): {lslr_sum_results['All (Total)']['sum'][-1]:.2f} mm×samples")
print("="*80)