In [None]:
%load_ext autoreload
%autoreload 2

from src.sampling.main import stratified_spatial_sampling_dual
from dataset.weather_graph_dataset import WeatherGraphDatasetWithRadarNew

import torch
from torch_geometric.data import HeteroData
from src.raingauge.utils import (
    get_station_coordinate_mappings,
    load_weather_station_dataset,
)
import pandas as pd
import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors
import tqdm
import torch.nn.functional as F
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import time
from scipy.stats import pearsonr
from src.radar.utils import load_radar_dataset
from src.visualization.main import pandas_to_geodataframe, visualise_singapore_outline
from src.visualization.radar import (
    improved_visualise_radar_grid,
    visualize_one_radar_image_with_cropping,
    visualize_one_radar_image,
)
from src.visualization.raingauge import visualise_gauge_grid
import cartopy.crs as ccrs
import matplotlib as mpl
from src.radar.preprocessor import RadarPreprocessor
from scipy.spatial import cKDTree
from src.miscellaneous import get_straight_distance
from models.gnn_radar import HeteroGNN_WithRadar
import logging
from datetime import datetime

In [None]:
# NOTE: Geographic extent of Singapore in longitude and latitude
bounds_singapore = {"left": 103.6, "right": 104.1, "top": 1.5, "bottom": 1.188}
bounds = [0.1, 0.2, 0.5, 1, 2, 4, 7, 10, 20]
norm = mpl.colors.BoundaryNorm(boundaries=bounds, ncolors=256, extend="both")

# Preprocess Radar Data

## Will find common data between radar and weather station

In [None]:
print("=== Radar Image Preprocessing ===\n")

# Load weather station data
# Assuming weather_station_df_pivot is available from your main script
# You'll need to load this from your saved data or regenerate it
try:
    weather_station_data = load_weather_station_dataset("weather_station_data.csv")
    weather_station_locations_raw = pd.read_csv(
        "database/weather_stations.csv"
    )

    # Filter for general stations
    weather_stations_df = weather_station_locations_raw[
        weather_station_locations_raw["gid"].isin(weather_station_data)
    ].copy()
    weather_stations_df = (
        weather_stations_df.set_index("gid").loc[weather_station_data].reset_index()
    )
    weather_station_locations = weather_station_locations_raw.set_index("gid").loc[
        weather_station_data
    ]

    cols = list(weather_stations_df.columns)
    cols.remove("time_sgt")
    cols.remove("gid")

    weather_station_df_pivot = (
        pd.pivot(data=weather_stations_df, index="time_sgt", columns="gid", values=cols)
        .resample("15min")
        .first()
    )

    print(f"Loaded weather station data: {weather_stations_df.shape[0]} timestamps")

except Exception as e:
    print(f"Error loading weather station data: {e}")
    print("Please ensure weather_station_data.csv is available")

# Initialize preprocessor
preprocessor = RadarPreprocessor(
    radar_base_path="database/sg_radar_data",
    output_path="database/sg_radar_data_cropped",
    weather_station_df=weather_station_df_pivot,
)

# Step 1: Get all radar files
print("\nStep 1: Scanning radar files...")
radar_files = preprocessor.get_all_radar_files()

if len(radar_files) == 0:
    print("No radar files found! Please check the radar_base_path.")

print(f"Found {len(radar_files)} radar files")
print(f"Date range: {radar_files[0][0]} to {radar_files[-1][0]}")

# Step 2: Match with weather data
print("\nStep 2: Matching with weather station data...")
matched_radar_df, matched_weather_df = preprocessor.match_with_weather_data(radar_files)

if len(matched_radar_df) == 0:
    raise ValueError("No matching timestamps found!")

# Step 3: Process and crop all matched files
print("\nStep 3: Cropping radar images to Singapore bounds...")
results_df = preprocessor.process_all_matched_files(matched_radar_df)

# Step 4: Save metadata
preprocessor.save_metadata(results_df)

# Save matched weather dataframe
matched_weather_path = preprocessor.output_path / "matched_weather_station_data.csv"
matched_weather_df.to_csv(matched_weather_path)
print(f"Matched weather station data saved to: {matched_weather_path}")

print("\n=== Preprocessing Complete ===")
print(f"Cropped radar images saved to: {preprocessor.output_path}")

In [None]:
radar_df = load_radar_dataset("sg_radar_data_cropped", cropped=True)
visualize_one_radar_image(radar_df=radar_df, n=1)

# Preprocess Radar Station Data

In [None]:
data = HeteroData()
dtype = torch.float32
radius_km = (
    1  # Depends on the radius we want to connect the radar grid to weather stations
)

# ---- 1) prepare station ID list
weather_station_df_pivot["rain_rate"] *= 12
station_ids = sorted({col[1] for col in weather_station_df_pivot.columns})

# ---- 2) count observations per variable/station
station_counts = weather_station_df_pivot.count().reset_index()
weather_station_info = pd.pivot(station_counts, index="gid", columns="level_0")

# ---- 3) prepare radar features
radar_features, grid_coords, grid_shape = preprocessor.prepare_radar_features_temporal(
    radar_df, weather_station_df_pivot
)

# ---- 4) classify rainfall vs general
rainfall_station = [
    gid for gid, row in weather_station_info.iterrows() if 0 in row.value_counts()
]

general_station = [
    s for s in weather_station_locations.keys() if s not in rainfall_station
]

# restrict to stations actually present
rainfall_station = [s for s in rainfall_station if s in weather_station_info.index]
general_station = [s for s in general_station if s in weather_station_info.index]

# ---- 5) build coordinate DF (always consistent)
print("Weather station locations loaded: ", weather_station_locations)
loc_df = weather_station_locations[["latitude", "longitude"]].copy()


# ---- 6) flatten MultiIndex cols
df_all = weather_station_info.copy()
df_all.columns = [
    "_".join([str(c) for c in col])  # tuple -> join
    if isinstance(col, tuple)
    else str(col)  # normal -> stringify
    for col in df_all.columns
]

# split
df_rain = df_all.loc[rainfall_station].copy()
df_gen = df_all.loc[general_station].copy()

