# V5 GNN-ConvLSTM Stacking - Unit Tests

Functional tests for V5 components. CPU-only for portability.

In [1]:
# Force CPU mode
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from typing import Tuple, Optional
import math
import json
import re

device = 'cpu'
print(f"PyTorch: {torch.__version__} | Device: {device}")

PyTorch: 2.9.1 | Device: cpu


## V5Config

In [2]:
class V5Config:
    """V5 configuration for testing."""
    def __init__(self):
        self.n_lat = 61
        self.n_lon = 65
        self.n_nodes = self.n_lat * self.n_lon
        self.gnn_hidden_dim = 64
        self.gnn_num_layers = 2
        self.gnn_dropout = 0.1
        self.gnn_type = 'GAT'
        self.convlstm_hidden_dim = 64
        self.convlstm_num_layers = 2
        self.fusion_hidden_dim = 128
        self.fusion_heads = 4
        self.fusion_dropout = 0.1
        self.meta_hidden_dim = 64
        self.meta_dropout = 0.1
        self.horizon = 1
        self.learning_rate = 0.001
        self.weight_decay = 1e-5
        self.batch_size = 4
        self.epochs = 100
        self.early_stopping_patience = 10

    @property
    def graph_input_dim(self):
        raise AttributeError(
            "V5Config.graph_input_dim removed. Use n_features param: "
            "GNNBranch(config, n_features=16, n_nodes=...)"
        )

config = V5Config()
print(f"Config created: n_nodes={config.n_nodes}")

Config created: n_nodes=3965


## Graph Structure

In [3]:
def create_grid_graph(n_lat, n_lon, device='cpu'):
    """Create 4-connected grid graph."""
    edges = []
    for i in range(n_lat):
        for j in range(n_lon):
            idx = i * n_lon + j
            if j < n_lon - 1:
                edges.append([idx, idx + 1])
                edges.append([idx + 1, idx])
            if i < n_lat - 1:
                edges.append([idx, idx + n_lon])
                edges.append([idx + n_lon, idx])
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
    return edge_index

# Use small grid for tests
test_lat, test_lon = 5, 5
test_nodes = test_lat * test_lon
edge_index = create_grid_graph(test_lat, test_lon, device)
print(f"Test grid: {test_lat}x{test_lon} = {test_nodes} nodes, {edge_index.shape[1]} edges")

Test grid: 5x5 = 25 nodes, 80 edges


## GNNBranch

In [4]:
class GNNBranch(nn.Module):
    """Graph Neural Network branch."""

    def __init__(self, config, n_features: int, n_nodes: int,
                 gnn_type: str = None, validate: bool = True):
        super().__init__()
        self.config = config
        self.n_features = n_features
        self.n_nodes = n_nodes
        self.gnn_type = gnn_type or config.gnn_type
        self.validate = validate
        self.hidden_dim = config.gnn_hidden_dim
        self.num_layers = config.gnn_num_layers
        self.dropout = config.gnn_dropout

        # Build layers
        self.convs = nn.ModuleList()
        in_dim = n_features
        for i in range(self.num_layers):
            out_dim = self.hidden_dim
            if self.gnn_type == 'GCN':
                self.convs.append(GCNConv(in_dim, out_dim))
            elif self.gnn_type == 'GAT':
                self.convs.append(GATConv(in_dim, out_dim, heads=4, concat=False))
            elif self.gnn_type == 'SAGE':
                self.convs.append(SAGEConv(in_dim, out_dim))
            in_dim = out_dim

        self.dropout_layer = nn.Dropout(self.dropout)
        self.output_dim = self.hidden_dim

    def _validate_tensor(self, tensor, name):
        if self.validate:
            if torch.isnan(tensor).any():
                raise ValueError(f"{name} contains NaN")
            if torch.isinf(tensor).any():
                raise ValueError(f"{name} contains Inf")

    def forward(self, x, edge_index, edge_weight=None):
        self._validate_tensor(x, "GNNBranch input")

        for i, conv in enumerate(self.convs):
            if self.gnn_type == 'GCN' and edge_weight is not None:
                x = conv(x, edge_index, edge_weight)
            else:
                x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = torch.relu(x)
                x = self.dropout_layer(x)

        self._validate_tensor(x, "GNNBranch output")
        return x

print("GNNBranch defined")

GNNBranch defined


## GridGraphFusion

