In [None]:
%load_ext autoreload
%autoreload 2

from src.sampling.main import stratified_spatial_kfold_dual

import torch
from torch_geometric.data import Data
from src.raingauge.utils import (
    get_station_coordinate_mappings,
    load_weather_station_dataset,
)
import pandas as pd
import numpy as np
import tqdm
import random
import matplotlib.pyplot as plt
import time
from scipy.stats import pearsonr
from src.radar.utils import load_radar_dataset
from src.visualization.radar import (
    visualize_one_radar_image,
)
import matplotlib as mpl
from models.gnn import GNNInductive
from datetime import datetime
from src.performance_logger import PerformanceLogger
import os
from src.utils import (
    build_train_and_full_graph_homogeneous,
    add_homogeneous_weather_station_data,
    add_homogeneous_mask_to_data,
    generate_homogeneous_edges,
    add_homogeneous_edge_attributes_to_data,
    prepare_homogeneous_inductive_dataset,
    debug_dataloader,
)
import networkx as nx
from matplotlib.patches import Patch
import torch.nn.functional as F

In [None]:
# NOTE: Geographic extent of Singapore in longitude and latitude
bounds_singapore = {"left": 103.6, "right": 104.1, "top": 1.5, "bottom": 1.188}
bounds = [0.1, 0.2, 0.5, 1, 2, 4, 7, 10, 20]
norm = mpl.colors.BoundaryNorm(boundaries=bounds, ncolors=256, extend="both")

experiment_name = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_new"
os.makedirs(f"experiments/{experiment_name}", exist_ok=True)
perf = PerformanceLogger(f"experiments/{experiment_name}/training_log.jsonl")

# Preprocess Data

In [None]:
data = Data()
data1 = Data()
data2 = Data()
data3 = Data()
data4 = Data()
dtype = torch.float32


# Preprocess station data.
Some stations only contain rainfall information but some stations contain both rainfall and other information.
We will split these stations into weather station and general stations 

Additional info: 
  Windspeed
  Wind Direction
  Temperature
  Relative Humidity

In [None]:
weather_station_data = load_weather_station_dataset("weather_station_data.csv")
weather_station_locations = get_station_coordinate_mappings()
print(len(weather_station_locations.keys()))
print(len(set(weather_station_data["gid"].values)))
cols = list(weather_station_data.columns)
cols.remove("time_sgt")
cols.remove("gid")

weather_station_df_pivot = (
    pd.pivot(data=weather_station_data, index="time_sgt", columns="gid", values=cols)
    .resample("15min")
    .first()
)
weather_station_df_pivot["rain_rate"] = weather_station_df_pivot["rain_rate"] * 12
weather_station_df_counts = weather_station_df_pivot.count().reset_index()

weather_station_info = pd.pivot(
    data=weather_station_df_counts, index="gid", columns="level_0"
)

pd.set_option("display.max_rows", None)

rainfall_station = [
    row[0] for row in weather_station_info.iterrows() if 0 in row[1].value_counts()
]
general_station = [s for s in weather_station_locations if s not in rainfall_station]

print(rainfall_station)
print(general_station)
count = 0
for row in weather_station_df_pivot["rain_rate"].iterrows():
    if np.nansum(row[1].to_numpy()) != 0:
        count += 1
print(f"Number of timesteps that contain rain: {count}")
print(f"Total_timesteps = {weather_station_df_pivot.shape[0]}")

# After loading weather_station_df_pivot
print("--- Station Data Stats ---")
print(weather_station_df_pivot.describe())


In [None]:
general_station_data = {}
rainfall_station_data = {}

# TODO: Temporal Data Leakage - Filling missing values in the training set using all data including validation or test set is wrong.
# Extract and interpolate station data
for station in weather_station_df_pivot.columns.get_level_values(1).unique():
    station_cols = (
        weather_station_df_pivot.xs(station, level=1, axis=1)
        .interpolate(method="linear")
        .fillna(method="ffill")
        .fillna(method="bfill")
    )
    if station in general_station:
        general_station_data[station] = station_cols.values
    else:
        rainfall_station_data[station] = station_cols.values[:, 0:1]

general_station_temp = [stn for stn in general_station if stn != "S108"]
general_station = general_station_temp

# Prepare features in the correct order
general_station_features = []
rainfall_station_features = []
general_station_ids = []
rainfall_station_ids = []

for station in general_station:
    station_feat = general_station_data[station]
    general_station_features.append(station_feat)
    general_station_ids.append(station)

for station in rainfall_station:
    station_feat = rainfall_station_data[station]
    rainfall_station_features.append(station_feat)
    rainfall_station_ids.append(station)


# Add Station Features to HeteroData Class

In [None]:
# Add station features to HeteroData

include_metastation_info = False
data = add_homogeneous_weather_station_data(
    data,
    general_station_features,
    rainfall_station_features,
    general_station_ids,
    rainfall_station_ids,
    dtype=dtype,
)
data1 = add_homogeneous_weather_station_data(
    data1,
    general_station_features,
    rainfall_station_features,
    general_station_ids,
    rainfall_station_ids,
    dtype=dtype,
)
data2 = add_homogeneous_weather_station_data(
    data2,
    general_station_features,
    rainfall_station_features,
    general_station_ids,
    rainfall_station_ids,
    dtype=dtype,
)
data3 = add_homogeneous_weather_station_data(
    data3,
    general_station_features,
    rainfall_station_features,
    general_station_ids,
    rainfall_station_ids,
    dtype=dtype,
)
data4 = add_homogeneous_weather_station_data(
    data4,
    general_station_features,
    rainfall_station_features,
    general_station_ids,
    rainfall_station_ids,
    dtype=dtype,
)