# ---- 7) attach coordinates (index matches)
df_all = df_all.join(loc_df)
df_rain = df_rain.join(loc_df)
df_gen = df_gen.join(loc_df)

# print(f"DF ALL: ",df_all.head())
# print(f"DF RAIN: ",df_rain.head())
# print(f"DF GEN: ",df_gen.head())

# ---- 8) numpy coordinate arrays
rain_coords = df_rain[["latitude", "longitude"]].to_numpy()
gen_coords = df_gen[["latitude", "longitude"]].to_numpy()
all_coords = df_all[["latitude", "longitude"]].to_numpy()

# Convert radius to degrees
# 111.0 km per degree is approximate conversion for latitude
radius_deg = radius_km / 111.0

# ---- 9) Radar to general stations
tree_gen = cKDTree(gen_coords)
radar_to_gen_list = tree_gen.query_ball_point(grid_coords, r=radius_deg)
radar_to_gen_src = []
radar_to_gen_dst = []
radar_to_gen_distances = []

for grid_idx, station_list in enumerate(radar_to_gen_list):
    for station_idx in station_list:
        # Use the same distance function as station-to-station edges
        dist = get_straight_distance(
            grid_coords[grid_idx],  # [lat, lon]
            gen_coords[station_idx],  # [lat, lon]
        )
        if dist <= radius_km:
            radar_to_gen_src.append(grid_idx)
            radar_to_gen_dst.append(station_idx)
            radar_to_gen_distances.append([dist])

print(f"Radar to general station connections: {len(radar_to_gen_src)}")

# ---- 10) Radar to rainfall stations
tree_rain = cKDTree(rain_coords)
radar_to_rain_list = tree_rain.query_ball_point(grid_coords, r=radius_deg)
radar_to_rain_src = []
radar_to_rain_dst = []
radar_to_rain_distances = []

for grid_idx, station_list in enumerate(radar_to_rain_list):
    for station_idx in station_list:
        dist = get_straight_distance(
            grid_coords[grid_idx],  # [lat, lon]
            rain_coords[station_idx],  # [lat, lon]
        )
        if dist <= radius_km:
            radar_to_rain_src.append(grid_idx)
            radar_to_rain_dst.append(station_idx)
            radar_to_rain_distances.append([dist])

print(f"Radar to rainfall station connections: {len(radar_to_rain_src)}")

# TODO: Since radar grids are not prediction targets, it's latent features / spatial context providers, not necessary to have radar to radar edges
# ---- 11) Radar to radar edges (radius-based, not connectivity-based)
# radar_to_radar_edges = preprocessor.create_grid_edges_radius(grid_coords, radius_km)
# print(f"Radar to radar edges: {radar_to_radar_edges.shape[1]} edges created")
# print(f"First 10 radar-radar edges:\n{radar_to_radar_edges[:, :10]}")

# ---- 12) Add radar node features to HeteroData
data["radar_grid"].x = torch.tensor(radar_features, dtype=dtype)
data["radar_grid"].y = torch.tensor(radar_features, dtype=dtype)

# ---- 13) Add edges (with empty array handling)
# Radar to general stations
if len(radar_to_gen_src) > 0:
    data["radar_grid", "radar_to_gen", "general_station"].edge_index = torch.tensor(
        np.array([radar_to_gen_src, radar_to_gen_dst]), dtype=torch.long
    )
    data["general_station", "gen_to_radar", "radar_grid"].edge_index = torch.tensor(
        np.array([radar_to_gen_dst, radar_to_gen_src]), dtype=torch.long
    )
    # Add distance attributes
    data["radar_grid", "radar_to_gen", "general_station"].edge_attr = torch.tensor(
        radar_to_gen_distances, dtype=dtype
    )
    data["general_station", "gen_to_radar", "radar_grid"].edge_attr = torch.tensor(
        radar_to_gen_distances, dtype=dtype
    )
else:
    print("WARNING: No radar to general station edges found within radius")
    data["radar_grid", "radar_to_gen", "general_station"].edge_index = torch.empty(
        (2, 0), dtype=torch.long
    )
    data["general_station", "gen_to_radar", "radar_grid"].edge_index = torch.empty(
        (2, 0), dtype=torch.long
    )
    data["radar_grid", "radar_to_gen", "general_station"].edge_attr = torch.empty(
        (0, 1), dtype=dtype
    )
    data["general_station", "gen_to_radar", "radar_grid"].edge_attr = torch.empty(
        (0, 1), dtype=dtype
    )

# Radar to rainfall stations
if len(radar_to_rain_src) > 0:
    data["radar_grid", "radar_to_rain", "rainfall_station"].edge_index = torch.tensor(
        np.array([radar_to_rain_src, radar_to_rain_dst]), dtype=torch.long
    )
    data["rainfall_station", "rain_to_radar", "radar_grid"].edge_index = torch.tensor(
        np.array([radar_to_rain_dst, radar_to_rain_src]), dtype=torch.long
    )
    # Add distance attributes
    data["radar_grid", "radar_to_rain", "rainfall_station"].edge_attr = torch.tensor(
        radar_to_rain_distances, dtype=dtype
    )
    data["rainfall_station", "rain_to_radar", "radar_grid"].edge_attr = torch.tensor(
        radar_to_rain_distances, dtype=dtype
    )
else:
    print("WARNING: No radar to rainfall station edges found within radius")
    data["radar_grid", "radar_to_rain", "rainfall_station"].edge_index = torch.empty(
        (2, 0), dtype=torch.long
    )
    data["rainfall_station", "rain_to_radar", "radar_grid"].edge_index = torch.empty(
        (2, 0), dtype=torch.long
    )
    data["radar_grid", "radar_to_rain", "rainfall_station"].edge_attr = torch.empty(
        (0, 1), dtype=dtype
    )
    data["rainfall_station", "rain_to_radar", "radar_grid"].edge_attr = torch.empty(
        (0, 1), dtype=dtype
    )

# ---- 14) Add masks for radar nodes
n_radar_nodes = len(grid_coords)
data["radar_grid"].train_mask = [1 for _ in range(n_radar_nodes)]
data["radar_grid"].val_mask = [1 for _ in range(n_radar_nodes)]
data["radar_grid"].test_mask = [1 for _ in range(n_radar_nodes)]

