# V7-AMES: Adaptive Multi-Expert Ensemble System

**Complete Training Pipeline for Monthly Precipitation Prediction**

This notebook implements a physics-guided multi-expert ensemble with adaptive routing for spatiotemporal precipitation forecasting in mountainous regions.

**Key Components:**
- 3 specialized experts (high/medium/low elevation)
- Physics-guided gating network
- Physics-informed meta-learner
- 3-stage hierarchical training protocol


## 1. Model Architecture

V7-AMES consists of 5 main components:

1. **Expert1_HighElevation**: GNN-TAT for elevation >3000m
2. **Expert2_LowElevation**: ConvLSTM for elevation <2000m
3. **Expert3_Transition**: Hybrid GNN-ConvLSTM for 2000-3000m
4. **PhysicsGuidedGating**: Adaptive routing with orographic priors
5. **PhysicsInformedMetaLearner**: Final ensemble with physics constraints


In [None]:
# =============================================================================
# SECTION 1: V7-AMES MODEL ARCHITECTURE
# =============================================================================

# V7-AMES: Adaptive Multi-Expert Ensemble System - Complete Architecture# This file contains all 5 main components for the V7-AMES modelimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch_geometric.nn import GATConvfrom dataclasses import dataclassfrom pathlib import Pathfrom typing import Optional, Tuple@dataclassclass V7Config:    """Configuration for V7-AMES model."""    # Data dimensions    n_lat: int = 61    n_lon: int = 65    n_nodes: int = 61 * 65  # 3965    horizon: int = 12    # Expert 1: High Elevation (GNN-TAT)    expert1_gnn_hidden: int = 64    expert1_gnn_layers: int = 3    expert1_gat_heads: int = 4    expert1_lstm_hidden: int = 64    expert1_lstm_layers: int = 2    expert1_attn_heads: int = 4    expert1_dropout: float = 0.2    # Expert 2: Low Elevation (ConvLSTM)    expert2_convlstm_hidden: int = 64    expert2_convlstm_layers: int = 2    expert2_kernel_size: int = 3    expert2_dropout: float = 0.2    # Expert 3: Transition (Hybrid)    expert3_gnn_hidden: int = 32    expert3_gnn_layers: int = 1    expert3_conv_hidden: int = 32    expert3_conv_layers: int = 2    expert3_dropout: float = 0.2    # Gating network    gating_hidden: int = 32    gating_dropout: float = 0.2    physics_prior_weight_init: float = 0.3    # Meta-learner    meta_hidden_1: int = 64    meta_hidden_2: int = 32    meta_dropout: float = 0.2    # Training    batch_size: int = 8    lr_stage1: float = 0.001    lr_stage2: float = 0.001    lr_stage3: float = 0.0001    epochs_stage1: int = 50    epochs_stage2: int = 30    epochs_stage3: int = 50    patience: int = 15    # Physics loss weights    lambda_mass_conservation: float = 0.05    lambda_orographic: float = 0.1    # Elevation thresholds    high_elev_threshold: float = 3000.0    medium_elev_min: float = 2000.0    medium_elev_max: float = 3000.0    low_elev_threshold: float = 2000.0    # Paths    data_dir: Path = Path("output/V7_AMES_Data")    model_dir: Path = Path("output/V7_AMES_Models")    v2_path: Optional[Path] = None    v4_path: Optional[Path] = None# ============================================================================# COMPONENT 1: Expert1_HighElevation (GNN-TAT for >3000m)# ============================================================================class Expert1_HighElevation(nn.Module):    """Expert 1: GNN-TAT for high elevation (>3000m).    Architecture based on V4 GNN-TAT:    - Temporal encoder: 2-layer LSTM    - Graph attention: 3 GAT layers    - Temporal attention: Multi-head attention    """    def __init__(self, config: V7Config, n_features: int):        super().__init__()        self.config = config        self.n_features = n_features        self.horizon = config.horizon        # Temporal encoder (LSTM)        self.temporal_encoder = nn.LSTM(            input_size=n_features,            hidden_size=config.expert1_lstm_hidden,            num_layers=config.expert1_lstm_layers,            batch_first=True,            dropout=config.expert1_dropout if config.expert1_lstm_layers > 1 else 0.0        )        # Graph attention layers (GAT)        self.gat_layers = nn.ModuleList()        in_dim = config.expert1_lstm_hidden        for i in range(config.expert1_gnn_layers):            self.gat_layers.append(                GATConv(                    in_channels=in_dim,                    out_channels=config.expert1_gnn_hidden,                    heads=config.expert1_gat_heads,                    dropout=config.expert1_dropout,                    concat=True if i < config.expert1_gnn_layers - 1 else False                )            )            if i < config.expert1_gnn_layers - 1:                in_dim = config.expert1_gnn_hidden * config.expert1_gat_heads            else:                in_dim = config.expert1_gnn_hidden        # Temporal attention        self.temporal_attention = nn.MultiheadAttention(            embed_dim=config.expert1_gnn_hidden,            num_heads=config.expert1_attn_heads,            dropout=config.expert1_dropout,            batch_first=True        )        # Output projection        self.output_projection = nn.Sequential(            nn.Linear(config.expert1_gnn_hidden, config.expert1_gnn_hidden // 2),            nn.ReLU(),            nn.Dropout(config.expert1_dropout),            nn.Linear(config.expert1_gnn_hidden // 2, config.horizon)        )    def forward(self, x, edge_index, edge_weight=None):        """Forward pass.        Args:            x: [batch, nodes, time, features]            edge_index: [2, num_edges]            edge_weight: [num_edges] (optional)        Returns:            predictions: [batch, nodes, horizon]        """        batch_size, n_nodes, time_steps, n_features = x.shape        # Temporal encoding per node        x_reshaped = x.view(batch_size * n_nodes, time_steps, n_features)        temporal_features, _ = self.temporal_encoder(x_reshaped)        temporal_features = temporal_features[:, -1, :]  # Take last timestep        temporal_features = temporal_features.view(batch_size, n_nodes, -1)        # Graph attention layers        graph_features = temporal_features        for i, gat_layer in enumerate(self.gat_layers):            # Flatten batch and nodes for GAT            x_flat = graph_features.view(batch_size * n_nodes, -1)            x_flat = gat_layer(x_flat, edge_index, edge_weight)            x_flat = F.elu(x_flat)            graph_features = x_flat.view(batch_size, n_nodes, -1)        # Temporal attention over nodes        attended_features, _ = self.temporal_attention(            graph_features, graph_features, graph_features        )        # Output projection        predictions = self.output_projection(attended_features)        return predictions  # [batch, nodes, horizon]# ============================================================================# COMPONENT 2: Expert2_LowElevation (ConvLSTM for <2000m)# ============================================================================class ConvLSTMCell(nn.Module):    """ConvLSTM cell for spatial-temporal processing."""    def __init__(self, input_dim, hidden_dim, kernel_size):        super().__init__()        self.hidden_dim = hidden_dim        padding = kernel_size // 2        self.conv = nn.Conv2d(            in_channels=input_dim + hidden_dim,            out_channels=4 * hidden_dim,            kernel_size=kernel_size,            padding=padding        )    def forward(self, x, h, c):        """Forward pass.        Args:            x: [batch, channels, height, width]            h: [batch, hidden_dim, height, width]            c: [batch, hidden_dim, height, width]        Returns:            h_next, c_next        """        combined = torch.cat([x, h], dim=1)        gates = self.conv(combined)        i, f, o, g = gates.chunk(4, dim=1)        i = torch.sigmoid(i)        f = torch.sigmoid(f)        o = torch.sigmoid(o)        g = torch.tanh(g)        c_next = f * c + i * g        h_next = o * torch.tanh(c_next)        return h_next, c_nextclass Expert2_LowElevation(nn.Module):    """Expert 2: ConvLSTM for low elevation (<2000m).    Architecture based on V2 Enhanced ConvLSTM:    - 2 ConvLSTM layers for spatial-temporal processing    - Output projection for multi-horizon prediction    """    def __init__(self, config: V7Config, n_features: int):        super().__init__()        self.config = config        self.n_features = n_features        self.horizon = config.horizon        self.n_lat = config.n_lat        self.n_lon = config.n_lon        # ConvLSTM layers        self.convlstm_layers = nn.ModuleList()        in_dim = n_features        for i in range(config.expert2_convlstm_layers):            self.convlstm_layers.append(                ConvLSTMCell(                    input_dim=in_dim,                    hidden_dim=config.expert2_convlstm_hidden,                    kernel_size=config.expert2_kernel_size                )            )            in_dim = config.expert2_convlstm_hidden        self.dropout = nn.Dropout(config.expert2_dropout)        # Output projection        self.output_projection = nn.Sequential(            nn.Conv2d(                config.expert2_convlstm_hidden,                config.expert2_convlstm_hidden // 2,                kernel_size=1            ),            nn.ReLU(),            nn.Dropout(config.expert2_dropout),            nn.Conv2d(                config.expert2_convlstm_hidden // 2,                config.horizon,                kernel_size=1            )        )    def forward(self, x, edge_index=None, edge_weight=None):        """Forward pass.        Args:            x: [batch, nodes, time, features]            edge_index: Ignored (for API compatibility)            edge_weight: Ignored (for API compatibility)        Returns:            predictions: [batch, nodes, horizon]        """        batch_size, n_nodes, time_steps, n_features = x.shape        # Reshape to grid        x_grid = x.view(batch_size, self.n_lat, self.n_lon, time_steps, n_features)        x_grid = x_grid.permute(0, 3, 4, 1, 2)  # [batch, time, features, lat, lon]        # Process through time with ConvLSTM        h = [torch.zeros(batch_size, self.config.expert2_convlstm_hidden,                        self.n_lat, self.n_lon, device=x.device)             for _ in range(self.config.expert2_convlstm_layers)]        c = [torch.zeros(batch_size, self.config.expert2_convlstm_hidden,                        self.n_lat, self.n_lon, device=x.device)             for _ in range(self.config.expert2_convlstm_layers)]        for t in range(time_steps):            x_t = x_grid[:, t]  # [batch, features, lat, lon]            for layer_idx, convlstm_cell in enumerate(self.convlstm_layers):                h[layer_idx], c[layer_idx] = convlstm_cell(                    x_t if layer_idx == 0 else h[layer_idx - 1],                    h[layer_idx],                    c[layer_idx]                )                if layer_idx < len(self.convlstm_layers) - 1:                    h[layer_idx] = self.dropout(h[layer_idx])        # Final hidden state        final_h = h[-1]  # [batch, hidden, lat, lon]        # Output projection        predictions = self.output_projection(final_h)  # [batch, horizon, lat, lon]        # Reshape to nodes        predictions = predictions.permute(0, 2, 3, 1)  # [batch, lat, lon, horizon]        predictions = predictions.reshape(batch_size, n_nodes, self.horizon)        return predictions  # [batch, nodes, horizon]# ============================================================================# COMPONENT 3: Expert3_Transition (Hybrid for 2000-3000m)# ============================================================================class Expert3_Transition(nn.Module):    """Expert 3: Hybrid GNN + Conv for transition zone (2000-3000m).    Lightweight hybrid combining:    - 1 GAT layer for graph structure    - 2 Conv layers for grid patterns    - Linear fusion    """    def __init__(self, config: V7Config, n_features: int):        super().__init__()        self.config = config        self.n_features = n_features        self.horizon = config.horizon        self.n_lat = config.n_lat        self.n_lon = config.n_lon        # GNN branch (lightweight)        self.gat = GATConv(            in_channels=n_features,            out_channels=config.expert3_gnn_hidden,            heads=4,            dropout=config.expert3_dropout,            concat=False        )        # Conv branch (lightweight)        self.conv1 = nn.Conv2d(            in_channels=n_features,            out_channels=config.expert3_conv_hidden,            kernel_size=3,            padding=1        )        self.conv2 = nn.Conv2d(            in_channels=config.expert3_conv_hidden,            out_channels=config.expert3_conv_hidden,            kernel_size=3,            padding=1        )        self.dropout = nn.Dropout(config.expert3_dropout)        # Fusion        self.fusion = nn.Linear(            config.expert3_gnn_hidden + config.expert3_conv_hidden,            config.expert3_gnn_hidden        )        # Output projection        self.output_projection = nn.Sequential(            nn.Linear(config.expert3_gnn_hidden, config.expert3_gnn_hidden // 2),            nn.ReLU(),            nn.Dropout(config.expert3_dropout),            nn.Linear(config.expert3_gnn_hidden // 2, config.horizon)        )    def forward(self, x, edge_index, edge_weight=None):        """Forward pass.        Args:            x: [batch, nodes, time, features]            edge_index: [2, num_edges]            edge_weight: [num_edges] (optional)        Returns:            predictions: [batch, nodes, horizon]        """        batch_size, n_nodes, time_steps, n_features = x.shape        # Take last timestep for simplicity        x_last = x[:, :, -1, :]  # [batch, nodes, features]        # GNN branch        x_flat = x_last.view(batch_size * n_nodes, n_features)        gnn_out = self.gat(x_flat, edge_index, edge_weight)        gnn_out = F.elu(gnn_out)        gnn_out = gnn_out.view(batch_size, n_nodes, -1)        # Conv branch        x_grid = x_last.view(batch_size, self.n_lat, self.n_lon, n_features)        x_grid = x_grid.permute(0, 3, 1, 2)  # [batch, features, lat, lon]        conv_out = F.relu(self.conv1(x_grid))        conv_out = self.dropout(conv_out)        conv_out = F.relu(self.conv2(conv_out))        conv_out = conv_out.permute(0, 2, 3, 1)  # [batch, lat, lon, hidden]        conv_out = conv_out.reshape(batch_size, n_nodes, -1)        # Fusion        fused = torch.cat([gnn_out, conv_out], dim=-1)        fused = self.fusion(fused)        fused = F.relu(fused)        # Output projection        predictions = self.output_projection(fused)        return predictions  # [batch, nodes, horizon]# ============================================================================# COMPONENT 4: PhysicsGuidedGating (Physics-informed routing)# ============================================================================class PhysicsGuidedGating(nn.Module):    """Physics-guided gating network for expert routing.    Combines:    - Physics priors (rule-based from elevation)    - Data-driven weights (learned from context)    - Learnable balance parameter alpha    """    def __init__(self, config: V7Config, n_context_features: int):        super().__init__()        self.config = config        # Learnable balance between physics and data        self.alpha_logit = nn.Parameter(            torch.tensor(self._inverse_sigmoid(config.physics_prior_weight_init))        )        # Data-driven gating network        self.context_encoder = nn.Sequential(            nn.Linear(n_context_features, config.gating_hidden),            nn.ReLU(),            nn.Dropout(config.gating_dropout),            nn.Linear(config.gating_hidden, 3)  # 3 experts        )    @staticmethod    def _inverse_sigmoid(y):        """Inverse sigmoid for initialization."""        y = max(min(y, 0.999), 0.001)        return torch.log(torch.tensor(y / (1 - y)))    def compute_physics_priors(self, elevation):        """Compute physics-based routing weights.        Args:            elevation: [batch, nodes] or [batch, nodes, 1]        Returns:            priors: [batch, nodes, 3] (weights for 3 experts)        """        if elevation.dim() == 3:            elevation = elevation.squeeze(-1)        # Expert 1: High elevation (>3000m)        w1 = torch.sigmoid((elevation - self.config.high_elev_threshold) / 500.0)        # Expert 2: Low elevation (<2000m)        w2 = torch.sigmoid((self.config.low_elev_threshold - elevation) / 500.0)        # Expert 3: Transition zone (2000-3000m) - Gaussian peak at 2500m        center = (self.config.medium_elev_min + self.config.medium_elev_max) / 2        sigma = (self.config.medium_elev_max - self.config.medium_elev_min) / 4        w3 = torch.exp(-((elevation - center) ** 2) / (2 * sigma ** 2))        # Stack and normalize        priors = torch.stack([w1, w2, w3], dim=-1)  # [batch, nodes, 3]        priors = F.softmax(priors, dim=-1)        return priors    def forward(self, context_features):        """Forward pass.        Args:            context_features: Dict with keys:                - 'elevation': [batch, nodes] or [batch, nodes, 1]                - 'context': [batch, nodes, n_context] (slope, aspect, lat, lon, season)        Returns:            weights: [batch, nodes, 3] (routing weights for 3 experts)            alpha: scalar (physics prior weight)        """        elevation = context_features['elevation']        context = context_features['context']        # Physics priors        physics_priors = self.compute_physics_priors(elevation)        # Data-driven weights        logits = self.context_encoder(context)        data_weights = F.softmax(logits, dim=-1)        # Learnable balance        alpha = torch.sigmoid(self.alpha_logit)        # Combine        weights = alpha * physics_priors + (1 - alpha) * data_weights        # Ensure normalization (might have small numerical errors)        weights = weights / weights.sum(dim=-1, keepdim=True)        return weights, alpha# ============================================================================# COMPONENT 5: PhysicsInformedMetaLearner (Final ensemble)# ============================================================================class PhysicsInformedMetaLearner(nn.Module):    """Physics-informed meta-learner for combining expert predictions.    Components:    - Weighted combination of expert predictions    - Meta-residual MLP    - Physics correction (orographic enhancement + rain shadow)    """    def __init__(self, config: V7Config, n_context_features: int):        super().__init__()        self.config = config        # Meta-residual network        meta_input_dim = 3 * config.horizon + n_context_features        self.meta_residual = nn.Sequential(            nn.Linear(meta_input_dim, config.meta_hidden_1),            nn.ReLU(),            nn.Dropout(config.meta_dropout),            nn.Linear(config.meta_hidden_1, config.meta_hidden_2),            nn.ReLU(),            nn.Dropout(config.meta_dropout),            nn.Linear(config.meta_hidden_2, config.horizon)        )        # Physics correction parameters (learnable)        self.orographic_enhancement = nn.Parameter(torch.tensor(0.1))        self.rain_shadow_suppression = nn.Parameter(torch.tensor(0.05))    def compute_physics_correction(self, context_features):        """Compute physics-based correction term.        Args:            context_features: Dict with keys:                - 'elevation': [batch, nodes]                - 'slope': [batch, nodes]                - 'aspect': [batch, nodes] (optional)        Returns:            correction: [batch, nodes, 1] (additive correction factor)        """        elevation = context_features['elevation']        slope = context_features.get('slope', torch.zeros_like(elevation))        if elevation.dim() == 3:            elevation = elevation.squeeze(-1)        if slope.dim() == 3:            slope = slope.squeeze(-1)        # Orographic enhancement (high elevation + high slope)        oro_factor = (elevation / 5000.0) * (slope / 90.0)        oro_enhancement = self.orographic_enhancement * oro_factor        # Rain shadow (leeward side - simplified, could use aspect)        # For now, just suppress high-slope areas on opposite side        rain_shadow = -self.rain_shadow_suppression * (slope / 90.0)        # Combine        correction = oro_enhancement + rain_shadow        correction = correction.unsqueeze(-1)  # [batch, nodes, 1]        return correction    def forward(self, expert_predictions, gating_weights, context_features):        """Forward pass.        Args:            expert_predictions: List of 3 tensors, each [batch, nodes, horizon]            gating_weights: [batch, nodes, 3]            context_features: Dict with context information        Returns:            final_prediction: [batch, nodes, horizon]        """        batch_size, n_nodes, horizon = expert_predictions[0].shape        # Weighted combination        stacked_preds = torch.stack(expert_predictions, dim=-1)  # [batch, nodes, horizon, 3]        weights_expanded = gating_weights.unsqueeze(2)  # [batch, nodes, 1, 3]        weighted_pred = (stacked_preds * weights_expanded).sum(dim=-1)  # [batch, nodes, horizon]        # Meta-residual        preds_flat = torch.cat([p for p in expert_predictions], dim=-1)  # [batch, nodes, 3*horizon]        context = context_features['context']        meta_input = torch.cat([preds_flat, context], dim=-1)        meta_residual = self.meta_residual(meta_input)  # [batch, nodes, horizon]        # Physics correction (broadcast to all horizons)        physics_correction = self.compute_physics_correction(context_features)  # [batch, nodes, 1]        physics_correction = physics_correction.expand(-1, -1, horizon)        # Final prediction        final_prediction = weighted_pred + meta_residual + physics_correction        return final_prediction# ============================================================================# COMPLETE V7-AMES MODEL# ============================================================================class V7_AMES(nn.Module):    """Complete V7-AMES model: Adaptive Multi-Expert Ensemble System.    Combines:    - 3 specialized experts (high/medium/low elevation)    - Physics-guided gating network    - Physics-informed meta-learner    """    def __init__(self, config: V7Config, n_features: int, n_context_features: int):        super().__init__()        self.config = config        # Experts        self.expert1 = Expert1_HighElevation(config, n_features)        self.expert2 = Expert2_LowElevation(config, n_features)        self.expert3 = Expert3_Transition(config, n_features)        # Gating        self.gating = PhysicsGuidedGating(config, n_context_features)        # Meta-learner        self.meta_learner = PhysicsInformedMetaLearner(config, n_context_features)    def forward(self, x, edge_index, edge_weight, context_features, stage='full'):        """Forward pass.        Args:            x: [batch, nodes, time, features]            edge_index: [2, num_edges]            edge_weight: [num_edges]            context_features: Dict with context information            stage: 'stage1', 'stage2', or 'full'        Returns:            predictions: [batch, nodes, horizon]            aux_outputs: Dict with auxiliary outputs (gating weights, etc.)        """        # Expert predictions        pred1 = self.expert1(x, edge_index, edge_weight)        pred2 = self.expert2(x, edge_index, edge_weight)        pred3 = self.expert3(x, edge_index, edge_weight)        if stage == 'stage1':            # Stage 1: Return individual expert predictions            return {                'expert1': pred1,                'expert2': pred2,                'expert3': pred3            }        # Gating weights        gating_weights, alpha = self.gating(context_features)        if stage == 'stage2':            # Stage 2: Return gated combination (no meta-learner yet)            stacked_preds = torch.stack([pred1, pred2, pred3], dim=-1)            weights_expanded = gating_weights.unsqueeze(2)            weighted_pred = (stacked_preds * weights_expanded).sum(dim=-1)            return weighted_pred, {'gating_weights': gating_weights, 'alpha': alpha}        # Stage 3 / Full: Complete pipeline with meta-learner        final_pred = self.meta_learner(            [pred1, pred2, pred3],            gating_weights,            context_features        )        aux_outputs = {            'expert1_pred': pred1,            'expert2_pred': pred2,            'expert3_pred': pred3,            'gating_weights': gating_weights,            'alpha': alpha        }        return final_pred, aux_outputs    def freeze_experts(self):        """Freeze expert parameters (for Stage 2)."""        for param in self.expert1.parameters():            param.requires_grad = False        for param in self.expert2.parameters():            param.requires_grad = False        for param in self.expert3.parameters():            param.requires_grad = False    def unfreeze_all(self):        """Unfreeze all parameters (for Stage 3)."""        for param in self.parameters():            param.requires_grad = True# ============================================================================# PHYSICS-INFORMED LOSS FUNCTIONS# ============================================================================def physics_informed_loss(predictions, targets, context_features, config):    """Compute physics-informed loss.    Args:        predictions: [batch, nodes, horizon]        targets: [batch, nodes, horizon]        context_features: Dict with elevation info        config: V7Config    Returns:        total_loss: scalar        loss_components: Dict with individual loss terms    """    # MSE loss    mse_loss = F.mse_loss(predictions, targets)    # Mass conservation constraint    pred_sum = predictions.sum(dim=(1, 2))    target_sum = targets.sum(dim=(1, 2))    mass_conservation_loss = torch.abs(pred_sum - target_sum) / (target_sum + 1e-6)    mass_conservation_loss = mass_conservation_loss.mean()    # Orographic enhancement constraint    elevation = context_features['elevation']    if elevation.dim() == 3:        elevation = elevation.squeeze(-1)    high_elev_mask = (elevation > config.high_elev_threshold).unsqueeze(-1)    if high_elev_mask.any():        preds_high = predictions[high_elev_mask.expand_as(predictions)]        targets_high = targets[high_elev_mask.expand_as(targets)]        orographic_loss = F.relu(targets_high - preds_high).mean()    else:        orographic_loss = torch.tensor(0.0, device=predictions.device)    # Total loss    total_loss = (        mse_loss +        config.lambda_mass_conservation * mass_conservation_loss +        config.lambda_orographic * orographic_loss    )    loss_components = {        'mse': mse_loss.item(),        'mass_conservation': mass_conservation_loss.item(),        'orographic': orographic_loss.item(),        'total': total_loss.item()    }    return total_loss, loss_components# ============================================================================# UTILITY FUNCTIONS# ============================================================================def set_random_seed(seed=42):    """Set random seed for reproducibility."""    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    import numpy as np    import random    np.random.seed(seed)    random.seed(seed)    torch.backends.cudnn.deterministic = True    torch.backends.cudnn.benchmark = Falsedef count_parameters(model):    """Count trainable parameters in model."""    return sum(p.numel() for p in model.parameters() if p.requires_grad)def print_model_summary(model, config):    """Print model architecture summary."""    print("=" * 80)    print("V7-AMES Model Summary")    print("=" * 80)    print(f"Expert 1 (High Elevation) parameters: {count_parameters(model.expert1):,}")    print(f"Expert 2 (Low Elevation) parameters: {count_parameters(model.expert2):,}")    print(f"Expert 3 (Transition) parameters: {count_parameters(model.expert3):,}")    print(f"Gating Network parameters: {count_parameters(model.gating):,}")    print(f"Meta-Learner parameters: {count_parameters(model.meta_learner):,}")    print("-" * 80)    print(f"Total parameters: {count_parameters(model):,}")    print("=" * 80)# End of V7-AMES architecture

