# Advanced Spatial Meta-Models: Stacking & Cross-Attention Fusion

This notebook implements two meta-model strategies for advanced precipitation prediction:

## Prerequisites
This notebook requires pre-trained base models from `advanced_spatial_models.ipynb`:
- ConvLSTM_Att models (3 experiments)
- ConvGRU_Res models (3 experiments)  
- Hybrid_Trans models (3 experiments)

## 🎯 Strategy 1: Stacking (Base Experiment)
- **Approach**: Ensemble stacking of spatial models
- **Difficulty**: ⭐⭐⭐ (High)
- **Originality**: ⭐⭐⭐⭐ (Very High)
- **Citability**: ⭐⭐⭐⭐ (Very High)
- **Description**: Easy to implement, highly citable if it improves spatial/temporal robustness

## 🚀 Strategy 2: Cross-Attention Fusion GRU ↔ LSTM-Att (Experimental)
- **Approach**: Dual-attention decoder with cross-modal fusion
- **Difficulty**: ⭐⭐⭐⭐ (Very High)
- **Originality**: ⭐⭐⭐⭐⭐ (Breakthrough)
- **Citability**: ⭐⭐⭐⭐⭐ (Breakthrough potential)
- **Description**: Never reported in hydrology. Inspired by Vision-Language Transformers (ViLT, Perceiver IO)

## 📊 Development Methodology
- Load pre-trained base models (no training duplication)
- English language for all implementations
- Consistent metrics: RMSE, MAE, MAPE, R²
- Same evaluation approach as base models
- Comprehensive visualization and model exports
- Output path: `output/Advanced_Spatial/meta_models/`


In [None]:
# Setup and Imports for Meta-Models
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import tensorflow as tf
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import logging
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Ridge, ElasticNet
import xgboost as xgb
import warnings
warnings.filterwarnings('ignore')

# 🔧 FIXED: Add scipy import for Colab compatibility
try:
    from scipy.ndimage import gaussian_filter
    SCIPY_AVAILABLE = True
except ImportError:
    logger.warning("⚠️ scipy not available, installing...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
    from scipy.ndimage import gaussian_filter
    SCIPY_AVAILABLE = True

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
tf.random.set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"🔥 Using device: {device}")

# 🔧 FIXED: Synchronized paths with advanced_spatial_models.ipynb
BASE_PATH = Path.cwd()
while not (BASE_PATH / 'models').exists() and BASE_PATH.parent != BASE_PATH:
    BASE_PATH = BASE_PATH.parent

# Use 'advanced_spatial' (lowercase) to match advanced_spatial_models.ipynb
ADVANCED_SPATIAL_ROOT = BASE_PATH / 'models' / 'output' / 'advanced_spatial'
META_MODELS_ROOT = ADVANCED_SPATIAL_ROOT / 'meta_models'
STACKING_OUTPUT = META_MODELS_ROOT / 'stacking'
CROSS_ATTENTION_OUTPUT = META_MODELS_ROOT / 'cross_attention'

# Create meta-model directories
META_MODELS_ROOT.mkdir(parents=True, exist_ok=True)
STACKING_OUTPUT.mkdir(parents=True, exist_ok=True)
CROSS_ATTENTION_OUTPUT.mkdir(parents=True, exist_ok=True)

logger.info(f"📁 Project root: {BASE_PATH}")
logger.info(f"📁 Advanced Spatial root: {ADVANCED_SPATIAL_ROOT}")
logger.info(f"📁 Meta-models root: {META_MODELS_ROOT}")

# Visualization settings
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


In [None]:
# Load Pre-trained Base Models and Utility Functions

def load_pretrained_base_models():
    """
    Load pre-trained base models from advanced_spatial_models.ipynb output
    
    Returns:
        dict: Dictionary containing loaded models and their metadata
    """
    logger.info("📦 Loading pre-trained base models...")
    
    # 🔧 FIXED: Define model structure matching advanced_spatial_models.ipynb exactly
    experiments = ['ConvLSTM-ED', 'ConvLSTM-ED-KCE', 'ConvLSTM-ED-KCE-PAFC']
    model_types = ['convlstm_att', 'convgru_res', 'hybrid_trans']
    
    logger.info(f"📁 Looking for models in: {ADVANCED_SPATIAL_ROOT}")
    logger.info(f"📊 Experiments: {experiments}")
    logger.info(f"🤖 Model types: {model_types}")
    
    loaded_models = {}
    
    for experiment in experiments:
        for model_type in model_types:
            model_path = ADVANCED_SPATIAL_ROOT / experiment / f"{model_type}_best.keras"
            model_name = f"{experiment}_{model_type}"
            
            if model_path.exists():
                try:
                    logger.info(f"   Loading {model_name} from {model_path}")
                    model = load_model(str(model_path), compile=False)
                    loaded_models[model_name] = {
                        'model': model,
                        'experiment': experiment,
                        'type': model_type,
                        'path': model_path
                    }
                    logger.info(f"   ✅ Successfully loaded {model_name}")
                    
                    # 🔧 ADDED: Memory management for Colab
                    if is_colab:
                        # Limit memory growth and cleanup
                        import gc
                        gc.collect()
                    
                except Exception as e:
                    logger.warning(f"   ⚠️ Failed to load {model_name}: {e}")
            else:
                logger.warning(f"   ⚠️ Model file not found: {model_path}")
    
    logger.info(f"✅ Loaded {len(loaded_models)} base models")
    return loaded_models