# ---- 15) Summary
print("\n=== Radar Grid Integration Summary ===")
print(f"Radar grid nodes: {n_radar_nodes}")
print(f"Grid shape: {grid_shape[0]} x {grid_shape[1]}")
print(f"Timesteps: {radar_features.shape[0]}")
print(f"Radar to general station edges: {len(radar_to_gen_src)}")
print(f"Radar to rainfall station edges: {len(radar_to_rain_src)}")
print(f"Station connection radius: {radius_km} km")
print(f"Grid connection radius: {radius_km} km")
print("=" * 40)

print("Radar Features: ", radar_features.shape)


# Preprocess station data.
Some stations only contain rainfall information but some stations contain both rainfall and other information.
We will split these stations into weather station and general stations 

Additional info: 
  Windspeed
  Wind Direction
  Temperature
  Relative Humidity

In [None]:
weather_station_data = load_weather_station_dataset('weather_station_data.csv')
weather_station_locations = get_station_coordinate_mappings()
print(len(weather_station_locations.keys()))
print(len(set(weather_station_data['gid'].values)))
cols = list(weather_station_data.columns)
cols.remove('time_sgt')
cols.remove('gid')

weather_station_df_pivot = pd.pivot(data=weather_station_data, index='time_sgt', columns='gid', values=cols).resample('15min').first()
weather_station_df_pivot['rain_rate'] = weather_station_df_pivot['rain_rate'] * 12
weather_station_df_counts = weather_station_df_pivot.count().reset_index()

weather_station_info = pd.pivot(data=weather_station_df_counts, index='gid', columns = 'level_0')

pd.set_option('display.max_rows', None)

rainfall_station = [row[0] for row in weather_station_info.iterrows() if 0 in row[1].value_counts()]
general_station = [s for s in weather_station_locations if s not in rainfall_station]

print(rainfall_station)
print(general_station)
count = 0
for row in weather_station_df_pivot['rain_rate'].iterrows():
  if np.nansum(row[1].to_numpy()) != 0:
    count += 1
print(f"Number of timesteps that contain rain: {count}")
print(f"Total_timesteps = {weather_station_df_pivot.shape[0]}")


In [None]:
general_station_data = {}
rainfall_station_data = {}

# TODO: Temporal Data Leakage - Filling missing values in the training set using all data including validation or test set is wrong.
# Extract and interpolate station data
for station in weather_station_df_pivot.columns.get_level_values(1).unique():
    station_cols = (
        weather_station_df_pivot.xs(station, level=1, axis=1)
        .interpolate(method="linear")
        .fillna(method="ffill")
        .fillna(method="bfill")
    )
    if station in general_station:
        general_station_data[station] = station_cols.values
    else:
        rainfall_station_data[station] = station_cols.values[:, 0:1]
        
general_station_temp = [stn for stn in general_station if stn != "S108"]
general_station = general_station_temp

# Prepare features in the correct order
general_station_features = []
rainfall_station_features = []

for station in general_station:
    station_feat = general_station_data[station]
    general_station_features.append(station_feat)

for station in rainfall_station:
    station_feat = rainfall_station_data[station]
    rainfall_station_features.append(station_feat)

In [None]:
# Add station features to HeteroData
data["general_station"].x = torch.tensor(
    np.array(general_station_features).transpose(1, 0, 2), dtype=dtype
)
data["rainfall_station"].x = torch.tensor(
    np.array(rainfall_station_features).transpose(1, 0, 2), dtype=dtype
)

# Add station targets
data["general_station"].y = torch.tensor(
    np.array(general_station_features)[:, :, 0:1].transpose(1, 0, 2), dtype=dtype
)
data["rainfall_station"].y = torch.tensor(
    np.array(rainfall_station_features).transpose(1, 0, 2), dtype=dtype
)

print(data)
print("\n=== Station Features Added ===")
print(f"General station features shape: {data['general_station'].x.shape}")
print(f"Rainfall station features shape: {data['rainfall_station'].x.shape}")

# After loading weather_station_df_pivot
print("--- Station Data Stats ---")
print(weather_station_df_pivot.describe())

# After loading radar data
print("\n--- Radar Data Stats ---")
print(f"Radar Min: {radar_features.min()}, Radar Max: {radar_features.max()}, Radar Mean: {radar_features.mean()}")

In [None]:
split_info = stratified_spatial_sampling_dual(weather_station_locations, seed=1111)
print(split_info)

data["general_station"].train_mask = [
    1 if station in split_info["ml"]["train"] else 0 for station in general_station
]
data["general_station"].val_mask = [
    1 if station in split_info["ml"]["validation"] else 0 for station in general_station
]
data["general_station"].test_mask = [
    1 if (x == 0 and y == 0) else 0
    for x, y in zip(
        data["general_station"].train_mask, data["general_station"].val_mask
    )
]

data["rainfall_station"].train_mask = [
    1 if station in split_info["ml"]["train"] else 0 for station in rainfall_station
]
data["rainfall_station"].val_mask = [
    1 if station in split_info["ml"]["validation"] else 0
    for station in rainfall_station
]
data["rainfall_station"].test_mask = [
    1 if (x == 0 and y == 0) else 0
    for x, y in zip(
        data["rainfall_station"].train_mask, data["rainfall_station"].val_mask
    )
]

print(data)

# Edge generation
We consider the location of the stations when performing our edge generation. 
General station locations and rainfall station locations will be considered the same and we will make a connection across the nodes if required. This will ensure that we can connect both the layers together in the graph.

In [None]:
K = 4  # Number of neighbors per node

ids = general_station + rainfall_station
print(f"\nTotal stations for KNN: {len(ids)}")

coordinates = []
for id in ids:
    coordinates.append(weather_station_locations[id])
coords = np.array(coordinates)

knn = NearestNeighbors(n_neighbors=K + 1, algorithm="ball_tree")
knn.fit(coords)

distances, indices = knn.kneighbors(coords)

G = nx.Graph()

