# 🌍 ClimAx Phase 4 - Complete Multi-Scenario Training

**Complete notebook for training ClimAx models across multiple climate scenarios**

---

## 📋 Features
- ✅ Multi-scenario training (historical + SSP126/245/370/585)
- ✅ Stratified splitting (all scenarios in train/val/test)
- ✅ Memory-efficient for 6GB GPU
- ✅ Progress bars with tqdm
- ✅ Automatic checkpointing

## 🎯 Requirements
- PyTorch with CUDA
- 6GB+ GPU memory
- Climate data in E:/Datasets/ (or update paths)

---

## 📦 Section 1: Install & Import Dependencies

In [1]:
# =============================================================================
# 📦 SECTION 1: INSTALL & IMPORT DEPENDENCIES - FIXED
# =============================================================================

# Install required packages (uncomment if needed on Kaggle)
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install xarray netcdf4 scipy scikit-learn tqdm matplotlib

import os
import sys  # ✅ CRITICAL: Added for sys.stdout.flush()
import time
import json
import logging
import gc
import glob
import math
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import xarray as xr
from scipy.ndimage import zoom
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.checkpoint import checkpoint

from tqdm import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Setup logging with more detail
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)
logger = logging.getLogger("phase4_training")

print("✅ All imports successful!")
print(f"🎮 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"🎮 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
else:
    print("⚠️ WARNING: No GPU detected - training will be VERY slow!")

# Force flush output immediately
sys.stdout.flush()

✅ All imports successful!
🎮 PyTorch version: 2.6.0+cu124
🎮 CUDA available: True
🎮 GPU: Tesla P100-PCIE-16GB
🎮 GPU Memory: 15.9 GB


## ⚙️ Section 2: Configuration

**⚠️ UPDATE THE DATA PATHS BELOW!**

In [2]:
# =============================================================================
# 🌍 CONFIGURATION - KAGGLE PATHS
# =============================================================================

# ✅ Updated for Kaggle dataset structure
INPUT_DATA_DIR = "/kaggle/input/climate-dataset/Datasets/inputs/input4mips"
OUTPUT_DATA_DIR = "/kaggle/input/climate-dataset/Datasets/outputs/CMIP6"
OUTPUT_DIR = "climax_phase4_results"

# Training settings
SMOKE_TEST = True              # True = 2 institutions, 2 epochs for testing
RESUME_IF_MODEL_EXISTS = True  # Skip already trained models
SKIP_TRAINING = False          # Set True to only load results

# Memory settings
TARGET_SPATIAL_H = 9
TARGET_SPATIAL_W = 19
BATCH_SIZE = 1
USE_CHECKPOINTING = True
EPOCHS = 2 if SMOKE_TEST else 15  # Short for smoke test

# Model settings
EMBED_DIM = 64
DEPTH = 4
NUM_HEADS = 4

# All 18 institutions
ALL_INSTITUTIONS = [
    "AWI-CM-1-1-MR", "BCC-CSM2-MR", "CAS-ESM2-0", "CESM2",
    "CESM2-WACCM", "CMCC-CM2-SR5", "CMCC-ESM2", "CNRM-CM6-1-HR",
    "EC-Earth3", "EC-Earth3-Veg", "EC-Earth3-Veg-LR", "FGOALS-f3-L",
    "GFDL-ESM4", "INM-CM4-8", "INM-CM5-0", "MPI-ESM1-2-HR",
    "MRI-ESM2-0", "TaiESM1"
]

SCENARIOS = {
    'historical': 0,
    'ssp126': 1,
    'ssp245': 2,
    'ssp370': 3,
    'ssp585': 4
}

print("✅ Configuration loaded")
print(f"   📁 Input: {INPUT_DATA_DIR}")
print(f"   📁 Output: {OUTPUT_DATA_DIR}")
print(f"   💾 Results: /kaggle/working/{OUTPUT_DIR}")
print(f"   🗺️  Spatial: {TARGET_SPATIAL_H}×{TARGET_SPATIAL_W}")
print(f"   🏢 Institutions: {len(ALL_INSTITUTIONS)}")
print(f"   🌍 Scenarios: {list(SCENARIOS.keys())}")
print(f"   ⚙️  Smoke test: {SMOKE_TEST}")

# Verify paths exist
if not os.path.exists(INPUT_DATA_DIR):
    print(f"\n❌ ERROR: Input path not found!")
    print(f"   Path: {INPUT_DATA_DIR}")
    print(f"   Make sure 'climate-dataset' is attached to this notebook!")
else:
    print(f"\n✅ Input path verified")

if not os.path.exists(OUTPUT_DATA_DIR):
    print(f"\n❌ ERROR: Output path not found!")
    print(f"   Path: {OUTPUT_DATA_DIR}")
    print(f"   Make sure 'climate-dataset' is attached to this notebook!")
else:
    print(f"✅ Output path verified")

✅ Configuration loaded
   📁 Input: /kaggle/input/climate-dataset/Datasets/inputs/input4mips
   📁 Output: /kaggle/input/climate-dataset/Datasets/outputs/CMIP6
   💾 Results: /kaggle/working/climax_phase4_results
   🗺️  Spatial: 9×19
   🏢 Institutions: 18
   🌍 Scenarios: ['historical', 'ssp126', 'ssp245', 'ssp370', 'ssp585']
   ⚙️  Smoke test: True

✅ Input path verified
✅ Output path verified


In [3]:
# =============================================================================
# 🔍 DEBUG: Inspect Data Structure
# =============================================================================

print("="*80)
print("🔍 DEBUGGING: Checking actual file structure")
print("="*80)

# Check inputs structure
print("\n📊 INPUT STRUCTURE:")
input_base = "/kaggle/input/climate-dataset/Datasets/inputs/input4mips"

if os.path.exists(input_base):
    scenarios = os.listdir(input_base)
    print(f"✅ Found scenarios: {scenarios}\n")
    
    # Check historical scenario as example
    hist_path = os.path.join(input_base, "historical")
    if os.path.exists(hist_path):
        print(f"📁 Contents of historical/:")
        for item in os.listdir(hist_path):
            item_path = os.path.join(hist_path, item)
            if os.path.isdir(item_path):
                # Count files in this directory
                files = []
                for root, dirs, filenames in os.walk(item_path):
                    files.extend([f for f in filenames if f.endswith('.nc')])
                print(f"   📂 {item}/ → {len(files)} .nc files")
    else:
        print("❌ historical/ not found")
else:
    print(f"❌ Path not found: {input_base}")

# Check outputs structure
print("\n🎯 OUTPUT STRUCTURE:")
output_base = "/kaggle/input/climate-dataset/Datasets/outputs/CMIP6"

if os.path.exists(output_base):
    institutions = os.listdir(output_base)
    print(f"✅ Found {len(institutions)} institutions\n")
    
    # Check first institution as example
    first_inst = institutions[0] if institutions else None
    if first_inst:
        inst_path = os.path.join(output_base, first_inst)
        print(f"📁 Contents of {first_inst}/:")
        
        if os.path.isdir(inst_path):
            items = os.listdir(inst_path)
            print(f"   Items: {items[:5]}...")  # Show first 5
            
            # Check if scenarios are inside
            for scenario in ['historical', 'ssp126', 'ssp245']:
                scenario_path = os.path.join(inst_path, scenario)
                if os.path.exists(scenario_path):
                    print(f"   ✓ {scenario}/ exists")
                    # Check for pr variable
                    pr_path = os.path.join(scenario_path, "pr")
                    if os.path.exists(pr_path):
                        files = []
                        for root, dirs, filenames in os.walk(pr_path):
                            files.extend([f for f in filenames if f.endswith('.nc')])
                        print(f"      → pr/ → {len(files)} .nc files")
                else:
                    print(f"   ✗ {scenario}/ not found")
else:
    print(f"❌ Path not found: {output_base}")

print("\n" + "="*80)

🔍 DEBUGGING: Checking actual file structure

📊 INPUT STRUCTURE:
✅ Found scenarios: ['ssp585', 'ssp370', 'historical', 'ssp126', 'ssp245']

📁 Contents of historical/:
   📂 CO2_sum/ → 165 .nc files
   📂 SO2_sum/ → 495 .nc files
   📂 BC_sum/ → 495 .nc files
   📂 CH4_sum/ → 495 .nc files

🎯 OUTPUT STRUCTURE:
✅ Found 18 institutions

📁 Contents of CMCC-CM2-SR5/:
   Items: ['ssp585', 'ssp370', 'historical', 'ssp126', 'ssp245']...
   ✓ historical/ exists
      → pr/ → 165 .nc files
   ✓ ssp126/ exists
      → pr/ → 86 .nc files
   ✓ ssp245/ exists
      → pr/ → 86 .nc files



In [4]:
class Phase4Config:
    """Minimal config for Phase 4 - FIXED for scenario channels"""
    def __init__(self):
        self.INPUT_DATA_DIR = INPUT_DATA_DIR
        self.OUTPUT_DATA_DIR = OUTPUT_DATA_DIR
        self.SPATIAL_HEIGHT = TARGET_SPATIAL_H
        self.SPATIAL_WIDTH = TARGET_SPATIAL_W
        self.SPATIAL_H = TARGET_SPATIAL_H
        self.SPATIAL_W = TARGET_SPATIAL_W
        self.PATCH_SIZE = 1
        self.EMBED_DIM = EMBED_DIM
        self.DEPTH = DEPTH
        self.NUM_HEADS = NUM_HEADS
        self.MLP_RATIO = 4.0
        self.DROPOUT_RATE = 0.1
        self.ATTENTION_DROPOUT = 0.1
        self.DROP_PATH_RATE = 0.1
        self.SEQUENCE_INPUT_LENGTH = 12
        self.SEQUENCE_OUTPUT_LENGTH = 3
        self.TEMPORAL_STRIDE = 1
        self.BATCH_SIZE = BATCH_SIZE
        self.TRAIN_RATIO = 0.7
        self.VAL_RATIO = 0.15
        self.LEARNING_RATE = 1e-4
        self.WEIGHT_DECAY = 0.05
        self.CLIP_GRAD_NORM = 1.0
        
        # Physical input variables (from files only)
        self.INPUT_VARIABLES = [
            'BC_anthro_fires', 'BC_no_fires',
            'CH4_anthro_fires', 'CH4_no_fires',
            'CO2_sum',
            'SO2_anthro_fires', 'SO2_no_fires'
        ]
        
        # ✅ FIXED: Total input channels (7 variables + 1 scenario channel)
        self.NUM_INPUT_CHANNELS = 8
        
        self.OUTPUT_VARIABLE = 'pr'

# Create directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "logs"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "plots"), exist_ok=True)

print("✅ Config class created")
print("✅ Output directories created")

✅ Config class created
✅ Output directories created


## 📊 Section 3: Data Loader (Stratified Multi-Scenario)

In [5]:
# =============================================================================
# 📊 SECTION 3: DATA LOADER - FIXED VERSION (From working Python file)
# =============================================================================

class ClimateDatasetWithScenario(Dataset):
    """PyTorch dataset with scenario information"""
    
    def __init__(self, X: np.ndarray, y: np.ndarray, scenarios: np.ndarray):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = torch.from_numpy(y.astype(np.float32))
        self.scenarios = torch.from_numpy(scenarios.astype(np.float32))
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.scenarios[idx]


class Phase4MultiScenarioDataLoader:
    """
    Multi-Scenario Data Loader - FIXED VERSION from working Python file
    ✅ Robust file discovery
    ✅ Better error handling
    ✅ Memory-efficient
    """
    
    def __init__(self, config, target_h: int = 9, target_w: int = 19):
        self.input_base = config.INPUT_DATA_DIR
        self.output_base = config.OUTPUT_DATA_DIR
        self.target_h = target_h
        self.target_w = target_w
        self.scalers = {}
        
        self.scenarios = {'historical': 0, 'ssp126': 1, 'ssp245': 2, 'ssp370': 3, 'ssp585': 4}
        self.input_variables = config.INPUT_VARIABLES
        self.variable_dirs = {
            'BC_anthro_fires': 'BC_sum', 'BC_no_fires': 'BC_sum',
            'CH4_anthro_fires': 'CH4_sum', 'CH4_no_fires': 'CH4_sum',
            'CO2_sum': 'CO2_sum',
            'SO2_anthro_fires': 'SO2_sum', 'SO2_no_fires': 'SO2_sum'
        }
        self.output_variable = config.OUTPUT_VARIABLE
        logger.info(f"📊 DataLoader initialized: {target_h}×{target_w}")
    
    def _matches_variable_pattern(self, filename: str, variable: str) -> bool:
        """Pattern matching for filenames"""
        f = filename.lower()
        patterns = {
            'BC_anthro_fires': (['bc', 'anthro'], ['fire']),
            'BC_no_fires': (['bc', 'no'], ['fire']),
            'CH4_anthro_fires': (['ch4', 'anthro'], ['fire']),
            'CH4_no_fires': (['ch4', 'no'], ['fire']),
            'CO2_sum': (['co2'], []),
            'SO2_anthro_fires': (['so2', 'anthro'], ['fire']),
            'SO2_no_fires': (['so2', 'no'], ['fire'])
        }
        
        if variable in patterns:
            req, opt = patterns[variable]
            has_req = all(p in f for p in req)
            if 'no' in req:
                has_req = has_req and 'anthro' not in f
            if opt:
                return has_req and any(p in f for p in opt)
            return has_req
        return False
    
    def _is_file_readable(self, file_path: str) -> bool:
        """Check if file can be opened - CRITICAL FIX"""
        try:
            if not os.path.exists(file_path) or os.path.getsize(file_path) < 1000:
                return False
            # Quick check without fully loading
            with xr.open_dataset(file_path, decode_times=False) as ds:
                _ = list(ds.data_vars.keys())
            return True
        except:
            return False
    
    def discover_files_multi_scenario(self, institution: str):
        """Discover files for all scenarios - FIXED with better logging"""
        logger.info(f"🔍 Discovering files for {institution}...")
        all_scenario_files = {}
        
        for scenario in self.scenarios.keys():
            logger.info(f"   📁 {scenario}")
            results_inputs = {}
            results_outputs = {}
            
            # Input files
            for var in self.input_variables:
                folder = self.variable_dirs.get(var, var)
                path = os.path.join(self.input_base, scenario, folder)
                
                if not os.path.exists(path):
                    logger.warning(f"      ⚠️ Path not found: {path}")
                    continue
                
                found = []
                try:
                    for root, dirs, files in os.walk(path):
                        for fname in files:
                            if fname.endswith('.nc') and self._matches_variable_pattern(fname, var):
                                fpath = os.path.join(root, fname)
                                if self._is_file_readable(fpath):
                                    found.append(fpath)
                                    if len(found) % 10 == 0:  # Progress update
                                        logger.info(f"         Found {len(found)} {var} files...")
                except Exception as e:
                    logger.error(f"      ❌ Error walking {path}: {e}")
                    continue
                
                if found:
                    results_inputs[var] = found
                    logger.info(f"      ✓ {var}: {len(found)} files")
                else:
                    logger.warning(f"      ⚠️ {var}: No readable files found")
            
            # Output files
            out_path = os.path.join(self.output_base, institution, scenario, self.output_variable)
            if os.path.exists(out_path):
                out_files = []
                try:
                    for root, dirs, files in os.walk(out_path):
                        for fname in files:
                            if fname.endswith('.nc'):
                                fpath = os.path.join(root, fname)
                                if self._is_file_readable(fpath):
                                    out_files.append(fpath)
                                    if len(out_files) % 10 == 0:
                                        logger.info(f"         Found {len(out_files)} output files...")
                except Exception as e:
                    logger.error(f"      ❌ Error walking {out_path}: {e}")
                
                if out_files:
                    results_outputs[self.output_variable] = out_files
                    logger.info(f"      ✓ {self.output_variable}: {len(out_files)} files")
            else:
                logger.warning(f"      ⚠️ Output path not found: {out_path}")
            
            all_scenario_files[scenario] = {"inputs": results_inputs, "outputs": results_outputs}
        
        return all_scenario_files
    
    def downsample_spatial(self, arr, target_h, target_w):
        """Downsample using scipy.ndimage.zoom"""
        if arr.shape[1] == target_h and arr.shape[2] == target_w:
            return arr
        T, H, W = arr.shape
        downsampled = zoom(arr, [1.0, target_h/H, target_w/W], order=1, mode='nearest')
        return downsampled[:, :target_h, :target_w]
    
    def _load_netcdf_list(self, paths, var_hint=None):
        """Load and concatenate NetCDF files"""
        arrays = []
        for p in sorted(paths):
            try:
                ds = xr.open_dataset(p, decode_times=False)
                dvars = list(ds.data_vars.keys())
                if not dvars:
                    ds.close()
                    continue
                var = var_hint if (var_hint and var_hint in dvars) else dvars[0]
                arr = ds[var].values
                ds.close()
                if arr.ndim == 2:
                    arr = np.expand_dims(arr, 0)
                elif arr.ndim > 3:
                    arr = arr.reshape(-1, arr.shape[-2], arr.shape[-1])
                arrays.append(np.nan_to_num(arr, 0.0, 0.0, 0.0))
            except Exception as e:
                logger.warning(f"⚠️ Error reading {os.path.basename(p)}: {e}")
        
        if not arrays:
            return np.zeros((0, 0, 0), dtype=float)
        
        try:
            return np.concatenate(arrays, axis=0)
        except:
            mh = max(a.shape[1] for a in arrays)
            mw = max(a.shape[2] for a in arrays)
            padded = [np.pad(a, ((0,0), (0,mh-a.shape[1]), (0,mw-a.shape[2])), 'edge') for a in arrays]
            return np.concatenate(padded, axis=0)
    
    def align_temporal_dimensions(self, var_data):
        """Align temporal dimensions"""
        if not var_data:
            return var_data
        times = {v: a.shape[0] for v, a in var_data.items()}
        mt = min(times.values())
        if mt != max(times.values()):
            logger.warning(f"⚠️ Aligning to {mt} timesteps")
            return {v: a[:mt] for v, a in var_data.items()}
        return var_data
    
    def load_all_scenarios(self, all_files):
        """Load all scenario data"""
        logger.info("📊 Loading scenarios...")
        all_data = {}
        
        for scenario, files in all_files.items():
            logger.info(f"   📁 {scenario}")
            var_data = {}
            
            for var, paths in files.get("inputs", {}).items():
                if paths:
                    logger.info(f"      Loading {var} ({len(paths)} files)...")
                    arr = self._load_netcdf_list(paths)
                    if arr.size > 0:
                        var_data[var] = arr
                        logger.info(f"      ✓ {var}: {arr.shape}")
            
            for var, paths in files.get("outputs", {}).items():
                if paths:
                    logger.info(f"      Loading {var} ({len(paths)} files)...")
                    arr = self._load_netcdf_list(paths)
                    if arr.size > 0:
                        var_data[var] = arr
                        logger.info(f"      ✓ {var}: {arr.shape}")
            
            if var_data:
                var_data = self.align_temporal_dimensions(var_data)
                logger.info(f"      🔽 Downsampling...")
                down_data = {}
                for v, a in var_data.items():
                    d = self.downsample_spatial(a, self.target_h, self.target_w)
                    down_data[v] = d
                    logger.info(f"         {v}: {a.shape} → {d.shape}")
                all_data[scenario] = down_data
        
        logger.info(f"   ✅ Loaded {len(all_data)} scenarios")
        return all_data
    
    def normalize_data(self, arr, var_name, fit=True):
        """Normalize data"""
        arr = np.nan_to_num(arr, 0.0, 0.0, 0.0)
        flat = arr.reshape(-1, 1)
        
        if fit or var_name not in self.scalers:
            scaler = MinMaxScaler()
            try:
                scaler.fit(flat)
            except:
                scaler.min_, scaler.scale_ = np.min(flat), 1.0
            self.scalers[var_name] = scaler
        else:
            scaler = self.scalers[var_name]
        
        return np.nan_to_num(scaler.transform(flat).reshape(arr.shape), 0.0, 0.0, 0.0)
    
    def create_multi_scenario_sequences(self, all_data, seq_in=12, seq_out=3, stride=1, train_ratio=0.7, val_ratio=0.15, batch_size=1):
        """Create sequences with stratified splitting"""
        logger.info("🔄 Creating sequences...")
        X_seqs, Y_seqs, sc_ids = [], [], []
        
        for scenario, var_data in all_data.items():
            if not var_data:
                continue
            
            sc_id = self.scenarios[scenario]
            logger.info(f"   📦 {scenario} (id={sc_id})")
            
            missing = [v for v in self.input_variables if v not in var_data]
            if missing or self.output_variable not in var_data:
                logger.warning(f"      ⚠️ Skipping - missing data")
                continue
            
            # Normalize
            norm = {v: self.normalize_data(var_data[v], v, True) for v in self.input_variables if v in var_data}
            norm[self.output_variable] = self.normalize_data(var_data[self.output_variable], self.output_variable, True)
            
            # Stack
            X_sc = np.stack([norm[v] for v in self.input_variables if v in norm], axis=1)
            Y_sc = norm[self.output_variable]
            
            T = X_sc.shape[0]
            n_samp = T - seq_in - seq_out + 1
            
            if n_samp <= 0:
                logger.warning(f"      ⚠️ Not enough timesteps")
                continue
            
            for start in range(0, n_samp, stride):
                X_seqs.append(X_sc[start:start+seq_in])
                Y_seqs.append(Y_sc[start+seq_in:start+seq_in+seq_out])
                sc_ids.append(sc_id)
            
            logger.info(f"      ✓ {n_samp} sequences")
        
        if not X_seqs:
            raise RuntimeError("❌ No sequences!")
        
        X_all = np.stack(X_seqs, 0)
        Y_all = np.stack(Y_seqs, 0)
        sc_all = np.array(sc_ids)
        
        # Add scenario channel
        N, T, C, H, W = X_all.shape
        sc_ch = np.repeat(sc_all[:, None, None, None, None], T*H*W, 1).reshape(N, T, 1, H, W)
        X_all = np.concatenate([X_all, sc_ch], 2)
        Y_all = np.expand_dims(Y_all, 2)
        
        logger.info(f"   🧩 Total: X={X_all.shape}, Y={Y_all.shape}")
        
        # Stratified split
        logger.info("   📊 Stratified split...")
        X_tr, X_tmp, y_tr, y_tmp, sc_tr, sc_tmp = train_test_split(
            X_all, Y_all, sc_all, test_size=(1-train_ratio), stratify=sc_all, random_state=42
        )
        vt_ratio = val_ratio / (1 - train_ratio)
        X_val, X_te, y_val, y_te, sc_val, sc_te = train_test_split(
            X_tmp, y_tmp, sc_tmp, test_size=(1-vt_ratio), stratify=sc_tmp, random_state=42
        )
        
        splits = {
            "train": {"input": X_tr, "target": y_tr, "scenarios": sc_tr},
            "validation": {"input": X_val, "target": y_val, "scenarios": sc_val},
            "test": {"input": X_te, "target": y_te, "scenarios": sc_te}
        }
        
        # CRITICAL FIX: pin_memory=False for Kaggle
        loaders = {
            "train": DataLoader(ClimateDatasetWithScenario(X_tr, y_tr, sc_tr), batch_size, True, pin_memory=False, num_workers=0),
            "validation": DataLoader(ClimateDatasetWithScenario(X_val, y_val, sc_val), batch_size, False, pin_memory=False, num_workers=0),
            "test": DataLoader(ClimateDatasetWithScenario(X_te, y_te, sc_te), batch_size, False, pin_memory=False, num_workers=0)
        }
        
        logger.info(f"   📦 Train:{len(X_tr)} Val:{len(X_val)} Test:{len(X_te)}")
        return splits, loaders


print("✅ Fixed DataLoader loaded")

✅ Fixed DataLoader loaded


## 🧠 Section 4: ClimAx Model (Memory-Efficient)

In [6]:
class MemoryEfficientPatchEmbedding(nn.Module):
    """Memory-efficient patch embedding with conv2d."""
    
    def __init__(self, patch_size: int, in_channels: int, embed_dim: int, 
                 spatial_h: int, spatial_w: int):
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.spatial_h = spatial_h
        self.spatial_w = spatial_w
        
        # Projection layer
        self.proj = nn.Conv2d(in_channels, embed_dim, 
                             kernel_size=patch_size, stride=patch_size)
        
        # Calculate number of patches
        self.num_patches_h = spatial_h // patch_size
        self.num_patches_w = spatial_w // patch_size
        self.num_patches = self.num_patches_h * self.num_patches_w
        
        logger.info(f"🔧 PatchEmbed: {spatial_h}×{spatial_w} → {self.num_patches_h}×{self.num_patches_w} = {self.num_patches} patches")
        
    def forward(self, x):
        # x: (B, T, C, H, W)
        B, T, C, H, W = x.shape
        
        # Verify dimensions
        if H != self.spatial_h or W != self.spatial_w:
            raise ValueError(f"Input spatial dims ({H}, {W}) don't match expected ({self.spatial_h}, {self.spatial_w})")
        
        # Process in chunks to save memory
        x = x.reshape(B * T, C, H, W)
        x = self.proj(x)  # (B*T, embed_dim, H', W')
        x = x.flatten(2).transpose(1, 2)  # (B*T, num_patches, embed_dim)
        x = x.reshape(B, T, self.num_patches, self.embed_dim)
        
        return x


class MemoryEfficientPositionalEmbedding(nn.Module):
    """Memory-efficient positional embeddings."""
    
    def __init__(self, num_patches: int, num_timesteps: int, embed_dim: int):
        super().__init__()
        self.num_patches = num_patches
        self.num_timesteps = num_timesteps
        self.embed_dim = embed_dim
        
        # Use smaller embeddings
        self.spatial_pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches, embed_dim))
        self.temporal_pos_embed = nn.Parameter(torch.zeros(1, num_timesteps, 1, embed_dim))
        
        # Initialize with smaller std
        nn.init.trunc_normal_(self.spatial_pos_embed, std=0.01)
        nn.init.trunc_normal_(self.temporal_pos_embed, std=0.01)
    
    def forward(self, x):
        B, T, P, D = x.shape
        
        if P != self.num_patches:
            raise ValueError(f"Number of patches mismatch: got {P}, expected {self.num_patches}")
        if T > self.num_timesteps:
            raise ValueError(f"Timesteps exceed maximum: got {T}, max {self.num_timesteps}")
        
        return x + self.spatial_pos_embed + self.temporal_pos_embed[:, :T, :, :]


