In [None]:
#!/usr/bin/env python
"""
figure1_interpretability_overview.py â€” Publication Figure 1
=========================================================================

Creates Figure 1: Embedding Interpretability Overview

Configuration:
 - More space between panel titles and figures (increased pad)
 - Panel (b) reference lines moved to avoid title overlap
 - Removed statistics box from panel (c)
 - Increased all font sizes for labels, ticks, and numbers

Usage:
 python 04_figure1_interpretability_overview.py

"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, Rectangle
from matplotlib.lines import Line2D
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import pdist
import os
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION
# =============================================================================

# Publication-quality settings - INCREASED FONT SIZES
plt.rcParams.update({
  'font.family': 'DejaVu Sans',
  'font.size': 11, # Base font size increased
  'axes.linewidth': 1.0,
  'axes.labelsize': 13, # Axis labels larger
  'axes.titlesize': 14, # Panel titles larger
  'axes.titleweight': 'bold',
  'xtick.labelsize': 10, # Tick labels larger
  'ytick.labelsize': 10,
  'figure.dpi': 150,
  'savefig.dpi': 300,
  'savefig.bbox': 'tight',
  'savefig.pad_inches': 0.15,
  'legend.fontsize': 10,
  'legend.framealpha': 0.95,
})

RESULTS_DIR = 'results'
FIG_DIR = 'results/figures'
DPI = 300

# Human-readable variable labels (short versions for figures)
ENV_LABELS = {
  'elevation': 'Elevation', 'slope': 'Slope', 'aspect': 'Aspect',
  'soil_clay_pct': 'Clay %', 'soil_organic_carbon': 'Organic C',
  'soil_ph': 'Soil pH', 'soil_water_capacity': 'Water Cap.',
  'flow_acc_log': 'Flow Acc.', 'tree_cover_2000': 'Tree Cover',
  'impervious_pct': 'Impervious',
  'ndvi_mean': 'NDVI', 'ndvi_max': 'NDVI max',
  'evi_mean': 'EVI', 'lai_mean': 'LAI',
  'lst_day_c': 'LST Day', 'lst_night_c': 'LST Night',
  'albedo': 'Albedo',
  'precip_annual_mm': 'Precip.', 'precip_max_month': 'Max Precip.',
  'temp_mean_c': 'Temp Mean', 'temp_range_c': 'Temp Range',
  'soil_moisture': 'Soil Moist.', 'runoff_annual_mm': 'Runoff',
  'et_annual_mm': 'ET',
  'nightlights': 'Nightlights', 'pop_density': 'Pop. Density',
}

# Variable to category mapping
ENV_CATEGORY = {
  'elevation': 'Terrain', 'slope': 'Terrain', 'aspect': 'Terrain',
  'soil_clay_pct': 'Soil', 'soil_organic_carbon': 'Soil',
  'soil_ph': 'Soil', 'soil_water_capacity': 'Soil',
  'flow_acc_log': 'Hydrology', 'tree_cover_2000': 'Vegetation',
  'impervious_pct': 'Urban',
  'ndvi_mean': 'Vegetation', 'ndvi_max': 'Vegetation',
  'evi_mean': 'Vegetation', 'lai_mean': 'Vegetation',
  'lst_day_c': 'Temperature', 'lst_night_c': 'Temperature',
  'albedo': 'Radiation',
  'precip_annual_mm': 'Climate', 'precip_max_month': 'Climate',
  'temp_mean_c': 'Temperature', 'temp_range_c': 'Temperature',
  'soil_moisture': 'Hydrology', 'runoff_annual_mm': 'Hydrology',
  'et_annual_mm': 'Hydrology',
  'nightlights': 'Urban', 'pop_density': 'Urban',
}

# Distinct, colorblind-friendly category colors
CATEGORY_COLORS = {
  'Terrain':   '#8B4513', # Saddle brown
  'Soil':    '#DAA520', # Goldenrod
  'Vegetation': '#228B22', # Forest green
  'Temperature': '#DC143C', # Crimson
  'Climate':   '#4169E1', # Royal blue
  'Hydrology':  '#00CED1', # Dark turquoise
  'Urban':    '#696969', # Dim gray
  'Radiation':  '#FFD700', # Gold
}

# Category order for consistent legends
CATEGORY_ORDER = ['Terrain', 'Soil', 'Vegetation', 'Temperature', 
         'Climate', 'Hydrology', 'Urban', 'Radiation']


# =============================================================================
# DATA LOADING
# =============================================================================

def load_results():
  """Load all analysis results needed for Figure 1."""
  print("Loading results...")
  R = {}
  
  # Spearman correlation matrix
  path = f'{RESULTS_DIR}/spearman_matrix.csv'
  if os.path.exists(path):
    R['spearman'] = pd.read_csv(path, index_col=0)
    print(f" âœ“ Spearman matrix: {R['spearman'].shape}")
  else:
    R['spearman'] = None
    print(f" âœ— Spearman matrix not found")
  
  # Dimension dictionary
  path = f'{RESULTS_DIR}/dimension_dictionary.csv'
  if os.path.exists(path):
    R['dd'] = pd.read_csv(path)
    print(f" âœ“ Dimension dictionary: {len(R['dd'])} dimensions")
  else:
    R['dd'] = None
    print(f" âœ— Dimension dictionary not found")
  
  # RF RÂ² scores
  path = f'{RESULTS_DIR}/rf_r2_scores.csv'
  if os.path.exists(path):
    R['rf_r2'] = pd.read_csv(path)
    print(f" âœ“ RF RÂ² scores: {len(R['rf_r2'])} variables")
  else:
    R['rf_r2'] = None
    print(f" âœ— RF RÂ² scores not found")
  
  # Transformer RÂ² scores
  path = f'{RESULTS_DIR}/transformer_r2_scores.csv'
  if os.path.exists(path):
    R['trans_r2'] = pd.read_csv(path)
    print(f" âœ“ Transformer RÂ² scores: {len(R['trans_r2'])} variables")
  else:
    R['trans_r2'] = None
    print(f" âœ— Transformer RÂ² scores not found")
  
  # Clustering orders (for heatmap)
  path = f'{RESULTS_DIR}/clustering_orders.npz'
  if os.path.exists(path):
    co = np.load(path)
    R['row_order'] = co['row_order']
    R['col_order'] = co['col_order']
    print(f" âœ“ Clustering orders loaded")
  else:
    R['row_order'] = None
    R['col_order'] = None
    print(f" âœ— Clustering orders not found (will compute)")
  
  return R


# =============================================================================
# PANEL (A): BI-CLUSTERED HEATMAP
# =============================================================================

def panel_a_heatmap(ax, corr_df, row_order, col_order):
  """
  Bi-clustered Spearman correlation heatmap with category color bars.
  """
  ae_cols = corr_df.index.tolist()
  env_cols = corr_df.columns.tolist()
  
  # Apply clustering order
  if row_order is not None and col_order is not None:
    dim_ordered = [ae_cols[i] for i in row_order]
    env_ordered = [env_cols[i] for i in col_order]
  else:
    # Compute clustering if not provided
    corr_vals = corr_df.values.copy()
    corr_vals[np.isnan(corr_vals)] = 0
    row_link = linkage(pdist(corr_vals, 'correlation'), method='average')
    col_link = linkage(pdist(corr_vals.T, 'correlation'), method='average')
    row_order = leaves_list(row_link)
    col_order = leaves_list(col_link)
    dim_ordered = [ae_cols[i] for i in row_order]
    env_ordered = [env_cols[i] for i in col_order]
  
  corr_ordered = corr_df.loc[dim_ordered, env_ordered].values
  
  # Create masked array for NaN handling
  hm = np.ma.masked_invalid(corr_ordered)
  
  # Colormap
  cmap = plt.cm.RdBu_r.copy()
  cmap.set_bad('#F5F5F5') # Light gray for NaN
  
  # Plot heatmap
  im = ax.imshow(hm, cmap=cmap, vmin=-1, vmax=1, aspect='auto', 
          interpolation='nearest')
  
  # X-axis labels (environmental variables)
  ax.set_xticks(np.arange(len(env_ordered)))
  ax.set_xticklabels([ENV_LABELS.get(v, v) for v in env_ordered],
            rotation=55, ha='right', fontsize=8)
  
  # Y-axis labels (dimensions) - show every 4th for readability
  ax.set_yticks(np.arange(0, len(dim_ordered), 4))
  ax.set_yticklabels([dim_ordered[i] for i in range(0, len(dim_ordered), 4)],
            fontsize=8, family='monospace')
  
  # Category color bar along x-axis (top)
  for idx, v in enumerate(env_ordered):
    cat = ENV_CATEGORY.get(v, 'Urban')
    color = CATEGORY_COLORS.get(cat, '#CCCCCC')
    rect = Rectangle((idx - 0.5, -3), 1, 2.5, 
             facecolor=color, edgecolor='white', 
             linewidth=0.3, clip_on=False)
    ax.add_patch(rect)
  
  # Axis labels
  ax.set_xlabel('Environmental Variable', fontweight='bold', fontsize=12, labelpad=18)
  ax.set_ylabel('AlphaEarth Dimension', fontweight='bold', fontsize=12)
  
  # Colorbar
  cbar = plt.colorbar(im, ax=ax, fraction=0.025, pad=0.02, shrink=0.85)
  cbar.set_label('Spearman Ï', fontsize=11)
  cbar.ax.tick_params(labelsize=9)
  
  # Title with MORE PADDING
  ax.set_title('(a) Spearman Correlation Matrix (64 dims Ã— 26 vars)',
         fontweight='bold', fontsize=13, pad=20)
  
  return dim_ordered, env_ordered


# =============================================================================
# PANEL (B): TOP DIMENSIONS LOLLIPOP CHART
# =============================================================================

def panel_b_top_dimensions(ax, dd, n_top=20):
  """
  Horizontal lollipop chart showing top interpretable dimensions.
  Clean, readable, with category coloring.
  : Reference lines and labels positioned to avoid title overlap.
  """
  # Sort by absolute Spearman correlation
  top_dims = dd.sort_values('sp_abs_max', ascending=True).tail(n_top).copy()
  
  y_pos = np.arange(len(top_dims))
  
  # Get colors based on category
  colors = [CATEGORY_COLORS.get(cat, '#999999') for cat in top_dims['sp_category']]
  
  # Draw horizontal lines (stems)
  for i, (_, row) in enumerate(top_dims.iterrows()):
    ax.hlines(y=i, xmin=0, xmax=row['sp_abs_max'], 
         color=colors[i], linewidth=2, alpha=0.7)
  
  # Draw dots
  ax.scatter(top_dims['sp_abs_max'], y_pos, c=colors, s=100, 
        edgecolors='white', linewidths=2, zorder=5)
  
  # Y-axis: dimension labels
  ax.set_yticks(y_pos)
  ax.set_yticklabels(top_dims['dimension'], fontsize=10, family='monospace',
            fontweight='bold')
  
  # Add variable annotations on the right
  for i, (_, row) in enumerate(top_dims.iterrows()):
    var_label = ENV_LABELS.get(row['sp_primary'], row['sp_primary'])
    sign = '+' if row['sp_rho'] > 0 else ''
    label_text = f"{var_label} ({sign}{row['sp_rho']:.2f})"
    
    # Position text to the right of the dot
    ax.annotate(label_text, 
          xy=(row['sp_abs_max'] + 0.02, i),
          va='center', ha='left', fontsize=9,
          color=colors[i], fontweight='medium')
  
  # Reference lines - MOVED LOWER to avoid title overlap
  ax.axvline(0.7, color='#2E7D32', linestyle='--', linewidth=1.2, alpha=0.7)
  ax.axvline(0.5, color='#F57C00', linestyle='--', linewidth=1.2, alpha=0.7)
  
  # Reference line labels - POSITIONED AT BOTTOM instead of top
  ax.text(0.7, -1.5, '|Ï|=0.7', fontsize=9, color='#2E7D32',
      ha='center', va='top', fontweight='bold')
  ax.text(0.5, -1.5, '|Ï|=0.5', fontsize=9, color='#F57C00',
      ha='center', va='top', fontweight='bold')
  
  # Axis settings
  ax.set_xlim(0, 1.0)
  ax.set_ylim(-2.5, len(top_dims) - 0.5) # Extended bottom for labels
  ax.set_xlabel('|Spearman Ï|', fontweight='bold', fontsize=12)
  ax.grid(axis='x', alpha=0.3, linestyle='-', linewidth=0.5)
  
  # Remove top and right spines
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  
  # Title with MORE PADDING
  ax.set_title('(b) Top 20 Interpretable Dimensions',
         fontweight='bold', fontsize=13, pad=20)


# =============================================================================
# PANEL (C): PREDICTIVE PERFORMANCE
# =============================================================================

def panel_c_predictive_performance(ax, rf_r2, trans_r2):
  """
  Grouped bar chart showing RF and Transformer RÂ² for all variables.
  Sorted by RF RÂ², colored by category.
  REMOVED: Statistics summary box.
  """
  # Merge RF and Transformer results
  if rf_r2 is None:
    ax.text(0.5, 0.5, 'RÂ² data not available', ha='center', va='center',
        transform=ax.transAxes, fontsize=14)
    ax.set_title('(c) Predictive Power from 64-D Embeddings', fontweight='bold', pad=20)
    return
  
  # Sort by RF RÂ²
  rf_r2_sorted = rf_r2.sort_values('r2_cv', ascending=False).copy()
  vars_sorted = rf_r2_sorted['variable'].tolist()
  
  # Get Transformer RÂ² values
  trans_vals = []
  if trans_r2 is not None:
    for v in vars_sorted:
      row = trans_r2[trans_r2['variable'] == v]
      if len(row) > 0 and 'val_r2' in row.columns:
        trans_vals.append(row['val_r2'].values[0])
      else:
        trans_vals.append(np.nan)
  else:
    trans_vals = [np.nan] * len(vars_sorted)
  
  rf_vals = rf_r2_sorted['r2_cv'].tolist()
  
  # Positions
  x = np.arange(len(vars_sorted))
  width = 0.35
  
  # Colors by category
  colors = [CATEGORY_COLORS.get(ENV_CATEGORY.get(v, 'Urban'), '#999999') 
       for v in vars_sorted]
  
  # Bars
  bars_rf = ax.bar(x - width/2, rf_vals, width, color=colors, 
           alpha=0.85, edgecolor='white', linewidth=0.5,
           label='Random Forest')
  
  # Transformer bars (slightly transparent with hatch)
  if any(np.isfinite(trans_vals)):
    bars_trans = ax.bar(x + width/2, trans_vals, width, color=colors,
              alpha=0.45, edgecolor='black', linewidth=0.5,
              hatch='///', label='Transformer')
  
  # X-axis labels
  ax.set_xticks(x)
  ax.set_xticklabels([ENV_LABELS.get(v, v) for v in vars_sorted],
            rotation=55, ha='right', fontsize=9)
  
  # Reference lines
  ax.axhline(0.9, color='#2E7D32', linestyle='--', linewidth=1.2, alpha=0.7)
  ax.axhline(0.7, color='#F57C00', linestyle='--', linewidth=1.2, alpha=0.7)
  ax.axhline(0.5, color='#D32F2F', linestyle='--', linewidth=1.2, alpha=0.7)
  
  # Reference line labels (on the right side)
  ax.text(len(vars_sorted) - 0.3, 0.91, 'RÂ²=0.9', fontsize=9, color='#2E7D32',
      ha='right', va='bottom', fontweight='bold')
  ax.text(len(vars_sorted) - 0.3, 0.71, 'RÂ²=0.7', fontsize=9, color='#F57C00',
      ha='right', va='bottom', fontweight='bold')
  ax.text(len(vars_sorted) - 0.3, 0.51, 'RÂ²=0.5', fontsize=9, color='#D32F2F',
      ha='right', va='bottom', fontweight='bold')
  
  # Axis settings
  ax.set_ylabel('RÂ² (5-fold CV)', fontweight='bold', fontsize=12)
  ax.set_ylim(0, 1.05)
  ax.set_xlim(-0.5, len(vars_sorted) - 0.5)
  ax.grid(axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
  
  # Remove top and right spines
  ax.spines['top'].set_visible(False)
  ax.spines['right'].set_visible(False)
  
  # Legend for RF vs Transformer
  legend_elements = [
    mpatches.Patch(facecolor='gray', alpha=0.85, edgecolor='white',
            label='Random Forest'),
    mpatches.Patch(facecolor='gray', alpha=0.45, edgecolor='black',
            hatch='///', label='Transformer'),
  ]
  ax.legend(handles=legend_elements, loc='upper right', fontsize=10,
       framealpha=0.95, bbox_to_anchor=(0.88, 1.0))
  
  # Title with MORE PADDING
  ax.set_title('(c) Predictive Power: Environmental Variables from 64-D Embeddings',
         fontweight='bold', fontsize=13, pad=20)


# =============================================================================
# MAIN FIGURE CREATION
# =============================================================================

def create_figure1(R):
  """
  Create Figure 1: Interpretability Overview
  
  Layout:
   - Panel (a): Heatmap (left, spans height)
   - Panel (b): Top dimensions (top right)
   - Panel (c): RÂ² performance (bottom, full width)
  """
  print("\n Creating Figure 1: Interpretability Overview...")
  
  # Validate required data
  if R['spearman'] is None or R['dd'] is None:
    print("  ERROR: Missing required data (spearman_matrix.csv or dimension_dictionary.csv)")
    return
  
  # Create figure with custom layout
  fig = plt.figure(figsize=(18, 15)) # Slightly taller for more spacing
  
  # GridSpec: 2 rows, 2 columns
  # Row 0: heatmap (left) + top dims (right)
  # Row 1: RÂ² chart (spans both columns)
  gs = gridspec.GridSpec(2, 2, 
              height_ratios=[1.3, 1],
              width_ratios=[1.1, 1],
              hspace=0.40, # INCREASED vertical space
              wspace=0.30) # INCREASED horizontal space
  
  # Panel (a): Heatmap - left column, top row
  ax_a = fig.add_subplot(gs[0, 0])
  dim_ordered, env_ordered = panel_a_heatmap(
    ax_a, R['spearman'], R['row_order'], R['col_order'])
  
  # Panel (b): Top dimensions - right column, top row 
  ax_b = fig.add_subplot(gs[0, 1])
  panel_b_top_dimensions(ax_b, R['dd'], n_top=20)
  
  # Panel (c): RÂ² performance - bottom row, spans both columns
  ax_c = fig.add_subplot(gs[1, :])
  panel_c_predictive_performance(ax_c, R['rf_r2'], R['trans_r2'])
  
  # =========================================================================
  # UNIFIED CATEGORY LEGEND (bottom center)
  # =========================================================================
  
  # Create legend handles in consistent order
  legend_handles = [mpatches.Patch(facecolor=CATEGORY_COLORS[cat], 
                   edgecolor='white', linewidth=0.5,
                   label=cat)
           for cat in CATEGORY_ORDER]
  
  # Add legend below the figure
  fig.legend(handles=legend_handles, 
        loc='lower center',
        ncol=8,
        fontsize=11,
        framealpha=0.95,
        title='Environmental Variable Category',
        title_fontsize=12,
        bbox_to_anchor=(0.5, 0.01),
        handlelength=1.5,
        handleheight=1,
        columnspacing=1.2)
  
  # Main title
  fig.suptitle('Figure 1: AlphaEarth Embedding Interpretability Analysis',
         fontsize=16, fontweight='bold', y=0.98)
  
  # Adjust layout to make room for legend
  plt.tight_layout(rect=[0, 0.06, 1, 0.95])
  
  # Save
  os.makedirs(FIG_DIR, exist_ok=True)
  fig.savefig(f'{FIG_DIR}/fig1_interpretability_overview.png', 
        dpi=DPI, facecolor='white', bbox_inches='tight')
  fig.savefig(f'{FIG_DIR}/fig1_interpretability_overview.pdf', 
        dpi=DPI, facecolor='white', bbox_inches='tight')
  plt.close(fig)
  
  print("  âœ“ Saved: fig1_interpretability_overview.png/pdf")


# =============================================================================
# MAIN
# =============================================================================

def main():
  os.makedirs(FIG_DIR, exist_ok=True)
  
  print("=" * 70)
  print("FIGURE 1: INTERPRETABILITY OVERVIEW")
  print("=" * 70)
  
  R = load_results()
  
  if R['spearman'] is None or R['dd'] is None:
    print("ERROR: Required data not found. Run 01_core_analysis.py first.")
    return
  
  create_figure1(R)
  
  print("\n" + "=" * 70)
  print("COMPLETE")
  print("=" * 70)
  print(f" Output: {FIG_DIR}/fig1_interpretability_overview.png/pdf")
  print("\n Configuration:")
  print(" â€¢ Increased padding between titles and figures")
  print(" â€¢ Panel (b) reference lines now at bottom, not overlapping title")
  print(" â€¢ Removed statistics box from panel (c)")
  print(" â€¢ Increased all font sizes (labels, ticks, numbers)")


if __name__ == '__main__':
  main()

In [None]:
#!/usr/bin/env python
"""
figure2_spatial_interpretability.py â€” Publication Figure 2
===========================================================================