edges = {
    "rainfall_to_rainfall": [],
    "rainfall_to_general": [],
    "general_to_rainfall": [],
    "general_to_general": [],
}

edge_attributes = {
    "rainfall_to_rainfall": [],
    "rainfall_to_general": [],
    "general_to_rainfall": [],
    "general_to_general": [],
}

# Add station coordinates for nx plotting
for idx, station in enumerate(general_station + rainfall_station):
    G.add_node(
        idx,
        pos=(
            weather_station_locations[station][1],
            weather_station_locations[station][0],
        ),
    )

color_map = ["green" for i in range(len(general_station))] + [
    "red" for i in range(len(rainfall_station))
]

# Build edges
for idx, row in enumerate(indices):
    origin = row[0]

    for n in row[1:]:
        G.add_edge(origin, n)
        if ids[origin] in rainfall_station:
            start_id = rainfall_station.index(ids[origin])
            if ids[n] in rainfall_station:
                end_id = rainfall_station.index(ids[n])
                edges["rainfall_to_rainfall"].append([start_id, end_id])
                edge_attributes["rainfall_to_rainfall"].append(
                    [
                        get_straight_distance(
                            weather_station_locations[ids[origin]],
                            weather_station_locations[ids[n]],
                        )
                    ]
                )
            else:
                end_id = general_station.index(ids[n])
                edges["rainfall_to_general"].append([start_id, end_id])
                edge_attributes["rainfall_to_general"].append(
                    [
                        get_straight_distance(
                            weather_station_locations[ids[origin]],
                            weather_station_locations[ids[n]],
                        )
                    ]
                )
        else:
            start_id = general_station.index(ids[origin])
            if ids[n] in rainfall_station:
                end_id = rainfall_station.index(ids[n])
                edges["general_to_rainfall"].append([start_id, end_id])
                edge_attributes["general_to_rainfall"].append(
                    [
                        get_straight_distance(
                            weather_station_locations[ids[origin]],
                            weather_station_locations[ids[n]],
                        )
                    ]
                )
            else:
                end_id = general_station.index(ids[n])
                edges["general_to_general"].append([start_id, end_id])
                edge_attributes["general_to_general"].append(
                    [
                        get_straight_distance(
                            weather_station_locations[ids[origin]],
                            weather_station_locations[ids[n]],
                        )
                    ]
                )

print(f"\nGraph info: {G}")
print(f"Connected components: {len(list(nx.connected_components(G)))}")
nx.draw(G, nx.get_node_attributes(G, 'pos'), node_color = color_map, with_labels=True, font_weight='bold')

# Convert edge lists to proper format
for key, val in edges.items():
    xarr = []
    yarr = []
    for x, y in val:
        xarr.append(x)
        yarr.append(y)
    edges[key] = [xarr, yarr]

# Add station-to-station edges
data["general_station", "gen_to_rain", "rainfall_station"].edge_index = torch.tensor(
    edges["general_to_rainfall"], dtype=torch.long
)
data["rainfall_station", "rain_to_gen", "general_station"].edge_index = torch.tensor(
    edges["rainfall_to_general"], dtype=torch.long
)
data["general_station", "gen_to_gen", "general_station"].edge_index = torch.tensor(
    edges["general_to_general"], dtype=torch.long
)
data["rainfall_station", "rain_to_rain", "rainfall_station"].edge_index = torch.tensor(
    edges["rainfall_to_rainfall"], dtype=torch.long
)

# Add edge attributes
data["general_station", "gen_to_rain", "rainfall_station"].edge_attr = torch.tensor(
    edge_attributes["general_to_rainfall"], dtype=dtype
)
data["rainfall_station", "rain_to_gen", "general_station"].edge_attr = torch.tensor(
    edge_attributes["rainfall_to_general"], dtype=dtype
)
data["general_station", "gen_to_gen", "general_station"].edge_attr = torch.tensor(
    edge_attributes["general_to_general"], dtype=dtype
)
data["rainfall_station", "rain_to_rain", "rainfall_station"].edge_attr = torch.tensor(
    edge_attributes["rainfall_to_rainfall"], dtype=dtype
)

print("\n=== Station-to-Station Edges Added ===")

print(data)

In [None]:
# ============================================================================
# FINAL SUMMARY
# ============================================================================

print("\n" + "=" * 60)
print("FINAL HETERODATA STRUCTURE")
print("=" * 60)
print(data)
print("\nNode types:", data.node_types)
print("Edge types:", data.edge_types)

print("\n--- Feature Shapes ---")
print(f"General stations: {data['general_station'].x.shape}")
print(f"Rainfall stations: {data['rainfall_station'].x.shape}")
print(f"Radar grid: {data['radar_grid'].x.shape}")

print("\n--- Edge Counts ---")
for edge_type in data.edge_types:
    edge_count = data[edge_type].edge_index.shape[1]
    print(f"{edge_type}: {edge_count} edges")

print("\n--- Mask Counts ---")
print(f"General train: {sum(data['general_station'].train_mask)}")
print(f"General val: {sum(data['general_station'].val_mask)}")
print(f"General test: {sum(data['general_station'].test_mask)}")
print(f"Rainfall train: {sum(data['rainfall_station'].train_mask)}")
print(f"Rainfall val: {sum(data['rainfall_station'].val_mask)}")
print(f"Rainfall test: {sum(data['rainfall_station'].test_mask)}")
print("=" * 60)

In [None]:
print(len(edge_attributes["rainfall_to_rainfall"]))

# Process edge indices
print(data)
print(data.edge_types)

print(data["general_station", "gen_to_rain", "rainfall_station"].edge_attr)
print(data["rainfall_station", "rain_to_gen", "general_station"].edge_index)
print(data["general_station", "gen_to_gen", "general_station"].edge_index)
print(
    len(
        set(
            data["rainfall_station", "rain_to_rain", "rainfall_station"]
            .edge_index.detach()
            .numpy()[0]
        )
    )
)

print(data.has_isolated_nodes())
print(data.has_self_loops())
print(data.is_undirected())

print(data["general_station", "gen_to_rain", "rainfall_station"]["edge_index"])

# Creating the GNN

In [None]:
model = HeteroGNN_WithRadar(hidden_channels=8, out_channels=1, num_layers=5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device=device)