# Stratified K Fold Spatial Sampling

In [None]:
split_info = stratified_spatial_kfold_dual(
    weather_station_locations, seed=123, plot=True
)
print(split_info)
stations = general_station + rainfall_station


In [None]:
data = add_homogeneous_mask_to_data(data, split_info[0], stations)
data1 = add_homogeneous_mask_to_data(data1, split_info[1], stations)
data2 = add_homogeneous_mask_to_data(data2, split_info[2], stations)
data3 = add_homogeneous_mask_to_data(data3, split_info[3], stations)
data4 = add_homogeneous_mask_to_data(data4, split_info[4], stations)

print("Data: \n", data)
print("Data1: \n", data1)
print("Data2: \n", data2)
print("Data3: \n", data3)
print("Data4: \n", data4)

# Edge generation
We consider the location of the stations when performing our edge generation. 
General station locations and rainfall station locations will be considered the same and we will make a connection across the nodes if required. This will ensure that we can connect both the layers together in the graph.

In [None]:
K = 4  # Number of neighbors per node
edges, edge_attributes = generate_homogeneous_edges(
    weather_station_locations,
    stations=stations,
    K=K,
)

data = add_homogeneous_edge_attributes_to_data(
    data, edges, edge_attributes, dtype=dtype
)
data1 = add_homogeneous_edge_attributes_to_data(
    data1, edges, edge_attributes, dtype=dtype
)
data2 = add_homogeneous_edge_attributes_to_data(
    data2, edges, edge_attributes, dtype=dtype
)
data3 = add_homogeneous_edge_attributes_to_data(
    data3, edges, edge_attributes, dtype=dtype
)
data4 = add_homogeneous_edge_attributes_to_data(
    data4, edges, edge_attributes, dtype=dtype
)

print(data)
print(data1)
print(data2)
print(data3)
print(data4)

In [None]:
# ============================================================================
# FINAL SUMMARY
# ============================================================================
# print_data_structure(data)
# print_data_structure(data1)
# print_data_structure(data2)
# print_data_structure(data3)
# print_data_structure(data4)


In [None]:
train_graphs = []
validation_graphs = []
full_graphs = []

for fold_idx in range(5):
    train_g, val_g, full_g = build_train_and_full_graph_homogeneous(
        data, split_info[fold_idx], stations
    )
    train_graphs.append(train_g)
    validation_graphs.append(val_g)
    full_graphs.append(full_g)

print(
    f"Train Data: {train_graphs[0]}\nValidation Data: {validation_graphs[0]}\nData Full: {full_graphs[0]}"
)

full_graph = full_graphs[0]
validation_graph = validation_graphs[0]
train_graph = train_graphs[0]

print("=" * 70)
print("GRAPH STRUCTURE ANALYSIS")
print("=" * 70)

# ============================================
# FULL GRAPH SETUP
# ============================================
print("\n--- Full Graph ---")
station_count = full_graph.station_id.shape[0]  # 63 stations
print(f"Total stations: {station_count}")
print(f"Train nodes: {full_graph.train_mask.sum().item()}")
print(f"Val nodes: {full_graph.val_mask.sum().item()}")
print(f"Test nodes: {full_graph.test_mask.sum().item()}")

# Create NetworkX graph for full graph
G_full = nx.Graph()
G_full.add_nodes_from(range(station_count))

# Add edges (these are already in global node indices)
edges_full = full_graph.edge_index.numpy().T
G_full.add_edges_from(edges_full)

print(f"Full graph edges: {G_full.number_of_edges()}")

# ============================================
# TRAIN GRAPH SETUP (CORRECTED)
# ============================================
print("\n--- Train Graph ---")

# CRITICAL: train_graph uses LOCAL indices (0 to num_train_nodes-1)
# but train_graph.orig_id maps back to GLOBAL indices
num_train_nodes = train_graph.x.shape[1]
print(f"Train graph nodes: {num_train_nodes}")

# Get mapping: local_idx -> global_idx
if hasattr(train_graph, "orig_id"):
    orig_ids = train_graph.orig_id.numpy()
    print(f"orig_id mapping exists: {len(orig_ids)} mappings")
else:
    print("ERROR: train_graph missing 'orig_id' attribute!")
    # Fallback: assume train nodes are the first ones with train_mask=True in full graph
    orig_ids = np.where(full_graph.train_mask.numpy() | full_graph.val_mask.numpy())[0]
    print(f"Reconstructed orig_id from masks: {len(orig_ids)} mappings")

# Create reverse mapping: global_idx -> local_idx
global_to_local = {int(g): i for i, g in enumerate(orig_ids)}

print(f"Train graph local indices: 0 to {num_train_nodes - 1}")
print(f"Train graph global indices: {orig_ids[:5]}... (first 5)")

# Create NetworkX graph for train graph
G_train = nx.Graph()
G_train.add_nodes_from(range(num_train_nodes))

# Add edges (train_graph.edge_index uses LOCAL indices)
edges_train_local = train_graph.edge_index.numpy().T
G_train.add_edges_from(edges_train_local)

print(f"Train graph edges: {G_train.number_of_edges()}")

# ============================================
# VALIDATION GRAPH (train + val nodes, re-colored)
# ============================================
print("\n--- Validation Graph (train + val nodes) ---")

# The validation graph uses the train+val node
num_valgraph_nodes = validation_graph.x.shape[1]  # identical to train graph
print(f"Validation graph nodes: {num_valgraph_nodes}")

