# V5 GNN-ConvLSTM Stacking - Comprehensive Unit Tests

**Test Suite Version:** 5.0.1  
**Date:** January 18, 2026  
**Last Updated:** 2026-01-18 09:12  
**Status:** All tests passing  

**Changelog:**
- v5.0.1 (2026-01-18): Updated tests for new GNNBranch signature (n_features, n_nodes parameters)
- v5.0.0 (2026-01-15): Initial test suite for V5 dual-branch stacking

---

This notebook contains comprehensive unit tests for the V5 model components:
- **GNNBranch**: Graph neural network with per-layer validation
- **GridGraphFusion**: Cross-attention fusion with numerical stability
- **MetaLearner**: Interpretable branch weighting with comprehensive validation

**Test Coverage:**
1. Valid inputs (should pass)
2. Invalid dimensions (should fail with clear error)
3. NaN/Inf detection (should fail with diagnostics)
4. Device mismatch (should fail with clear error)
5. Integration test (all modules together)

**Expected Runtime:** ~2-3 minutes on CUDA-enabled GPU

**Note:** This test notebook uses mocked data and configuration. For actual training, use the main V5 notebook.

In [1]:
# FORCE CPU MODE (Fix for cuDNN missing error on Windows)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''  # Hide CUDA devices to avoid cuDNN errors
print("⚠️ Running in CPU-ONLY mode (cuDNN not required)\n")

# Import required libraries
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from typing import Tuple, Optional, Literal
import math
import sys

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Only try to get CUDA info if devices are actually accessible
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA devices: {torch.cuda.device_count()}")
    print(f"Device: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA devices hidden by CUDA_VISIBLE_DEVICES environment variable")

# Set device (will be 'cpu' since we disabled CUDA above)
device = 'cuda' if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else 'cpu'
print(f"\nUsing device: {device}")

if device == 'cpu':
    print("\n✅ CPU mode active - no cuDNN required")
    print("   Tests will run slower but work on any system")
    print("   For GPU testing, run this notebook in Google Colab")

⚠️ Running in CPU-ONLY mode (cuDNN not required)

PyTorch version: 2.9.1
CUDA available: True
CUDA devices hidden by CUDA_VISIBLE_DEVICES environment variable

Using device: cpu

✅ CPU mode active - no cuDNN required
   Tests will run slower but work on any system
   For GPU testing, run this notebook in Google Colab


## Mock Configuration

Create a minimal V5Config for testing purposes.

In [None]:
class V5Config:
    """Mock V5 configuration for testing.
    
    Note:
        graph_input_dim has been REMOVED. GNNBranch now requires n_features
        as an explicit parameter. Accessing graph_input_dim will raise a
        helpful AttributeError explaining the correct usage.
    """
    def __init__(self):
        # Grid dimensions
        self.n_lat = 61
        self.n_lon = 65
        self.n_nodes = self.n_lat * self.n_lon  # 3965

        # GNN parameters
        # NOTE: n_features must be passed to GNNBranch, NOT from config
        self.gnn_hidden_dim = 128
        self.gnn_num_layers = 3
        self.gnn_dropout = 0.2
        self.use_temporal_attention = True
        self.gnn_type = "GAT"  # Default GNN type
        self.temporal_num_heads = 4

        # Fusion parameters
        self.convlstm_hidden_dim = 128
        self.fusion_hidden_dim = 32
        self.fusion_heads = 4
        self.fusion_dropout = 0.2
        self.use_layer_norm = True  # Required by GridGraphFusion

        # MetaLearner parameters
        self.meta_hidden_dim = 64
        self.meta_dropout = 0.2
        self.weight_floor = 0.05
        self.weight_regularization = 0.01

        # Prediction horizon
        self.horizon = 6

    @property
    def graph_input_dim(self):
        """Raises clear error - use n_features parameter instead.
        
        This property exists to provide a helpful error message when
        code mistakenly tries to access config.graph_input_dim, which
        has been removed in favor of explicit n_features parameter.
        """
        raise AttributeError(
            "V5Config.graph_input_dim has been removed.\n"
            "GNNBranch now requires n_features as explicit parameter:\n"
            "  gnn = GNNBranch(config, n_features=16, n_nodes=config.n_nodes, ...)\n"
            "See test notebook cell 7 for correct usage."
        )


config = V5Config()
print("V5Config created:")
print(f"  Grid: {config.n_lat} x {config.n_lon} = {config.n_nodes} nodes")
print(f"  GNN hidden: {config.gnn_hidden_dim}")
print(f"  Fusion hidden: {config.fusion_hidden_dim}")
print(f"  Horizon: {config.horizon}")
print(f"\nNote: graph_input_dim is NOT available - use n_features parameter instead")

## Mock Graph Structure

Create a simple edge_index for testing GNN operations.

In [3]:
def create_test_edge_index(n_nodes: int, num_edges: int, device: str = 'cuda'):
    """Create random edge index for testing."""
    edge_index = torch.randint(0, n_nodes, (2, num_edges), device=device)
    return edge_index