In [5]:
class GridGraphFusion(nn.Module):
    """Cross-attention fusion between grid and graph representations."""

    def __init__(self, grid_dim: int, graph_dim: int, config):
        super().__init__()
        self.grid_dim = grid_dim
        self.graph_dim = graph_dim
        self.hidden_dim = config.fusion_hidden_dim
        self.n_heads = config.fusion_heads
        self.dropout = config.fusion_dropout

        self.grid_proj = nn.Linear(grid_dim, self.hidden_dim)
        self.graph_proj = nn.Linear(graph_dim, self.hidden_dim)
        self.cross_attention = nn.MultiheadAttention(
            self.hidden_dim, self.n_heads, dropout=self.dropout, batch_first=True
        )
        self.output_proj = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        self.layer_norm = nn.LayerNorm(self.hidden_dim)
        self.output_dim = self.hidden_dim

    def forward(self, grid_features, graph_features):
        grid_proj = self.grid_proj(grid_features)
        graph_proj = self.graph_proj(graph_features)

        fused, _ = self.cross_attention(grid_proj, graph_proj, graph_proj)
        combined = torch.cat([grid_proj, fused], dim=-1)
        output = self.output_proj(combined)
        output = self.layer_norm(output)
        return output

print("GridGraphFusion defined")

GridGraphFusion defined


## MetaLearner