class MemoryEfficientTransformerBlock(nn.Module):
    """Memory-efficient Transformer block with gradient checkpointing."""
    
    def __init__(self, embed_dim: int, num_heads: int, mlp_ratio: float = 4.0,
                 dropout: float = 0.1, attention_dropout: float = 0.1):
        super().__init__()
        
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim, num_heads, 
            dropout=attention_dropout, 
            batch_first=True
        )
        
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        # Self-attention with residual
        normed = self.norm1(x)
        attn_out, _ = self.attn(normed, normed, normed)
        x = x + attn_out
        
        # FFN with residual
        x = x + self.mlp(self.norm2(x))
        
        return x


class ClimAxMemoryEfficientModel(nn.Module):
    """
    ClimAx Model - MEMORY EFFICIENT VERSION
    
    ✅ Designed for 6GB GPU
    ✅ Uses ConvLSTM spatial dimensions (9×19)
    ✅ Gradient checkpointing enabled
    ✅ Optimized attention mechanisms
    ✅ Supports NUM_INPUT_CHANNELS for scenario embeddings
    """
    
    def __init__(self, config, spatial_h: Optional[int] = None, 
                 spatial_w: Optional[int] = None, use_checkpointing: bool = True):
        super().__init__()
        
        # Get spatial dims - prioritize 9×19 for memory efficiency
        if spatial_h is not None and spatial_w is not None:
            self.spatial_h = spatial_h
            self.spatial_w = spatial_w
        else:
            # Default to ConvLSTM's memory-efficient dimensions
            self.spatial_h = self._get_config_attr(config, 'SPATIAL_HEIGHT', 9)
            self.spatial_w = self._get_config_attr(config, 'SPATIAL_WIDTH', 19)
        
        # CRITICAL: Warn if dimensions are too large
        total_spatial = self.spatial_h * self.spatial_w
        if total_spatial > 500:
            logger.warning(f"⚠️  Large spatial dimensions ({self.spatial_h}×{self.spatial_w} = {total_spatial}) may cause OOM!")
            logger.warning(f"   Consider downsampling to 9×19 (171 pixels) like ConvLSTM")
        
        # Get model params
        self.patch_size = self._get_config_attr(config, 'PATCH_SIZE', 2)
        self.embed_dim = self._get_config_attr(config, 'EMBED_DIM', 128)
        self.depth = self._get_config_attr(config, 'DEPTH', 8)
        self.num_heads = self._get_config_attr(config, 'NUM_HEADS', 8)
        self.mlp_ratio = self._get_config_attr(config, 'MLP_RATIO', 4.0)
        self.dropout_rate = self._get_config_attr(config, 'DROPOUT_RATE', 0.1)
        self.attention_dropout = self._get_config_attr(config, 'ATTENTION_DROPOUT', 0.1)
        
        # Reduce model size if needed
        if total_spatial > 300:
            logger.warning("   → Reducing embed_dim from 128 to 64 to save memory")
            self.embed_dim = 64
            logger.warning("   → Reducing depth from 8 to 4 to save memory")
            self.depth = 4
        
        # ✅ FIXED: Input channels - prioritize NUM_INPUT_CHANNELS
        self.num_input_vars = self._get_config_attr(config, 'NUM_INPUT_CHANNELS', None)
        if self.num_input_vars is None:
            # Fall back to counting INPUT_VARIABLES
            input_vars = self._get_config_attr(config, 'INPUT_VARIABLES', [])
            self.num_input_vars = len(input_vars) if input_vars else 4
            logger.info(f"   📊 Input channels from INPUT_VARIABLES: {self.num_input_vars}")
        else:
            logger.info(f"   📊 Input channels from NUM_INPUT_CHANNELS: {self.num_input_vars}")
        
        self.seq_in_len = self._get_config_attr(config, 'SEQUENCE_INPUT_LENGTH', 12)
        self.seq_out_len = self._get_config_attr(config, 'SEQUENCE_OUTPUT_LENGTH', 3)
        
        # Calculate patches
        self.num_patches_h = self.spatial_h // self.patch_size
        self.num_patches_w = self.spatial_w // self.patch_size
        self.num_patches = self.num_patches_h * self.num_patches_w
        
        # Memory estimation
        tokens_per_batch = self.num_patches * self.seq_in_len
        attn_memory_gb = (tokens_per_batch ** 2 * 4) / (1024 ** 3)  # float32 bytes to GB
        
        logger.info(f"🔧 Memory-Efficient ClimAx initialized:")
        logger.info(f"   Spatial: {self.spatial_h}×{self.spatial_w} → {self.num_patches} patches")
        logger.info(f"   Tokens per sample: {tokens_per_batch}")
        logger.info(f"   Est. attention memory: {attn_memory_gb:.2f} GB per batch")
        logger.info(f"   Embed dim: {self.embed_dim}, Depth: {self.depth}")
        logger.info(f"   Gradient checkpointing: {use_checkpointing}")
        
        if attn_memory_gb > 4:
            logger.error(f"❌ Estimated memory ({attn_memory_gb:.2f} GB) exceeds safe limit!")
            logger.error(f"   Please reduce spatial dimensions or batch size")
            raise RuntimeError(f"Model too large for 6GB GPU! Need ~{attn_memory_gb:.1f}GB")
        
        self.use_checkpointing = use_checkpointing
        
        # Patch embedding
        self.patch_embed = MemoryEfficientPatchEmbedding(
            patch_size=self.patch_size,
            in_channels=self.num_input_vars,
            embed_dim=self.embed_dim,
            spatial_h=self.spatial_h,
            spatial_w=self.spatial_w
        )
        
        # Positional embeddings
        self.pos_embed = MemoryEfficientPositionalEmbedding(
            num_patches=self.num_patches,
            num_timesteps=self.seq_in_len,
            embed_dim=self.embed_dim
        )
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            MemoryEfficientTransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                mlp_ratio=self.mlp_ratio,
                dropout=self.dropout_rate,
                attention_dropout=self.attention_dropout
            ) for _ in range(self.depth)
        ])
        
        self.norm = nn.LayerNorm(self.embed_dim)
        
        # Lightweight prediction head
        self.head = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.GELU(),
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.embed_dim, self.patch_size * self.patch_size * self.seq_out_len)
        )
        
        # Count parameters
        self.num_params = sum(p.numel() for p in self.parameters())
        logger.info(f"✅ Model ready - Parameters: {self.num_params:,}")
    
    def _get_config_attr(self, config, attr_name, default):
        """Safely get attribute from config."""
        if hasattr(config, attr_name):
            return getattr(config, attr_name)
        
        for ns in ['data', 'model', 'training']:
            if hasattr(config, ns):
                ns_obj = getattr(config, ns)
                if ns_obj is not None and hasattr(ns_obj, attr_name):
                    return getattr(ns_obj, attr_name)
        
        return default
    
    def forward(self, x):
        """
        Memory-efficient forward pass with gradient checkpointing.
        
        Args:
            x: (B, T, C, H, W) input tensor
        
        Returns:
            (B, T_out, H, W) predictions
        """
        B, T, C, H, W = x.shape
        
        # Store original spatial dims for reconstruction
        original_h, original_w = H, W
        
        # Verify dimensions match
        if H != self.spatial_h or W != self.spatial_w:
            raise ValueError(
                f"Input spatial dims ({H}, {W}) don't match model ({self.spatial_h}, {self.spatial_w}). "
                f"Please downsample your data to {self.spatial_h}×{self.spatial_w} before training."
            )
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, T, num_patches, embed_dim)
        
        # Add positional embeddings
        x = self.pos_embed(x)
        
        # Flatten for transformer
        B, T, P, D = x.shape
        x = x.reshape(B, T * P, D)
        
        # Apply transformer blocks with gradient checkpointing
        for i, block in enumerate(self.blocks):
            if self.use_checkpointing and self.training:
                x = checkpoint(block, x, use_reentrant=False)
            else:
                x = block(x)
        
        x = self.norm(x)
        
        # Prediction head
        x = self.head(x)  # (B, T*P, patch_size^2 * seq_out_len)
        
        # Reshape to spatial output
        x = x.reshape(B, T, P, self.patch_size, self.patch_size, self.seq_out_len)
        
        # Reorganize patches back to spatial grid
        x = x.reshape(B, T, self.num_patches_h, self.num_patches_w, 
                     self.patch_size, self.patch_size, self.seq_out_len)
        
        # Merge patches: (B, T, seq_out_len, H_recon, W_recon)
        x = x.permute(0, 1, 6, 2, 4, 3, 5).contiguous()
        
        # Calculate reconstructed dimensions
        h_recon = self.num_patches_h * self.patch_size
        w_recon = self.num_patches_w * self.patch_size
        
        x = x.reshape(B, T, self.seq_out_len, h_recon, w_recon)
        
        # Upsample if needed to match original dimensions
        if h_recon != original_h or w_recon != original_w:
            x = x.reshape(B * T * self.seq_out_len, 1, h_recon, w_recon)
            x = F.interpolate(x, size=(original_h, original_w), mode='bilinear', align_corners=False)
            x = x.reshape(B, T, self.seq_out_len, original_h, original_w)
        
        # Average over input timesteps
        x = x.mean(dim=1)  # (B, seq_out_len, H, W)
        
        return x