In [None]:
def train_epoch(model, data, dataloader, optimizer, device, verbose=False, log_file="training_radar_debug.log"):
    """Training loop with radar data."""
    model.train()
    losses = []

    charge_bar = tqdm.tqdm(dataloader, desc='training')

    # Setup logging if verbose
    if verbose:
        logger = logging.getLogger("train_debug")
        logger.setLevel(logging.INFO)
        if not logger.handlers:                           # avoid adding twice
            fh = logging.FileHandler("training_radar_debug.log", mode="a")
            fh.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
            logger.addHandler(fh)
        # logging.basicConfig(
        #     filename=log_file,
        #     level=logging.INFO,
        #     format="%(asctime)s - %(message)s",
        #     filemode='w'
        # )
        logger.info(f"=== Training Epoch Debug Log Started at {datetime.now()} ===")

    for batch in charge_bar:
      # reset gradients
      optimizer.zero_grad()

      edge_index_dict = batch["edge_index_dict"]
      edge_attribute_dict = batch["edge_attr_dict"]

      for i in range(batch['gen_x'].shape[0]):

        train_metastation_mask = torch.tensor(batch['metastation_mask'], dtype=torch.bool).to(device)
        train_rainfallstation_mask = torch.tensor(batch['rainfallstation_mask'], dtype=torch.bool).to(device)
        step_loss = []

        training_metastation_indices = train_metastation_mask.nonzero(as_tuple=False)
        training_rainfallstation_indices = train_rainfallstation_mask.nonzero(as_tuple=False)
        gen_x = batch['gen_x']  # [batch_size, num_gen_nodes, gen_features]
        rain_x = batch['rain_x']  # [batch_size, num_rain_nodes, rain_features]
        radar_x = batch["radar_x"]
        gen_y = batch['gen_y']
        rain_y = batch['rain_y']

        #Start by indiviually masking metastations
        for idx in training_metastation_indices:
            gen_x_masked=gen_x[i].clone()
            rain_x_masked=rain_x[i].clone()
    
            gen_x_masked[~train_metastation_mask.bool()] = 0
            rain_x_masked[~train_rainfallstation_mask.bool()] = 0
            gen_x_masked[idx, 0] = 0 #mask only the first value which corresponds to the rainfall value

            x_dict = {
                "general_station": gen_x_masked,
                "rainfall_station": rain_x_masked,
                "radar_grid": radar_x[i],
            }
            optimizer.zero_grad()
            out = model(x_dict, edge_index_dict, edge_attribute_dict)

            # Model prediction
            gen_predictions = out['general_station'][idx]
            gen_actual = gen_y[i][idx]

            training_loss = F.mse_loss(gen_predictions, gen_actual) 
            if verbose:
                log_msg = f"""
                    --- DEBUG STATS (General Station) ---
                    --- INPUT FEATURES ---
                    Radar grid features:     Min={x_dict['radar_grid'].min():.2f}, Max={x_dict['radar_grid'].max():.2f}, Mean={x_dict['radar_grid'].mean():.2f}
                    General station features: Min={x_dict['general_station'].min():.2f}, Max={x_dict['general_station'].max():.2f}, Mean={x_dict['general_station'].mean():.2f}

                    --- MODEL PREDICTIONS (raw) ---
                    Pred (gen) tensor:    Min={out['general_station'].min():.2f}, Max={out['general_station'].max():.2f}, Mean={out['general_station'].mean():.2f}

                    --- GROUND TRUTH ---
                    Truth (gen) tensor:   Min={gen_y[i].min():.2f}, Max={gen_y[i].max():.2f}, Mean={gen_y[i].mean():.2f}

                    --- LOSS ---
                    Loss for this sample: {training_loss.item():.2f}
                    -------------------------------------------------
                    """
                logger.info(log_msg)
            step_loss.append(training_loss)

        #Indiviually mask rain stations
        for idx in training_rainfallstation_indices:
            gen_x_masked=gen_x[i].clone()
            rain_x_masked=rain_x[i].clone()
    
            #Mask stations that are not training stations
            gen_x_masked[~train_metastation_mask.bool()] = 0
            rain_x_masked[~train_rainfallstation_mask.bool()] = 0
            #Mask the selected rainfall station
            rain_x_masked[idx, 0] = 0

            x_dict = {
                'general_station': gen_x_masked,
                'rainfall_station': rain_x_masked,
                "radar_grid": radar_x[i],
            }

            optimizer.zero_grad()
            out = model(x_dict, edge_index_dict, edge_attribute_dict)

            # Model prediction
            rain_predictions = out['rainfall_station'][idx]
            rainfall_actual = rain_y[i][idx]

            training_loss = F.mse_loss(rain_predictions, rainfall_actual)
            if verbose:
                log_msg = f"""
                    --- DEBUG STATS (Rainfall Station) ---
                    --- INPUT FEATURES ---
                    Radar grid features:     Min={x_dict['radar_grid'].min():.2f}, Max={x_dict['radar_grid'].max():.2f}, Mean={x_dict['radar_grid'].mean():.2f}
                    Rainfall station features: Min={x_dict['rainfall_station'].min():.2f}, Max={x_dict['rainfall_station'].max():.2f}, Mean={x_dict['rainfall_station'].mean():.2f}

                    --- MODEL PREDICTIONS (raw) ---
                    Pred (rain) tensor:   Min={out['rainfall_station'].min():.2f}, Max={out['rainfall_station'].max():.2f}, Mean={out['rainfall_station'].mean():.2f}

                    --- GROUND TRUTH ---
                    Truth (rain) tensor:  Min={rain_y[i].min():.2f}, Max={rain_y[i].max():.2f}, Mean={rain_y[i].mean():.2f}

                    --- LOSS ---
                    Loss for this sample: {training_loss.item():.2f}
                    -------------------------------------------------
                    """
                logger.info(log_msg)
            step_loss.append(training_loss)
        
        loss = torch.stack(step_loss).mean()
        losses.append(loss.detach())

        #backpropagate
        loss.backward()

        # Update weights
        optimizer.step()

    losses = torch.stack(losses).mean().item()

    return losses