def evaluate_metrics_np(y_true, y_pred):
    """Calculate evaluation metrics for numpy arrays"""
    # Remove NaN/Inf values
    mask = np.isfinite(y_true) & np.isfinite(y_pred)
    if mask.sum() == 0:
        return np.nan, np.nan, np.nan, np.nan
    
    y_true, y_pred = y_true[mask], y_pred[mask]
    
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    
    # MAPE calculation (avoid division by zero)
    mape = np.mean(np.abs((y_true - y_pred) / np.maximum(y_true, 1e-8))) * 100
    
    r2 = r2_score(y_true, y_pred)
    
    return rmse, mae, mape, r2

def load_mock_data_for_testing():
    """
    Create mock data for testing meta-models
    This will be replaced with real predictions from loaded base models
    """
    logger.info("📊 Loading mock data for meta-model testing...")
    
    # Mock parameters (these will come from real base models)
    n_samples = 100
    horizon = 3
    ny, nx = 64, 64
    
    # Generate mock base model predictions
    np.random.seed(42)
    
    # Create realistic precipitation-like data with spatial patterns
    base_predictions = {}
    experiments = ['ConvLSTM-ED', 'ConvLSTM-ED-KCE', 'ConvLSTM-ED-KCE-PAFC']
    model_types = ['convlstm_att', 'convgru_res', 'hybrid_trans']
    model_names = [f"{exp}_{model_type}" for exp in experiments for model_type in model_types]
    
    for model_name in model_names:
        # Generate spatially coherent precipitation patterns
        base_pred = np.random.exponential(scale=2.0, size=(n_samples, horizon, ny, nx))
        base_pred = np.maximum(0, base_pred)  # Ensure non-negative
        
        # Add spatial smoothing for realism
        from scipy.ndimage import gaussian_filter
        for i in range(n_samples):
            for h in range(horizon):
                base_pred[i, h] = gaussian_filter(base_pred[i, h], sigma=1.5)
        
        base_predictions[model_name] = base_pred
    
    # Generate mock ground truth with some correlation to predictions
    true_values = np.mean([pred for pred in base_predictions.values()], axis=0) + \
                  np.random.normal(0, 0.5, (n_samples, horizon, ny, nx))
    true_values = np.maximum(0, true_values)  # Ensure non-negative
    
    logger.info(f"✅ Mock data created:")
    logger.info(f"   Models: {len(model_names)}")
    logger.info(f"   Samples: {n_samples}, Horizon: {horizon}")
    logger.info(f"   Spatial dims: {ny}×{nx}")
    
    return base_predictions, true_values, model_names

def plot_training_history(history, title="Training History", save_path=None):
    """Plot training and validation loss"""
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    epochs = range(1, len(history['train_loss']) + 1)
    ax.plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    ax.plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        logger.info(f"📈 Training history saved to {save_path}")
    
    plt.show()

def save_metrics_to_csv(metrics_list, output_path):
    """Save metrics list to CSV file"""
    df = pd.DataFrame(metrics_list)
    df.to_csv(output_path, index=False)
    logger.info(f"📊 Metrics saved to {output_path}")
    return df