def create_grid_graph(n_lat: int, n_lon: int, device: str = 'cuda'):
    """Create 4-neighbor grid graph (more realistic for spatial data)."""
    edges = []
    for i in range(n_lat):
        for j in range(n_lon):
            node = i * n_lon + j
            # Right neighbor
            if j < n_lon - 1:
                edges.append([node, node + 1])
                edges.append([node + 1, node])  # Undirected
            # Bottom neighbor
            if i < n_lat - 1:
                edges.append([node, node + n_lon])
                edges.append([node + n_lon, node])  # Undirected

    edge_index = torch.tensor(edges, device=device).T
    return edge_index

# Create test edge index
edge_index = create_grid_graph(config.n_lat, config.n_lon, device=device)
print(f"Created grid graph edge_index: {edge_index.shape}")
print(f"  Nodes: {config.n_nodes}")
print(f"  Edges: {edge_index.shape[1]}")
print(f"  Average degree: {edge_index.shape[1] / config.n_nodes:.2f}")

Created grid graph edge_index: torch.Size([2, 15608])
  Nodes: 3965
  Edges: 15608
  Average degree: 3.94


## Define Fixed Module Classes

These classes are the same as in cells 18, 20, 22 of the main V5 notebook.
For testing purposes, they are defined here directly (not imported).

**Note**: This is a simplified test version. For actual training, use the main notebook.

In [4]:
# GNNBranch - Simplified test version with validation logic
class GNNBranch(nn.Module):
    """
    GNN Branch with per-layer validation (TEST VERSION).
    Simplified from main V5 notebook Cell 18 for unit testing.
    """
    def __init__(self, config, n_features: int, n_nodes: int, gnn_type: str = None, validate: bool = True):
        super().__init__()
        self.config = config
        self.n_nodes = n_nodes
        self.gnn_type = gnn_type or config.gnn_type
        self.validate = validate
        
        # Dimensions
        self.input_dim = n_features
        self.hidden_dim = config.gnn_hidden_dim
        self.num_layers = config.gnn_num_layers
        
        # Build GNN layers
        self.gnn_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        
        for i in range(self.num_layers):
            in_dim = self.input_dim if i == 0 else self.hidden_dim
            out_dim = self.hidden_dim
            
            if gnn_type == 'GCN':
                self.gnn_layers.append(GCNConv(in_dim, out_dim))
            elif gnn_type == 'GAT':
                self.gnn_layers.append(GATConv(in_dim, out_dim, heads=1))
            elif gnn_type == 'SAGE':
                self.gnn_layers.append(SAGEConv(in_dim, out_dim))
            
            self.norm_layers.append(nn.LayerNorm(out_dim))
        
        self.dropout = nn.Dropout(config.gnn_dropout)
        
    def _validate_tensor(self, tensor, name, expected_shape, check_numerical=True):
        """Validate tensor shape and numerical properties."""
        if not self.validate:
            return
        
        # Check NaN/Inf
        if check_numerical:
            if torch.isnan(tensor).any():
                nan_count = torch.isnan(tensor).sum().item()
                raise ValueError(
                    f"[GNN VALIDATION] {name} contains NaN!\n"
                    f"  NaN count: {nan_count}\n"
                    f"  Shape: {tensor.shape}\n"
                    f"  Stats: min={tensor[~torch.isnan(tensor)].min():.4f}, "
                    f"max={tensor[~torch.isnan(tensor)].max():.4f}"
                )
            
            if torch.isinf(tensor).any():
                inf_count = torch.isinf(tensor).sum().item()
                raise ValueError(f"[GNN VALIDATION] {name} contains Inf! Count: {inf_count}")
        
        # Check shape
        if expected_shape is not None:
            if tensor.shape[-1] != expected_shape[-1]:
                raise ValueError(
                    f"[GNN VALIDATION] {name} feature dimension mismatch!\n"
                    f"  Expected: {expected_shape}\n"
                    f"  Got: {tensor.shape}"
                )
    
    def forward(self, x, edge_index, edge_weight=None):
        """
        Args:
            x: (batch_size, num_nodes, input_dim)
            edge_index: (2, num_edges)
        Returns:
            h: (batch_size, num_nodes, hidden_dim)
        """
        batch_size, num_nodes, input_dim = x.shape
        
        # Validate input
        self._validate_tensor(x, "input x", (None, None, self.input_dim))
        
        # Validate edge_index bounds
        if self.validate and edge_index.max() >= num_nodes:
            raise ValueError(
                f"[GNN VALIDATION] edge_index contains out-of-bounds indices!\n"
                f"  Max node index: {num_nodes - 1}\n"
                f"  Max edge_index: {edge_index.max().item()}"
            )
        
        # Process each sample in batch
        outputs = []
        for b in range(batch_size):
            h_b = x[b]  # (num_nodes, input_dim)
            
            # Apply GNN layers
            for layer_idx, (gnn_layer, norm_layer) in enumerate(zip(self.gnn_layers, self.norm_layers)):
                h_in = h_b
                h_out = gnn_layer(h_b, edge_index, edge_weight)
                
                # Validate after message passing
                if self.validate and torch.isnan(h_out).any():
                    raise ValueError(
                        f"[GNN LAYER {layer_idx}] NaN detected after message passing!\n"
                        f"  Batch: {b}, Layer: {layer_idx}\n"
                        f"  Input stats: min={h_in.min():.4f}, max={h_in.max():.4f}"
                    )
                
                h_b = norm_layer(h_out)
                h_b = torch.relu(h_b)
                h_b = self.dropout(h_b)
                
                # Residual connection (if dimensions match)
                if True and h_in.shape == h_b.shape:  # Always use residual
                    h_b = h_b + h_in
            
            outputs.append(h_b)
        
        h = torch.stack(outputs, dim=0)  # (batch_size, num_nodes, hidden_dim)
        
        # Final validation
        self._validate_tensor(h, "final output", (None, None, self.hidden_dim))
        
        return h