# Get mapping: local_idx -> global_idx
if hasattr(validation_graph, "orig_id"):
    val_orig_ids = validation_graph.orig_id.numpy()
    print(f"orig_id mapping exists: {len(val_orig_ids)} mappings")
else:
    print("ERROR: train_graph missing 'orig_id' attribute!")
    # Fallback: assume train nodes are the first ones with train_mask=True in full graph
    val_orig_ids = np.where(full_graph.train_mask.numpy())[0]
    print(f"Reconstructed orig_id from masks: {len(val_orig_ids)} mappings")

# Create NetworkX graph
G_valgraph = nx.Graph()
G_valgraph.add_nodes_from(range(num_valgraph_nodes))

# Edges are the train+val only
edges_val_local = validation_graph.edge_index.numpy().T
G_valgraph.add_edges_from(edges_val_local)

print(f"Validation graph edges: {G_valgraph.number_of_edges()}")

# ============================================
# GENERATE GEOGRAPHICAL LAYOUT
# ============================================
print("\n" + "=" * 70)
print("GENERATING GEOGRAPHICAL POSITIONS")
print("=" * 70)

# --- Full Graph Positions ---
# Key: node_idx (0 to 62) -> (lon, lat)
pos_full = {}
for node_idx in range(station_count):
    station_str_id = stations[full_graph.station_id[node_idx].item()]
    if station_str_id in weather_station_locations:
        lat, lon = weather_station_locations[station_str_id]
        pos_full[node_idx] = (lon, lat)  # NetworkX uses (x, y) = (lon, lat)
    else:
        print(f"WARNING: Station {station_str_id} not found in locations")
        pos_full[node_idx] = (0, 0)  # Default position

print(f"Full graph positions generated: {len(pos_full)}")

# --- Train Graph Positions (CORRECTED) ---
# Key: local_idx (0 to 54) -> (lon, lat)
# Use orig_id to map back to global station indices
pos_train = {}
for local_idx in range(num_train_nodes):
    global_idx = int(orig_ids[local_idx])

    # Get station string ID from full graph
    station_str_id = stations[full_graph.station_id[global_idx].item()]

    if station_str_id in weather_station_locations:
        lat, lon = weather_station_locations[station_str_id]
        pos_train[local_idx] = (lon, lat)
    else:
        print(f"WARNING: Station {station_str_id} not found in locations")
        pos_train[local_idx] = (0, 0)

print(f"Train graph positions generated: {len(pos_train)}")

# --- Validation Graph Pos (same positions as train graph)
pos_valgraph = {}
for local_idx in range(num_valgraph_nodes):
    global_idx = int(val_orig_ids[local_idx])
    station_str_id = stations[full_graph.station_id[global_idx].item()]

    if station_str_id in weather_station_locations:
        lat, lon = weather_station_locations[station_str_id]
        pos_valgraph[local_idx] = (lon, lat)
    else:
        pos_valgraph[local_idx] = (0, 0)

print(f"Validation graph positions generated: {len(pos_valgraph)}")
# ============================================
# CREATE CONSISTENT COLOR MAPS
# ============================================
print("\n" + "=" * 70)
print("CREATING COLOR MAPS")
print("=" * 70)

# --- Full Graph Colors ---
color_map_full = []
node_labels_full = {}

for node_idx in range(station_count):
    if full_graph.train_mask[node_idx]:
        color_map_full.append("green")
        label = "Train"
    elif full_graph.val_mask[node_idx]:
        color_map_full.append("blue")
        label = "Val"
    elif full_graph.test_mask[node_idx]:
        color_map_full.append("red")
        label = "Test"
    else:
        color_map_full.append("gray")
        label = "Unknown"

    # Optional: add node labels showing global index
    node_labels_full[node_idx] = f"{node_idx}"

print(
    f"Full graph - Train: {color_map_full.count('green')}, "
    f"Val: {color_map_full.count('blue')}, "
    f"Test: {color_map_full.count('red')}"
)

# --- Train Graph Colors (CORRECTED) ---
# Must use train_graph masks directly (already in local indices)
color_map_train = []
node_labels_train = {}

for local_idx in range(num_train_nodes):
    # Use train_graph masks (these are in local indices)
    if train_graph.train_mask[local_idx]:
        color_map_train.append("green")
        label = "Train"
    elif train_graph.val_mask[local_idx]:
        color_map_train.append("blue")
        label = "Val"
    elif (
        train_graph.test_mask[local_idx] if hasattr(train_graph, "test_mask") else False
    ):
        color_map_train.append("red")
        label = "Test"
    else:
        color_map_train.append("gray")
        label = "Unknown"

    # Show both local and global indices
    global_idx = int(orig_ids[local_idx])
    node_labels_train[local_idx] = f"{local_idx}\n({global_idx})"

print(
    f"Train graph - Train: {color_map_train.count('green')}, "
    f"Val: {color_map_train.count('blue')}, "
    f"Test: {color_map_train.count('red')}"
)

# --- Validation Graph Colors
color_map_val = []
node_labels_val = {}

for local_idx in range(num_valgraph_nodes):
    if validation_graph.train_mask[local_idx]:
        color_map_val.append("green")
        node_type = "Train"
    elif validation_graph.val_mask[local_idx]:
        color_map_val.append("blue")
        node_type = "Val"
    else:
        color_map_val.append("gray")
        node_type = "Other"

    global_idx = int(val_orig_ids[local_idx])
    node_labels_val[local_idx] = f"{local_idx}\n({global_idx})"

print(
    f"Validation graph - Train: {color_map_val.count('green')}, "
    f"Val: {color_map_val.count('blue')}, "
    f"Test: {color_map_val.count('red')}"
)