Creates Figure 2: Spatial Interpretability Demonstration

Features:
  - 6 dimension-variable pairs (6Ã—2 grid) covering different categories
  - Shows Spearman Ï, RF RÂ², and Transformer RÂ² for each pair
  - No arrows in the middle
  - Better spacing between header and content
  - Improved colorbar placement
  - Generates missing grids on the fly from parquet data

Usage:
  python 05_figure2_spatial_interpretability.py

"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
import pyarrow.parquet as pq
import os
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION
# =============================================================================

plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'font.size': 10,
    'axes.linewidth': 0.8,
    'axes.labelsize': 11,
    'axes.titlesize': 11,
    'axes.titleweight': 'bold',
    'xtick.labelsize': 8,
    'ytick.labelsize': 8,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
})

RESULTS_DIR = 'results'
FIG_DIR = 'results/figures'
DATA_DIR = '../data/unified_conus'
DPI = 300

# Full environmental variable labels
ENV_LABELS_FULL = {
    'elevation': 'Elevation (m)', 
    'slope': 'Slope (Â°)', 
    'aspect': 'Aspect (Â°)',
    'soil_clay_pct': 'Soil Clay (%)', 
    'soil_organic_carbon': 'Soil Organic Carbon',
    'soil_ph': 'Soil pH', 
    'soil_water_capacity': 'Soil Water Capacity',
    'flow_acc_log': 'Flow Accumulation (log)', 
    'tree_cover_2000': 'Tree Cover (%)',
    'impervious_pct': 'Impervious Surface (%)',
    'ndvi_mean': 'NDVI (mean)', 
    'ndvi_max': 'NDVI (max)',
    'evi_mean': 'Enhanced Vegetation Index', 
    'lai_mean': 'Leaf Area Index',
    'lst_day_c': 'Land Surface Temp. Day (Â°C)', 
    'lst_night_c': 'Land Surface Temp. Night (Â°C)',
    'albedo': 'Surface Albedo',
    'precip_annual_mm': 'Annual Precipitation (mm)', 
    'precip_max_month': 'Max Monthly Precip. (mm)',
    'temp_mean_c': 'Mean Air Temperature (Â°C)', 
    'temp_range_c': 'Diurnal Temp. Range (Â°C)',
    'soil_moisture': 'Soil Moisture', 
    'runoff_annual_mm': 'Annual Runoff (mm)',
    'et_annual_mm': 'Evapotranspiration (mm/yr)',
    'nightlights': 'Nighttime Lights', 
    'pop_density': 'Population Density',
}