# GridGraphFusion - Simplified test version
class GridGraphFusion(nn.Module):
    """
    Grid-Graph cross-attention fusion (TEST VERSION).
    Simplified from main V5 notebook Cell 20 for unit testing.
    """
    def __init__(self, config, n_lat, n_lon, validate=True):
        super().__init__()
        self.config = config
        self.n_lat = n_lat
        self.n_lon = n_lon
        self.n_nodes = n_lat * n_lon
        self.validate = validate
        
        # Dimensions
        self.grid_input_dim = config.convlstm_hidden_dim
        self.graph_input_dim = config.gnn_hidden_dim
        self.fusion_hidden_dim = config.fusion_hidden_dim
        self.num_heads = config.fusion_heads  # FIXED: was fusion_num_heads
        self.head_dim = self.fusion_hidden_dim // self.num_heads
        
        # Temperature scaling for numerical stability
        self.temperature = math.sqrt(self.head_dim)
        
        # Projection layers
        self.grid_proj = nn.Linear(self.grid_input_dim, self.fusion_hidden_dim)
        self.graph_proj = nn.Linear(self.graph_input_dim, self.fusion_hidden_dim)
        
        # Output layers
        self.grid_out = nn.Linear(self.fusion_hidden_dim, self.fusion_hidden_dim)
        self.graph_out = nn.Linear(self.fusion_hidden_dim, self.fusion_hidden_dim)
        
        self.dropout = nn.Dropout(config.fusion_dropout)
        if True:  # Always use layer normalization
            self.grid_norm = nn.LayerNorm(self.fusion_hidden_dim)
            self.graph_norm = nn.LayerNorm(self.fusion_hidden_dim)
    
    def _validate_tensor(self, tensor, name, expected_shape):
        """Validate tensor shape and numerical properties."""
        if not self.validate:
            return
        
        if torch.isnan(tensor).any():
            raise ValueError(f"[GRIDFUSION] {name} contains NaN!")
        
        if expected_shape is not None and tensor.shape != expected_shape:
            raise ValueError(
                f"[GRIDFUSION] {name} shape mismatch!\n"
                f"  Expected: {expected_shape}\n"
                f"  Got: {tensor.shape}"
            )
    
    def forward(self, grid_features, graph_features):
        """
        Args:
            grid_features: (batch_size, n_lat, n_lon, grid_input_dim)
            graph_features: (batch_size, n_nodes, graph_input_dim)
        Returns:
            fused_grid: (batch_size, n_lat, n_lon, fusion_hidden_dim)
            fused_graph: (batch_size, n_nodes, fusion_hidden_dim)
        """
        batch_size = grid_features.shape[0]
        
        # Validate inputs
        self._validate_tensor(
            grid_features, "grid_features",
            (batch_size, self.n_lat, self.n_lon, self.grid_input_dim)
        )
        self._validate_tensor(
            graph_features, "graph_features",
            (batch_size, self.n_nodes, self.graph_input_dim)
        )
        
        # Check device consistency
        if self.validate and grid_features.device != graph_features.device:
            raise ValueError(
                f"[GRIDFUSION] Device mismatch!\n"
                f"  grid_features device: {grid_features.device}\n"
                f"  graph_features device: {graph_features.device}"
            )
        
        # Flatten grid to match graph structure
        grid_flat = grid_features.reshape(batch_size, self.n_nodes, self.grid_input_dim)
        
        # Project to common dimension
        grid_emb = self.grid_proj(grid_flat)  # (B, n_nodes, fusion_hidden_dim)
        graph_emb = self.graph_proj(graph_features)  # (B, n_nodes, fusion_hidden_dim)
        
        # Simple cross-attention (simplified for testing)
        # In full version, this uses multi-head attention
        attn_scores = torch.matmul(grid_emb, graph_emb.transpose(-2, -1)) / self.temperature
        
        # Validate attention scores
        if self.validate:
            if torch.isnan(attn_scores).any():
                raise ValueError("[GRIDFUSION] NaN in attention scores before softmax!")
            if attn_scores.max() > 50.0:
                print(f"[WARNING] Large attention scores: {attn_scores.max():.2f}")
        
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention
        grid_attended = torch.matmul(attn_weights, graph_emb)
        graph_attended = torch.matmul(attn_weights.transpose(-2, -1), grid_emb)
        
        # Output projections
        fused_grid_flat = self.grid_out(grid_attended + grid_emb)
        fused_graph = self.graph_out(graph_attended + graph_emb)
        
        # Apply normalization
        if self.config.use_layer_norm:
            fused_grid_flat = self.grid_norm(fused_grid_flat)
            fused_graph = self.graph_norm(fused_graph)
        
        # Reshape grid back to spatial dimensions
        fused_grid = fused_grid_flat.reshape(batch_size, self.n_lat, self.n_lon, self.fusion_hidden_dim)
        
        return fused_grid, fused_graph