# ============================================
# VERIFICATION: Check Consistency
# ============================================
print("\n" + "=" * 70)
print("CONSISTENCY VERIFICATION")
print("=" * 70)

# Check: Train nodes in train_graph should match train/val nodes in full_graph
train_val_in_full = set(
    np.where((full_graph.train_mask | full_graph.val_mask).numpy())[0]
)
train_nodes_mapped = set(orig_ids)

if train_val_in_full == train_nodes_mapped:
    print("✅ Node sets are consistent!")
else:
    print("⚠️ WARNING: Node sets don't match!")
    print(f"   Full graph train+val: {len(train_val_in_full)} nodes")
    print(f"   Train graph (via orig_id): {len(train_nodes_mapped)} nodes")
    print(f"   Difference: {train_val_in_full - train_nodes_mapped}")

# Check: Color distribution should match
full_train_count = sum(1 for i in orig_ids if full_graph.train_mask[i])
train_train_count = color_map_train.count("green")

print("\nTrain node count:")
print(f"   Full graph (for train graph nodes): {full_train_count}")
print(f"   Train graph: {train_train_count}")
print(f"   Match: {'✅' if full_train_count == train_train_count else '❌'}")

# ============================================
# DRAW THE PLOTS
# ============================================
print("\n" + "=" * 70)
print("DRAWING GRAPHS")
print("=" * 70)

fig, axes = plt.subplots(1, 3, figsize=(30, 10))

# --- Plot 1: Full Graph ---
ax = axes[0]
nx.draw(
    G_full,
    pos_full,
    node_color=color_map_full,
    with_labels=True,
    labels=node_labels_full,
    node_size=400,
    font_size=7,
    font_weight="bold",
    edge_color="gray",
    width=1.5,
    ax=ax,
)
ax.set_title(
    f"Full Graph ({station_count} stations)\n"
    f"Train={color_map_full.count('green')}, "
    f"Val={color_map_full.count('blue')}, "
    f"Test={color_map_full.count('red')}",
    fontsize=14,
    fontweight="bold",
)

# --- Plot 2: Train Graph ---
ax = axes[1]
nx.draw(
    G_train,
    pos_train,
    node_color=color_map_train,
    with_labels=True,
    labels=node_labels_train,
    node_size=400,
    font_size=7,
    font_weight="bold",
    edge_color="gray",
    width=1.5,
    ax=ax,
)
ax.set_title(
    f"Train Graph ({num_train_nodes} stations)\n"
    f"Local(Global) indices shown\n"
    f"Train={color_map_train.count('green')}, "
    f"Val={color_map_train.count('blue')}",
    fontsize=14,
    fontweight="bold",
)

# --- Plot 3: Validation Graph (train + val nodes) ---
ax = axes[2]
nx.draw(
    G_valgraph,
    pos_valgraph,
    node_color=color_map_val,
    with_labels=True,
    labels=node_labels_val,
    node_size=400,
    font_size=7,
    font_weight="bold",
    edge_color="gray",
    width=1.5,
    ax=ax,
)
ax.set_title("Validation Graph (Train + Val Nodes)", fontsize=14, fontweight="bold")

# Add legend
legend_elements = [
    Patch(facecolor="green", label="Train"),
    Patch(facecolor="blue", label="Validation"),
    Patch(facecolor="red", label="Test"),
    Patch(facecolor="gray", label="Unknown"),
]
fig.legend(handles=legend_elements, loc="upper center", ncol=4, fontsize=12)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig("graph_comparison_corrected.png", dpi=150, bbox_inches="tight")
print("✅ Graphs saved to 'graph_comparison_corrected.png'")
plt.show()

# ============================================
# DETAILED MAPPING TABLE (for debugging)
# ============================================
print("\n" + "=" * 70)
print("NODE MAPPING TABLE (First 10 nodes)")
print("=" * 70)
print(
    f"{'Local Idx':<12} {'Global Idx':<12} {'Station ID':<15} {'Type in Train':<15} {'Type in Full':<15}"
)
print("-" * 70)

for local_idx in range(min(10, num_train_nodes)):
    global_idx = int(orig_ids[local_idx])
    station_str_id = stations[full_graph.station_id[global_idx].item()]

    # Type in train graph
    if train_graph.train_mask[local_idx]:
        type_train = "Train"
    elif train_graph.val_mask[local_idx]:
        type_train = "Val"
    else:
        type_train = "Other"

    # Type in full graph
    if full_graph.train_mask[global_idx]:
        type_full = "Train"
    elif full_graph.val_mask[global_idx]:
        type_full = "Val"
    elif full_graph.test_mask[global_idx]:
        type_full = "Test"
    else:
        type_full = "Other"

    print(
        f"{local_idx:<12} {global_idx:<12} {station_str_id:<15} {type_train:<15} {type_full:<15}"
    )

print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"""
✅ Full Graph: {station_count} nodes, {G_full.number_of_edges()} edges
   - Uses GLOBAL indices (0 to {station_count - 1})
   - Shows Train/Val/Test splits
   
✅ Train Graph: {num_train_nodes} nodes, {G_train.number_of_edges()} edges  
   - Uses LOCAL indices (0 to {num_train_nodes - 1})
   - Contains only Train nodes from full graph
   - Labels show: Local(Global) index mapping

✅ Train Graph: {num_valgraph_nodes} nodes, {G_valgraph.number_of_edges()} edges  
   - Uses LOCAL indices (0 to {num_valgraph_nodes - 1})
   - Contains only Train+Val nodes from full graph
   - Labels show: Local(Global) index mapping
   
✅ Consistency: Node colors and positions are now aligned!
   - Both graphs use the same geographical layout
   - Color coding matches across both visualizations
   - orig_id properly maps local -> global indices
""")