## 2. Data Preparation

Prepare elevation-stratified datasets and context features for training:

- **Elevation masks**: High (>3000m), Medium (2000-3000m), Low (<2000m)
- **Context features**: Elevation, slope, aspect, latitude, longitude
- **Output**: Saved to `output/V7_AMES_Data/`


In [None]:
# =============================================================================
# SECTION 2: DATA PREPARATION
# =============================================================================

# V7-AMES Data Preparation# Creates elevation masks and context features for V7-AMES trainingimport numpy as npfrom pathlib import Pathimport torchdef prepare_v7_data(elevation_path, output_dir, config):    """    Prepare elevation-stratified data for V7-AMES.    Creates:    - mask_high.npy: High elevation mask (>3000m)    - mask_medium.npy: Medium elevation mask (2000-3000m)    - mask_low.npy: Low elevation mask (<2000m)    - context_features_spatial.npy: Spatial context features    Args:        elevation_path: Path to elevation data file        output_dir: Directory to save output files        config: V7Config instance    """    print("="*80)    print("V7-AMES DATA PREPARATION")    print("="*80)    # Create output directory    output_dir = Path(output_dir)    output_dir.mkdir(parents=True, exist_ok=True)    # Load elevation data    print("\n1. Loading elevation data...")    if isinstance(elevation_path, str):        elevation_path = Path(elevation_path)    if elevation_path.suffix == '.npy':        elevation = np.load(elevation_path)    elif elevation_path.suffix == '.nc':        import xarray as xr        ds = xr.open_dataset(elevation_path)        elevation = ds['elevation'].values    else:        raise ValueError(f"Unsupported elevation file format: {elevation_path.suffix}")    print(f"   Elevation shape: {elevation.shape}")    print(f"   Elevation range: [{elevation.min():.1f}, {elevation.max():.1f}] meters")    # Create elevation masks    print("\n2. Creating elevation masks...")    # High elevation mask (>3000m)    mask_high = (elevation > config.high_elev_threshold).astype(np.float32)    n_high = mask_high.sum()    pct_high = (n_high / mask_high.size) * 100    print(f"   High elevation (>{config.high_elev_threshold}m): {int(n_high)} cells ({pct_high:.1f}%)")    # Low elevation mask (<2000m)    mask_low = (elevation < config.low_elev_threshold).astype(np.float32)    n_low = mask_low.sum()    pct_low = (n_low / mask_low.size) * 100    print(f"   Low elevation (<{config.low_elev_threshold}m): {int(n_low)} cells ({pct_low:.1f}%)")    # Medium elevation mask (2000-3000m)    mask_medium = ((elevation >= config.medium_elev_min) &                   (elevation <= config.medium_elev_max)).astype(np.float32)    n_medium = mask_medium.sum()    pct_medium = (n_medium / mask_medium.size) * 100    print(f"   Medium elevation ({config.medium_elev_min}-{config.medium_elev_max}m): {int(n_medium)} cells ({pct_medium:.1f}%)")    # Save masks    print("\n3. Saving masks...")    np.save(output_dir / 'mask_high.npy', mask_high)    np.save(output_dir / 'mask_medium.npy', mask_medium)    np.save(output_dir / 'mask_low.npy', mask_low)    print(f"   Saved to: {output_dir}")    # Create context features    print("\n4. Creating spatial context features...")    lat, lon = elevation.shape    # Elevation (normalized)    elev_norm = (elevation - elevation.mean()) / elevation.std()    # Slope (approximate from elevation gradient)    grad_y, grad_x = np.gradient(elevation)    slope = np.sqrt(grad_x**2 + grad_y**2)    slope_norm = (slope - slope.mean()) / slope.std()    # Aspect (direction of slope)    aspect = np.arctan2(grad_y, grad_x)    aspect_sin = np.sin(aspect)    aspect_cos = np.cos(aspect)    # Latitude (normalized)    lat_coords = np.arange(lat)    lat_grid = np.tile(lat_coords[:, np.newaxis], (1, lon))    lat_norm = (lat_grid - lat_grid.mean()) / lat_grid.std()    # Longitude (normalized)    lon_coords = np.arange(lon)    lon_grid = np.tile(lon_coords[np.newaxis, :], (lat, 1))    lon_norm = (lon_grid - lon_grid.mean()) / lon_grid.std()    # Stack context features: [lat, lon, 5]    # Features: elevation, slope, aspect_sin, aspect_cos, lat, lon    context_features = np.stack([        elev_norm,        slope_norm,        aspect_sin,        aspect_cos,        lat_norm,        lon_norm    ], axis=-1).astype(np.float32)    print(f"   Context features shape: {context_features.shape}")    print(f"   Features: elevation, slope, aspect_sin, aspect_cos, lat, lon")    # Save context features    np.save(output_dir / 'context_features_spatial.npy', context_features)    print(f"   Saved context features to: {output_dir / 'context_features_spatial.npy'}")    print("\n" + "="*80)    print("DATA PREPARATION COMPLETE")    print("="*80)    print(f"\nOutput files in {output_dir}:")    print("  - mask_high.npy")    print("  - mask_medium.npy")    print("  - mask_low.npy")    print("  - context_features_spatial.npy")    print("\nReady for V7-AMES training!")    return {        'mask_high': mask_high,        'mask_medium': mask_medium,        'mask_low': mask_low,        'context_features': context_features    }# Standalone execution for testingif __name__ == "__main__":    from v7_architecture_temp import V7Config    config = V7Config()    # Example: Prepare data (adjust path as needed)    elevation_path = "data/processed/elevation.npy"  # Update this path    output_dir = "output/V7_AMES_Data"    try:        results = prepare_v7_data(elevation_path, output_dir, config)        print("\nData preparation successful!")    except FileNotFoundError:        print(f"\nERROR: Elevation file not found at {elevation_path}")        print("Please update the elevation_path variable to point to your elevation data.")        print("\nFor Colab/testing, you can create dummy data:")        print("```python")        print("import numpy as np")        print("elevation = np.random.uniform(1000, 4000, (61, 65))")        print("np.save('elevation_dummy.npy', elevation)")        print("```")