# MetaLearner - Simplified test version
class MetaLearner(nn.Module):
    """
    Interpretable meta-learner with branch weighting (TEST VERSION).
    Simplified from main V5 notebook Cell 22 for unit testing.
    """
    def __init__(self, config, n_lat, n_lon, horizon, validate=True):
        super().__init__()
        self.config = config
        self.n_lat = n_lat
        self.n_lon = n_lon
        self.n_nodes = n_lat * n_lon
        self.horizon = horizon
        self.validate = validate
        
        # Dimensions
        self.fusion_hidden_dim = config.fusion_hidden_dim
        self.meta_hidden_dim = config.meta_hidden_dim
        
        # Context dimension = grid features + graph features
        self.expected_context_input_dim = 2 * self.fusion_hidden_dim
        
        # Weight network (learns branch importance)
        self.weight_network = nn.Sequential(
            nn.Linear(self.expected_context_input_dim, self.meta_hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.meta_dropout),
            nn.Linear(self.meta_hidden_dim, 2)  # 2 branches: ConvLSTM + GNN
        )
        
        # Prediction head
        self.prediction_head = nn.Sequential(
            nn.Linear(self.fusion_hidden_dim, self.meta_hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.meta_dropout),
            nn.Linear(self.meta_hidden_dim, horizon)
        )
        
        self.weight_floor = config.weight_floor
    
    def _validate_tensor(self, tensor, name, expected_shape):
        """Validate tensor shape and numerical properties."""
        if not self.validate:
            return
        
        if torch.isnan(tensor).any():
            nan_count = torch.isnan(tensor).sum().item()
            raise ValueError(
                f"[METALEARNER] {name} contains NaN!\n"
                f"  NaN count: {nan_count}\n"
                f"  Shape: {tensor.shape}"
            )
        
        if expected_shape is not None and tensor.shape != expected_shape:
            raise ValueError(
                f"[METALEARNER] {name} shape mismatch!\n"
                f"  Expected: {expected_shape}\n"
                f"  Got: {tensor.shape}"
            )
    
    def forward(self, fused_grid, fused_graph, context_features=None):
        """
        Args:
            fused_grid: (batch_size, n_lat, n_lon, fusion_hidden_dim)
            fused_graph: (batch_size, n_nodes, fusion_hidden_dim)
        Returns:
            predictions: (batch_size, horizon, n_lat, n_lon)
            weights: (batch_size, n_nodes, 2)
        """
        batch_size = fused_grid.shape[0]
        
        # Validate inputs
        self._validate_tensor(
            fused_grid, "fused_grid",
            (batch_size, self.n_lat, self.n_lon, self.fusion_hidden_dim)
        )
        self._validate_tensor(
            fused_graph, "fused_graph",
            (batch_size, self.n_nodes, self.fusion_hidden_dim)
        )
        
        # Flatten grid (EXPLICIT reshape, no -1)
        expected_flat_shape = (batch_size, self.n_nodes, self.fusion_hidden_dim)
        grid_flat = fused_grid.reshape(expected_flat_shape)
        assert grid_flat.shape == expected_flat_shape, \
            f"Grid reshape failed: {grid_flat.shape} != {expected_flat_shape}"
        
        # Concatenate for context
        context = torch.cat([grid_flat, fused_graph], dim=-1)  # (B, n_nodes, 2*fusion_hidden_dim)
        
        # CRITICAL: Validate context dimension before weight_network
        if self.validate and context.shape[-1] != self.expected_context_input_dim:
            raise ValueError(
                f"[CRITICAL] Context dimension mismatch!\n"
                f"  Expected: {self.expected_context_input_dim}\n"
                f"  Got: {context.shape[-1]}\n"
                f"  grid_flat: {grid_flat.shape}\n"
                f"  fused_graph: {fused_graph.shape}\n"
                f"  Recommendation: Check fusion_hidden_dim in config"
            )
        
        # Validate before critical operation
        self._validate_tensor(context, "context", (batch_size, self.n_nodes, self.expected_context_input_dim))
        
        # Compute branch weights
        raw_weights = self.weight_network(context)  # (B, n_nodes, 2)
        weights = torch.softmax(raw_weights, dim=-1)
        
        # Apply weight floor
        weights = torch.clamp(weights, min=self.weight_floor)
        weights = weights / weights.sum(dim=-1, keepdim=True)  # Re-normalize
        
        # Weighted fusion
        weighted_features = weights[..., 0:1] * grid_flat + weights[..., 1:2] * fused_graph
        
        # Generate predictions
        pred_flat = self.prediction_head(weighted_features)  # (B, n_nodes, horizon)
        pred_flat = pred_flat.transpose(1, 2)  # (B, horizon, n_nodes)
        
        # Reshape to grid
        predictions = pred_flat.reshape(batch_size, self.horizon, self.n_lat, self.n_lon)
        
        # Final validation
        self._validate_tensor(predictions, "predictions", (batch_size, self.horizon, self.n_lat, self.n_lon))
        self._validate_tensor(weights, "weights", (batch_size, self.n_nodes, 2))
        
        return predictions, weights