In [None]:
def validate(model, data, dataloader, device):
    """Validation loop with radar data."""
    model.eval()  # Set to eval mode
    total_validation_loss = 0

    # Prepare masks
    val_gen_mask = torch.tensor(data["general_station"].val_mask, dtype=torch.bool).to(
        device
    )
    val_rain_mask = torch.tensor(
        data["rainfall_station"].val_mask, dtype=torch.bool
    ).to(device)
    test_gen_mask = torch.tensor(
        data["general_station"].test_mask, dtype=torch.bool
    ).to(device)
    test_rain_mask = torch.tensor(
        data["rainfall_station"].test_mask, dtype=torch.bool
    ).to(device)
    train_gen_mask = torch.tensor(
        data["general_station"].train_mask, dtype=torch.bool
    ).to(device)
    train_rain_mask = torch.tensor(
        data["rainfall_station"].train_mask, dtype=torch.bool
    ).to(device)

    # Move edge data to device
    edge_index_dict = {key: val.to(device) for key, val in data.edge_index_dict.items()}
    edge_attr_dict = {key: val.to(device) for key, val in data.edge_attr_dict.items()}

    with torch.no_grad():  # No gradients during validation
        for batch in tqdm.tqdm(dataloader, desc="Validation"):
            gen_x = batch["gen_x"].to(device)
            rain_x = batch["rain_x"].to(device)
            radar_x = batch["radar_x"].to(device)  # Add radar
            gen_y = batch["gen_y"].to(device)
            rain_y = batch["rain_y"].to(device)

            batch_size = gen_x.shape[0]
            batch_loss = 0

            for i in range(batch_size):
                # Mask train and test stations for validation
                gen_x_masked = gen_x[i].clone()
                rain_x_masked = rain_x[i].clone()

                gen_x_masked[test_gen_mask] = 0
                rain_x_masked[test_rain_mask] = 0
                gen_x_masked[train_gen_mask] = 0
                rain_x_masked[train_rain_mask] = 0

                # Create input dictionary with radar data
                x_dict = {
                    "general_station": gen_x_masked,
                    "rainfall_station": rain_x_masked,
                    "radar_grid": radar_x[i],  # Radar is NOT masked
                }

                out = model(x_dict, edge_index_dict, edge_attr_dict)

                # Calculate loss only on validation stations
                gen_predictions = out["general_station"][val_gen_mask]
                rain_predictions = out["rainfall_station"][val_rain_mask]

                validation_loss = F.mse_loss(
                    gen_predictions, gen_y[i][val_gen_mask]
                ) + F.mse_loss(rain_predictions, rain_y[i][val_rain_mask])

                batch_loss += validation_loss.item()

            total_validation_loss += batch_loss / batch_size

    return total_validation_loss / len(dataloader)

In [None]:
# set seeds
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


batch_size = 16
train_dataset = WeatherGraphDatasetWithRadarNew(data, mode="train")
val_dataset = WeatherGraphDatasetWithRadarNew(data, mode="val")


def collate_temporal_graphs(batch):
  gen_x = torch.stack([item['gen_x'] for item in batch])
  rain_x = torch.stack([item['rain_x'] for item in batch])
  radar_x= torch.stack([item["radar_x"]for item in batch])
  gen_y = torch.stack([item['gen_y'] for item in batch])
  rain_y = torch.stack([item['rain_y'] for item in batch])

  metastation_mask = batch[0]['metastation_mask']
  rainfallstation_mask = batch[0]['rainfallstation_mask']
  edge_index_dict = batch[0]['edge_index_dict']
  edge_attribute_dict = batch[0]['edge_attr_dict']

  return {
      'gen_x': gen_x,
      'rain_x': rain_x,
      "radar_x": radar_x,
      'gen_y': gen_y,
      'rain_y': rain_y,
      'metastation_mask': metastation_mask,
      'rainfallstation_mask': rainfallstation_mask,
      'edge_index_dict': edge_index_dict,
      'edge_attr_dict': edge_attribute_dict
  }

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_temporal_graphs,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_temporal_graphs,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
training_loss_arr = []
validation_loss_arr = []
early = 0
mini = 1000
stopping_condition = 3
epochs = 0

training_start = time.time()
for i in range(20):
    print(f"-----EPOCH: {i + 1}-----")
    train_loss = train_epoch(model, data, train_loader, optimizer, device, verbose=True)
    validation_loss = validate(model, data, val_loader, device)
    training_loss_arr.append(train_loss)
    validation_loss_arr.append(validation_loss)
    if mini >= validation_loss:
        mini = validation_loss
        early = 0
    else:
        early += 1
    epochs += 1
    if early >= stopping_condition:
        print("Early stop loss")
        break

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Validation Loss: {validation_loss:.4f}")

training_end = time.time()

print(f"Training took {training_end - training_start} seconds over {epochs} epochs")
plt.plot(training_loss_arr, label="training_loss", color="blue")
plt.plot(validation_loss_arr, label="validation_loss", color="red")
plt.legend()

torch.save(model.state_dict(), "weights/weather_gnn_best.pth")
print("âœ… model weights saved to weather_gnn_best.pth")


In [None]:
print(next(iter(val_loader))["gen_x"].shape)

In [None]:
total_params = sum(param.numel() for param in model.parameters())
print(total_params)
print(list(param for param in model.parameters()))