## 3. Training Pipeline

3-stage hierarchical training protocol:

- **Stage 1**: Pre-train each expert independently on filtered data
- **Stage 2**: Train gating network with frozen experts
- **Stage 3**: Joint fine-tuning with physics-informed loss


In [None]:
# =============================================================================
# SECTION 3: TRAINING SETUP - IMPORTS
# =============================================================================

import numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import Dataset, DataLoaderfrom pathlib import Pathfrom dataclasses import dataclass, field
import jsonfrom datetime import datetimeimport matplotlib.pyplot as pltfrom tqdm import tqdm# Try to import PyTorch Geometrictry:    from torch_geometric.nn import GATConv, global_mean_pool    GNN_AVAILABLE = True    print(" PyTorch Geometric available")except ImportError:    print("WARNING: PyTorch Geometric not available. GNN experts will be disabled.")    GNN_AVAILABLE = False# Device configurationdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {device}")# Set random seed for reproducibilitydef set_seed(seed=42):    import random    random.seed(seed)    np.random.seed(seed)    torch.manual_seed(seed)    if torch.cuda.is_available():        torch.cuda.manual_seed_all(seed)        torch.backends.cudnn.deterministic = True        torch.backends.cudnn.benchmark = Falseset_seed(42)print(" Random seed set to 42")