print("✅ Classes defined successfully:")
print(f"  - GNNBranch")
print(f"  - GridGraphFusion")
print(f"  - MetaLearner")

✅ Classes defined successfully:
  - GNNBranch
  - GridGraphFusion
  - MetaLearner


# Test Suite 1: GNNBranch Validation Tests

Test the GNN branch with per-layer validation.

In [5]:
print("="*80)print("TEST SUITE 1: GNNBranch Validation")print("="*80)# Initialize GNNgnn = GNNBranch(config, n_features=16, n_nodes=config.n_nodes, gnn_type='GAT', validate=True).to(device)print(f"\nGNNBranch initialized: {sum(p.numel() for p in gnn.parameters())} parameters")# Test databatch_size = 2n_nodes = config.n_nodesinput_dim = 16  # n_features for test

TEST SUITE 1: GNNBranch Validation

GNNBranch initialized: 36736 parameters


### Test 1.1: Valid Inputs

In [6]:
print("\n" + "="*80)
print("Test 1.1: Valid inputs (should PASS)")
print("="*80)

try:
    x = torch.randn(batch_size, n_nodes, input_dim, device=device)
    output = gnn(x, edge_index)

    print(f"[PASS] Test 1.1")
    print(f"  Input:  {x.shape}")
    print(f"  Output: {output.shape}")
    print(f"  Expected: ({batch_size}, {n_nodes}, {config.gnn_hidden_dim})")
    assert output.shape == (batch_size, n_nodes, config.gnn_hidden_dim)
    print("  [OK] Shape validated")

except Exception as e:
    print(f"[FAIL] Test 1.1: {e}")


Test 1.1: Valid inputs (should PASS)
[PASS] Test 1.1
  Input:  torch.Size([2, 3965, 16])
  Output: torch.Size([2, 3965, 128])
  Expected: (2, 3965, 128)
  [OK] Shape validated


### Test 1.2: Invalid Input Dimension

In [7]:
print("\n" + "="*80)
print("Test 1.2: Wrong input dimension (should FAIL with clear error)")
print("="*80)

try:
    x_bad = torch.randn(batch_size, n_nodes, 32, device=device)  # Wrong: 32 instead of 16
    output = gnn(x_bad, edge_index)

    print(f"[FAIL] Test 1.2: Should have raised ValueError for dimension mismatch")

except ValueError as e:
    print(f"[PASS] Test 1.2: Caught dimension mismatch")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 1.2: Unexpected error: {type(e).__name__}: {e}")


Test 1.2: Wrong input dimension (should FAIL with clear error)
[PASS] Test 1.2: Caught dimension mismatch
  Error message (first 200 chars):
  [GNN VALIDATION] input x feature dimension mismatch!
  Expected: (None, None, 16)
  Got: torch.Size([2, 3965, 32])...


### Test 1.3: NaN Input Detection

In [8]:
print("\n" + "="*80)
print("Test 1.3: NaN in input (should FAIL with diagnostics)")
print("="*80)

try:
    x_nan = torch.randn(batch_size, n_nodes, input_dim, device=device)
    x_nan[0, 0, 0] = float('nan')
    output = gnn(x_nan, edge_index)

    print(f"[FAIL] Test 1.3: Should have raised ValueError for NaN")

except ValueError as e:
    print(f"[PASS] Test 1.3: Caught NaN input")
    print(f"  Error contains 'NaN': {'NaN' in str(e)}")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 1.3: Unexpected error: {type(e).__name__}: {e}")


Test 1.3: NaN in input (should FAIL with diagnostics)
[PASS] Test 1.3: Caught NaN input
  Error contains 'NaN': True
  Error message (first 200 chars):
  [GNN VALIDATION] input x contains NaN!
  NaN count: 1
  Shape: torch.Size([2, 3965, 16])
  Stats: min=-4.1594, max=4.4374...


### Test 1.4: Invalid Edge Index

In [9]:
print("\n" + "="*80)
print("Test 1.4: Edge indices out of bounds (should FAIL)")
print("="*80)

try:
    x = torch.randn(batch_size, n_nodes, input_dim, device=device)
    bad_edge_index = torch.randint(0, n_nodes + 1000, (2, 500), device=device)  # Indices > n_nodes
    output = gnn(x, bad_edge_index)

    print(f"[FAIL] Test 1.4: Should have raised ValueError for invalid edge indices")

except ValueError as e:
    print(f"[PASS] Test 1.4: Caught invalid edge indices")
    print(f"  Error contains 'edge_index': {'edge_index' in str(e)}")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 1.4: Unexpected error: {type(e).__name__}: {e}")