In [6]:
class MetaLearner(nn.Module):
    """Meta-learner for combining branch predictions."""

    def __init__(self, fusion_dim: int, config):
        super().__init__()
        self.fusion_dim = fusion_dim
        self.hidden_dim = config.meta_hidden_dim
        self.dropout = config.meta_dropout

        self.fc1 = nn.Linear(fusion_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.hidden_dim // 2)
        self.fc_weights = nn.Linear(self.hidden_dim // 2, 2)
        self.fc_pred = nn.Linear(self.hidden_dim // 2, 1)
        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(self, fusion_features):
        x = torch.relu(self.fc1(fusion_features))
        x = self.dropout_layer(x)
        x = torch.relu(self.fc2(x))

        weights = torch.softmax(self.fc_weights(x), dim=-1)
        prediction = self.fc_pred(x)
        return prediction, weights

print("MetaLearner defined")

MetaLearner defined


## Test Suite 1: GNNBranch

In [7]:
print("=" * 60)
print("TEST SUITE 1: GNNBranch")
print("=" * 60)

passed, failed = 0, 0

# Test 1.1: Basic forward pass
try:
    gnn = GNNBranch(config, n_features=16, n_nodes=test_nodes, gnn_type='GAT').to(device)
    x = torch.randn(test_nodes, 16, device=device)
    out = gnn(x, edge_index)
    assert out.shape == (test_nodes, config.gnn_hidden_dim)
    print("[PASS] 1.1 Basic forward pass")
    passed += 1
except Exception as e:
    print(f"[FAIL] 1.1 Basic forward pass: {e}")
    failed += 1

# Test 1.2: Batch processing
try:
    x_batch = torch.randn(2, test_nodes, 16, device=device)
    out_batch = torch.stack([gnn(x_batch[i], edge_index) for i in range(2)])
    assert out_batch.shape == (2, test_nodes, config.gnn_hidden_dim)
    print("[PASS] 1.2 Batch processing")
    passed += 1
except Exception as e:
    print(f"[FAIL] 1.2 Batch processing: {e}")
    failed += 1

# Test 1.3: GNN types
for gnn_type in ['GCN', 'GAT', 'SAGE']:
    try:
        gnn_test = GNNBranch(config, n_features=16, n_nodes=test_nodes, gnn_type=gnn_type).to(device)
        out = gnn_test(x, edge_index)
        assert out.shape == (test_nodes, config.gnn_hidden_dim)
        print(f"[PASS] 1.3 GNN type: {gnn_type}")
        passed += 1
    except Exception as e:
        print(f"[FAIL] 1.3 GNN type {gnn_type}: {e}")
        failed += 1

# Test 1.4: NaN detection
try:
    x_nan = torch.randn(test_nodes, 16, device=device)
    x_nan[0, 0] = float('nan')
    try:
        gnn(x_nan, edge_index)
        print("[FAIL] 1.4 NaN detection - should have raised error")
        failed += 1
    except ValueError as e:
        if "NaN" in str(e):
            print("[PASS] 1.4 NaN detection")
            passed += 1
        else:
            raise e
except Exception as e:
    print(f"[FAIL] 1.4 NaN detection: {e}")
    failed += 1

print(f"\nSuite 1 Results: {passed} passed, {failed} failed")

TEST SUITE 1: GNNBranch
[PASS] 1.1 Basic forward pass
[PASS] 1.2 Batch processing
[PASS] 1.3 GNN type: GCN
[PASS] 1.3 GNN type: GAT
[PASS] 1.3 GNN type: SAGE
[PASS] 1.4 NaN detection

Suite 1 Results: 6 passed, 0 failed


## Test Suite 2: GridGraphFusion

In [8]:
print("=" * 60)
print("TEST SUITE 2: GridGraphFusion")
print("=" * 60)

passed, failed = 0, 0

# Test 2.1: Basic fusion
try:
    fusion = GridGraphFusion(64, 64, config).to(device)
    grid_feat = torch.randn(4, test_nodes, 64, device=device)
    graph_feat = torch.randn(4, test_nodes, 64, device=device)
    out = fusion(grid_feat, graph_feat)
    assert out.shape == (4, test_nodes, config.fusion_hidden_dim)
    print("[PASS] 2.1 Basic fusion")
    passed += 1
except Exception as e:
    print(f"[FAIL] 2.1 Basic fusion: {e}")
    failed += 1

# Test 2.2: Different input dimensions
try:
    fusion2 = GridGraphFusion(32, 128, config).to(device)
    grid_feat = torch.randn(4, test_nodes, 32, device=device)
    graph_feat = torch.randn(4, test_nodes, 128, device=device)
    out = fusion2(grid_feat, graph_feat)
    assert out.shape == (4, test_nodes, config.fusion_hidden_dim)
    print("[PASS] 2.2 Different input dimensions")
    passed += 1
except Exception as e:
    print(f"[FAIL] 2.2 Different input dimensions: {e}")
    failed += 1

# Test 2.3: Single sample
try:
    grid_feat = torch.randn(1, test_nodes, 64, device=device)
    graph_feat = torch.randn(1, test_nodes, 64, device=device)
    out = fusion(grid_feat, graph_feat)
    assert out.shape == (1, test_nodes, config.fusion_hidden_dim)
    print("[PASS] 2.3 Single sample")
    passed += 1
except Exception as e:
    print(f"[FAIL] 2.3 Single sample: {e}")
    failed += 1

print(f"\nSuite 2 Results: {passed} passed, {failed} failed")

TEST SUITE 2: GridGraphFusion
[PASS] 2.1 Basic fusion
[PASS] 2.2 Different input dimensions
[PASS] 2.3 Single sample

Suite 2 Results: 3 passed, 0 failed


## Test Suite 3: MetaLearner

In [9]:
print("=" * 60)
print("TEST SUITE 3: MetaLearner")
print("=" * 60)

passed, failed = 0, 0

# Test 3.1: Basic forward
try:
    meta = MetaLearner(128, config).to(device)
    fusion_feat = torch.randn(4, test_nodes, 128, device=device)
    pred, weights = meta(fusion_feat)
    assert pred.shape == (4, test_nodes, 1)
    assert weights.shape == (4, test_nodes, 2)
    print("[PASS] 3.1 Basic forward")
    passed += 1
except Exception as e:
    print(f"[FAIL] 3.1 Basic forward: {e}")
    failed += 1

# Test 3.2: Weights sum to 1
try:
    weight_sums = weights.sum(dim=-1)
    assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5)
    print("[PASS] 3.2 Weights sum to 1")
    passed += 1
except Exception as e:
    print(f"[FAIL] 3.2 Weights sum to 1: {e}")
    failed += 1

# Test 3.3: Weights in [0, 1]
try:
    assert (weights >= 0).all() and (weights <= 1).all()
    print("[PASS] 3.3 Weights in [0, 1]")
    passed += 1
except Exception as e:
    print(f"[FAIL] 3.3 Weights in [0, 1]: {e}")
    failed += 1

print(f"\nSuite 3 Results: {passed} passed, {failed} failed")

TEST SUITE 3: MetaLearner
[PASS] 3.1 Basic forward
[PASS] 3.2 Weights sum to 1
[PASS] 3.3 Weights in [0, 1]

Suite 3 Results: 3 passed, 0 failed


## Test Suite 4: V5Config

In [10]:
print("=" * 60)
print("TEST SUITE 4: V5Config")
print("=" * 60)

passed, failed = 0, 0

# Test 4.1: Required attributes
required = ['n_lat', 'n_lon', 'n_nodes', 'gnn_hidden_dim', 'gnn_num_layers',
            'gnn_dropout', 'convlstm_hidden_dim', 'fusion_hidden_dim', 'meta_hidden_dim']
try:
    for attr in required:
        assert hasattr(config, attr), f"Missing {attr}"
    print("[PASS] 4.1 Required attributes present")
    passed += 1
except Exception as e:
    print(f"[FAIL] 4.1 Required attributes: {e}")
    failed += 1

# Test 4.2: Removed attribute error
try:
    _ = config.graph_input_dim
    print("[FAIL] 4.2 graph_input_dim - should raise AttributeError")
    failed += 1
except AttributeError as e:
    if "removed" in str(e).lower() or "n_features" in str(e).lower():
        print("[PASS] 4.2 graph_input_dim raises helpful error")
        passed += 1
    else:
        print(f"[FAIL] 4.2 graph_input_dim - wrong error message: {e}")
        failed += 1

# Test 4.3: n_nodes calculation
try:
    assert config.n_nodes == config.n_lat * config.n_lon
    print("[PASS] 4.3 n_nodes = n_lat * n_lon")
    passed += 1
except Exception as e:
    print(f"[FAIL] 4.3 n_nodes calculation: {e}")
    failed += 1

print(f"\nSuite 4 Results: {passed} passed, {failed} failed")

TEST SUITE 4: V5Config
[PASS] 4.1 Required attributes present
[PASS] 4.2 graph_input_dim raises helpful error
[PASS] 4.3 n_nodes = n_lat * n_lon

Suite 4 Results: 3 passed, 0 failed


## Test Suite 5: Integration

In [11]:
print("=" * 60)
print("TEST SUITE 5: Integration")
print("=" * 60)

passed, failed = 0, 0

# Test 5.1: Full pipeline
try:
    batch_size = 2
    n_features = 16

    # Create modules
    gnn = GNNBranch(config, n_features=n_features, n_nodes=test_nodes, gnn_type='GAT').to(device)
    fusion = GridGraphFusion(gnn.output_dim, gnn.output_dim, config).to(device)
    meta = MetaLearner(fusion.output_dim, config).to(device)

    # Forward pass
    x = torch.randn(batch_size, test_nodes, n_features, device=device)
    gnn_out = torch.stack([gnn(x[i], edge_index) for i in range(batch_size)])
    grid_feat = torch.randn(batch_size, test_nodes, gnn.output_dim, device=device)
    fused = fusion(grid_feat, gnn_out)
    pred, weights = meta(fused)

    assert pred.shape == (batch_size, test_nodes, 1)
    assert weights.shape == (batch_size, test_nodes, 2)
    print("[PASS] 5.1 Full pipeline forward")
    passed += 1
except Exception as e:
    print(f"[FAIL] 5.1 Full pipeline: {e}")
    failed += 1

# Test 5.2: Gradient flow
try:
    pred.sum().backward()
    # Check gradients exist
    for name, param in gnn.named_parameters():
        if param.grad is not None:
            assert not torch.isnan(param.grad).any(), f"NaN in {name} grad"
    print("[PASS] 5.2 Gradient flow (no NaN)")
    passed += 1
except Exception as e:
    print(f"[FAIL] 5.2 Gradient flow: {e}")
    failed += 1

print(f"\nSuite 5 Results: {passed} passed, {failed} failed")

TEST SUITE 5: Integration
[PASS] 5.1 Full pipeline forward
[PASS] 5.2 Gradient flow (no NaN)

Suite 5 Results: 2 passed, 0 failed


## Test Suite 6: Edge Cases

In [12]:
print("=" * 60)
print("TEST SUITE 6: Edge Cases")
print("=" * 60)

passed, failed = 0, 0

# Test 6.1: Single node graph
try:
    single_edge = torch.tensor([[0], [0]], device=device)
    gnn_single = GNNBranch(config, n_features=16, n_nodes=1, gnn_type='GCN').to(device)
    x_single = torch.randn(1, 16, device=device)
    out = gnn_single(x_single, single_edge)
    assert out.shape == (1, config.gnn_hidden_dim)
    print("[PASS] 6.1 Single node graph")
    passed += 1
except Exception as e:
    print(f"[FAIL] 6.1 Single node: {e}")
    failed += 1

# Test 6.2: Very small features
try:
    gnn_small = GNNBranch(config, n_features=1, n_nodes=test_nodes, gnn_type='GCN').to(device)
    x_small = torch.randn(test_nodes, 1, device=device)
    out = gnn_small(x_small, edge_index)
    assert out.shape == (test_nodes, config.gnn_hidden_dim)
    print("[PASS] 6.2 Single feature dimension")
    passed += 1
except Exception as e:
    print(f"[FAIL] 6.2 Single feature: {e}")
    failed += 1

# Test 6.3: Large feature dimension
try:
    gnn_large = GNNBranch(config, n_features=512, n_nodes=test_nodes, gnn_type='GCN').to(device)
    x_large = torch.randn(test_nodes, 512, device=device)
    out = gnn_large(x_large, edge_index)
    assert out.shape == (test_nodes, config.gnn_hidden_dim)
    print("[PASS] 6.3 Large feature dimension")
    passed += 1
except Exception as e:
    print(f"[FAIL] 6.3 Large feature: {e}")
    failed += 1

# Test 6.4: Zero input
try:
    x_zero = torch.zeros(test_nodes, 16, device=device)
    gnn_test = GNNBranch(config, n_features=16, n_nodes=test_nodes, gnn_type='GCN', validate=False).to(device)
    out = gnn_test(x_zero, edge_index)
    assert not torch.isnan(out).any()
    print("[PASS] 6.4 Zero input")
    passed += 1
except Exception as e:
    print(f"[FAIL] 6.4 Zero input: {e}")
    failed += 1

print(f"\nSuite 6 Results: {passed} passed, {failed} failed")

TEST SUITE 6: Edge Cases
[PASS] 6.1 Single node graph
[PASS] 6.2 Single feature dimension
[PASS] 6.3 Large feature dimension
[PASS] 6.4 Zero input

Suite 6 Results: 4 passed, 0 failed


## Test Suite 7: Static Notebook Validation

In [13]:
print("=" * 60)
print("TEST SUITE 7: Static Notebook Validation")
print("=" * 60)

passed, failed = 0, 0

MAIN_NOTEBOOK = 'base_models_gnn_convlstm_stacking_v5.ipynb'

# Test 7.1: Valid JSON
try:
    with open(MAIN_NOTEBOOK, encoding='utf-8') as f:
        main_nb = json.load(f)
    print("[PASS] 7.1 Main notebook is valid JSON")
    passed += 1
except Exception as e:
    print(f"[FAIL] 7.1 JSON parsing: {e}")
    failed += 1
    main_nb = None

if main_nb:
    # Test 7.2: No config.graph_input_dim
    try:
        bad_pattern = re.compile(r'config\.graph_input_dim')
        found = []
        for i, cell in enumerate(main_nb['cells']):
            src = ''.join(cell.get('source', []))
            if bad_pattern.search(src):
                found.append(i)
        if found:
            print(f"[FAIL] 7.2 Found config.graph_input_dim in cells: {found}")
            failed += 1
        else:
            print("[PASS] 7.2 No config.graph_input_dim references")
            passed += 1
    except Exception as e:
        print(f"[FAIL] 7.2: {e}")
        failed += 1

    # Test 7.3: GNNBranch has n_features param
    try:
        gnn_pattern = re.compile(r'GNNBranch\s*\([^)]*n_features')
        found_correct = False
        for cell in main_nb['cells']:
            src = ''.join(cell.get('source', []))
            if gnn_pattern.search(src):
                found_correct = True
                break
        if found_correct:
            print("[PASS] 7.3 GNNBranch uses n_features parameter")
            passed += 1
        else:
            print("[FAIL] 7.3 GNNBranch missing n_features parameter")
            failed += 1
    except Exception as e:
        print(f"[FAIL] 7.3: {e}")
        failed += 1

print(f"\nSuite 7 Results: {passed} passed, {failed} failed")

TEST SUITE 7: Static Notebook Validation
[PASS] 7.1 Main notebook is valid JSON
[PASS] 7.2 No config.graph_input_dim references
[PASS] 7.3 GNNBranch uses n_features parameter

Suite 7 Results: 3 passed, 0 failed


## Summary

In [14]:
print("=" * 60)
print("V5 TEST SUMMARY")
print("=" * 60)

print("""
Test Suites:
1. GNNBranch      - Forward pass, batch, types, NaN
2. GridGraphFusion - Basic fusion, dims, single sample
3. MetaLearner    - Forward, weights sum, bounds
4. V5Config       - Required attrs, removed attr
5. Integration    - Full pipeline, gradients
6. Edge Cases     - Single node, feature dims, zero
7. Static         - JSON valid, no bad patterns

All tests use CPU mode for portability.
Run in Colab or local Jupyter.
""")

V5 TEST SUMMARY

Test Suites:
1. GNNBranch      - Forward pass, batch, types, NaN
2. GridGraphFusion - Basic fusion, dims, single sample
3. MetaLearner    - Forward, weights sum, bounds
4. V5Config       - Required attrs, removed attr
5. Integration    - Full pipeline, gradients
6. Edge Cases     - Single node, feature dims, zero
7. Static         - JSON valid, no bad patterns

All tests use CPU mode for portability.
Run in Colab or local Jupyter.