In [None]:
def test_model(model, data, device, collate_fn):
    model.eval()
    total_rmse = 0

    plot_preds = np.array([])
    plot_actual = np.array([])

    test_dataset = WeatherGraphDatasetWithRadarNew(data, mode="test") 

    val_gen_mask = torch.tensor(
        data["general_station"].val_mask, dtype=torch.bool
    ).to(device)
    val_rain_mask = torch.tensor(
        data["rainfall_station"].val_mask, dtype=torch.bool  # Bug fixed here
    ).to(device)
    test_gen_mask = torch.tensor(
        data["general_station"].test_mask, dtype=torch.bool
    ).to(device)
    test_rain_mask = torch.tensor(
        data["rainfall_station"].test_mask, dtype=torch.bool
    ).to(device)

    edge_index_dict = {key: val.to(device) for key, val in data.edge_index_dict.items()}
    edge_attr_dict = {key: val.to(device) for key, val in data.edge_attr_dict.items()}

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn  # Use the passed-in collate_fn
    )

    count = 0
    printed_stats = False  # Flag to print stats only once
    with torch.no_grad():
        for batch in tqdm.tqdm(test_dataloader, desc="Testing"):
            gen_x = batch["gen_x"].to(device)
            rain_x = batch["rain_x"].to(device)
            radar_x = batch["radar_x"].to(device)
            gen_y = batch["gen_y"].to(device)
            rain_y = batch["rain_y"].to(device)

            batch_size = gen_x.shape[0]
            batch_rmse = 0

            for i in range(batch_size):
                gen_x_masked = gen_x[i].clone()
                rain_x_masked = rain_x[i].clone()

                # Mask out all non-training nodes, as done in training
                gen_x_masked[val_gen_mask] = 0
                rain_x_masked[val_rain_mask] = 0
                gen_x_masked[test_gen_mask] = 0
                rain_x_masked[test_rain_mask] = 0
                
                x_dict = {
                    "general_station": gen_x_masked,
                    "rainfall_station": rain_x_masked,
                    "radar_grid": radar_x[i],
                }

                out = model(x_dict, edge_index_dict, edge_attr_dict)

                # 3. Calculate predictions and targets on TEST nodes
                gen_predictions = out["general_station"][test_gen_mask]
                rain_predictions = out["rainfall_station"][test_rain_mask]

                gen_targets = gen_y[i][test_gen_mask]
                rain_targets = rain_y[i][test_rain_mask]

                # ==== DEBUGGING PRINT BLOCK ====
                if not printed_stats and i == 0:
                    print("\n--- DEBUG STATS (First Sample of Test Set) ---")
                    
                    print("\n--- INPUT FEATURES (Masked) ---")
                    print(f"Radar grid features:     Min={x_dict['radar_grid'].min():.2f}, Max={x_dict['radar_grid'].max():.2f}, Mean={x_dict['radar_grid'].mean():.2f}")
                    print(f"General station features:  Min={x_dict['general_station'].min():.2f}, Max={x_dict['general_station'].max():.2f}, Mean={x_dict['general_station'].mean():.2f}")
                    
                    print("\n--- MODEL PREDICTIONS (for TEST nodes) ---")
                    if gen_predictions.numel() > 0:
                        print(f"Pred (gen) tensor:    Min={gen_predictions.min():.2f}, Max={gen_predictions.max():.2f}, Mean={gen_predictions.mean():.2f}")
                    else:
                        print("Pred (gen) tensor is empty for this batch.")

                    if rain_predictions.numel() > 0:
                        print(f"Pred (rain) tensor:   Min={rain_predictions.min():.2f}, Max={rain_predictions.max():.2f}, Mean={rain_predictions.mean():.2f}")
                    else:
                        print("Pred (rain) tensor is empty for this batch.")


                    print("\n--- GROUND TRUTH (for TEST nodes) ---")
                    if gen_targets.numel() > 0:
                        print(f"Truth (gen) tensor:    Min={gen_targets.min():.2f}, Max={gen_targets.max():.2f}, Mean={gen_targets.mean():.2f}")
                    else:
                        print("Truth (gen) tensor is empty for this batch.")

                    if rain_targets.numel() > 0:
                        print(f"Truth (rain) tensor:   Min={rain_targets.min():.2f}, Max={rain_targets.max():.2f}, Mean={rain_targets.mean():.2f}")
                    else:
                        print("Truth (rain) tensor is empty for this batch.")
                    
                    printed_stats = True
                # ===============================

                plot_preds = np.concatenate(
                    (
                        plot_preds,
                        gen_predictions.cpu().detach().numpy().flatten(),
                        rain_predictions.cpu().detach().numpy().flatten(),
                    )
                )
                plot_actual = np.concatenate(
                    (
                        plot_actual,
                        gen_targets.cpu().detach().numpy().flatten(),
                        rain_targets.cpu().detach().numpy().flatten(),
                    )
                )

                gen_MSE_arr = (gen_predictions - gen_targets) ** 2
                rain_MSE_arr = (rain_predictions - rain_targets) ** 2

                all_squared_errors = torch.cat([gen_MSE_arr, rain_MSE_arr])
                test_rmse = torch.sqrt(torch.mean(all_squared_errors))

                batch_rmse += test_rmse.item()
                count += 1

            total_rmse += batch_rmse

    # --- Plotting and Final Metrics ---
    plt.figure(figsize=(8, 8)) # Make plot bigger
    plt.scatter(x=plot_actual, y=plot_preds, alpha=0.5) # Add alpha
    plot_bound = max(
        np.nanmax(plot_actual).astype(int), np.nanmax(plot_preds).astype(int)
    )
    plt.plot(np.linspace(0, plot_bound, 100), np.linspace(0, plot_bound, 100), 'r--') # Add red dashed line
    plt.xlabel("Actual Rainfall (Test Stations)")
    plt.ylabel("Predicted Rainfall (Test Stations)")
    plt.title("Test Set Performance")
    plt.grid(True)
    plt.savefig("radar_new_test_scatter_plot.png")
    plt.close() # Close plot to prevent double display
    print("Saved test scatter plot to 'test_scatter_plot.png'")


    mask = ~np.isnan(plot_actual) & ~np.isnan(plot_preds) # Also check for NaNs in preds
    pearson_r_global, pearson_p_global = pearsonr(plot_actual[mask], plot_preds[mask])

    print(f"Pearson correlation (Test Stations): {pearson_r_global}")
    
    final_rmse = total_rmse / count
    print(f"Final Test RMSE (Test Stations): {final_rmse}")
    return final_rmse

In [None]:
RMSE = test_model(model, data, device, collate_temporal_graphs)
print(f"TEST RMSE: {RMSE}")

# Visualisation of output
Test event will be 02-05-2025 0415 to 0615