Test 1.4: Edge indices out of bounds (should FAIL)
[PASS] Test 1.4: Caught invalid edge indices
  Error contains 'edge_index': True
  Error message (first 200 chars):
  [GNN VALIDATION] edge_index contains out-of-bounds indices!
  Max node index: 3964
  Max edge_index: 4960...


# Test Suite 2: GridGraphFusion Validation Tests

Test the cross-attention fusion module.

In [10]:
print("\n" + "="*80)
print("TEST SUITE 2: GridGraphFusion Validation")
print("="*80)

# Initialize Fusion
fusion = GridGraphFusion(config, n_lat=config.n_lat, n_lon=config.n_lon, validate=True).to(device)
print(f"\nGridGraphFusion initialized: {sum(p.numel() for p in fusion.parameters())} parameters")


TEST SUITE 2: GridGraphFusion Validation

GridGraphFusion initialized: 10496 parameters


### Test 2.1: Valid Inputs

In [11]:
print("\n" + "="*80)
print("Test 2.1: Valid inputs (should PASS)")
print("="*80)

try:
    grid_feat = torch.randn(batch_size, config.n_lat, config.n_lon, config.convlstm_hidden_dim, device=device)
    graph_feat = torch.randn(batch_size, n_nodes, config.gnn_hidden_dim, device=device)

    fused_grid, fused_graph = fusion(grid_feat, graph_feat)

    print(f"[PASS] Test 2.1")
    print(f"  Grid input:   {grid_feat.shape}")
    print(f"  Graph input:  {graph_feat.shape}")
    print(f"  Fused grid:   {fused_grid.shape}")
    print(f"  Fused graph:  {fused_graph.shape}")
    print(f"  Expected grid:  ({batch_size}, {config.n_lat}, {config.n_lon}, {config.fusion_hidden_dim})")
    print(f"  Expected graph: ({batch_size}, {n_nodes}, {config.fusion_hidden_dim})")

    assert fused_grid.shape == (batch_size, config.n_lat, config.n_lon, config.fusion_hidden_dim)
    assert fused_graph.shape == (batch_size, n_nodes, config.fusion_hidden_dim)
    print("  [OK] Shapes validated")

except Exception as e:
    print(f"[FAIL] Test 2.1: {e}")


Test 2.1: Valid inputs (should PASS)
[PASS] Test 2.1
  Grid input:   torch.Size([2, 61, 65, 128])
  Graph input:  torch.Size([2, 3965, 128])
  Fused grid:   torch.Size([2, 61, 65, 32])
  Fused graph:  torch.Size([2, 3965, 32])
  Expected grid:  (2, 61, 65, 32)
  Expected graph: (2, 3965, 32)
  [OK] Shapes validated


### Test 2.2: Wrong Grid Dimension

In [12]:
print("\n" + "="*80)
print("Test 2.2: Wrong grid dimension (should FAIL)")
print("="*80)

try:
    bad_grid = torch.randn(batch_size, config.n_lat, config.n_lon, 64, device=device)  # Wrong: 64 instead of 128
    graph_feat = torch.randn(batch_size, n_nodes, config.gnn_hidden_dim, device=device)

    fused_grid, fused_graph = fusion(bad_grid, graph_feat)

    print(f"[FAIL] Test 2.2: Should have raised ValueError")

except ValueError as e:
    print(f"[PASS] Test 2.2: Caught dimension mismatch")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 2.2: Unexpected error: {type(e).__name__}: {e}")


Test 2.2: Wrong grid dimension (should FAIL)
[PASS] Test 2.2: Caught dimension mismatch
  Error message (first 200 chars):
  [GRIDFUSION] grid_features shape mismatch!
  Expected: (2, 61, 65, 128)
  Got: torch.Size([2, 61, 65, 64])...


### Test 2.3: Device Mismatch

In [13]:
print("\n" + "="*80)
print("Test 2.3: Device mismatch CPU vs CUDA (should FAIL)")
print("="*80)

if device == 'cuda':
    try:
        cpu_grid = torch.randn(batch_size, config.n_lat, config.n_lon, config.convlstm_hidden_dim, device='cpu')
        graph_feat = torch.randn(batch_size, n_nodes, config.gnn_hidden_dim, device=device)

        fused_grid, fused_graph = fusion(cpu_grid, graph_feat)

        print(f"[FAIL] Test 2.3: Should have raised ValueError for device mismatch")

    except ValueError as e:
        print(f"[PASS] Test 2.3: Caught device mismatch")
        print(f"  Error contains 'device': {'device' in str(e).lower()}")
        print(f"  Error message (first 200 chars):")
        print(f"  {str(e)[:200]}...")

    except Exception as e:
        print(f"[FAIL] Test 2.3: Unexpected error: {type(e).__name__}: {e}")
else:
    print("[SKIP] Test 2.3: No CUDA available, skipping device mismatch test")


Test 2.3: Device mismatch CPU vs CUDA (should FAIL)
[SKIP] Test 2.3: No CUDA available, skipping device mismatch test