# 🔧 FIXED: Load REAL Predictions from Advanced Spatial Models
def load_real_predictions_from_manifests():
    """
    Load REAL predictions from the exported manifests and prediction files
    
    Returns:
        dict: Base model predictions
        np.ndarray: Ground truth values  
        list: Model names
    """
    logger.info("📦 Loading REAL predictions from advanced_spatial_models.ipynb output...")
    
    # Try to load from stacking manifest first
    manifest_path = STACKING_OUTPUT / 'stacking_manifest.json'
    predictions_dir = META_MODELS_ROOT / 'predictions'
    
    if not manifest_path.exists():
        logger.warning(f"⚠️ Manifest not found: {manifest_path}")
        logger.warning("🔄 Falling back to mock data - Please run advanced_spatial_models.ipynb first!")
        return load_mock_data_for_testing()
    
    try:
        # Load manifest
        with open(manifest_path, 'r') as f:
            manifest = json.load(f)
        
        logger.info(f"✅ Found manifest with {len(manifest['models'])} models")
        
        # Load predictions for each model
        base_predictions = {}
        model_names = []
        
        for model_name, model_info in manifest['models'].items():
            pred_file = Path(model_info['predictions_file'])
            
            if pred_file.exists():
                try:
                    predictions = np.load(pred_file)
                    base_predictions[model_name] = predictions
                    model_names.append(model_name)
                    logger.info(f"✅ Loaded {model_name}: {predictions.shape}")
                except Exception as e:
                    logger.warning(f"⚠️ Failed to load {model_name}: {e}")
            else:
                logger.warning(f"⚠️ Prediction file not found: {pred_file}")
        
        # Load ground truth
        ground_truth_file = manifest.get('ground_truth_file')
        if ground_truth_file and Path(ground_truth_file).exists():
            true_values = np.load(ground_truth_file)
            logger.info(f"✅ Loaded ground truth: {true_values.shape}")
        else:
            logger.warning("⚠️ Ground truth not found, creating synthetic targets")
            # Create synthetic ground truth based on average predictions
            if base_predictions:
                first_pred = list(base_predictions.values())[0]
                true_values = np.mean([pred for pred in base_predictions.values()], axis=0) + \
                            np.random.normal(0, 0.1, first_pred.shape)
                true_values = np.maximum(0, true_values)
            else:
                return load_mock_data_for_testing()
        
        if not base_predictions:
            logger.warning("⚠️ No predictions loaded, falling back to mock data")
            return load_mock_data_for_testing()
        
        logger.info(f"🎯 Successfully loaded REAL predictions:")
        logger.info(f"   Models: {len(model_names)}")
        logger.info(f"   Samples: {true_values.shape[0]}")
        logger.info(f"   Horizon: {true_values.shape[1]}")
        logger.info(f"   Spatial dims: {true_values.shape[2]}×{true_values.shape[3]}")
        
        return base_predictions, true_values, model_names
        
    except Exception as e:
        logger.error(f"❌ Error loading predictions: {e}")
        logger.warning("🔄 Falling back to mock data")
        return load_mock_data_for_testing()

def check_colab_compatibility():
    """Check if running in Google Colab and adjust paths accordingly"""
    try:
        import google.colab
        IN_COLAB = True
        logger.info("🔗 Running in Google Colab")
        
        # Mount Google Drive if not already mounted
        if not Path('/content/drive/MyDrive').exists():
            logger.info("📁 Mounting Google Drive...")
            from google.colab import drive
            drive.mount('/content/drive')
        
        # 🔧 FIXED: Update paths for Colab with correct naming
        global BASE_PATH, ADVANCED_SPATIAL_ROOT, META_MODELS_ROOT, STACKING_OUTPUT, CROSS_ATTENTION_OUTPUT
        BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
        # Use 'advanced_spatial' (lowercase) to match advanced_spatial_models.ipynb
        ADVANCED_SPATIAL_ROOT = BASE_PATH / 'models' / 'output' / 'advanced_spatial'
        META_MODELS_ROOT = ADVANCED_SPATIAL_ROOT / 'meta_models'
        STACKING_OUTPUT = META_MODELS_ROOT / 'stacking'
        CROSS_ATTENTION_OUTPUT = META_MODELS_ROOT / 'cross_attention'
        
        logger.info(f"📁 Updated paths for Colab:")
        logger.info(f"   Base: {BASE_PATH}")
        logger.info(f"   Advanced Spatial: {ADVANCED_SPATIAL_ROOT}")
        
        return True
        
    except ImportError:
        logger.info("💻 Running locally (not in Colab)")
        return False

# Check Colab compatibility and adjust paths
is_colab = check_colab_compatibility()

# Load the pre-trained models (for fallback if needed)
loaded_base_models = load_pretrained_base_models()

# 🚀 Load REAL predictions instead of mock data
base_predictions, true_values, model_names = load_real_predictions_from_manifests()

# Extract specific models for cross-attention (GRU and LSTM)
gru_models = [name for name in model_names if 'convgru_res' in name]
lstm_models = [name for name in model_names if 'convlstm_att' in name]

logger.info(f"🎯 Models identified for Cross-Attention:")
logger.info(f"   GRU models: {gru_models}")
logger.info(f"   LSTM models: {lstm_models}")

# Prepare data splits
n_samples = true_values.shape[0]
train_size = int(0.8 * n_samples)
train_indices = np.arange(train_size)
val_indices = np.arange(train_size, n_samples)

# Split base predictions
train_base_predictions = {name: pred[train_indices] for name, pred in base_predictions.items()}
val_base_predictions = {name: pred[val_indices] for name, pred in base_predictions.items()}
train_targets = true_values[train_indices]
val_targets = true_values[val_indices]

logger.info(f"📊 Data split completed:")
logger.info(f"   Training samples: {len(train_indices)}")
logger.info(f"   Validation samples: {len(val_indices)}")
logger.info("✅ REAL data loading completed successfully!")