In [None]:
def visualize_one_event(test_event_data, radar_features_event, do_plot=True):
    """
    Prepare a single example from `test_event_data` (a pandas slice like
    weather_station_df_pivot.iloc[593:602]) and run the model for inference.

    This function returns:
      gen_out: numpy array of predicted general_station outputs (shape: [num_gen_nodes, out_features])
      rain_out: numpy array of predicted rainfall_station outputs (shape: [num_rain_nodes, out_features])
    """
    model.eval()

    # clone template (so we keep masks/edge_index/order)
    test_data = data.clone()

    # --- collect station-wise time-series just like you had before ---
    test_general_station_data = {}
    test_rainfall_station_data = {}

    for station in test_event_data.columns.get_level_values(1).unique():
        station_cols = (
            test_event_data.xs(station, level=1, axis=1)
            .interpolate(method="linear")
            .fillna(method="ffill")
            .fillna(method="bfill")
        )
        if station in general_station:
            test_general_station_data[station] = (
                station_cols.values
            )  # shape [T, gen_feat]
        else:
            test_rainfall_station_data[station] = station_cols.values[
                :, 0:1
            ]  # [T, rain_feat=1]

    # Build arrays in the correct node ordering
    gen_feats_list = []
    rain_feats_list = []

    for station in general_station:
        gen_feats_list.append(
            test_general_station_data[station]
        )  # each item: [T, gen_feat_per_t]
    for station in rainfall_station:
        rain_feats_list.append(
            test_rainfall_station_data[station]
        )  # each item: [T, rain_feat_per_t]

    # Convert to numpy arrays and get shapes
    # After np.array(gen_feats_list) => shape [num_gen_nodes, T, gen_feat_per_t]
    gen_arr = np.array(gen_feats_list)  # [N_gen, T, Fg]
    rain_arr = np.array(rain_feats_list)  # [N_rain, T, Fr]

    # --- Convert to the 2-D node-feature format the model expects ---
    # There are different sensible choices here:
    #  - take last timestep: arr[:, -1, :] -> [N, F]
    #  - flatten the time axis into the feature axis: arr.reshape(N, T*F)
    # The training/test collate you used produces per-node features (no time dim).
    # To match that, we flatten time into features (preserves the whole window).
    def flatten_time_axis(arr):
        # arr: [N, T, F]
        N, T, F = arr.shape
        return arr.reshape(N, T * F)  # [N, T*F]

    gen_node_feats = gen_arr[:, -1, :].astype(np.float32)
    rain_node_feats = rain_arr[:, -1, :].astype(np.float32)  # [N, F]
    radar_node_feats = radar_features_event[-1].float()  # [N_radar, F]

    # Convert to torch tensors (2-D per node type) and move to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_dict = {
        "general_station": torch.tensor(
            gen_node_feats, dtype=torch.float, device=device
        ),
        "rainfall_station": torch.tensor(
            rain_node_feats, dtype=torch.float, device=device
        ),
        "radar_grid": torch.tensor(radar_node_feats, dtype=torch.float, device=device),
    }

    # Move edge structures to device (the same ones you used in test_model)
    edge_index_dict = {k: v.to(device) for k, v in data.edge_index_dict.items()}
    edge_attr_dict = {k: v.to(device) for k, v in data.edge_attr_dict.items()}

    # Run model
    with torch.no_grad():
        out = model(x_dict, edge_index_dict, edge_attr_dict)

    # out[...] are torch tensors shaped like [num_nodes, out_features] (depending on your model head)
    gen_out = out["general_station"].cpu().numpy()
    rain_out = out["rainfall_station"].cpu().numpy()

    return gen_out, rain_out


In [None]:
test_event_data = weather_station_df_pivot.iloc[593:602]  # your 9 timestamps
radar_features_event = data['radar_grid'].x[593:602]
gen_out, rain_out = visualize_one_event(test_event_data, radar_features_event)
out_np = np.concatenate([gen_out, rain_out], axis=0)

In [None]:
print(test_data.edge_index_dict)

# Visualise rain on radar grid
Hard coded to plot only consequitive 9 timestamps

In [None]:
print(out_np / 12)

# Visualize Radar Image

In [None]:
radar_df = load_radar_dataset("sg_radar_data")

visualize_one_radar_image(radar_df=radar_df)

In [None]:
fig, ax = plt.subplots(
    3, 3, figsize=(15, 12), subplot_kw={"projection": ccrs.PlateCarree()}
)

out_np = out_np / 12
for idx, timestamp in enumerate(out_np):
    output = {}
    count = 0

    for stn in general_station:
        output[stn] = float(timestamp[count])
        count += 1
    for stn in rainfall_station:
        output[stn] = float(timestamp[count])
        count += 1
    axi = ax[idx // 3][idx % 3]
    node_df = pd.Series(output)
    node_df = pandas_to_geodataframe(node_df)
    visualise_gauge_grid(node_df=node_df, ax=axi)
    improved_visualise_radar_grid(
        radar_df.iloc[idx], ax=axi, zoom=bounds_singapore, norm=norm
    )
    visualise_singapore_outline(ax=axi)

In [None]:
original_rainfall_rates = (
    weather_station_df_pivot.iloc[1773:1797].resample("15min").first()["rain_rate"]
)


print(original_rainfall_rates)

In [None]:
print(out)

In [None]:
actual_arr = []
pred_arr = []

for idx, timestamp in enumerate(out):
    output = {}
    count = 0
    a_arr = []
    p_arr = []

    for stn in general_station:
        output[stn] = float(timestamp[count])
        count += 1
    for stn in rainfall_station:
        output[stn] = float(timestamp[count])
        count += 1

    for key, value in output.items():
        a_arr.append(original_rainfall_rates.iloc[idx][key])
        p_arr.append(output[key])
    a_arr = list(map(lambda x: float(x), a_arr))
    actual_arr.append(a_arr)
    pred_arr.append(p_arr)

actual_arr = np.array(actual_arr)
pred_arr = np.array(pred_arr)

print(actual_arr)
print(pred_arr)
error = []
for i in range(len(actual_arr)):
    error.append(np.nanmean(actual_arr - pred_arr) ** 2)

MSE = np.mean(np.array(error))
print(MSE)


In [None]:
print(original_rainfall_rates.iloc[0])