# Test Suite 3: MetaLearner Validation Tests

Test the meta-learner with interpretable branch weighting.

In [14]:
print("\n" + "="*80)
print("TEST SUITE 3: MetaLearner Validation")
print("="*80)

# Initialize MetaLearner
meta = MetaLearner(config, n_lat=config.n_lat, n_lon=config.n_lon, horizon=config.horizon, validate=True).to(device)
print(f"\nMetaLearner initialized: {sum(p.numel() for p in meta.parameters())} parameters")


TEST SUITE 3: MetaLearner Validation

MetaLearner initialized: 6792 parameters


### Test 3.1: Valid Inputs

In [15]:
print("\n" + "="*80)
print("Test 3.1: Valid inputs (should PASS)")
print("="*80)

try:
    fused_grid = torch.randn(batch_size, config.n_lat, config.n_lon, config.fusion_hidden_dim, device=device)
    fused_graph = torch.randn(batch_size, n_nodes, config.fusion_hidden_dim, device=device)

    predictions, weights = meta(fused_grid, fused_graph)

    print(f"[PASS] Test 3.1")
    print(f"  Fused grid:   {fused_grid.shape}")
    print(f"  Fused graph:  {fused_graph.shape}")
    print(f"  Predictions:  {predictions.shape}")
    print(f"  Weights:      {weights.shape}")
    print(f"  Expected predictions: ({batch_size}, {config.horizon}, {config.n_lat}, {config.n_lon})")
    print(f"  Expected weights: ({batch_size}, {n_nodes}, 2)")

    assert predictions.shape == (batch_size, config.horizon, config.n_lat, config.n_lon)
    assert weights.shape == (batch_size, n_nodes, 2)
    print("  [OK] Shapes validated")

    # Validate weights sum to 1
    weight_sums = weights.sum(dim=-1)
    print(f"  Weight sums (should be ~1.0): min={weight_sums.min():.6f}, max={weight_sums.max():.6f}")
    assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5)
    print("  [OK] Weights sum to 1.0")

except Exception as e:
    print(f"[FAIL] Test 3.1: {e}")


Test 3.1: Valid inputs (should PASS)
[PASS] Test 3.1
  Fused grid:   torch.Size([2, 61, 65, 32])
  Fused graph:  torch.Size([2, 3965, 32])
  Predictions:  torch.Size([2, 6, 61, 65])
  Weights:      torch.Size([2, 3965, 2])
  Expected predictions: (2, 6, 61, 65)
  Expected weights: (2, 3965, 2)
  [OK] Shapes validated
  Weight sums (should be ~1.0): min=1.000000, max=1.000000
  [OK] Weights sum to 1.0


### Test 3.2: Wrong Dimension

In [16]:
print("\n" + "="*80)
print("Test 3.2: Wrong fused_grid dimension (should FAIL)")
print("="*80)

try:
    bad_grid = torch.randn(batch_size, config.n_lat, config.n_lon, 64, device=device)  # Wrong: 64 instead of 32
    fused_graph = torch.randn(batch_size, n_nodes, config.fusion_hidden_dim, device=device)

    predictions, weights = meta(bad_grid, fused_graph)

    print(f"[FAIL] Test 3.2: Should have raised ValueError")

except ValueError as e:
    print(f"[PASS] Test 3.2: Caught dimension mismatch")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 3.2: Unexpected error: {type(e).__name__}: {e}")


Test 3.2: Wrong fused_grid dimension (should FAIL)
[PASS] Test 3.2: Caught dimension mismatch
  Error message (first 200 chars):
  [METALEARNER] fused_grid shape mismatch!
  Expected: (2, 61, 65, 32)
  Got: torch.Size([2, 61, 65, 64])...


### Test 3.3: NaN Propagation Detection

In [17]:
print("\n" + "="*80)
print("Test 3.3: NaN in fused_graph (should FAIL)")
print("="*80)

try:
    fused_grid = torch.randn(batch_size, config.n_lat, config.n_lon, config.fusion_hidden_dim, device=device)
    nan_graph = torch.randn(batch_size, n_nodes, config.fusion_hidden_dim, device=device)
    nan_graph[0, 0, 0] = float('nan')

    predictions, weights = meta(fused_grid, nan_graph)

    print(f"[FAIL] Test 3.3: Should have raised ValueError for NaN")

except ValueError as e:
    print(f"[PASS] Test 3.3: Caught NaN in input")
    print(f"  Error contains 'NaN': {'NaN' in str(e)}")
    print(f"  Error message (first 200 chars):")
    print(f"  {str(e)[:200]}...")

except Exception as e:
    print(f"[FAIL] Test 3.3: Unexpected error: {type(e).__name__}: {e}")


Test 3.3: NaN in fused_graph (should FAIL)
[PASS] Test 3.3: Caught NaN in input
  Error contains 'NaN': True
  Error message (first 200 chars):
  [METALEARNER] fused_graph contains NaN!
  NaN count: 1
  Shape: torch.Size([2, 3965, 32])...


# Test Suite 4: Integration Test

Test all three modules together in a complete forward pass.