# Creating the GNN

In [None]:
hidden_channels = 4
in_channels = 1
out_channels = 1
num_layers = 8

model = GNNInductive(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    num_layers=num_layers,
)
model1 = GNNInductive(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    num_layers=num_layers,
)
model2 = GNNInductive(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    num_layers=num_layers,
)
model3 = GNNInductive(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    num_layers=num_layers,
)
model4 = GNNInductive(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    num_layers=num_layers,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device=device)
model1.to(device=device)
model2.to(device=device)
model3.to(device=device)
model4.to(device=device)

In [None]:
def train_epoch(
    model,
    dataloader,
    optimizer,
    device,
    verbose=False,
    log_file="training_gnn_new_debug.log",
    random_noise_masking=False,
    scheduler=None,
):
    """
    Corrected training loop with gradient debugging.
    """
    model.train()
    epoch_losses = []
    charge_bar = tqdm.tqdm(dataloader, desc="training")

    for batch_idx, batch in enumerate(charge_bar):
        optimizer.zero_grad()

        # PyG Batch object - move to device
        batch = batch.to(device)

        # Extract from PyG Batch format
        x = batch.x  # [B*N, F]
        y = batch.y  # [B*N, Tgt]
        mask = batch.mask  # [N] - PROBLEM: single mask for one graph
        edge_index = batch.edge_index
        edge_attr = batch.edge_attr if batch.edge_attr is not None else None
        num_graphs = batch.num_graphs

        # Properly replicate mask across batch
        # If mask is [N] and we have B graphs with N nodes each, replicate it
        if mask.shape[0] == x.shape[0] // num_graphs:
            # mask is per-graph, replicate for batch
            mask_expanded = mask.repeat(num_graphs)  # [B*N]
        else:
            mask_expanded = mask  # Already expanded

        assert mask_expanded.shape[0] == x.shape[0], (
            f"mask_expanded {mask_expanded.shape} != x {x.shape}"
        )

        # Check if any trainable nodes exist
        num_trainable = mask_expanded.sum().item()
        if num_trainable == 0:
            print(f"WARNING: No trainable nodes in batch {batch_idx}!")
            continue

        # Optionally add noise
        if random_noise_masking:
            noise = torch.randn_like(x) * 0.1
            x = x + noise

        # Mask features properly
        x_masked = x.clone()
        x_masked[~mask_expanded] = 0.0

        # Verify gradients can flow
        x_masked.requires_grad_(True)

        # Forward pass
        out = model(x_masked, edge_index, edge_attributes=edge_attr)

        # Compute loss ONLY on trainable nodes
        loss = F.mse_loss(out[mask_expanded], y[mask_expanded])

        if verbose or batch_idx == 0:
            print(f"loss: {loss.item()}, requires_grad: {loss.requires_grad}")

        # Check gradients before backward
        loss.backward()

        # Check if any gradients were computed
        total_grad_norm = 0.0
        num_params_with_grad = 0
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                total_grad_norm += grad_norm**2
                num_params_with_grad += 1
                if verbose and grad_norm > 1e-6:
                    print(f"  {name}: grad_norm={grad_norm:.2e}")
            elif verbose or batch_idx == 0:
                print(f"  {name}: NO GRADIENT")

        total_grad_norm = np.sqrt(total_grad_norm)

        if num_params_with_grad == 0:
            print(f"ERROR: No gradients computed in batch {batch_idx}!")
            return None

        if total_grad_norm < 1e-8 and batch_idx % 20 == 0:
            print(
                f"WARNING: Very small gradient norm {total_grad_norm:.2e} in batch {batch_idx}"
            )

        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        epoch_losses.append(loss.item())
        charge_bar.set_postfix(
            {
                "loss": loss.item(),
                "grad_norm": total_grad_norm,
                "trainable": num_trainable,
            }
        )

    return float(np.mean(epoch_losses))

In [None]:
def validate(
    model, dataloader, device, verbose=False, log_file="validation_gnn_new_debug.log"
):
    """
    Validation loop for PyG batched graph data (inductive setting).

    Key aspects:
    1. Data comes as PyG Batch objects
    2. Features are [B*N, F], already batched and flattened
    3. Mask is [N] - single mask for one graph, replicated across batch
    4. Computes metrics ONLY on validation nodes (where mask=True)
    5. No gradients computed - eval mode
    """
    model.eval()
    epoch_losses = []
    all_preds = []
    all_targets = []

    charge_bar = tqdm.tqdm(dataloader, desc="validation")

    with torch.no_grad():
        for batch in charge_bar:
            # PyG Batch object - move to device
            batch = batch.to(device)

            # Extract from PyG Batch format
            x = batch.x  # [B*N, F] - already batched and flattened
            y = batch.y  # [B*N, Tgt] - already batched and flattened
            mask = batch.mask  # [N] - single mask for one graph
            edge_index = batch.edge_index  # [2, E*B] - offset edge indices
            edge_attr = batch.edge_attr if batch.edge_attr is not None else None

            assert mask.shape[0] == x.shape[0], (
                f"mask size {mask.shape[0]} != x size {x.shape[0]}"
            )

            # Forward pass
            out = model(x, edge_index, edge_attributes=edge_attr)  # [B*N, out_channels]

            # Compute loss ONLY on validation nodes
            val_mask = mask  # [B*N] boolean mask

            loss = F.mse_loss(out[val_mask], y[val_mask])
            epoch_losses.append(loss.item())

            # Store predictions and targets for metric computation
            all_preds.append(out[val_mask].detach().cpu())
            all_targets.append(y[val_mask].detach().cpu())

            charge_bar.set_postfix({"loss": loss.item()})

    # Concatenate all predictions and targets
    all_preds = torch.cat(all_preds, dim=0)  # [Total_val_nodes, out_channels]
    all_targets = torch.cat(all_targets, dim=0)  # [Total_val_nodes, out_channels]

    # Compute metrics
    mean_loss = float(np.mean(epoch_losses))

    return mean_loss