# Category definitions
ENV_CATEGORY = {
    'elevation': 'Terrain', 'slope': 'Terrain', 'aspect': 'Terrain',
    'soil_clay_pct': 'Soil', 'soil_organic_carbon': 'Soil',
    'soil_ph': 'Soil', 'soil_water_capacity': 'Soil',
    'flow_acc_log': 'Hydrology', 'tree_cover_2000': 'Vegetation',
    'impervious_pct': 'Urban',
    'ndvi_mean': 'Vegetation', 'ndvi_max': 'Vegetation',
    'evi_mean': 'Vegetation', 'lai_mean': 'Vegetation',
    'lst_day_c': 'Temperature', 'lst_night_c': 'Temperature',
    'albedo': 'Radiation',
    'precip_annual_mm': 'Climate', 'precip_max_month': 'Climate',
    'temp_mean_c': 'Temperature', 'temp_range_c': 'Temperature',
    'soil_moisture': 'Hydrology', 'runoff_annual_mm': 'Hydrology',
    'et_annual_mm': 'Hydrology',
    'nightlights': 'Urban', 'pop_density': 'Urban',
}

CATEGORY_COLORS = {
    'Terrain':     '#8B4513',
    'Soil':        '#DAA520',
    'Vegetation':  '#228B22',
    'Temperature': '#DC143C',
    'Climate':     '#4169E1',
    'Hydrology':   '#00CED1',
    'Urban':       '#696969',
    'Radiation':   '#FFD700',
}

# Category-appropriate colormaps
CATEGORY_CMAPS = {
    'Terrain':     'terrain',
    'Soil':        'YlOrBr',
    'Vegetation':  'YlGn',
    'Temperature': 'RdYlBu_r',
    'Climate':     'Blues',
    'Hydrology':   'GnBu',
    'Urban':       'gray_r',
    'Radiation':   'YlOrRd',
}

# CONUS extent and grid parameters
CONUS_EXTENT = [-125.0, -66.5, 24.5, 49.5]
GRID_SPACING = 0.025


# =============================================================================
# DATA LOADING
# =============================================================================

def load_results():
    """Load all analysis results needed for Figure 2."""
    print("Loading results...")
    R = {}
    
    # Dimension dictionary
    path = f'{RESULTS_DIR}/dimension_dictionary.csv'
    if os.path.exists(path):
        R['dd'] = pd.read_csv(path)
        print(f"  âœ“ Dimension dictionary: {len(R['dd'])} dimensions")
    else:
        print(f"  âœ— Dimension dictionary not found")
        return None
    
    # RF RÂ² scores
    path = f'{RESULTS_DIR}/rf_r2_scores.csv'
    if os.path.exists(path):
        R['rf_r2'] = pd.read_csv(path)
        R['rf_r2_dict'] = dict(zip(R['rf_r2']['variable'], R['rf_r2']['r2_cv']))
        print(f"  âœ“ RF RÂ² scores: {len(R['rf_r2'])} variables")
    else:
        R['rf_r2_dict'] = {}
        print(f"  âœ— RF RÂ² not found")
    
    # Transformer RÂ² scores
    path = f'{RESULTS_DIR}/transformer_r2_scores.csv'
    if os.path.exists(path):
        R['trans_r2'] = pd.read_csv(path)
        R['trans_r2_dict'] = dict(zip(R['trans_r2']['variable'], R['trans_r2']['val_r2']))
        print(f"  âœ“ Transformer RÂ² scores: {len(R['trans_r2'])} variables")
    else:
        R['trans_r2_dict'] = {}
        print(f"  âœ— Transformer RÂ² not found")
    
    # CONUS grids (if pre-computed)
    path = f'{RESULTS_DIR}/conus_grids.npz'
    if os.path.exists(path):
        R['grids'] = dict(np.load(path, allow_pickle=True))
        print(f"  âœ“ Pre-computed grids: {len(R['grids'])} grids")
    else:
        R['grids'] = {}
        print(f"  âœ— No pre-computed grids (will generate from parquet)")
    
    return R


def rasterize_from_parquet(var_name, year=2021):
    """Load variable from parquet and rasterize to grid."""
    fp = f'{DATA_DIR}/conus_{year}_unified.parquet'
    if not os.path.exists(fp):
        print(f"    âœ— File not found: {fp}")
        return None
    
    try:
        # Check if column exists
        schema_cols = pq.read_schema(fp).names
        if var_name not in schema_cols:
            print(f"    âœ— Column {var_name} not in {year} data")
            return None
        
        # Load only needed columns
        cols_needed = ['longitude', 'latitude', var_name]
        df = pd.read_parquet(fp, columns=cols_needed)
        df = df.dropna(subset=[var_name])
        
        if len(df) == 0:
            print(f"    âœ— No valid data for {var_name}")
            return None
        
        # Grid parameters
        lo, hi = CONUS_EXTENT[0], CONUS_EXTENT[1]
        la, ha = CONUS_EXTENT[2], CONUS_EXTENT[3]
        nc = int(round((hi - lo) / GRID_SPACING))
        nr = int(round((ha - la) / GRID_SPACING))
        
        # Rasterize
        grid = np.full((nr, nc), np.nan, dtype=np.float32)
        ci = np.round((df['longitude'].values - lo) / GRID_SPACING).astype(int)
        ri = np.round((df['latitude'].values - la) / GRID_SPACING).astype(int)
        ok = (ci >= 0) & (ci < nc) & (ri >= 0) & (ri < nr)
        grid[ri[ok], ci[ok]] = df[var_name].values[ok]
        
        valid_count = np.isfinite(grid).sum()
        print(f"    âœ“ Rasterized {var_name}: {valid_count:,} valid pixels")
        return grid
        
    except Exception as e:
        print(f"    âœ— Error rasterizing {var_name}: {e}")
        return None