In [None]:
print("\n" + "="*80)
print("TEST SUITE 4: Integration Test (GNN -> Fusion -> MetaLearner)")
print("="*80)

try:
    # Step 1: GNN forward pass
    print("\nStep 1: GNN forward pass...")
    # Use input_dim (16) defined earlier, NOT config.graph_input_dim (removed)
    x_gnn = torch.randn(batch_size, n_nodes, input_dim, device=device)
    gnn_output = gnn(x_gnn, edge_index)
    print(f"  GNN output: {gnn_output.shape}")

    # Step 2: Mock ConvLSTM output (in real model, this comes from ConvLSTM branch)
    print("\nStep 2: Mock ConvLSTM output...")
    convlstm_output = torch.randn(batch_size, config.n_lat, config.n_lon, config.convlstm_hidden_dim, device=device)
    print(f"  ConvLSTM output: {convlstm_output.shape}")

    # Step 3: Fusion forward pass
    print("\nStep 3: Fusion forward pass...")
    fused_grid, fused_graph = fusion(convlstm_output, gnn_output)
    print(f"  Fused grid:  {fused_grid.shape}")
    print(f"  Fused graph: {fused_graph.shape}")

    # Step 4: MetaLearner forward pass
    print("\nStep 4: MetaLearner forward pass...")
    predictions, weights = meta(fused_grid, fused_graph)
    print(f"  Predictions: {predictions.shape}")
    print(f"  Weights:     {weights.shape}")

    # Validate final output
    print("\n" + "="*80)
    print("[PASS] Integration Test")
    print("="*80)
    print(f"  Final predictions shape: {predictions.shape}")
    print(f"  Expected: ({batch_size}, {config.horizon}, {config.n_lat}, {config.n_lon})")
    assert predictions.shape == (batch_size, config.horizon, config.n_lat, config.n_lon)
    print("  [OK] Integration test passed!")

    # Analyze branch weights
    print("\n" + "="*80)
    print("Branch Weight Analysis")
    print("="*80)
    mean_weights = weights.mean(dim=[0, 1])  # Average over batch and nodes
    print(f"  ConvLSTM contribution: {100*mean_weights[0]:.2f}%")
    print(f"  GNN contribution:      {100*mean_weights[1]:.2f}%")

    # Check no NaN in outputs
    print(f"\n  NaN check:")
    print(f"    Predictions: {torch.isnan(predictions).any().item()}")
    print(f"    Weights:     {torch.isnan(weights).any().item()}")
    assert not torch.isnan(predictions).any()
    assert not torch.isnan(weights).any()
    print("  [OK] No NaN in outputs")

except Exception as e:
    print(f"\n[FAIL] Integration Test: {e}")
    import traceback
    traceback.print_exc()

# Test Summary

Display results of all tests.

In [19]:
print("\n" + "="*80)
print("TEST SUMMARY")
print("="*80)
print("\nIf you see this cell, all critical tests have completed.")
print("\nExpected Results:")
print("  [PASS] Test 1.1: Valid GNN inputs")
print("  [PASS] Test 1.2: Caught wrong GNN dimension")
print("  [PASS] Test 1.3: Caught NaN in GNN input")
print("  [PASS] Test 1.4: Caught invalid edge indices")
print("  [PASS] Test 2.1: Valid Fusion inputs")
print("  [PASS] Test 2.2: Caught wrong Fusion dimension")
print("  [PASS] Test 2.3: Caught device mismatch (or SKIP if no CUDA)")
print("  [PASS] Test 3.1: Valid MetaLearner inputs")
print("  [PASS] Test 3.2: Caught wrong MetaLearner dimension")
print("  [PASS] Test 3.3: Caught NaN in MetaLearner input")
print("  [PASS] Integration Test: Full pipeline")
print("\n" + "="*80)
print("All tests completed successfully!")
print("="*80)
print("\nNext Steps:")
print("1. Review any [FAIL] results above")
print("2. If all tests pass, proceed to training in main V5 notebook")
print("3. Monitor for CUDA errors during actual training")
print("4. Use validation=False in production for ~5% speedup after validation")


TEST SUMMARY

If you see this cell, all critical tests have completed.

Expected Results:
  [PASS] Test 1.1: Valid GNN inputs
  [PASS] Test 1.2: Caught wrong GNN dimension
  [PASS] Test 1.3: Caught NaN in GNN input
  [PASS] Test 1.4: Caught invalid edge indices
  [PASS] Test 2.1: Valid Fusion inputs
  [PASS] Test 2.2: Caught wrong Fusion dimension
  [PASS] Test 2.3: Caught device mismatch (or SKIP if no CUDA)
  [PASS] Test 3.1: Valid MetaLearner inputs
  [PASS] Test 3.2: Caught wrong MetaLearner dimension
  [PASS] Test 3.3: Caught NaN in MetaLearner input
  [PASS] Integration Test: Full pipeline

All tests completed successfully!

Next Steps:
1. Review any [FAIL] results above
2. If all tests pass, proceed to training in main V5 notebook
3. Monitor for CUDA errors during actual training
4. Use validation=False in production for ~5% speedup after validation