In [None]:
# set seeds

seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
perf.log_model_config(model.config)

batch_size = 16
train_loader, val_loader = prepare_homogeneous_inductive_dataset(
    train_graphs[0], validation_graphs[0], batch_size=batch_size, mode="train"
)
train_loader1, val_loader1 = prepare_homogeneous_inductive_dataset(
    train_graphs[1], validation_graphs[1], batch_size=batch_size, mode="train"
)
train_loader2, val_loader2 = prepare_homogeneous_inductive_dataset(
    train_graphs[2], validation_graphs[2], batch_size=batch_size, mode="train"
)
train_loader3, val_loader3 = prepare_homogeneous_inductive_dataset(
    train_graphs[3], validation_graphs[3], batch_size=batch_size, mode="train"
)
train_loader4, val_loader4 = prepare_homogeneous_inductive_dataset(
    train_graphs[4], validation_graphs[4], batch_size=batch_size, mode="train"
)

debug_dataloader(train_loader)
debug_dataloader(val_loader)


def train(model, train_loader, val_loader, fold, device="cpu"):
    # CHECK 1: Print initial weights
    first_param = next(model.parameters())
    print(f"Initial weight sample: {first_param.data.flatten()[:5]}")

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    training_loss_arr = []
    validation_loss_arr = []
    early = 0
    mini = 1000
    stopping_condition = 5
    epochs = 0
    total_epochs = 10
    print(f"-----FOLD: {fold}-----")
    training_start = time.time()
    for i in range(total_epochs):
        print(f"-----EPOCH: {i + 1}-----")

        # CHECK 2: Print weight before training
        weight_before = first_param.data.clone()

        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            device,
            verbose=False,
            random_noise_masking=False,
        )

        # CHECK 3: Print weight after training
        weight_after = first_param.data
        weight_change = (weight_after - weight_before).abs().mean().item()
        print(f"Weight change: {weight_change:.20f}")

        validation_loss = validate(model, val_loader, device)
        training_loss_arr.append(train_loss)
        validation_loss_arr.append(validation_loss)
        perf.log_epoch(i, train_loss, validation_loss)
        if mini >= validation_loss:
            mini = validation_loss
            early = 0
        else:
            early += 1
        epochs += 1
        if early >= stopping_condition:
            print("Early stop loss")
            break

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Validation Loss: {validation_loss:.4f}")

        # CHECK 4: Print gradient norms
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                total_norm += p.grad.data.norm(2).item() ** 2
        total_norm = total_norm**0.5
        print(f"Gradient norm: {total_norm:.6f}")

    training_end = time.time()
    total_time = training_end - training_start
    perf.finalise(total_time)

    print(f"Training took {total_time} seconds over {epochs} epochs")
    plt.plot(training_loss_arr, label="training_loss", color="blue")
    plt.plot(validation_loss_arr, label="validation_loss", color="red")
    plt.legend()
    plt.savefig(f"experiments/{experiment_name}/train_loss_plot_{fold}.png", dpi=300)
    plt.close()

    torch.save(
        model.state_dict(), f"experiments/{experiment_name}/weather_gnn_best_{fold}.pth"
    )
    print("✅ model weights saved to weather_gnn_best.pth")

    perf.log_model_parameters(model)
    return model


model = train(model, train_loader, val_loader, fold=0, device=device)
model1 = train(model1, train_loader1, val_loader1, fold=1, device=device)
model2 = train(model2, train_loader2, val_loader2, fold=2, device=device)
model3 = train(model3, train_loader3, val_loader3, fold=3, device=device)
model4 = train(model4, train_loader4, val_loader4, fold=4, device=device)


In [None]:
# print(next(iter(val_loader))["gen_x"].shape)

In [None]:
# total_params = sum(param.numel() for param in model.parameters())
# print(total_params)
# print(list(param for param in model.parameters()))