In [None]:
# =============================================================================
# HELPER FUNCTIONS: TRAINING VISUALIZATION
# =============================================================================

def plot_training_history(train_losses, val_losses, title, save_path=None):
    """Plot training and validation loss curves."""
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Train Loss', linewidth=2)
    plt.plot(val_losses, label='Val Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(title, fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()

def print_metrics_table(metrics_dict):
    """Print metrics in a formatted table."""
    print("\n" + "="*60)
    print(f"{'Metric':<20} {'Value':>15}")
    print("-"*60)
    for key, value in metrics_dict.items():
        if isinstance(value, float):
            print(f"{key:<20} {value:>15.4f}")
        else:
            print(f"{key:<20} {value:>15}")
    print("="*60)


### 3.1 Configuration


In [None]:
# =============================================================================
# SECTION 3.1: V7-AMES CONFIGURATION
# =============================================================================

@dataclass
class V7Config:    """Complete configuration for V7-AMES"""    # Paths    data_dir = Path('output/V7_AMES_Data')    v2_path = Path('output/V2_Enhanced_Models/map_exports/H12/BASIC/ConvLSTM_Enhanced')    v4_path = Path('output/V4_GNN_TAT_Models/map_exports/H12/BASIC/GNN_TAT_GAT')    output_dir = Path('output/V7_AMES_Models')    # Data    input_window = 60    horizon = 12    grid_shape = (61, 65)    n_nodes = 61 * 65  # 3965 nodes    n_features_basic = 12    n_features_kce = 15    n_features_context = 5  # elevation, slope, aspect, lat_norm, lon_norm    # Expert 1: High Elevation Specialist (GNN-TAT)    expert1_hidden_dim = 64    expert1_num_layers = 3    expert1_dropout = 0.2    expert1_heads = 4    # Expert 2: Low Elevation Specialist (ConvLSTM)    expert2_hidden_channels = 64    expert2_num_layers = 2    expert2_kernel_size = 3    # Expert 3: Transition Zone Specialist (Hybrid)    expert3_gnn_dim = 32    expert3_conv_dim = 32    # Gating Network    gating_hidden_dim = 32    gating_num_experts = 3    physics_prior_weight = 0.3  # Initial balance: 30% physics, 70% data    # Meta-Learner    meta_hidden_dim = 64    # Training    batch_size = 8    learning_rate = 0.001    weight_decay = 1e-5    # Stage 1: Pre-train experts    epochs_stage1 = 50    patience_stage1 = 10    # Stage 2: Train gating    epochs_stage2 = 30    patience_stage2 = 8    # Stage 3: Joint fine-tuning    epochs_stage3 = 50    patience_stage3 = 10    # Physics loss weights    lambda_physics = 0.1    lambda_mass_conservation = 0.05    lambda_orographic = 0.1    # Device    device = device    def __post_init__(self):        self.output_dir.mkdir(parents=True, exist_ok=True)config = V7Config()print("Configuration loaded:")print(f"  Grid shape: {config.grid_shape}")print(f"  Horizon: {config.horizon}")print(f"  Experts: {config.gating_num_experts}")print(f"  Device: {config.device}")

### 3.2 Dataset Class


In [None]:
class V7Dataset(Dataset):    """    Dataset for V7-AMES training    Loads predictions from existing models (V2/V4) and filters by elevation zone    """    def __init__(self, predictions_path, targets_path, context_path,                 mask_path=None, config=None, use_actual_data=False):        """        Args:            predictions_path: Path to .npy predictions [samples, horizons, lat, lon, 1]            targets_path: Path to targets .npy (same shape)            context_path: Path to context features [lat, lon, n_features]            mask_path: Optional elevation mask [lat, lon] (bool)            config: V7Config object            use_actual_data: If False, uses dummy data for testing        """        self.config = config or V7Config()        self.use_actual_data = use_actual_data        if use_actual_data and Path(predictions_path).exists():            # Load real data            print(f"Loading data from {predictions_path}...")            self.data = np.load(predictions_path)            self.targets = np.load(targets_path)            self.context_spatial = np.load(context_path)            # Load mask if provided            if mask_path and Path(mask_path).exists():                self.mask = np.load(mask_path)                print(f"  Mask loaded: {self.mask.sum()} cells")            else:                self.mask = None            self.n_samples = self.data.shape[0]            self.lat, self.lon = self.data.shape[2:4]            # Flatten spatial dims: [samples, horizons, n_nodes, 1]            self.data_flat = self.data.reshape(self.n_samples, self.config.horizon, -1, 1)            self.targets_flat = self.targets.reshape(self.n_samples, self.config.horizon, -1, 1)            self.context_nodes = self.context_spatial.reshape(-1, self.context_spatial.shape[-1])            # Apply mask if provided            if self.mask is not None:                mask_flat = self.mask.flatten()                self.data_flat = self.data_flat[:, :, mask_flat, :]                self.targets_flat = self.targets_flat[:, :, mask_flat, :]                self.context_nodes = self.context_nodes[mask_flat, :]            self.n_nodes = self.data_flat.shape[2]            print(f"  Loaded: {self.n_samples} samples, {self.n_nodes} nodes")        else:            # Dummy data for testing            print("Using dummy data for testing...")            self.n_samples = 100            self.n_nodes = config.n_nodes if mask_path is None else 500            self.data_flat = np.random.randn(self.n_samples, config.horizon, self.n_nodes, 1)            self.targets_flat = np.random.randn(self.n_samples, config.horizon, self.n_nodes, 1)            self.context_nodes = np.random.randn(self.n_nodes, config.n_features_context)            print(f"  Generated: {self.n_samples} samples, {self.n_nodes} nodes")    def __len__(self):        return self.n_samples    def __getitem__(self, idx):        """        Returns:            x_grid: [horizon, lat, lon, 1] grid format for ConvLSTM            x_graph: [n_nodes, horizon] graph format for GNN            context: [n_nodes, context_dim] physical context            y: [horizon, n_nodes, 1] targets        """        # Get data        data = self.data_flat[idx]  # [horizon, n_nodes, 1]        y = self.targets_flat[idx]  # [horizon, n_nodes, 1]        # Convert to grid format (reshape back to grid)        if self.use_actual_data:            # Real grid shape            x_grid = data.reshape(self.config.horizon, self.config.grid_shape[0],                                 self.config.grid_shape[1], 1)        else:            # Dummy grid (just use subset of nodes)            lat, lon = self.config.grid_shape            x_grid = data[:, :lat*lon, :].reshape(self.config.horizon, lat, lon, 1)        # Graph format: [n_nodes, horizon]        x_graph = data.squeeze(-1).transpose(0, 1)  # [n_nodes, horizon]        # Context (static per sample)        context = torch.from_numpy(self.context_nodes).float()        # Convert to tensors        x_grid = torch.from_numpy(x_grid).float()        x_graph = torch.from_numpy(x_graph).float()        y = torch.from_numpy(y).float()        return x_grid, x_graph, context, y# Test datasetprint("\nTesting dataset...")test_dataset = V7Dataset(    predictions_path=config.v4_path / 'predictions.npy',    targets_path=config.v4_path / 'targets.npy',    context_path=config.data_dir / 'context_features_spatial.npy',    mask_path=None,  # No mask for initial test    config=config,    use_actual_data=False  # Use dummy data for now)x_grid, x_graph, context, y = test_dataset[0]print(f"Sample batch shapes:")print(f"  x_grid: {x_grid.shape}")print(f"  x_graph: {x_graph.shape}")print(f"  context: {context.shape}")print(f"  y: {y.shape}")

### 3.3 Import Model Components


In [None]:
# Import V7-AMES model from the architecture fileimport syssys.path.insert(0, str(Path('models').absolute()))# Load the architectureexec(open('base_models_v7_ames_adaptive_multi_expert.py').read())print(" V7-AMES architecture imported")

### 3.4 Graph Construction for GNN Experts


In [None]:
def create_grid_graph(lat, lon, k_neighbors=8):    """    Create k-NN graph from 2D grid    Args:        lat, lon: Grid dimensions        k_neighbors: Number of nearest neighbors    Returns:        edge_index: [2, num_edges]        edge_weight: [num_edges]    """    n_nodes = lat * lon    # Node positions    positions = []    for i in range(lat):        for j in range(lon):            positions.append([i, j])    positions = np.array(positions)    # Build k-NN graph    from scipy.spatial import distance_matrix    dist_matrix = distance_matrix(positions, positions)    edge_list = []    edge_weights = []    for i in range(n_nodes):        # Get k nearest neighbors (excluding self)        neighbors = np.argsort(dist_matrix[i])[1:k_neighbors+1]        for j in neighbors:            edge_list.append([i, j])            edge_weights.append(1.0 / (dist_matrix[i, j] + 1e-6))    edge_index = torch.tensor(edge_list, dtype=torch.long).t()    edge_weight = torch.tensor(edge_weights, dtype=torch.float)    print(f"Graph created: {n_nodes} nodes, {edge_index.size(1)} edges")    return edge_index, edge_weight# Create graphlat, lon = config.grid_shapeedge_index, edge_weight = create_grid_graph(lat, lon, k_neighbors=8)edge_index = edge_index.to(config.device)edge_weight = edge_weight.to(config.device)

## 4. Stage 1: Pre-train Experts

Train each expert independently on filtered elevation data.


In [None]:
# =============================================================================
# SECTION 4.1: STAGE 1 - EXPERT 1 (HIGH ELEVATION)
# =============================================================================

print("STAGE 1: PRE-TRAINING EXPERTS")
print()

# Split dataset into train/val
print("Stage 1.1: Training Expert 1 (High Elevation Specialist)")
print("-" * 60)

# Load high elevation dataset
expert1_dataset = V7Dataset(
    predictions_path=config.v4_path / 'predictions.npy',
    targets_path=config.v4_path / 'targets.npy',
    context_path=config.data_dir / 'context_features_spatial.npy',
    mask_path=config.data_dir / 'mask_high.npy',
    config=config,
    use_actual_data=Path(config.data_dir / 'mask_high.npy').exists()
)

# Train/val split
train_size = int(0.8 * len(expert1_dataset))
val_size = len(expert1_dataset) - train_size
expert1_train, expert1_val = torch.utils.data.random_split(
    expert1_dataset, [train_size, val_size]
)

expert1_train_loader = DataLoader(expert1_train, batch_size=config.batch_size, shuffle=True)
expert1_val_loader = DataLoader(expert1_val, batch_size=config.batch_size, shuffle=False)

print(f"Train samples: {len(expert1_train)}")
print(f"Val samples: {len(expert1_val)}")

# Initialize Expert 1
if GNN_AVAILABLE:
    expert1 = Expert1_HighElevation(config, n_features=16).to(config.device)
    optimizer1 = torch.optim.Adam(
        expert1.parameters(),
        lr=config.lr_stage1,
        weight_decay=config.weight_decay
    )
    criterion = nn.MSELoss()
    
    print(f"Expert 1 parameters: {sum(p.numel() for p in expert1.parameters()):,}")
    print()
    
    # Training history
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config.epochs_stage1):
        # Training
        expert1.train()
        train_loss = 0
        train_batches = 0
        
        pbar = tqdm(expert1_train_loader, desc=f"Epoch {epoch+1}/{config.epochs_stage1}")
        for batch in pbar:
            x_grid, x_graph, context, y = batch
            x_graph = x_graph.to(config.device)
            y = y.to(config.device)
            
            optimizer1.zero_grad()
            predictions = expert1(x_graph, edge_index)
            loss = criterion(predictions, y.mean(dim=2))
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(expert1.parameters(), max_norm=1.0)
            optimizer1.step()
            
            train_loss += loss.item()
            train_batches += 1
            pbar.set_postfix({'train_loss': f'{loss.item():.4f}'})
        
        avg_train_loss = train_loss / train_batches
        train_losses.append(avg_train_loss)
        
        # Validation
        expert1.eval()
        val_loss = 0
        val_batches = 0
        
        with torch.no_grad():
            for batch in expert1_val_loader:
                x_grid, x_graph, context, y = batch
                x_graph = x_graph.to(config.device)
                y = y.to(config.device)
                
                predictions = expert1(x_graph, edge_index)
                loss = criterion(predictions, y.mean(dim=2))
                
                val_loss += loss.item()
                val_batches += 1
        
        avg_val_loss = val_loss / val_batches
        val_losses.append(avg_val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
        
        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': expert1.state_dict(),
                'optimizer_state_dict': optimizer1.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, config.model_dir / 'expert1_best.pt')
            print(f"  -> Best model saved (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= config.patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Plot training history
    plot_training_history(
        train_losses, val_losses,
        'Expert 1 (High Elevation) Training',
        save_path=config.model_dir / 'expert1_training.png'
    )
    
    # Print final metrics
    print_metrics_table({
        'Best Val Loss': best_val_loss,
        'Final Train Loss': train_losses[-1],
        'Total Epochs': len(train_losses),
        'Parameters': sum(p.numel() for p in expert1.parameters())
    })

else:
    print("GNN not available. Skipping Expert 1 training.")


### 4.2 Expert 2: Low Elevation Specialist


In [None]:
# =============================================================================
# SECTION 4.2: STAGE 1 - EXPERT 2 (LOW ELEVATION)
# =============================================================================

print("\n" + "-" * 60)print("Stage 1.2: Training Expert 2 (Low Elevation Specialist)")print("-" * 60)expert2_dataset = V7Dataset(    predictions_path=config.v2_path / 'predictions.npy',    targets_path=config.v2_path / 'targets.npy',    context_path=config.data_dir / 'context_features_spatial.npy',    mask_path=config.data_dir / 'mask_low.npy',  # Low elevation only    config=config,    use_actual_data=Path(config.data_dir / 'mask_low.npy').exists())expert2_loader = DataLoader(    expert2_dataset,    batch_size=config.batch_size,    shuffle=True,    num_workers=0)# Initialize Expert 2expert2 = Expert2_LowElevation(config).to(config.device)optimizer2 = torch.optim.Adam(expert2.parameters(),                               lr=config.learning_rate,                               weight_decay=config.weight_decay)print(f"Expert 2 initialized: {sum(p.numel() for p in expert2.parameters())} parameters")# Training loop for Expert 2best_loss = float('inf')patience_counter = 0for epoch in range(config.epochs_stage1):    expert2.train()    total_loss = 0    num_batches = 0    pbar = tqdm(expert2_loader, desc=f"Epoch {epoch+1}/{config.epochs_stage1}")    for x_grid, x_graph, context, y in pbar:        # Move to device        x_grid = x_grid.to(config.device)        context = context.to(config.device)        y = y.to(config.device)        # Forward pass (Expert 2 uses grid format)        optimizer2.zero_grad()        predictions = expert2(x_grid)  # [batch, horizon, 1]        # Compute loss        loss = criterion(predictions, y.mean(dim=2))        # Backward        loss.backward()        torch.nn.utils.clip_grad_norm_(expert2.parameters(), max_norm=1.0)        optimizer2.step()        total_loss += loss.item()        num_batches += 1        pbar.set_postfix({'loss': f'{loss.item():.4f}'})    avg_loss = total_loss / num_batches    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")    # Early stopping    if avg_loss < best_loss:        best_loss = avg_loss        patience_counter = 0        torch.save({            'epoch': epoch,            'model_state_dict': expert2.state_dict(),            'optimizer_state_dict': optimizer2.state_dict(),            'loss': best_loss,        }, config.output_dir / 'expert2_best.pt')        print(f"  → Best model saved (loss: {best_loss:.4f})")    else:        patience_counter += 1        if patience_counter >= config.patience_stage1:            print(f"Early stopping triggered at epoch {epoch+1}")            breakprint(f"\nExpert 2 training complete. Best loss: {best_loss:.4f}")

### 4.3 Expert 3: Transition Zone Specialist


In [None]:
# =============================================================================
# SECTION 4.3: STAGE 1 - EXPERT 3 (TRANSITION)
# =============================================================================

print("\n" + "-" * 60)print("Stage 1.3: Training Expert 3 (Transition Zone Specialist)")print("-" * 60)expert3_dataset = V7Dataset(    predictions_path=config.v4_path / 'predictions.npy',    targets_path=config.v4_path / 'targets.npy',    context_path=config.data_dir / 'context_features_spatial.npy',    mask_path=config.data_dir / 'mask_medium.npy',  # Medium elevation    config=config,    use_actual_data=Path(config.data_dir / 'mask_medium.npy').exists())expert3_loader = DataLoader(    expert3_dataset,    batch_size=config.batch_size,    shuffle=True,    num_workers=0)# Initialize Expert 3if GNN_AVAILABLE:    expert3 = Expert3_Transition(config).to(config.device)    optimizer3 = torch.optim.Adam(expert3.parameters(),                                   lr=config.learning_rate,                                   weight_decay=config.weight_decay)    print(f"Expert 3 initialized: {sum(p.numel() for p in expert3.parameters())} parameters")    # Training loop (simplified - hybrid needs both graph and grid)    best_loss = float('inf')    patience_counter = 0    for epoch in range(config.epochs_stage1):        expert3.train()        total_loss = 0        num_batches = 0        pbar = tqdm(expert3_loader, desc=f"Epoch {epoch+1}/{config.epochs_stage1}")        for x_grid, x_graph, context, y in pbar:            x_grid = x_grid.to(config.device)            x_graph = x_graph.to(config.device)            context = context.to(config.device)            y = y.to(config.device)            optimizer3.zero_grad()            predictions = expert3(x_graph, x_grid, edge_index)            loss = criterion(predictions, y.mean(dim=2))            loss.backward()            torch.nn.utils.clip_grad_norm_(expert3.parameters(), max_norm=1.0)            optimizer3.step()            total_loss += loss.item()            num_batches += 1            pbar.set_postfix({'loss': f'{loss.item():.4f}'})        avg_loss = total_loss / num_batches        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")        if avg_loss < best_loss:            best_loss = avg_loss            patience_counter = 0            torch.save({                'epoch': epoch,                'model_state_dict': expert3.state_dict(),                'optimizer_state_dict': optimizer3.state_dict(),                'loss': best_loss,            }, config.output_dir / 'expert3_best.pt')            print(f"  → Best model saved (loss: {best_loss:.4f})")        else:            patience_counter += 1            if patience_counter >= config.patience_stage1:                print(f"Early stopping triggered at epoch {epoch+1}")                break    print(f"\nExpert 3 training complete. Best loss: {best_loss:.4f}")else:    print("GNN not available, skipping Expert 3")

### 4.4 Stage 1 Summary


In [None]:
print("\n" + "="*80)print("STAGE 1 COMPLETE: All Experts Pre-trained")print()print("Checkpoints saved:")if GNN_AVAILABLE:    print(f"   Expert 1 (High Elev): {config.output_dir / 'expert1_best.pt'}")print(f"   Expert 2 (Low Elev): {config.output_dir / 'expert2_best.pt'}")if GNN_AVAILABLE:    print(f"   Expert 3 (Transition): {config.output_dir / 'expert3_best.pt'}")print()

## 5. Stage 2: Train Gating Network

Train physics-guided gating network with frozen experts.


In [None]:
# =============================================================================
# SECTION 5: STAGE 2 - GATING NETWORK TRAINING
# =============================================================================

print("\n" + "="*80)print("STAGE 2: TRAINING GATING NETWORK")print()# Load full dataset (all elevation zones)full_dataset = V7Dataset(    predictions_path=config.v4_path / 'predictions.npy',    targets_path=config.v4_path / 'targets.npy',    context_path=config.data_dir / 'context_features_spatial.npy',    mask_path=None,  # No mask - use all data    config=config,    use_actual_data=Path(config.v4_path / 'predictions.npy').exists())full_loader = DataLoader(    full_dataset,    batch_size=config.batch_size,    shuffle=True,    num_workers=0)# Initialize complete V7-AMES modelv7_model = V7_AMES(config).to(config.device)# Load expert checkpointsif GNN_AVAILABLE and Path(config.output_dir / 'expert1_best.pt').exists():    checkpoint = torch.load(config.output_dir / 'expert1_best.pt')    v7_model.expert1.load_state_dict(checkpoint['model_state_dict'])    print(" Expert 1 loaded")if Path(config.output_dir / 'expert2_best.pt').exists():    checkpoint = torch.load(config.output_dir / 'expert2_best.pt')    v7_model.expert2.load_state_dict(checkpoint['model_state_dict'])    print(" Expert 2 loaded")if GNN_AVAILABLE and Path(config.output_dir / 'expert3_best.pt').exists():    checkpoint = torch.load(config.output_dir / 'expert3_best.pt')    v7_model.expert3.load_state_dict(checkpoint['model_state_dict'])    print(" Expert 3 loaded")# Freeze expertsfor param in v7_model.expert1.parameters() if v7_model.expert1 else []:    param.requires_grad = Falsefor param in v7_model.expert2.parameters():    param.requires_grad = Falsefor param in v7_model.expert3.parameters() if v7_model.expert3 else []:    param.requires_grad = Falseprint("Experts frozen. Training only gating network...")# Optimizer for gating network onlygating_optimizer = torch.optim.Adam(    v7_model.gating_network.parameters(),    lr=config.learning_rate,    weight_decay=config.weight_decay)# Training loop Stage 2best_loss = float('inf')patience_counter = 0for epoch in range(config.epochs_stage2):    v7_model.train()    total_loss = 0    num_batches = 0    pbar = tqdm(full_loader, desc=f"Stage 2 Epoch {epoch+1}/{config.epochs_stage2}")    for x_grid, x_graph, context, y in pbar:        x_grid = x_grid.to(config.device)        x_graph = x_graph.to(config.device)        context = context.to(config.device)        y = y.to(config.device)        gating_optimizer.zero_grad()        # Forward through complete model (with correct inputs)        predictions, gating_weights = v7_model(x_grid, x_graph, context, edge_index)        # Loss        loss = criterion(predictions, y.mean(dim=2))        # Backward        loss.backward()        torch.nn.utils.clip_grad_norm_(v7_model.gating_network.parameters(), max_norm=1.0)        gating_optimizer.step()        total_loss += loss.item()        num_batches += 1        pbar.set_postfix({'loss': f'{loss.item():.4f}'})    avg_loss = total_loss / num_batches    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")    if avg_loss < best_loss:        best_loss = avg_loss        patience_counter = 0        torch.save({            'epoch': epoch,            'model_state_dict': v7_model.state_dict(),            'optimizer_state_dict': gating_optimizer.state_dict(),            'loss': best_loss,        }, config.output_dir / 'v7_ames_stage2_best.pt')        print(f"  → Best model saved (loss: {best_loss:.4f})")    else:        patience_counter += 1        if patience_counter >= config.patience_stage2:            print(f"Early stopping triggered at epoch {epoch+1}")            breakprint(f"\nStage 2 complete. Best loss: {best_loss:.4f}")

## 6. Stage 3: Joint Fine-Tuning

Fine-tune all components with physics-informed loss.


In [None]:
# =============================================================================
# SECTION 6: STAGE 3 - JOINT FINE-TUNING
# =============================================================================

print("\n" + "="*80)print("STAGE 3: JOINT FINE-TUNING")print()# Unfreeze all parametersfor param in v7_model.parameters():    param.requires_grad = Trueprint("All parameters unfrozen for joint training")# Optimizer for full modelfull_optimizer = torch.optim.Adam(    v7_model.parameters(),    lr=config.learning_rate * 0.1,  # Lower LR for fine-tuning    weight_decay=config.weight_decay)# Training loop Stage 3 with physics-informed lossbest_loss = float('inf')patience_counter = 0for epoch in range(config.epochs_stage3):    v7_model.train()    total_loss = 0    total_mse = 0    total_mass = 0    total_oro = 0    num_batches = 0    pbar = tqdm(full_loader, desc=f"Stage 3 Epoch {epoch+1}/{config.epochs_stage3}")    for x_grid, x_graph, context, y in pbar:        x_grid = x_grid.to(config.device)        x_graph = x_graph.to(config.device)        context = context.to(config.device)        y = y.to(config.device)        full_optimizer.zero_grad()        # Forward (with correct inputs)        predictions, gating_weights = v7_model(x_grid, x_graph, context, edge_index)        # Physics-informed loss (context aggregation handled inside function)        loss, loss_components = v7_model.physics_informed_loss(            predictions, y.mean(dim=2), context        )        # Backward        loss.backward()        torch.nn.utils.clip_grad_norm_(v7_model.parameters(), max_norm=1.0)        full_optimizer.step()        total_loss += loss.item()        total_mse += loss_components['mse']        total_mass += loss_components['mass']        total_oro += loss_components['orographic']        num_batches += 1        pbar.set_postfix({            'loss': f'{loss.item():.4f}',            'mse': f'{loss_components["mse"]:.4f}'        })    avg_loss = total_loss / num_batches    avg_mse = total_mse / num_batches    avg_mass = total_mass / num_batches    avg_oro = total_oro / num_batches    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f} "          f"(MSE: {avg_mse:.4f}, Mass: {avg_mass:.4f}, Oro: {avg_oro:.4f})")    if avg_loss < best_loss:        best_loss = avg_loss        patience_counter = 0        torch.save({            'epoch': epoch,            'model_state_dict': v7_model.state_dict(),            'optimizer_state_dict': full_optimizer.state_dict(),            'loss': best_loss,        }, config.output_dir / 'v7_ames_final_best.pt')        print(f"  → Best model saved (loss: {best_loss:.4f})")    else:        patience_counter += 1        if patience_counter >= config.patience_stage3:            print(f"Early stopping triggered at epoch {epoch+1}")            breakprint(f"\nStage 3 complete. Best loss: {best_loss:.4f}")

## 7. Training Complete - Results Summary


In [None]:
print("\n" + "="*80)print("V7-AMES TRAINING COMPLETE")print()print("Final model saved: v7_ames_final_best.pt")print()print("Training Summary:")print("  Stage 1: Expert pre-training - COMPLETE")print("  Stage 2: Gating network - COMPLETE")print("  Stage 3: Joint fine-tuning - COMPLETE")print()print("  1. Evaluate on validation set")print("  2. Ablation studies (disable physics loss, compare with V4)")print("  3. Generate visualizations")print("  4. Create predictions for test set")