def get_grid(R, var_name, is_dimension=False):
    """Get grid from cache or generate from parquet."""
    if is_dimension:
        key = f'dim_{var_name}'
    else:
        key = f'var_{var_name}'
    
    # Check cache first
    if key in R['grids']:
        return R['grids'][key]
    
    # Generate from parquet
    print(f"    Generating grid for {var_name}...")
    grid = rasterize_from_parquet(var_name)
    
    if grid is not None:
        R['grids'][key] = grid
    
    return grid


def select_six_diverse_pairs(dd, R):
    """
    Select 6 dimension-variable pairs from different categories.
    Prioritize by |Ï| but ensure category diversity.
    """
    print("\nSelecting 6 diverse dimension-variable pairs...")
    
    # Target categories (in order of preference)
    target_categories = ['Climate', 'Temperature', 'Vegetation', 'Terrain', 'Hydrology', 'Soil']
    
    selected = []
    used_categories = set()
    
    # First pass: one per category
    for cat in target_categories:
        if len(selected) >= 6:
            break
            
        cat_dims = dd[dd['sp_category'] == cat].sort_values('sp_abs_max', ascending=False)
        
        for _, row in cat_dims.iterrows():
            dim = row['dimension']
            var = row['sp_primary']
            rho = row['sp_rho']
            
            # Try to get grids
            dim_grid = get_grid(R, dim, is_dimension=True)
            var_grid = get_grid(R, var, is_dimension=False)
            
            if dim_grid is not None and var_grid is not None:
                rf_r2 = R['rf_r2_dict'].get(var, np.nan)
                trans_r2 = R['trans_r2_dict'].get(var, np.nan)
                
                selected.append({
                    'dimension': dim,
                    'variable': var,
                    'rho': rho,
                    'category': cat,
                    'rf_r2': rf_r2,
                    'trans_r2': trans_r2,
                    'dim_grid': dim_grid,
                    'var_grid': var_grid,
                })
                used_categories.add(cat)
                print(f"  âœ“ {cat}: {dim} â†’ {var} (Ï={rho:+.3f}, RF RÂ²={rf_r2:.2f}, Trans RÂ²={trans_r2:.2f})")
                break
    
    # If we don't have 6, add more from top correlations
    if len(selected) < 6:
        dd_sorted = dd.sort_values('sp_abs_max', ascending=False)
        for _, row in dd_sorted.iterrows():
            if len(selected) >= 6:
                break
            
            dim = row['dimension']
            var = row['sp_primary']
            cat = row['sp_category']
            
            # Skip if dimension already used
            if dim in [s['dimension'] for s in selected]:
                continue
            
            dim_grid = get_grid(R, dim, is_dimension=True)
            var_grid = get_grid(R, var, is_dimension=False)
            
            if dim_grid is not None and var_grid is not None:
                rf_r2 = R['rf_r2_dict'].get(var, np.nan)
                trans_r2 = R['trans_r2_dict'].get(var, np.nan)
                
                selected.append({
                    'dimension': dim,
                    'variable': var,
                    'rho': row['sp_rho'],
                    'category': cat,
                    'rf_r2': rf_r2,
                    'trans_r2': trans_r2,
                    'dim_grid': dim_grid,
                    'var_grid': var_grid,
                })
                print(f"  âœ“ {cat}: {dim} â†’ {var} (Ï={row['sp_rho']:+.3f})")
    
    print(f"\nSelected {len(selected)} pairs for visualization")
    return selected


# =============================================================================
# FIGURE CREATION
# =============================================================================

def create_figure2(selected_pairs):
    """
    Create Figure 2: Side-by-side dimension vs variable comparison.
    
    Layout: 6 rows Ã— 2 columns
    - Left: Embedding dimension
    - Right: Environmental variable it encodes
    - Annotations include Spearman Ï, RF RÂ², Transformer RÂ²
    """
    print("\nCreating Figure 2: Spatial Interpretability...")
    
    n_pairs = len(selected_pairs)
    extent = CONUS_EXTENT
    
    # Create figure - taller to accommodate 6 rows
    fig = plt.figure(figsize=(14, 3.2 * n_pairs + 1.5))
    
    # Main title with more space
    fig.suptitle(
        'Figure 2: AlphaEarth Dimensions Encode Interpretable Geophysical Properties',
        fontsize=14, fontweight='bold', y=0.995
    )
    
    # Subtitle with spacing
    fig.text(0.5, 0.98, 
             'Left: Embedding dimension values  |  Right: Corresponding environmental variable  |  Metrics: Ï=Spearman, RF/Trans=RÂ²',
             ha='center', fontsize=10, style='italic', color='#555')
    
    # GridSpec with more top margin
    gs = gridspec.GridSpec(n_pairs, 2, 
                           hspace=0.35, wspace=0.12,
                           left=0.06, right=0.94, 
                           top=0.95, bottom=0.03)
    
    panel_labels = 'abcdefghijkl'
    
    for row_idx, pair in enumerate(selected_pairs):
        dim = pair['dimension']
        var = pair['variable']
        rho = pair['rho']
        cat = pair['category']
        rf_r2 = pair['rf_r2']
        trans_r2 = pair['trans_r2']
        dim_grid = pair['dim_grid']
        var_grid = pair['var_grid']
        
        # Get category styling
        cat_color = CATEGORY_COLORS.get(cat, '#999')
        var_cmap = CATEGORY_CMAPS.get(cat, 'viridis')
        
        # ===== LEFT PANEL: Embedding Dimension =====
        ax_dim = fig.add_subplot(gs[row_idx, 0])
        
        dim_ma = np.ma.masked_invalid(dim_grid)
        valid_dim = dim_grid[np.isfinite(dim_grid)]
        
        if len(valid_dim) > 0:
            vabs = np.percentile(np.abs(valid_dim), 98)
            im_dim = ax_dim.imshow(dim_ma, origin='lower', extent=extent,
                                   cmap='RdBu_r', vmin=-vabs, vmax=vabs,
                                   interpolation='nearest', aspect='auto')
            
            # Colorbar - positioned better
            cbar_dim = plt.colorbar(im_dim, ax=ax_dim, fraction=0.03, pad=0.01, 
                                    shrink=0.8, aspect=20)
            cbar_dim.ax.tick_params(labelsize=7)
        
        ax_dim.set_xlim(extent[0], extent[1])
        ax_dim.set_ylim(extent[2], extent[3])
        
        # Panel label and title
        panel_lbl = panel_labels[row_idx * 2]
        ax_dim.set_title(f'({panel_lbl}) Dimension {dim}', 
                        fontsize=11, fontweight='bold', pad=12)
        
        # Category indicator (colored bar at top)
        bar_height = 1.0
        rect = Rectangle((extent[0], extent[3]), extent[1] - extent[0], bar_height,
                         facecolor=cat_color, edgecolor='none', 
                         clip_on=False, zorder=10)
        ax_dim.add_patch(rect)
        ax_dim.text((extent[0] + extent[1])/2, extent[3] + bar_height/2, cat,
                   ha='center', va='center', fontsize=8, fontweight='bold',
                   color='white', zorder=11)
        
        # Axis labels only on bottom row
        if row_idx == n_pairs - 1:
            ax_dim.set_xlabel('Longitude', fontsize=9)
        ax_dim.set_ylabel('Latitude', fontsize=9)
        ax_dim.tick_params(labelsize=7)
        
        # ===== RIGHT PANEL: Environmental Variable =====
        ax_var = fig.add_subplot(gs[row_idx, 1])
        
        var_ma = np.ma.masked_invalid(var_grid)
        valid_var = var_grid[np.isfinite(var_grid)]
        
        if len(valid_var) > 0:
            vmin_var = np.percentile(valid_var, 2)
            vmax_var = np.percentile(valid_var, 98)
            
            im_var = ax_var.imshow(var_ma, origin='lower', extent=extent,
                                   cmap=var_cmap, vmin=vmin_var, vmax=vmax_var,
                                   interpolation='nearest', aspect='auto')
            
            # Colorbar with units
            cbar_var = plt.colorbar(im_var, ax=ax_var, fraction=0.03, pad=0.01,
                                    shrink=0.8, aspect=20)
            cbar_var.ax.tick_params(labelsize=7)
            
            # Extract units if present
            var_label_full = ENV_LABELS_FULL.get(var, var)
            if '(' in var_label_full and ')' in var_label_full:
                units = var_label_full[var_label_full.find('(')+1:var_label_full.find(')')]
                cbar_var.set_label(units, fontsize=8)
        
        ax_var.set_xlim(extent[0], extent[1])
        ax_var.set_ylim(extent[2], extent[3])
        
        # Panel label and title with all metrics
        panel_lbl_r = panel_labels[row_idx * 2 + 1]
        var_short = ENV_LABELS_FULL.get(var, var)
        sign = '+' if rho > 0 else ''
        
        # Multi-metric annotation
        metrics_str = f'Ï={sign}{rho:.2f}'
        if np.isfinite(rf_r2):
            metrics_str += f', RF={rf_r2:.2f}'
        if np.isfinite(trans_r2):
            metrics_str += f', Trans={trans_r2:.2f}'
        
        ax_var.set_title(f'({panel_lbl_r}) {var_short}\n{metrics_str}', 
                        fontsize=10, fontweight='bold', pad=12)
        
        # Same category indicator
        rect2 = Rectangle((extent[0], extent[3]), extent[1] - extent[0], bar_height,
                          facecolor=cat_color, edgecolor='none', 
                          clip_on=False, zorder=10)
        ax_var.add_patch(rect2)
        ax_var.text((extent[0] + extent[1])/2, extent[3] + bar_height/2, cat,
                   ha='center', va='center', fontsize=8, fontweight='bold',
                   color='white', zorder=11)
        
        # Axis labels
        if row_idx == n_pairs - 1:
            ax_var.set_xlabel('Longitude', fontsize=9)
        ax_var.tick_params(labelsize=7)
        ax_var.set_yticklabels([])  # Remove y labels on right panels
        
        print(f"  âœ“ Row {row_idx+1}: {dim} â†’ {var}")
    
    # Save
    os.makedirs(FIG_DIR, exist_ok=True)
    fig.savefig(f'{FIG_DIR}/fig2_spatial_interpretability.png', 
                dpi=DPI, facecolor='white', bbox_inches='tight')
    fig.savefig(f'{FIG_DIR}/fig2_spatial_interpretability.pdf', 
                dpi=DPI, facecolor='white', bbox_inches='tight')
    plt.close(fig)
    print(f"\n  âœ“ Saved: fig2_spatial_interpretability.png/pdf")