print("✅ ClimAxMemoryEfficientModel class defined")

✅ ClimAxMemoryEfficientModel class defined


## 🚂 Section 5: Training Function

In [7]:
# =============================================================================
# 🚂 SECTION 5: TRAINING FUNCTION - FIXED WITH DEBUG LOGGING
# =============================================================================

def train_single_institution_multi_scenario(institution: str, config: Phase4Config) -> dict:
    """Train a single institution - FIXED with extensive debug logging"""
    
    print(f"\n{'='*80}")
    print(f"🌍 Training: {institution}")
    print(f"{'='*80}")
    
    start_time = time.time()
    
    # Check for existing model
    model_path = os.path.join(OUTPUT_DIR, "checkpoints", f"{institution}_multiscenario_best.pt")
    
    if RESUME_IF_MODEL_EXISTS and os.path.exists(model_path):
        print(f"⭐ Model exists - skipping {institution}")
        return {'success': True, 'institution': institution, 'skipped': True}
    
    try:
        # STEP 1: Create data loader
        print("📊 Step 1/8: Initializing data loader...")
        sys.stdout.flush()
        dl = Phase4MultiScenarioDataLoader(config, target_h=TARGET_SPATIAL_H, target_w=TARGET_SPATIAL_W)
        print("✅ Data loader created")
        sys.stdout.flush()
        
        # STEP 2: Discover files
        print("🔍 Step 2/8: Discovering files...")
        sys.stdout.flush()
        all_scenario_files = dl.discover_files_multi_scenario(institution)
        
        has_data = any(
            bool(files['inputs']) or bool(files['outputs'])
            for files in all_scenario_files.values()
        )
        
        if not has_data:
            print(f"⚠️ No data for {institution}")
            return {'success': False, 'institution': institution, 'error': 'No data'}
        
        print("✅ File discovery complete")
        sys.stdout.flush()
        
        # STEP 3: Load all scenarios
        print("📊 Step 3/8: Loading data...")
        sys.stdout.flush()
        all_scenario_data = dl.load_all_scenarios(all_scenario_files)
        
        if not all_scenario_data:
            print(f"⚠️ Failed to load data")
            return {'success': False, 'institution': institution, 'error': 'Load failed'}
        
        print("✅ Data loaded")
        sys.stdout.flush()
        
        # STEP 4: Create sequences
        print("🔄 Step 4/8: Creating sequences...")
        sys.stdout.flush()
        data_splits, dataloaders = dl.create_multi_scenario_sequences(all_scenario_data)
        print("✅ Sequences created")
        sys.stdout.flush()
        
        # STEP 5: Create model
        print("🧠 Step 5/8: Creating model...")
        sys.stdout.flush()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"   Using device: {device}")
        
        model = ClimAxMemoryEfficientModel(
            config,
            spatial_h=TARGET_SPATIAL_H,
            spatial_w=TARGET_SPATIAL_W,
            use_checkpointing=USE_CHECKPOINTING
        ).to(device)
        
        print("✅ Model created")
        sys.stdout.flush()
        
        # STEP 6: Setup training
        print("⚙️ Step 6/8: Setting up optimizer...")
        sys.stdout.flush()
        optimizer = optim.AdamW(
            model.parameters(),
            lr=config.LEARNING_RATE,
            weight_decay=config.WEIGHT_DECAY
        )
        criterion = nn.MSELoss()
        
        history = {
            'train_loss': [],
            'val_loss': [],
            'scenario_losses': {}
        }
        best_val_loss = float('inf')
        
        print("✅ Optimizer ready")
        sys.stdout.flush()
        
        print(f"\n🚂 Step 7/8: Training for {EPOCHS} epochs...")
        sys.stdout.flush()
        
        # STEP 7: Training loop with progress bars
        epoch_pbar = tqdm(range(EPOCHS), desc=f"🌍 {institution}", position=0, leave=True)
        
        for ep in epoch_pbar:
            print(f"\n--- Epoch {ep+1}/{EPOCHS} ---")
            sys.stdout.flush()
            
            # === TRAINING PHASE ===
            model.train()
            train_loss = 0.0
            train_batches = 0
            
            print(f"Training: Processing {len(dataloaders['train'])} batches...")
            sys.stdout.flush()
            
            train_pbar = tqdm(
                dataloaders['train'], 
                desc=f"   Train Epoch {ep+1}/{EPOCHS}",
                position=1,
                leave=False,
                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
            )
            
            for batch_idx, batch in enumerate(train_pbar):
                try:
                    X, y, scenarios = batch
                    X = X.to(device)
                    y = y.to(device)
                    
                    optimizer.zero_grad()
                    preds = model(X)
                    
                    # Align shapes
                    if preds.shape != y.shape:
                        if preds.ndim == 4 and y.ndim == 5:
                            preds = preds.unsqueeze(2)
                    
                    loss = criterion(preds, y)
                    loss.backward()
                    
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.CLIP_GRAD_NORM)
                    optimizer.step()
                    
                    train_loss += loss.item()
                    train_batches += 1
                    
                    # Update progress bar
                    train_pbar.set_postfix({'loss': f'{loss.item():.6f}'})
                    
                    # Debug print every 5 batches
                    if batch_idx % 5 == 0:
                        print(f"      Batch {batch_idx}: loss={loss.item():.6f}")
                        sys.stdout.flush()
                    
                    if train_batches % 10 == 0:
                        torch.cuda.empty_cache()
                
                except Exception as e:
                    print(f"\n❌ Error in training batch {batch_idx}: {e}")
                    import traceback
                    traceback.print_exc()
                    raise
            
            avg_train_loss = train_loss / max(1, train_batches)
            history['train_loss'].append(avg_train_loss)
            
            print(f"   ✓ Train loss: {avg_train_loss:.6f}")
            sys.stdout.flush()
            
            # === VALIDATION PHASE ===
            model.eval()
            val_loss = 0.0
            val_batches = 0
            scenario_losses = {s: [] for s in SCENARIOS.keys()}
            
            print(f"Validation: Processing {len(dataloaders['validation'])} batches...")
            sys.stdout.flush()
            
            val_pbar = tqdm(
                dataloaders['validation'],
                desc=f"   Val Epoch {ep+1}/{EPOCHS}",
                position=1,
                leave=False,
                bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'
            )
            
            with torch.no_grad():
                for batch_idx, batch in enumerate(val_pbar):
                    try:
                        Xv, yv, scenarios_v = batch
                        Xv = Xv.to(device)
                        yv = yv.to(device)
                        
                        preds = model(Xv)
                        
                        if preds.shape != yv.shape:
                            if preds.ndim == 4 and yv.ndim == 5:
                                preds = preds.unsqueeze(2)
                        
                        batch_loss = criterion(preds, yv)
                        val_loss += batch_loss.item()
                        val_batches += 1
                        
                        # Track per-scenario losses
                        for i, scenario_id in enumerate(scenarios_v.cpu().numpy()):
                            scenario_name = [k for k, v in SCENARIOS.items() if v == scenario_id][0]
                            scenario_losses[scenario_name].append(batch_loss.item())
                        
                        val_pbar.set_postfix({'loss': f'{batch_loss.item():.6f}'})
                    
                    except Exception as e:
                        print(f"\n❌ Error in validation batch {batch_idx}: {e}")
                        import traceback
                        traceback.print_exc()
                        raise
            
            avg_val_loss = val_loss / max(1, val_batches)
            history['val_loss'].append(avg_val_loss)
            
            print(f"   ✓ Val loss: {avg_val_loss:.6f}")
            sys.stdout.flush()
            
            # Calculate per-scenario averages
            avg_scenario_losses = {
                s: np.mean(losses) if losses else 0.0
                for s, losses in scenario_losses.items()
            }
            history['scenario_losses'][f'epoch_{ep+1}'] = avg_scenario_losses
            
            # Update best model
            is_best = avg_val_loss < best_val_loss
            if is_best:
                best_val_loss = avg_val_loss
                best_indicator = "⭐ NEW BEST"
            else:
                best_indicator = ""
            
            # Update main progress bar
            epoch_pbar.set_postfix({
                'train': f'{avg_train_loss:.6f}',
                'val': f'{avg_val_loss:.6f}',
                'best': f'{best_val_loss:.6f}',
                'status': best_indicator
            })
            
            # Memory cleanup
            torch.cuda.empty_cache()
            gc.collect()
        
        # Close progress bars
        epoch_pbar.close()
        
        # STEP 8: Save model
        print("💾 Step 8/8: Saving model...")
        sys.stdout.flush()
        torch.save({
            'institution': institution,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history,
            'best_val_loss': best_val_loss,
            'spatial_dims': [TARGET_SPATIAL_H, TARGET_SPATIAL_W],
            'scenarios': list(SCENARIOS.keys()),
            'timestamp': datetime.now().isoformat()
        }, model_path)
        
        print(f"✅ Saved to {model_path}")
        sys.stdout.flush()
        
        # Save results
        summary = {
            'institution': institution,
            'training_time': time.time() - start_time,
            'epochs_trained': len(history['train_loss']),
            'best_val_loss': best_val_loss,
            'spatial_dims': [TARGET_SPATIAL_H, TARGET_SPATIAL_W],
            'scenarios': list(SCENARIOS.keys()),
            'final_scenario_losses': history['scenario_losses'].get(f'epoch_{EPOCHS}', {}),
            'timestamp': datetime.now().isoformat()
        }
        
        results_path = os.path.join(OUTPUT_DIR, "logs", f"{institution}_phase4_results.json")
        with open(results_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        return {'success': True, 'institution': institution, 'results': summary}
    
    except Exception as e:
        print(f"\n❌ FATAL ERROR: {e}")
        import traceback
        traceback.print_exc()
        sys.stdout.flush()
        return {'success': False, 'institution': institution, 'error': str(e)}
    
    finally:
        torch.cuda.empty_cache()
        gc.collect()


print("✅ Fixed training function loaded")

✅ Fixed training function loaded


## 🚀 Section 6: Main Training Loop

In [None]:
# =============================================================================
# 🚀 SECTION 6: MAIN TRAINING EXECUTION - FIXED WITH DEBUG
# =============================================================================

print("="*80)
print("🌍 ClimAx Phase 4 - Multi-Scenario Training")
print("="*80)
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
sys.stdout.flush()

# Create config
config = Phase4Config()

# Determine institutions to train
if SMOKE_TEST:
    institutions = ALL_INSTITUTIONS[:2]
    print(f"⚠️ SMOKE TEST - Training only {len(institutions)} institutions")
else:
    institutions = ALL_INSTITUTIONS
    print(f"📋 Training all {len(institutions)} institutions")

print(f"   Scenarios: {list(SCENARIOS.keys())}")
print(f"   Spatial: {TARGET_SPATIAL_H}×{TARGET_SPATIAL_W}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Epochs: {EPOCHS}")
print(f"   Patch size: {config.PATCH_SIZE}")
print(f"   Input channels: {config.NUM_INPUT_CHANNELS}")
sys.stdout.flush()

# GPU Check
if torch.cuda.is_available():
    print(f"\n🎮 GPU Status:")
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
    print(f"   Current allocation: {torch.cuda.memory_allocated() / (1024**2):.1f} MB")
else:
    print("\n⚠️ WARNING: No GPU available - will be VERY slow!")
sys.stdout.flush()

if SKIP_TRAINING:
    print("SKIP_TRAINING=True - Loading existing results...")
    summary_path = os.path.join(OUTPUT_DIR, "logs", "phase4_training_summary.json")
    if os.path.exists(summary_path):
        with open(summary_path, 'r') as f:
            summary = json.load(f)
        print("✅ Loaded existing results")
    else:
        print("⚠️ No existing results found")
        summary = {}
else:
    # Training loop with main progress bar
    print("\n" + "="*80)
    print("🚂 STARTING TRAINING")
    print("="*80)
    sys.stdout.flush()
    
    t0 = time.time()
    all_results = []
    
    main_pbar = tqdm(
        institutions,
        desc="🌍 Overall Progress",
        position=0,
        bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
    )
    
    for inst_idx, institution in enumerate(main_pbar):
        print(f"\n{'='*80}")
        print(f"Institution {inst_idx+1}/{len(institutions)}: {institution}")
        print(f"{'='*80}")
        sys.stdout.flush()
        
        main_pbar.set_description(f"🌍 Training {institution}")
        
        try:
            result = train_single_institution_multi_scenario(institution, config)
            all_results.append(result)
            
            # Update main progress with status
            successful = sum(1 for r in all_results if r.get('success') and not r.get('skipped'))
            skipped = sum(1 for r in all_results if r.get('skipped'))
            failed = sum(1 for r in all_results if not r.get('success'))
            
            print(f"\n📊 Progress Summary:")
            print(f"   ✅ Successful: {successful}")
            print(f"   ⭐ Skipped: {skipped}")
            print(f"   ❌ Failed: {failed}")
            sys.stdout.flush()
            
            main_pbar.set_postfix({
                'success': successful,
                'skipped': skipped,
                'failed': failed
            })
            
            # Save progress after each institution
            progress = {
                'completed': len(all_results),
                'total': len(institutions),
                'results': all_results,
                'timestamp': datetime.now().isoformat()
            }
            
            progress_path = os.path.join(OUTPUT_DIR, "logs", "phase4_progress.json")
            with open(progress_path, 'w') as f:
                json.dump(progress, f, indent=2)
            print(f"💾 Progress saved to {progress_path}")
            sys.stdout.flush()
        
        except Exception as e:
            print(f"\n❌ CRITICAL ERROR processing {institution}: {e}")
            import traceback
            traceback.print_exc()
            sys.stdout.flush()
            
            all_results.append({
                'success': False,
                'institution': institution,
                'error': str(e)
            })
    
    main_pbar.close()
    elapsed = time.time() - t0
    
    # Compile summary
    successful = [r for r in all_results if r.get('success') and not r.get('skipped')]
    skipped = [r for r in all_results if r.get('skipped')]
    failed = [r for r in all_results if not r.get('success')]
    
    summary = {
        'phase': 'phase4_multi_scenario',
        'total_institutions': len(institutions),
        'successful': len(successful),
        'skipped': len(skipped),
        'failed': len(failed),
        'total_time_hours': elapsed / 3600,
        'scenarios': list(SCENARIOS.keys()),
        'results': all_results,
        'timestamp': datetime.now().isoformat()
    }
    
    # Save summary
    summary_path = os.path.join(OUTPUT_DIR, "logs", "phase4_training_summary.json")
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\n{'='*80}")
    print("✅ Phase 4 Training Complete!")
    print(f"{'='*80}")
    print(f"Total institutions: {len(institutions)}")
    print(f"  ✅ Successful: {len(successful)}")
    print(f"  ⭐ Skipped: {len(skipped)}")
    print(f"  ❌ Failed: {len(failed)}")
    print(f"Total time: {elapsed/3600:.2f} hours")
    print(f"{'='*80}")
    sys.stdout.flush()

🌍 ClimAx Phase 4 - Multi-Scenario Training
Start time: 2025-10-09 09:58:13
⚠️ SMOKE TEST - Training only 2 institutions
   Scenarios: ['historical', 'ssp126', 'ssp245', 'ssp370', 'ssp585']
   Spatial: 9×19
   Batch size: 1
   Epochs: 2
   Patch size: 1
   Input channels: 8

🎮 GPU Status:
   Device: Tesla P100-PCIE-16GB
   Memory: 15.9 GB
   Current allocation: 0.0 MB

🚂 STARTING TRAINING


🌍 Overall Progress:   0%|          | 0/2 [00:00<?]


Institution 1/2: AWI-CM-1-1-MR


🌍 Training AWI-CM-1-1-MR:   0%|          | 0/2 [00:00<?]


🌍 Training: AWI-CM-1-1-MR
📊 Step 1/8: Initializing data loader...
✅ Data loader created
🔍 Step 2/8: Discovering files...


## 📊 Section 7: Display Results

In [None]:
# Display results
if summary:
    print("\n" + "="*80)
    print("📊 PHASE 4 RESULTS SUMMARY")
    print("="*80)
    print(f"Total institutions: {summary.get('total_institutions', 0)}")
    print(f"  ✅ Successful: {summary.get('successful', 0)}")
    print(f"  ⭐ Skipped: {summary.get('skipped', 0)}")
    print(f"  ❌ Failed: {summary.get('failed', 0)}")
    print(f"Total time: {summary.get('total_time_hours', 0):.2f}h")
    print(f"Scenarios: {summary.get('scenarios', [])}")
    
    # Scenario performance
    print(f"\n🌍 Scenario Performance Across All Institutions:")
    print("="*80)
    
    scenario_perf = {s: [] for s in SCENARIOS.keys()}
    
    for result in summary.get('results', []):
        if result.get('success') and not result.get('skipped'):
            inst_results = result.get('results', {})
            final_losses = inst_results.get('final_scenario_losses', {})
            
            for scenario, loss in final_losses.items():
                if scenario in scenario_perf:
                    scenario_perf[scenario].append(loss)
    
    for scenario, losses in scenario_perf.items():
        if losses:
            avg = np.mean(losses)
            std = np.std(losses)
            min_loss = np.min(losses)
            max_loss = np.max(losses)
            print(f"   {scenario.upper():12s}: Avg={avg:.6f}±{std:.6f} | Min={min_loss:.6f} | Max={max_loss:.6f}")
        else:
            print(f"   {scenario.upper():12s}: No data")
    
    print("="*80)

print(f"\n✅ Results saved to: {OUTPUT_DIR}")
print("="*80)

## 🎉 Done!

Your models are now trained and saved in the `climax_phase4_results/` directory.

### 📁 Output Files:
- **Models**: `climax_phase4_results/checkpoints/*.pt`
- **Logs**: `climax_phase4_results/logs/*.json`
- **Summary**: `climax_phase4_results/logs/phase4_training_summary.json`

### 🔄 Next Steps:
1. Load trained models for inference
2. Evaluate on test set
3. Generate predictions for different scenarios
4. Visualize results