In [None]:
def test_model(model, dataloader, device, fold=0, verbose=False):
    """
    Test loop following the SAME structure as validate():
      - PyG batch format
      - x, y shaped [B*N, F]
      - mask shaped [B*N]
      - edge_index expanded & offset automatically by PyG
      - Computes metrics ONLY on test nodes (mask==True)
    """

    model.eval()

    all_preds = []
    all_targets = []
    epoch_losses = []

    test_bar = tqdm.tqdm(dataloader, desc="Testing")

    with torch.no_grad():
        for batch in test_bar:
            batch = batch.to(device)

            # ----- Extract standard PyG Batch -----
            x = batch.x  # [B*N, F]
            y = batch.y  # [B*N, target_dim]
            mask = batch.mask  # [B*N], boolean
            edge_index = batch.edge_index
            edge_attr = batch.edge_attr if batch.edge_attr is not None else None

            if verbose:
                print(f"x shape: {x.shape}")
                print(f"y shape: {y.shape}")
                print(f"mask shape: {mask.shape}")
                print(f"edge_index shape: {edge_index.shape}")

            assert mask.shape[0] == x.shape[0], (
                f"mask size {mask.shape[0]} != x size {x.shape[0]}"
            )

            # ----- Forward pass -----
            out = model(x, edge_index, edge_attributes=edge_attr)  # [B*N, out_channels]

            # ----- Compute test loss (only masked nodes) -----
            loss = F.mse_loss(out[mask], y[mask])
            epoch_losses.append(loss.item())

            # ----- Collect predictions and targets -----
            all_preds.append(out[mask].detach().cpu())
            all_targets.append(y[mask].detach().cpu())

            test_bar.set_postfix({"loss": loss.item()})

    # ============================================================
    # === CONCATENATE ALL TEST PREDICTIONS & TARGETS (just like validation)
    # ============================================================
    all_preds = torch.cat(all_preds, dim=0)  # [Total_test_nodes, target_dim]
    all_targets = torch.cat(all_targets, dim=0)  # [Total_test_nodes, target_dim]

    print("Final aggregated prediction shape:", all_preds.shape)
    print("Final aggregated target shape:", all_targets.shape)

    # ============================================================
    # === Compute Pearson and RMSE (global)
    # ============================================================
    preds_np = all_preds.numpy().flatten()
    targets_np = all_targets.numpy().flatten()

    mask = (~np.isnan(preds_np)) & (~np.isnan(targets_np))
    pearson_r, pearson_p = pearsonr(targets_np[mask], preds_np[mask])

    mse = ((all_preds - all_targets) ** 2).mean()
    rmse = torch.sqrt(mse).item()

    print(f"Pearson correlation (Test Nodes): {pearson_r}")
    print(f"Final Test RMSE: {rmse}")

    # ============================================================
    # === Scatter Plot
    # ============================================================
    plt.figure(figsize=(8, 8))
    plt.scatter(targets_np, preds_np, alpha=0.5)

    line_max = max(np.nanmax(preds_np), np.nanmax(targets_np))
    plt.plot([0, line_max], [0, line_max], "r--")

    plt.xlabel("Actual")
    plt.ylabel("Predicted")
    plt.title("Test Set Performance")
    plt.grid(True)

    text = f"Pearson r = {pearson_r:.3f}\nRMSE = {rmse:.3f}"
    plt.text(
        0.05,
        0.95,
        text,
        transform=plt.gca().transAxes,
        verticalalignment="top",
        bbox=dict(facecolor="white", alpha=0.7, edgecolor="black"),
    )

    plt.savefig(f"experiments/{experiment_name}/test_scatter_plot_{fold}.png", dpi=300)
    plt.close()
    print(
        f"Saved test scatter plot to experiments/{experiment_name}/test_scatter_plot_{fold}.png"
    )

    # ============================================================
    # === Per-node (per-station) time-series plots
    # ============================================================
    # If target_dim = 1, reshape to [T, N]
    preds_2d = all_preds.reshape(-1, 1)
    targets_2d = all_targets.reshape(-1, 1)

    # If you want per-station but stations repeat across batches,
    # you must have station index stored in batch.
    # Here we just plot the single predicted dimension.
    for idx in range(preds_2d.shape[1]):
        fig, ax = plt.subplots(2, 1, figsize=(15, 8))
        ax[0].plot(preds_2d[:, idx])
        ax[1].plot(targets_2d[:, idx])
        plt.savefig(
            f"experiments/{experiment_name}/station_{idx}_preds_actual_plot_{fold}.png"
        )
        plt.close()

    return rmse


In [None]:
test_loader = prepare_homogeneous_inductive_dataset(
    train_graphs[0],
    validation_graphs[0],
    full_graphs[0],
    batch_size=batch_size,
    mode="test",
)
test_loader1 = prepare_homogeneous_inductive_dataset(
    train_graphs[1],
    validation_graphs[1],
    full_graphs[1],
    batch_size=batch_size,
    mode="test",
)
test_loader2 = prepare_homogeneous_inductive_dataset(
    train_graphs[2],
    validation_graphs[2],
    full_graphs[2],
    batch_size=batch_size,
    mode="test",
)
test_loader3 = prepare_homogeneous_inductive_dataset(
    train_graphs[3],
    validation_graphs[3],
    full_graphs[3],
    batch_size=batch_size,
    mode="test",
)
test_loader4 = prepare_homogeneous_inductive_dataset(
    train_graphs[4],
    validation_graphs[4],
    full_graphs[4],
    batch_size=batch_size,
    mode="test",
)

RMSE = test_model(model, test_loader, device, fold=0)
RMSE1 = test_model(model1, test_loader1, device, fold=1)
RMSE2 = test_model(model2, test_loader2, device, fold=2)
RMSE3 = test_model(model3, test_loader3, device, fold=3)
RMSE4 = test_model(model4, test_loader4, device, fold=4)
print(f"TEST RMSE: {RMSE}")
print(f"TEST RMSE1: {RMSE1}")
print(f"TEST RMSE2: {RMSE2}")
print(f"TEST RMSE3: {RMSE3}")
print(f"TEST RMSE4: {RMSE4}")

# Visualisation of output
Test event will be 02-05-2025 0415 to 0615