# =============================================================================
# MAIN
# =============================================================================

def main():
    os.makedirs(FIG_DIR, exist_ok=True)
    
    print("=" * 70)
    print("FIGURE 2: SPATIAL INTERPRETABILITY")
    print("=" * 70)
    
    # Load data
    R = load_results()
    
    if R is None:
        print("ERROR: Could not load dimension dictionary. Run 01_core_analysis.py first.")
        return
    
    # Select 6 diverse pairs
    selected_pairs = select_six_diverse_pairs(R['dd'], R)
    
    if len(selected_pairs) < 4:
        print("ERROR: Not enough dimension-variable pairs with grid data.")
        return
    
    # Create figure
    create_figure2(selected_pairs)
    
    print("\n" + "=" * 70)
    print("COMPLETE")
    print("=" * 70)
    print(f"\n  Output: {FIG_DIR}/fig2_spatial_interpretability.png/pdf")
    print("\n  Improvements in this version:")
    print("  â€¢ 6 dimension-variable pairs (vs 3)")
    print("  â€¢ Shows Ï (Spearman), RF RÂ², and Transformer RÂ² for each")
    print("  â€¢ No arrows between panels")
    print("  â€¢ Better spacing between title and content")
    print("  â€¢ Improved colorbar placement")
    print("  â€¢ Diverse categories represented")


if __name__ == '__main__':
    main()

In [None]:
#!/usr/bin/env python
"""
figure3_method_networks.py â€” Publication Figure 3
=====================================================

Panel (a): Three bipartite network graphs showing dimension-variable 
           connections for Spearman, RF, and Transformer methods
Panel (b): Method agreement scatter plot at the pair level

Usage:
  python 06_figure3_method_networks.py
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import os
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION
# =============================================================================

plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'font.size': 13,
    'axes.linewidth': 1.2,
    'axes.labelsize': 15,
    'axes.titlesize': 16,
    'axes.titleweight': 'bold',
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
    'legend.fontsize': 12,
    'legend.framealpha': 0.95,
})

RESULTS_DIR = 'results'
FIG_DIR = 'results/figures'
DPI = 300

ENV_LABELS_SHORT = {
    'elevation': 'Elev', 'slope': 'Slope', 'aspect': 'Aspect',
    'soil_clay_pct': 'Clay', 'soil_organic_carbon': 'OrgC',
    'soil_ph': 'pH', 'soil_water_capacity': 'SoilWC',
    'flow_acc_log': 'FlowAcc', 'tree_cover_2000': 'Trees',
    'impervious_pct': 'Imperv',
    'ndvi_mean': 'NDVI', 'ndvi_max': 'NDVImax',
    'evi_mean': 'EVI', 'lai_mean': 'LAI',
    'lst_day_c': 'LSTday', 'lst_night_c': 'LSTngt',
    'albedo': 'Albedo',
    'precip_annual_mm': 'Precip', 'precip_max_month': 'PrecMax',
    'temp_mean_c': 'Temp', 'temp_range_c': 'TempRng',
    'soil_moisture': 'SoilM', 'runoff_annual_mm': 'Runoff',
    'et_annual_mm': 'ET',
    'nightlights': 'Lights', 'pop_density': 'PopDen',
}

ENV_CATEGORY = {
    'elevation': 'Terrain', 'slope': 'Terrain', 'aspect': 'Terrain',
    'soil_clay_pct': 'Soil', 'soil_organic_carbon': 'Soil',
    'soil_ph': 'Soil', 'soil_water_capacity': 'Soil',
    'flow_acc_log': 'Hydrology', 'tree_cover_2000': 'Vegetation',
    'impervious_pct': 'Urban',
    'ndvi_mean': 'Vegetation', 'ndvi_max': 'Vegetation',
    'evi_mean': 'Vegetation', 'lai_mean': 'Vegetation',
    'lst_day_c': 'Temperature', 'lst_night_c': 'Temperature',
    'albedo': 'Radiation',
    'precip_annual_mm': 'Climate', 'precip_max_month': 'Climate',
    'temp_mean_c': 'Temperature', 'temp_range_c': 'Temperature',
    'soil_moisture': 'Hydrology', 'runoff_annual_mm': 'Hydrology',
    'et_annual_mm': 'Hydrology',
    'nightlights': 'Urban', 'pop_density': 'Urban',
}

CATEGORY_COLORS = {
    'Terrain':     '#8B4513',
    'Soil':        '#DAA520',
    'Vegetation':  '#228B22',
    'Temperature': '#DC143C',
    'Climate':     '#4169E1',
    'Hydrology':   '#00CED1',
    'Urban':       '#696969',
    'Radiation':   '#FFD700',
}

CATEGORY_ORDER = ['Terrain', 'Soil', 'Vegetation', 'Temperature', 
                  'Climate', 'Hydrology', 'Urban', 'Radiation']


def load_results():
    """Load analysis results."""
    print("Loading results...")
    R = {}
    
    path = f'{RESULTS_DIR}/spearman_matrix.csv'
    if os.path.exists(path):
        R['spearman'] = pd.read_csv(path, index_col=0)
        print(f"  âœ“ Spearman: {R['spearman'].shape}")
    else:
        return None
    
    path = f'{RESULTS_DIR}/rf_importance_matrix.csv'
    R['rf_imp'] = pd.read_csv(path, index_col=0) if os.path.exists(path) else None
    
    path = f'{RESULTS_DIR}/transformer_importance_matrix.csv'
    R['trans_imp'] = pd.read_csv(path, index_col=0) if os.path.exists(path) else None
    
    path = f'{RESULTS_DIR}/dimension_dictionary.csv'
    R['dd'] = pd.read_csv(path) if os.path.exists(path) else None
    
    return R


def draw_bipartite_network(ax, matrix, dims, vars_, title, method_type='spearman',
                           threshold=0.25, max_edges=150):
    """Draw bipartite network with 150 connections."""
    
    if method_type == 'spearman':
        values = matrix.loc[dims, vars_].abs().values
        edge_label = '|Ï|'
    else:
        values = matrix.loc[dims, vars_].values
        col_max = np.nanmax(values, axis=0, keepdims=True)
        col_max[col_max == 0] = 1
        values = values / col_max
        edge_label = 'Rel. Imp.'
    
    n_dims = len(dims)
    n_vars = len(vars_)
    
    # Sort variables by category
    var_cat_order = []
    for v in vars_:
        cat = ENV_CATEGORY.get(v, 'Urban')
        cat_idx = CATEGORY_ORDER.index(cat) if cat in CATEGORY_ORDER else 99
        var_cat_order.append((cat_idx, v))
    vars_sorted = [v for _, v in sorted(var_cat_order)]
    
    var_idx_map = {v: i for i, v in enumerate(vars_)}
    sorted_col_indices = [var_idx_map[v] for v in vars_sorted]
    values = values[:, sorted_col_indices]
    
    # Positions - USE FULL SPACE
    dim_y = np.linspace(0.99, 0.01, n_dims)
    var_y = np.linspace(0.99, 0.01, n_vars)
    dim_x = 0.10
    var_x = 0.90
    
    # Collect ALL edges above threshold
    edges = []
    for i, d in enumerate(dims):
        for j, v in enumerate(vars_sorted):
            val = values[i, j]
            if np.isfinite(val) and val > threshold:
                edges.append({
                    'dim_idx': i, 'var_idx': j, 'dim': d, 'var': v,
                    'value': val, 'category': ENV_CATEGORY.get(v, 'Urban'),
                })
    
    # Keep top max_edges
    edges = sorted(edges, key=lambda x: x['value'], reverse=True)[:max_edges]
    print(f"    {title}: {len(edges)} edges drawn")
    
    # Draw edges (weakest first)
    for edge in reversed(edges):
        i, j = edge['dim_idx'], edge['var_idx']
        val = edge['value']
        color = CATEGORY_COLORS.get(edge['category'], '#999999')
        
        x_points = [dim_x, 0.5, var_x]
        y_points = [dim_y[i], (dim_y[i] + var_y[j]) / 2, var_y[j]]
        
        lw = 1.0 + 4.5 * (val - threshold) / (1.0 - threshold)
        alpha = 0.3 + 0.5 * (val - threshold) / (1.0 - threshold)
        
        ax.plot(x_points, y_points, color=color, linewidth=lw, 
                alpha=alpha, solid_capstyle='round', zorder=1)
    
    # Draw dimension nodes
    for i, d in enumerate(dims):
        ax.scatter(dim_x, dim_y[i], s=90, c='#333333', 
                   edgecolors='white', linewidths=0.6, zorder=3)
        ax.text(dim_x - 0.015, dim_y[i], d, ha='right', va='center',
                fontsize=8, family='monospace', fontweight='bold')
    
    # Draw variable nodes
    for j, v in enumerate(vars_sorted):
        color = CATEGORY_COLORS.get(ENV_CATEGORY.get(v, 'Urban'), '#999999')
        ax.scatter(var_x, var_y[j], s=110, c=color, 
                   edgecolors='white', linewidths=0.6, zorder=3)
        ax.text(var_x + 0.015, var_y[j], ENV_LABELS_SHORT.get(v, v),
                ha='left', va='center', fontsize=9, color=color, fontweight='bold')
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    ax.set_title(title, fontsize=15, fontweight='bold', pad=6)
    
    ax.text(0.5, -0.03, f'{len(edges)} connections ({edge_label} > {threshold:.2f})',
            ha='center', va='top', fontsize=11, color='#444', transform=ax.transAxes)
    
    return edges


def draw_scatter(ax, spearman_df, rf_df, trans_df):
    """Scatter plot comparing methods."""
    dims = spearman_df.index.tolist()
    vars_ = spearman_df.columns.tolist()
    
    pairs = []
    for d in dims:
        for v in vars_:
            sp_val = abs(spearman_df.loc[d, v]) if np.isfinite(spearman_df.loc[d, v]) else np.nan
            rf_val = np.nan
            if rf_df is not None and d in rf_df.index and v in rf_df.columns:
                rf_val = rf_df.loc[d, v]
            
            if np.isfinite(sp_val) and np.isfinite(rf_val):
                pairs.append({
                    'dim': d, 'var': v, 'spearman': sp_val, 'rf': rf_val,
                    'category': ENV_CATEGORY.get(v, 'Urban'),
                })
    
    pairs_df = pd.DataFrame(pairs)
    
    # Normalize RF per variable
    for v in vars_:
        mask = pairs_df['var'] == v
        if mask.sum() > 0:
            max_rf = pairs_df.loc[mask, 'rf'].max()
            if max_rf > 0:
                pairs_df.loc[mask, 'rf_norm'] = pairs_df.loc[mask, 'rf'] / max_rf
            else:
                pairs_df.loc[mask, 'rf_norm'] = 0
    
    colors = [CATEGORY_COLORS.get(cat, '#999999') for cat in pairs_df['category']]
    
    ax.scatter(pairs_df['spearman'], pairs_df['rf_norm'], 
               c=colors, s=35, alpha=0.65, edgecolors='white', 
               linewidths=0.3, rasterized=True)
    
    # Annotate top 12 pairs
    top_pairs = pairs_df.nlargest(12, 'spearman')
    for _, row in top_pairs.iterrows():
        label = f"{row['dim']}Ã—{ENV_LABELS_SHORT.get(row['var'], row['var'])}"
        ax.annotate(label, (row['spearman'], row['rf_norm']),
                    fontsize=10, fontweight='bold',
                    xytext=(6, 6), textcoords='offset points',
                    bbox=dict(boxstyle='round,pad=0.25', facecolor='white',
                              alpha=0.85, edgecolor='none'))
    
    ax.axhline(0.5, color='#888', linestyle='--', linewidth=1.2, alpha=0.6)
    ax.axvline(0.5, color='#888', linestyle='--', linewidth=1.2, alpha=0.6)
    
    ax.text(0.73, 0.08, 'High Ï, Low RF\n(Linear only)', ha='center', va='bottom', 
            fontsize=12, color='#555', style='italic')
    ax.text(0.15, 0.92, 'Low Ï, High RF\n(Nonlinear)', ha='center', va='top',
            fontsize=12, color='#555', style='italic')
    ax.text(0.73, 0.92, 'High Ï, High RF\n(Agreement)', ha='center', va='top',
            fontsize=13, color='#228B22', fontweight='bold')
    
    ax.set_xlabel('Spearman |Ï|', fontweight='bold', fontsize=15)
    ax.set_ylabel('RF Importance (normalized)', fontweight='bold', fontsize=15)
    ax.set_xlim(0, 0.85)
    ax.set_ylim(0, 1.08)
    ax.grid(alpha=0.25, linestyle='-', linewidth=0.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    corr = pairs_df[['spearman', 'rf_norm']].corr().iloc[0, 1]
    ax.text(0.03, 0.97, f'r = {corr:.3f}\nn = {len(pairs_df):,}', 
            transform=ax.transAxes, fontsize=13, fontweight='bold', va='top',
            bbox=dict(facecolor='white', alpha=0.9, edgecolor='#ccc'))
    
    ax.set_title('(b) Linear vs Nonlinear Method Agreement', fontweight='bold', fontsize=16, pad=10)
    
    return pairs_df


def create_figure3(R):
    """Create Figure 3."""
    print("\nCreating Figure 3...")
    
    spearman_df = R['spearman']
    rf_df = R['rf_imp']
    trans_df = R['trans_imp']
    
    dims = spearman_df.index.tolist()
    vars_ = spearman_df.columns.tolist()
    
    # Figure size - fill space
    fig = plt.figure(figsize=(20, 18))
    
    gs = gridspec.GridSpec(2, 3, 
                           height_ratios=[1.15, 0.85],
                           hspace=0.12, wspace=0.06,
                           left=0.01, right=0.99,
                           top=0.96, bottom=0.07)
    
    # Panel (a): Networks - 150 connections each
    ax_sp = fig.add_subplot(gs[0, 0])
    draw_bipartite_network(ax_sp, spearman_df, dims, vars_,
                           '(aâ‚) Spearman |Ï|', method_type='spearman',
                           threshold=0.30, max_edges=150)
    
    ax_rf = fig.add_subplot(gs[0, 1])
    if rf_df is not None:
        draw_bipartite_network(ax_rf, rf_df, dims, vars_,
                               '(aâ‚‚) RF Importance', method_type='importance',
                               threshold=0.18, max_edges=150)
    
    ax_tr = fig.add_subplot(gs[0, 2])
    if trans_df is not None:
        draw_bipartite_network(ax_tr, trans_df, dims, vars_,
                               '(aâ‚ƒ) Transformer Importance', method_type='importance',
                               threshold=0.12, max_edges=150)
    
    # Panel (b): Scatter
    ax_scatter = fig.add_subplot(gs[1, :])
    draw_scatter(ax_scatter, spearman_df, rf_df, trans_df)
    
    # Legend
    legend_handles = [mpatches.Patch(facecolor=CATEGORY_COLORS[cat], 
                                      edgecolor='white', linewidth=0.5, label=cat)
                      for cat in CATEGORY_ORDER]
    
    fig.legend(handles=legend_handles, loc='lower center', ncol=8,
               fontsize=12, framealpha=0.95, title='Variable Category',
               title_fontsize=13, bbox_to_anchor=(0.5, 0.005),
               handlelength=2, handleheight=1.3, columnspacing=1.5)
    
    os.makedirs(FIG_DIR, exist_ok=True)
    fig.savefig(f'{FIG_DIR}/fig3_method_networks.png', dpi=DPI, facecolor='white')
    fig.savefig(f'{FIG_DIR}/fig3_method_networks.pdf', dpi=DPI, facecolor='white')
    plt.close(fig)
    print(f"  âœ“ Saved: {FIG_DIR}/fig3_method_networks.png/pdf")


def main():
    os.makedirs(FIG_DIR, exist_ok=True)
    R = load_results()
    if R is None:
        print("ERROR: Data not found")
        return
    create_figure3(R)


if __name__ == '__main__':
    main()

In [None]:
#!/usr/bin/env python
"""
figure4_validation.py â€” Publication Figure 4
================================================

Panel (a): Spatial CV Comparison - Random vs Spatial Block CV for RF/Transformer
           (dumbbell plot showing generalization gap)
           
Panel (b): Temporal Stability Distribution - All 64 dimensions ranked by 
           stability across 7 years (2017-2023)
           
Panel (c): Year-to-Year Correlation Heatmap - Profile correlations between years
           
Panel (d): Temporal Evolution - Top dimension-variable pairs over 7 years

Usage:
  python 07_figure4_validation.py