In [None]:
def visualize_one_event(test_event_data, radar_features_event, do_plot=True):
    """
    Prepare a single example from `test_event_data` (a pandas slice like
    weather_station_df_pivot.iloc[593:602]) and run the model for inference.

    This function returns:
      gen_out: numpy array of predicted general_station outputs (shape: [num_gen_nodes, out_features])
      rain_out: numpy array of predicted rainfall_station outputs (shape: [num_rain_nodes, out_features])
    """
    model.eval()

    # clone template (so we keep masks/edge_index/order)
    test_data = data.clone()

    # --- collect station-wise time-series just like you had before ---
    test_general_station_data = {}
    test_rainfall_station_data = {}

    for station in test_event_data.columns.get_level_values(1).unique():
        station_cols = (
            test_event_data.xs(station, level=1, axis=1)
            .interpolate(method="linear")
            .fillna(method="ffill")
            .fillna(method="bfill")
        )
        if station in general_station:
            test_general_station_data[station] = (
                station_cols.values
            )  # shape [T, gen_feat]
        else:
            test_rainfall_station_data[station] = station_cols.values[
                :, 0:1
            ]  # [T, rain_feat=1]

    # Build arrays in the correct node ordering
    gen_feats_list = []
    rain_feats_list = []

    for station in general_station:
        gen_feats_list.append(
            test_general_station_data[station]
        )  # each item: [T, gen_feat_per_t]
    for station in rainfall_station:
        rain_feats_list.append(
            test_rainfall_station_data[station]
        )  # each item: [T, rain_feat_per_t]

    # Convert to numpy arrays and get shapes
    # After np.array(gen_feats_list) => shape [num_gen_nodes, T, gen_feat_per_t]
    gen_arr = np.array(gen_feats_list)  # [N_gen, T, Fg]
    rain_arr = np.array(rain_feats_list)  # [N_rain, T, Fr]

    # --- Convert to the 2-D node-feature format the model expects ---
    # There are different sensible choices here:
    #  - take last timestep: arr[:, -1, :] -> [N, F]
    #  - flatten the time axis into the feature axis: arr.reshape(N, T*F)
    # The training/test collate you used produces per-node features (no time dim).
    # To match that, we flatten time into features (preserves the whole window).
    def flatten_time_axis(arr):
        # arr: [N, T, F]
        N, T, F = arr.shape
        return arr.reshape(N, T * F)  # [N, T*F]

    gen_node_feats = gen_arr[:, -1, :].astype(np.float32)
    rain_node_feats = rain_arr[:, -1, :].astype(np.float32)  # [N, F]
    radar_node_feats = radar_features_event[-1].float()  # [N_radar, F]

    # Convert to torch tensors (2-D per node type) and move to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_dict = {
        "general_station": torch.tensor(
            gen_node_feats, dtype=torch.float, device=device
        ),
        "rainfall_station": torch.tensor(
            rain_node_feats, dtype=torch.float, device=device
        ),
        "radar_grid": torch.tensor(radar_node_feats, dtype=torch.float, device=device),
    }

    # Move edge structures to device (the same ones you used in test_model)
    edge_index_dict = {k: v.to(device) for k, v in data.edge_index_dict.items()}
    edge_attr_dict = {k: v.to(device) for k, v in data.edge_attr_dict.items()}

    # Run model
    with torch.no_grad():
        out = model(x_dict, edge_index_dict, edge_attr_dict)

    # out[...] are torch tensors shaped like [num_nodes, out_features] (depending on your model head)
    gen_out = out["general_station"].cpu().numpy()
    rain_out = out["rainfall_station"].cpu().numpy()

    return gen_out, rain_out


In [None]:
test_event_data = weather_station_df_pivot.iloc[593:602]  # your 9 timestamps
radar_features_event = data["radar_grid"].x[593:602]
gen_out, rain_out = visualize_one_event(test_event_data, radar_features_event)
out_np = np.concatenate([gen_out, rain_out], axis=0)

In [None]:
print(test_data.edge_index_dict)

# Visualise rain on radar grid
Hard coded to plot only consequitive 9 timestamps

In [None]:
print(out_np / 12)

# Visualize Radar Image

In [None]:
radar_df = load_radar_dataset("sg_radar_data")

visualize_one_radar_image(radar_df=radar_df)

In [None]:
# fig, ax = plt.subplots(
#     3, 3, figsize=(15, 12), subplot_kw={"projection": ccrs.PlateCarree()}
# )

# out_np = out_np / 12
# for idx, timestamp in enumerate(out_np):
#     output = {}
#     count = 0

#     for stn in general_station:
#         output[stn] = float(timestamp[count])
#         count += 1
#     for stn in rainfall_station:
#         output[stn] = float(timestamp[count])
#         count += 1
#     axi = ax[idx // 3][idx % 3]
#     node_df = pd.Series(output)
#     node_df = pandas_to_geodataframe(node_df)
#     visualise_gauge_grid(node_df=node_df, ax=axi)
#     improved_visualise_radar_grid(
#         radar_df.iloc[idx], ax=axi, zoom=bounds_singapore, norm=norm
#     )
#     visualise_singapore_outline(ax=axi)

In [None]:
original_rainfall_rates = (
    weather_station_df_pivot.iloc[1773:1797].resample("15min").first()["rain_rate"]
)


print(original_rainfall_rates)

In [None]:
print(out)

In [None]:
actual_arr = []
pred_arr = []

for idx, timestamp in enumerate(out):
    output = {}
    count = 0
    a_arr = []
    p_arr = []

    for stn in general_station:
        output[stn] = float(timestamp[count])
        count += 1
    for stn in rainfall_station:
        output[stn] = float(timestamp[count])
        count += 1

    for key, value in output.items():
        a_arr.append(original_rainfall_rates.iloc[idx][key])
        p_arr.append(output[key])
    a_arr = list(map(lambda x: float(x), a_arr))
    actual_arr.append(a_arr)
    pred_arr.append(p_arr)

actual_arr = np.array(actual_arr)
pred_arr = np.array(pred_arr)

print(actual_arr)
print(pred_arr)
error = []
for i in range(len(actual_arr)):
    error.append(np.nanmean(actual_arr - pred_arr) ** 2)

MSE = np.mean(np.array(error))
print(MSE)


In [None]:
print(original_rainfall_rates.iloc[0])