"""

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import os
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION
# =============================================================================

plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'font.size': 13,
    'axes.linewidth': 1.2,
    'axes.labelsize': 15,
    'axes.titlesize': 16,
    'axes.titleweight': 'bold',
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
    'legend.fontsize': 12,
    'legend.framealpha': 0.95,
})

RESULTS_DIR = 'results'
FIG_DIR = 'results/figures'
DPI = 300

ENV_LABELS_SHORT = {
    'elevation': 'Elevation', 'slope': 'Slope', 'aspect': 'Aspect',
    'soil_clay_pct': 'Clay %', 'soil_organic_carbon': 'Organic C',
    'soil_ph': 'Soil pH', 'soil_water_capacity': 'Soil WC',
    'flow_acc_log': 'Flow Acc.', 'tree_cover_2000': 'Tree Cover',
    'impervious_pct': 'Impervious',
    'ndvi_mean': 'NDVI', 'ndvi_max': 'NDVI max',
    'evi_mean': 'EVI', 'lai_mean': 'LAI',
    'lst_day_c': 'LST Day', 'lst_night_c': 'LST Night',
    'albedo': 'Albedo',
    'precip_annual_mm': 'Precipitation', 'precip_max_month': 'Precip Max',
    'temp_mean_c': 'Temperature', 'temp_range_c': 'Temp Range',
    'soil_moisture': 'Soil Moisture', 'runoff_annual_mm': 'Runoff',
    'et_annual_mm': 'ET',
    'nightlights': 'Nightlights', 'pop_density': 'Pop Density',
}

ENV_CATEGORY = {
    'elevation': 'Terrain', 'slope': 'Terrain', 'aspect': 'Terrain',
    'soil_clay_pct': 'Soil', 'soil_organic_carbon': 'Soil',
    'soil_ph': 'Soil', 'soil_water_capacity': 'Soil',
    'flow_acc_log': 'Hydrology', 'tree_cover_2000': 'Vegetation',
    'impervious_pct': 'Urban',
    'ndvi_mean': 'Vegetation', 'ndvi_max': 'Vegetation',
    'evi_mean': 'Vegetation', 'lai_mean': 'Vegetation',
    'lst_day_c': 'Temperature', 'lst_night_c': 'Temperature',
    'albedo': 'Radiation',
    'precip_annual_mm': 'Climate', 'precip_max_month': 'Climate',
    'temp_mean_c': 'Temperature', 'temp_range_c': 'Temperature',
    'soil_moisture': 'Hydrology', 'runoff_annual_mm': 'Hydrology',
    'et_annual_mm': 'Hydrology',
    'nightlights': 'Urban', 'pop_density': 'Urban',
}

CATEGORY_COLORS = {
    'Terrain':     '#8B4513',
    'Soil':        '#DAA520',
    'Vegetation':  '#228B22',
    'Temperature': '#DC143C',
    'Climate':     '#4169E1',
    'Hydrology':   '#00CED1',
    'Urban':       '#696969',
    'Radiation':   '#FFD700',
}


def load_results():
    """Load all results needed for Figure 4."""
    print("Loading results...")
    R = {}
    
    path = f'{RESULTS_DIR}/temporal_stability.csv'
    R['temporal_stab'] = pd.read_csv(path) if os.path.exists(path) else None
    if R['temporal_stab'] is not None:
        print(f"  âœ“ Temporal stability: {len(R['temporal_stab'])} dims")
    
    path = f'{RESULTS_DIR}/temporal_series.csv'
    R['temporal_series'] = pd.read_csv(path) if os.path.exists(path) else None
    if R['temporal_series'] is not None:
        print(f"  âœ“ Temporal series: {len(R['temporal_series'])} pairs")
    
    path = f'{RESULTS_DIR}/cv_comparison.csv'
    R['cv_rf'] = pd.read_csv(path) if os.path.exists(path) else None
    
    path = f'{RESULTS_DIR}/transformer_r2_scores.csv'
    R['trans_r2'] = pd.read_csv(path) if os.path.exists(path) else None
    if R['trans_r2'] is not None:
        print(f"  âœ“ Transformer RÂ²: {len(R['trans_r2'])} vars")
    
    path = f'{RESULTS_DIR}/rf_r2_scores.csv'
    R['rf_r2'] = pd.read_csv(path) if os.path.exists(path) else None
    
    path = f'{RESULTS_DIR}/dimension_dictionary.csv'
    R['dd'] = pd.read_csv(path) if os.path.exists(path) else None
    
    # Load yearly Spearman matrices for heatmap
    R['yearly_spearman'] = {}
    for yr in range(2017, 2024):
        path = f'{RESULTS_DIR}/temporal_spearman_{yr}.csv'
        if os.path.exists(path):
            R['yearly_spearman'][yr] = pd.read_csv(path, index_col=0)
    print(f"  âœ“ Yearly Spearman: {len(R['yearly_spearman'])} years")
    
    return R


# =============================================================================
# PANEL (a): SPATIAL CV COMPARISON
# =============================================================================

def panel_a_spatial_cv(ax, trans_r2, cv_rf):
    """Dumbbell plot: Random CV vs Spatial Block CV."""
    
    if trans_r2 is None:
        ax.text(0.5, 0.5, 'Transformer CV data not available', 
                ha='center', va='center', fontsize=14)
        ax.set_title('(a) Spatial Generalization', fontweight='bold', pad=10)
        return
    
    # Use transformer data (has both random and spatial CV)
    df = trans_r2.copy()
    df = df.dropna(subset=['random_cv_r2', 'spatial_cv_r2'])
    df['gap'] = df['random_cv_r2'] - df['spatial_cv_r2']
    df = df.sort_values('random_cv_r2', ascending=True).reset_index(drop=True)
    
    # Take all variables
    n_vars = len(df)
    y_pos = np.arange(n_vars)
    
    # Draw dumbbells
    for i, row in df.iterrows():
        var = row['variable']
        cat = ENV_CATEGORY.get(var, 'Urban')
        color = CATEGORY_COLORS.get(cat, '#999')
        
        # Line connecting random and spatial
        ax.plot([row['random_cv_r2'], row['spatial_cv_r2']], [i, i], 
                color=color, linewidth=3.5, solid_capstyle='round', alpha=0.8)
        
        # Random CV (circle, filled)
        ax.scatter(row['random_cv_r2'], i, s=140, c=color, marker='o',
                   edgecolors='white', linewidths=1.5, zorder=5)
        
        # Spatial CV (square, filled)
        ax.scatter(row['spatial_cv_r2'], i, s=140, c=color, marker='s',
                   edgecolors='white', linewidths=1.5, zorder=5)
        
        # Gap annotation for large gaps
        if row['gap'] > 0.02:
            ax.text(min(row['random_cv_r2'], row['spatial_cv_r2']) - 0.02, i, 
                    f"Î”={row['gap']:.3f}", fontsize=9, va='center', ha='right', color='#666')
    
    # Y-axis labels
    ax.set_yticks(y_pos)
    ax.set_yticklabels([ENV_LABELS_SHORT.get(v, v) for v in df['variable']], fontsize=11)
    
    # Reference lines
    ax.axvline(0.9, color='#888', linestyle='--', linewidth=1.2, alpha=0.5)
    ax.axvline(0.8, color='#888', linestyle=':', linewidth=1, alpha=0.4)
    
    # Legend
    legend_elements = [
        Line2D([0], [0], marker='o', color='gray', markersize=12, 
               markerfacecolor='gray', markeredgecolor='white', linewidth=0,
               label='Random CV'),
        Line2D([0], [0], marker='s', color='gray', markersize=12,
               markerfacecolor='gray', markeredgecolor='white', linewidth=0,
               label='Spatial Block CV'),
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=11, framealpha=0.95)
    
    # Statistics box
    mean_gap = df['gap'].mean()
    max_gap = df['gap'].max()
    n_small_gap = (df['gap'] < 0.02).sum()
    
    stats_text = f"Mean gap: {mean_gap:.3f}\nMax gap: {max_gap:.3f}\n{n_small_gap}/{n_vars} vars: gap < 0.02"
    ax.text(0.03, 0.97, stats_text, transform=ax.transAxes,
            fontsize=12, va='top', ha='left', fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.9, edgecolor='#ccc', linewidth=0.5))
    
    ax.set_xlim(0.0, 1.05)
    ax.set_ylim(-0.8, n_vars - 0.2)
    ax.set_xlabel('RÂ² (5-Fold Cross-Validation)', fontweight='bold', fontsize=14)
    ax.grid(axis='x', alpha=0.3, linestyle='-', linewidth=0.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.set_title('(a) Spatial Generalization: Random vs Block CV', 
                 fontweight='bold', fontsize=15, pad=10)


# =============================================================================
# PANEL (b): TEMPORAL STABILITY DISTRIBUTION
# =============================================================================

def panel_b_temporal_stability(ax, temporal_stab, dd):
    """Horizontal bar chart of temporal stability for all 64 dimensions."""
    
    if temporal_stab is None or dd is None:
        ax.text(0.5, 0.5, 'Temporal stability data not available',
                ha='center', va='center', fontsize=14)
        ax.set_title('(b) Temporal Stability', fontweight='bold', pad=10)
        return
    
    # Merge with dimension dictionary for categories
    stab = temporal_stab.merge(dd[['dimension', 'sp_primary', 'sp_category']], 
                                on='dimension', how='left')
    stab = stab.sort_values('mean_profile_corr', ascending=True).reset_index(drop=True)
    
    n_dims = len(stab)
    y_pos = np.arange(n_dims)
    
    # Colors by category
    colors = [CATEGORY_COLORS.get(cat, '#999') for cat in stab['sp_category']]
    
    # Bars
    bars = ax.barh(y_pos, stab['mean_profile_corr'], color=colors, 
                   edgecolor='white', linewidth=0.4, height=0.85)
    
    # Error bars
    ax.errorbar(stab['mean_profile_corr'], y_pos, 
                xerr=stab['std_profile_corr'], 
                fmt='none', ecolor='#333', elinewidth=0.8, capsize=2, capthick=0.8)
    
    # Y-axis: show every 4th dimension
    ax.set_yticks(y_pos[::4])
    ax.set_yticklabels(stab['dimension'].values[::4], fontsize=10, family='monospace')
    
    # Reference lines
    ax.axvline(0.95, color='#2E7D32', linestyle='--', linewidth=2, alpha=0.7)
    ax.axvline(0.90, color='#F57C00', linestyle='--', linewidth=2, alpha=0.7)
    
    ax.text(0.952, n_dims * 0.97, 'r = 0.95', fontsize=11, color='#2E7D32', fontweight='bold')
    ax.text(0.902, n_dims * 0.97, 'r = 0.90', fontsize=11, color='#F57C00', fontweight='bold')
    
    # Annotate top 3 (most stable)
    top3 = stab.nlargest(3, 'mean_profile_corr')
    for _, row in top3.iterrows():
        idx = stab[stab['dimension'] == row['dimension']].index[0]
        var_label = ENV_LABELS_SHORT.get(row['sp_primary'], row['sp_primary'][:8])
        ax.annotate(f"{row['dimension']}â†’{var_label}", 
                    (row['mean_profile_corr'], idx),
                    xytext=(8, 0), textcoords='offset points',
                    fontsize=10, fontweight='bold', color='#2E7D32', va='center')
    
    # Annotate bottom 3 (least stable)
    bot3 = stab.nsmallest(3, 'mean_profile_corr')
    for _, row in bot3.iterrows():
        idx = stab[stab['dimension'] == row['dimension']].index[0]
        ax.annotate(f"{row['dimension']}", 
                    (row['mean_profile_corr'], idx),
                    xytext=(-5, 0), textcoords='offset points',
                    fontsize=10, fontweight='bold', color='#D32F2F', va='center', ha='right')
    
    # Statistics
    mean_stab = stab['mean_profile_corr'].mean()
    n_above_90 = (stab['mean_profile_corr'] > 0.90).sum()
    n_above_95 = (stab['mean_profile_corr'] > 0.95).sum()
    
    stats_text = f"Mean: r = {mean_stab:.3f}\n{n_above_95}/64 > 0.95\n{n_above_90}/64 > 0.90"
    ax.text(0.03, 0.97, stats_text, transform=ax.transAxes,
            fontsize=12, va='top', ha='left', fontweight='bold',
            bbox=dict(facecolor='white', alpha=0.9, edgecolor='#ccc', linewidth=0.5))
    
    ax.set_xlim(0.70, 1.02)
    ax.set_ylim(-1, n_dims)
    ax.set_xlabel('Mean Pairwise Profile Correlation (2017â€“2023)', fontweight='bold', fontsize=14)
    ax.grid(axis='x', alpha=0.3, linestyle='-', linewidth=0.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.set_title('(b) Temporal Stability of 64 Dimensions', fontweight='bold', fontsize=15, pad=10)


# =============================================================================
# PANEL (c): YEAR-TO-YEAR CORRELATION HEATMAP
# =============================================================================

def panel_c_year_heatmap(ax, yearly_spearman):
    """Heatmap showing profile correlations between years."""
    
    if not yearly_spearman:
        ax.text(0.5, 0.5, 'Yearly data not available',
                ha='center', va='center', fontsize=14)
        ax.set_title('(c) Year-to-Year Consistency', fontweight='bold', pad=10)
        return
    
    years = sorted(yearly_spearman.keys())
    n_years = len(years)
    
    # Compute correlation matrix between year profiles
    corr_mat = np.zeros((n_years, n_years))
    
    for i, y1 in enumerate(years):
        for j, y2 in enumerate(years):
            p1 = yearly_spearman[y1].values.flatten()
            p2 = yearly_spearman[y2].values.flatten()
            mask = np.isfinite(p1) & np.isfinite(p2)
            if mask.sum() > 10:
                corr_mat[i, j] = np.corrcoef(p1[mask], p2[mask])[0, 1]
            else:
                corr_mat[i, j] = np.nan
    
    # Plot heatmap
    im = ax.imshow(corr_mat, cmap='RdYlGn', vmin=0.90, vmax=1.0, aspect='equal')
    
    # Ticks
    ax.set_xticks(range(n_years))
    ax.set_xticklabels([str(y) for y in years], fontsize=12, rotation=45, ha='right')
    ax.set_yticks(range(n_years))
    ax.set_yticklabels([str(y) for y in years], fontsize=12)
    
    # Annotate cells
    for i in range(n_years):
        for j in range(n_years):
            val = corr_mat[i, j]
            if np.isfinite(val):
                color = 'white' if val > 0.97 else 'black'
                ax.text(j, i, f'{val:.3f}', ha='center', va='center',
                        fontsize=11, fontweight='bold', color=color)
    
    # Colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Profile Correlation (r)', fontsize=12, fontweight='bold')
    cbar.ax.tick_params(labelsize=11)
    
    ax.set_xlabel('Year', fontweight='bold', fontsize=14)
    ax.set_ylabel('Year', fontweight='bold', fontsize=14)
    
    ax.set_title('(c) Year-to-Year Profile Correlation', fontweight='bold', fontsize=15, pad=10)


# =============================================================================
# PANEL (d): TEMPORAL EVOLUTION OF TOP PAIRS
# =============================================================================

def panel_d_temporal_evolution(ax, temporal_series, dd):
    """Line plots showing year-by-year Ï for top dimension-variable pairs."""
    
    if temporal_series is None or dd is None:
        ax.text(0.5, 0.5, 'Temporal series not available',
                ha='center', va='center', fontsize=14)
        ax.set_title('(d) Temporal Evolution', fontweight='bold', pad=10)
        return
    
    rho_cols = [c for c in temporal_series.columns if c.startswith('rho_')]
    years = [int(c.replace('rho_', '')) for c in rho_cols]
    
    # Get top 10 pairs by |Ï|
    top_dims = dd.nlargest(10, 'sp_abs_max')
    
    for i, (_, row) in enumerate(top_dims.iterrows()):
        dim = row['dimension']
        var = row['sp_primary']
        cat = row['sp_category']
        color = CATEGORY_COLORS.get(cat, '#999')
        
        ts_row = temporal_series[(temporal_series['dimension'] == dim) & 
                                  (temporal_series['variable'] == var)]
        
        if len(ts_row) == 0:
            continue
        
        rho_vals = [abs(ts_row[c].values[0]) if np.isfinite(ts_row[c].values[0]) else np.nan 
                    for c in rho_cols]
        
        label = f"{dim}â†’{ENV_LABELS_SHORT.get(var, var)}"
        
        ax.plot(years, rho_vals, '-o', color=color, linewidth=2.5, 
                markersize=8, markeredgecolor='white', markeredgewidth=1.2,
                alpha=0.85, label=label)
    
    ax.set_xlim(2016.5, 2023.5)
    ax.set_ylim(0.45, 0.85)
    ax.set_xticks(years)
    ax.set_xticklabels([str(y) for y in years], fontsize=12)
    ax.set_xlabel('Year', fontweight='bold', fontsize=14)
    ax.set_ylabel('|Spearman Ï|', fontweight='bold', fontsize=14)
    ax.grid(alpha=0.3, linestyle='-', linewidth=0.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    ax.legend(loc='lower left', fontsize=10, ncol=2, framealpha=0.95,
              columnspacing=1, handletextpad=0.5)
    
    ax.text(0.97, 0.97, 'Top 10 pairs by |Ï|',
            transform=ax.transAxes, fontsize=12, va='top', ha='right',
            fontweight='bold', style='italic',
            bbox=dict(facecolor='white', alpha=0.9, edgecolor='#ccc'))
    
    ax.set_title('(d) Temporal Consistency of Top Relationships', 
                 fontweight='bold', fontsize=15, pad=10)


# =============================================================================
# MAIN FIGURE
# =============================================================================

def create_figure4(R):
    """Create Figure 4: Validation Analysis."""
    print("\nCreating Figure 4: Validation...")
    
    # Figure size
    fig = plt.figure(figsize=(20, 18))
    
    # 2x2 grid
    gs = gridspec.GridSpec(2, 2, 
                           height_ratios=[1, 1],
                           width_ratios=[1, 1],
                           hspace=0.22, wspace=0.18,
                           left=0.06, right=0.96,
                           top=0.96, bottom=0.06)
    
    # Panel (a): Spatial CV - top left
    ax_a = fig.add_subplot(gs[0, 0])
    panel_a_spatial_cv(ax_a, R['trans_r2'], R['cv_rf'])
    
    # Panel (b): Temporal stability - top right
    ax_b = fig.add_subplot(gs[0, 1])
    panel_b_temporal_stability(ax_b, R['temporal_stab'], R['dd'])
    
    # Panel (c): Year-to-year heatmap - bottom left
    ax_c = fig.add_subplot(gs[1, 0])
    panel_c_year_heatmap(ax_c, R['yearly_spearman'])
    
    # Panel (d): Temporal evolution - bottom right
    ax_d = fig.add_subplot(gs[1, 1])
    panel_d_temporal_evolution(ax_d, R['temporal_series'], R['dd'])
    
    # Category legend at bottom
    legend_handles = [mpatches.Patch(facecolor=CATEGORY_COLORS[cat], 
                                      edgecolor='white', linewidth=0.5, label=cat)
                      for cat in ['Terrain', 'Soil', 'Vegetation', 'Temperature',
                                  'Climate', 'Hydrology', 'Urban', 'Radiation']]
    
    fig.legend(handles=legend_handles, loc='lower center', ncol=8,
               fontsize=11, framealpha=0.95, title='Variable Category',
               title_fontsize=12, bbox_to_anchor=(0.5, 0.005),
               handlelength=1.8, handleheight=1.2, columnspacing=1.5)
    
    # Save
    os.makedirs(FIG_DIR, exist_ok=True)
    fig.savefig(f'{FIG_DIR}/fig4_validation.png', dpi=DPI, facecolor='white')
    fig.savefig(f'{FIG_DIR}/fig4_validation.pdf', dpi=DPI, facecolor='white')
    plt.close(fig)
    
    print(f"  âœ“ Saved: {FIG_DIR}/fig4_validation.png/pdf")


def main():
    os.makedirs(FIG_DIR, exist_ok=True)
    
    print("=" * 70)
    print("FIGURE 4: VALIDATION (SPATIAL & TEMPORAL)")
    print("=" * 70)
    
    R = load_results()
    create_figure4(R)
    
    print("\n" + "=" * 70)
    print("COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()