In [1]:
import numpy as np
import pandas as pd
import geopandas as gpd
from pathlib import Path
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import Data, InMemoryDataset, DataLoader
from torch.functional import F
from tqdm import tqdm
from pathlib import Path
import pandas as pd
import geopandas as gpd
import torch
import rasterio
from rasterio.windows import Window
import torch_geometric.nn

import torch
from sklearn.neighbors import NearestNeighbors
from torch.utils.data import Dataset
from torch.utils.data import Dataset

In [None]:

FLOOD_KEEP = [
    "region",
    "in_flom_analyseomraade",
    "hazard_class_flom",
]

LANDSLIDE_CLOSEST_KEEP = [
    "inside_source_area",
    "inside_runout_area",
    "any_landslide_area_inside",
    "dist_to_trigger_point",
    "dist_to_runout_point",
    "dist_to_source_area",
    "dist_to_runout_area",
    "dist_to_landslide_event",
    "sourcearea_area_m2",
    "runoutarea_area_m2",
    "trigger_point_far",
    "runout_point_far",
    "landslide_event_far",
]

LANDSLIDE_EVENT_RAW_KEEP = [
    "skredID",
    "distance_m",
    "skredType",
    "skredTidspunkt",
]

HYDRO_KEEP = [
    "dist_to_river",
    "dist_to_lake",
    "dist_to_hyd",
    "arealEnhet_km2",
    "arealTotal_km2",
    "QNormal_lskm2",
    "QNormal_Mm3Aar",
    "QNormalOppstrm_Mm3Aar",
    "elveordenstrahler",
    "arealregineenhet_km2",
    "areal_km2",
    "arealNorge_km2",
    "nedborfeltareal_km2",
    "minsteVannforing",
]


def aggregate_per_house(df, house_id_col="bygningsnummer"):
    if df is None or df.empty:
        return df
    if house_id_col not in df.columns:
        return df

    num_cols = [
        c for c in df.columns
        if c != house_id_col and pd.api.types.is_numeric_dtype(df[c])
    ]
    bool_cols = [
        c for c in df.columns
        if c != house_id_col and df[c].dtype == "bool"
    ]
    other_cols = [
        c for c in df.columns
        if c not in ([house_id_col] + num_cols + bool_cols)
    ]

    agg = {}
    for c in num_cols:
        agg[c] = "median"
    for c in bool_cols:
        agg[c] = "max"     
    for c in other_cols:
        agg[c] = "first"   

    out = df.groupby(house_id_col, as_index=False).agg(agg)
    return out


def keep_only(df: pd.DataFrame, keep: list[str], *, always_keep: str | None = None) -> pd.DataFrame:
    keep_set = set(keep)
    if always_keep:
        keep_set.add(always_keep)

    cols = [c for c in df.columns if c in keep_set]
    return df.loc[:, cols].copy()


def encode_raw_landslides(
    landslides_events: pd.DataFrame,
    house_df: pd.DataFrame,
    house_id_col="bygningsnummer"
) -> pd.DataFrame:
    """
    0 = no events
    1 = one event
    2 = two+ events
    """
    if landslides_events is None or landslides_events.empty:
        out = house_df[[house_id_col]].copy()
        out["landslide_hazard_level"] = 0
        return out

    counts = landslides_events[house_id_col].value_counts()
    hazard = counts.apply(lambda x: 2 if x > 1 else 1)

    hazard_df = hazard.reset_index()
    hazard_df.columns = [house_id_col, "landslide_hazard_level"]

    merged = house_df[[house_id_col]].merge(hazard_df, on=house_id_col, how="left")
    merged["landslide_hazard_level"] = merged["landslide_hazard_level"].fillna(0).astype(int)
    return merged


def clean_link_tables(
    flood_links: pd.DataFrame | None = None,
    landslide_closest: pd.DataFrame | None = None,
    landslide_events_raw: pd.DataFrame | None = None,
    hydro_links: pd.DataFrame | None = None,
    *,
    house_id_col: str = "bygningsnummer",
) -> dict[str, pd.DataFrame]:

    out = {}

    if flood_links is not None:
        out["flood"] = keep_only(flood_links, FLOOD_KEEP, always_keep=house_id_col)

    if landslide_closest is not None:
        out["landslide_closest"] = keep_only(landslide_closest, LANDSLIDE_CLOSEST_KEEP, always_keep=house_id_col)

    if landslide_events_raw is not None:
        out["landslide_events_raw"] = keep_only(landslide_events_raw, LANDSLIDE_EVENT_RAW_KEEP, always_keep=house_id_col)

    if hydro_links is not None:
        out["hydro"] = keep_only(hydro_links, HYDRO_KEEP, always_keep=house_id_col)

    return out


def load_houses_data(
    region: str,
    base_dir: str | Path = "master",
    target_epsg: int = 25833,
    house_id_col: str = "bygningsnummer",
    use_landslide_closest: bool = True,
):
    base_dir = Path(base_dir)

    houses_path = base_dir / f"raw/vector/houses/houses_{region}.gpkg"
    houses_layer = f"houses_{region}"

    houses = gpd.read_file(houses_path, layer=houses_layer).to_crs(epsg=target_epsg)
    houses = houses.set_geometry("geometry")
    houses[house_id_col] = houses[house_id_col].astype("int64")

    # link tables
    f_links = pd.read_parquet(base_dir / f"processed/links/flood_links_{region}.parquet")
    h_links = pd.read_parquet(base_dir / f"processed/links/hydro_links_{region}.parquet")

    # closest landslide table is optional
    l_closest = None
    if use_landslide_closest:
        l_closest_path = base_dir / f"processed/links/landslide_links_closest_{region}.parquet"
        if l_closest_path.exists():
            l_closest = pd.read_parquet(l_closest_path)

    # raw landslide events -> used for encoding target
    l_events = pd.read_parquet(base_dir / f"processed/links/landslide_links_{region}.parquet")

    cleaned = clean_link_tables(
        flood_links=f_links,
        landslide_closest=l_closest,
        landslide_events_raw=l_events,
        hydro_links=h_links,
        house_id_col=house_id_col,
    )

    f_clean = aggregate_per_house(cleaned["flood"], house_id_col)
    h_clean = aggregate_per_house(cleaned["hydro"], house_id_col)
    l_closest_clean = aggregate_per_house(cleaned.get("landslide_closest"), house_id_col)

    l_events_clean = cleaned.get("landslide_events_raw")
    l_encoded = encode_raw_landslides(l_events_clean, houses, house_id_col=house_id_col)
    l_encoded = aggregate_per_house(l_encoded, house_id_col)

    # merge links safely
    link_tables = [f_clean, h_clean, l_encoded]
    if l_closest_clean is not None:
        link_tables.append(l_closest_clean)

    links = link_tables[0]
    for nxt in link_tables[1:]:
        links = links.merge(nxt, on=house_id_col, how="left")

    houses_gdf = houses.merge(links, on=house_id_col, how="left")
    houses_gdf = gpd.GeoDataFrame(houses_gdf, geometry="geometry", crs=houses.crs)

    return houses_gdf


def build_knn_edge_index(coords: np.ndarray, k: int = 8, make_undirected: bool = True):
    if len(coords) == 0:
        return torch.empty((2, 0), dtype=torch.long)

    n_neighbors = min(k + 1, len(coords))
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm="auto")
    nbrs.fit(coords)
    _, idxs = nbrs.kneighbors(coords)

    src_list, dst_list = [], []
    for i in range(len(coords)):
        neigh = idxs[i][1:]  # remove self
        for j in neigh:
            src_list.append(i)
            dst_list.append(j)

    edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)

    if make_undirected and edge_index.numel() > 0:
        rev = edge_index.flip(0)
        edge_index = torch.cat([edge_index, rev], dim=1)

    return edge_index


from pathlib import Path
import torch
import numpy as np


def build_and_save_knn_graph(
    houses_gdf,
    out_path,
    house_id_col="bygningsnummer",
    k=8,
):
    houses = houses_gdf[[house_id_col, "geometry"]].copy()
    houses[house_id_col] = houses[house_id_col].astype(int)
    houses = houses.dropna(subset=["geometry"])
    houses = houses.sort_values(house_id_col).reset_index(drop=True)

    coords = np.array([(g.x, g.y) for g in houses.geometry], dtype=np.float32)
    pos = torch.from_numpy(coords)

    edge_index = build_knn_edge_index(coords, k=k)

    house_id = torch.tensor(houses[house_id_col].to_numpy(), dtype=torch.long)

    payload = {
        "house_id": house_id,
        "pos": pos,
        "edge_index": edge_index,
        "k": k,
        "house_id_col": house_id_col,
    }

    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(payload, out_path)
    print("Saved graph:", out_path, "| nodes:", len(houses))

    return payload


def load_graph_bundle(path):
    return torch.load(Path(path), map_location="cpu")


def default_raster_paths_for_region(region, base_dir="../"):
    base_dir = Path(base_dir)
    raw_raster_dir = base_dir / "raw" / "rasters"
    processed_raster_dir = base_dir / "processed" / "rasters"

    return {
        "dtm": raw_raster_dir / f"{region}_dtm10.tif",
        "slope": processed_raster_dir / f"{region}_slope_deg.tif",
        "aspect": processed_raster_dir / f"{region}_aspect_deg.tif",
        "curv": processed_raster_dir / f"{region}_curvature.tif",
        "flowacc": processed_raster_dir / f"{region}_flowacc_d8.tif",
        "twi": processed_raster_dir / f"{region}_twi.tif",
    }


from torch_geometric.data import Data


def build_data_objects(
    region,
    house_id_col="bygningsnummer",
    base_dir="../",
    flood_label="hazard_class_flom",
    landslide_label="landslide_hazard_level",
    graph_path=None,
):
    houses_gdf = load_houses_data(
        region,
        base_dir=base_dir,
    )
    houses_gdf = houses_gdf.set_geometry("geometry").copy()
    houses_gdf[house_id_col] = houses_gdf[house_id_col].astype(int)

    graph = load_graph_bundle(graph_path)
    graph_ids = graph["house_id"].cpu().numpy().tolist()

    # 1) filter houses to graph ids
    houses = houses_gdf[houses_gdf[house_id_col].isin(graph_ids)].copy()

    # 2) enforce EXACT order = graph house_id order
    order_map = {hid: i for i, hid in enumerate(graph_ids)}
    houses["__order"] = houses[house_id_col].map(order_map)
    houses = houses.sort_values("__order").drop(columns="__order").reset_index(drop=True)

    # sanity
    assert len(houses) == len(graph_ids), f"Graph nodes ({len(graph_ids)}) != houses ({len(houses)})"

    # 3) build numeric features
    df_for_features = pd.DataFrame(houses.drop(columns=["geometry"], errors="ignore"))

    drop_always = {house_id_col, "bygningId", "gml_id", "uuidBygning", "versjonId"}
    drop_labels = {flood_label, landslide_label}

    num_df = df_for_features.select_dtypes(include=[np.number]).copy()
    num_df = num_df[[c for c in num_df.columns if c not in (drop_always | drop_labels)]]
    num_df = num_df.dropna(axis=1, how="all")
    num_df = num_df.fillna(num_df.median(numeric_only=True))

    x = torch.tensor(num_df.to_numpy(dtype=np.float32))

    # 4) labels + masks
    y_flood_t = None
    y_landslide_t = None
    y_flood_mask_t = None
    y_landslide_mask_t = None

    if flood_label in df_for_features.columns:
        y_f = df_for_features[flood_label].to_numpy()
        flood_mask = ~pd.isna(y_f)
        y_f = pd.Series(y_f).fillna(-1).astype(int).to_numpy()
        y_flood_t = torch.tensor(y_f, dtype=torch.long)
        y_flood_mask_t = torch.tensor(flood_mask, dtype=torch.bool)

    if landslide_label in df_for_features.columns:
        y_l = df_for_features[landslide_label].to_numpy()
        landslide_mask = ~pd.isna(y_l)
        y_l = pd.Series(y_l).fillna(-1).astype(int).to_numpy()
        y_landslide_t = torch.tensor(y_l, dtype=torch.long)
        y_landslide_mask_t = torch.tensor(landslide_mask, dtype=torch.bool)

    data = Data(
        x=x,
        pos=graph["pos"],
        edge_index=graph["edge_index"],
        house_id=graph["house_id"],
    )

    if y_flood_t is not None:
        data.y_flood = y_flood_t
        data.y_flood_mask = y_flood_mask_t

    if y_landslide_t is not None:
        data.y_landslide = y_landslide_t
        data.y_landslide_mask = y_landslide_mask_t

    return data, houses, num_df.columns.tolist(), graph

from torch_geometric.transforms import RandomNodeSplit


def add_splits(data, val_ratio=0.1, test_ratio=0.1):
    splitter = RandomNodeSplit(
        num_val=int(val_ratio * data.num_nodes),
        num_test=int(test_ratio * data.num_nodes),
    )
    data = splitter(data)
    return data


class TerrainBuilderStack:
    def __init__(self, region, raster_paths, houses_gdf, patches_m=200):
        self.region = region
        self.raster_paths = raster_paths  # dict: {"dtm": Path, ...}
        self.houses_gdf = houses_gdf
        self.patches_m = patches_m
        self._sources = None
        self.label_flood = "hazard_class_flom"
        self.label_landslide = "landslide_hazard_level"

    def _open_sources(self):
        if self._sources is None:
            keys = ["dtm", "slope", "aspect", "curv", "flowacc", "twi"]
            self._sources = {k: rasterio.open(self.raster_paths[k]) for k in keys}

            # ensure house CRS matches rasters
            dtm_crs = self._sources["dtm"].crs
            if self.houses_gdf.crs != dtm_crs:
                self.houses_gdf = self.houses_gdf.to_crs(dtm_crs)

        return self._sources

    def close(self):
        if self._sources:
            for src in self._sources.values():
                src.close()
        self._sources = None

    def extract_one_patch(self, geom):
        srcs = self._open_sources()
        dtm_src = srcs["dtm"]

        res_x, res_y = dtm_src.res
        assert abs(res_x - res_y) < 1e-6, "Rasters must have square pixels"
        res = float(res_x)

        half_pixels = int((self.patches_m / 2) / res)
        patch_pixels = half_pixels * 2

        row_idx, col_idx = dtm_src.index(geom.x, geom.y)

        window = Window(
            col_idx - half_pixels,
            row_idx - half_pixels,
            patch_pixels,
            patch_pixels,
        )

        keys = ["dtm", "slope", "aspect", "curv", "flowacc", "twi"]
        patches = [srcs[k].read(1, window=window) for k in keys]

        # Skip edge cases
        if any(p.shape != (patch_pixels, patch_pixels) for p in patches):
            return None

        stack = np.stack(patches, axis=0).astype("float32")

        stack = np.where(np.isfinite(stack), stack, np.nan)
        for c in range(stack.shape[0]):
            med = np.nanmedian(stack[c])
            if np.isnan(med):
                med = 0.0
            stack[c] = np.where(np.isnan(stack[c]), med, stack[c])

        return stack


class TerrainEncoderCNN(nn.Module):
    def __init__(self, in_channels=6, t_dim=128):
        super(TerrainEncoderCNN, self).__init__()
        import torch.nn as nn
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, bias=False, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
        )
        self.block1 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.SiLU(),
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 96, 3, padding=1, bias=False),
            nn.BatchNorm2d(96),
            nn.SiLU(),
            nn.Conv2d(96, 96, 3, padding=1, bias=False),
            nn.BatchNorm2d(96),
            nn.SiLU(),
        )

        self.block3 = nn.Sequential(
            nn.Conv2d(96, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.SiLU(),
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.SiLU(),
        )

        self.pool = nn.MaxPool2d(2)
        self.gap = nn.AdaptiveAvgPool2d((1, 1))

        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 256),
            nn.SiLU(),
            nn.Dropout(0.2),
            nn.Linear(256, t_dim),
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.pool(x)
        x = self.block2(x)
        x = self.pool(x)
        x = self.block3(x)
        x = self.gap(x)
        x = self.head(x)
        return x


class TabularMLP(nn.Module):
    def __init__(self, in_dim, d_tab):
        super(TabularMLP, self).__init__()
        import torch.nn as nn
        self.net = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.LayerNorm(128),
            nn.SiLU(),
            nn.Dropout(0.15),

            nn.Linear(128, 128),
            nn.LayerNorm(128),
            nn.SiLU(),
            nn.Dropout(0.15),

            nn.Linear(128, d_tab),
        )

    def forward(self, x):
        return self.net(x)


class CombinedMLP(nn.Module):
    def __init__(self, d_terrain=128, d_tab=64, d_node=192):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_terrain + d_tab, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
            nn.Dropout(0.2),
            nn.Linear(256, d_node),
        )

    def forward(self, z_terrain, z_tab):
        return self.net(torch.cat([z_terrain, z_tab], dim=-1))

from torch_geometric.nn import GCNConv  
class GNNBackbone(nn.Module):
    def __init__(self, in_dim=192, hidden_dim=192, num_layers=3):
        super(GNNBackbone, self).__init__()

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_dim, hidden_dim))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)

        return x


class HazardHeads(nn.Module):
    def __init__(self, in_dim, num_flood_classes=3):
        super(HazardHeads, self).__init__()
        self.flood_head = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_flood_classes),
        )

    def forward(self, x):
        flood_logits = self.flood_head(x)
        return flood_logits


class HazardGNNModel(nn.Module):
    def __init__(
        self,
        terrain_in_channels=6,
        terrain_dim=128,
        tabular_in_dim=50,
        tabular_dim=64,
        gnn_hidden_dim=192,
        gnn_layers=3,
        num_flood_classes=3,
    ):
        super(HazardGNNModel, self).__init__()

        self.terrain_encoder = TerrainEncoderCNN(
            in_channels=terrain_in_channels,
            t_dim=terrain_dim
        )
        self.tabular_mlp = TabularMLP(
            in_dim=tabular_in_dim,
            d_tab=tabular_dim
        )
        self.combined_mlp = CombinedMLP(
            d_terrain=terrain_dim,
            d_tab=tabular_dim,
            d_node=gnn_hidden_dim
        )

        # âœ… FIXED: input dim matches combined output
        self.gnn_backbone = GNNBackbone(
            in_dim=gnn_hidden_dim,
            hidden_dim=gnn_hidden_dim,
            num_layers=gnn_layers
        )

        self.hazard_heads = HazardHeads(
            in_dim=gnn_hidden_dim,
            num_flood_classes=num_flood_classes,
        )

    def forward(self, data, patches, node_idx):
        z_terrain = self.terrain_encoder(patches)
        z_tab = self.tabular_mlp(data.x[node_idx])
        combined = self.combined_mlp(z_terrain, z_tab)

        x_full = data.x.new_zeros((data.num_nodes, combined.size(-1)))
        x_full[node_idx] = combined

        gnn_feats = self.gnn_backbone(x_full, data.edge_index)
        out = gnn_feats[node_idx]

        flood_logits = self.hazard_heads(out)
        return flood_logits





class HouseGraphDataset(Dataset):
    def __init__(
        self,
        region,
        base_dir="../",
        k=8,
        flood_label="hazard_class_flom",
        landslide_label="landslide_hazard_level",
        graph_path="",
        patches_m=100,
    ):
        self.region = region
        self.base_dir = base_dir
        self.k = k
        self.flood_label = flood_label
        self.landslide_label = landslide_label
        self.graph_path = graph_path
        self.patches_m = patches_m

        self.raster_paths = default_raster_paths_for_region(region, base_dir=base_dir)

        data, houses_ok, feature_cols, graph = build_data_objects(
            region,
            house_id_col="bygningsnummer",
            base_dir=base_dir,
            flood_label=flood_label,
            landslide_label=landslide_label,
            graph_path=graph_path,
        )

        self.data = data
        self.houses_ok = houses_ok
        self.feature_cols = feature_cols
        self.graph = graph

        self.terrain = TerrainBuilderStack(
            region=region,
            raster_paths=self.raster_paths,
            houses_gdf=self.houses_ok,
            patches_m=patches_m,
        )

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.data

    def get_patches_for_node_idx(self, node_idx_tensor):
        patches = []
        for i in node_idx_tensor.tolist():
            geom = self.houses_ok.iloc[int(i)].geometry
            patch = self.terrain.extract_one_patch(geom)

            if patch is None:
                # fallback zeros
                # infer size from dtm resolution
                dtm_src = self.terrain._open_sources()["dtm"]
                res = float(dtm_src.res[0])
                half = int((self.patches_m / 2) / res)
                px = half * 2
                patch = np.zeros((6, px, px), dtype="float32")

            patches.append(torch.from_numpy(patch).float())

        return torch.stack(patches, dim=0)

    def close(self):
        self.terrain.close()

    def __del__(self):
        try:
            self.close()
        except Exception:
            pass


import torch.nn as nn


def masked_ce_loss(logits, y, mask):
    if logits is None or y is None or mask is None:
        return 0.0

    if mask.sum() == 0:
        return 0.0

    return nn.CrossEntropyLoss()(logits[mask], y[mask])


def _index_from_mask(mask):
    return mask.nonzero(as_tuple=False).view(-1)


def _safe_masked_idx(idx, mask):
    """Return subset of idx where mask is True."""
    if mask is None:
        return idx
    return idx[mask[idx]]


def train_epoch_node_batches(
    model,
    dataset,          # HouseGraphDataset
    train_idx,        # tensor of node indices
    optimizer,
    loss_fn_flood,
    loss_fn_landslide,
    device,
    node_batch_size=256,
    steps_per_epoch=200,
):
    model.train()
    data = dataset.data.to(device)

    total_loss = 0.0
    n_steps = 0

    # These may or may not exist depending on your build_data_objects
    flood_mask = getattr(data, "y_flood_mask", None)
    ls_mask    = getattr(data, "y_landslide_mask", None)

    for _ in tqdm(range(steps_per_epoch), desc="train steps"):
        if train_idx.numel() == 0:
            break

        # sample node batch from train_idx
        perm = torch.randperm(train_idx.numel())
        node_idx = train_idx[perm[:min(node_batch_size, train_idx.numel())]].to(device)

        # load patches only for these nodes
        patches = dataset.get_patches_for_node_idx(node_idx.cpu()).to(device)

        flood_logits = model(data, patches, node_idx)

        # ----- flood loss -----
        lf = 0.0
        if hasattr(data, "y_flood"):
            valid_f_idx = _safe_masked_idx(node_idx, flood_mask)
            if valid_f_idx.numel() > 0:
                # map valid_f_idx -> positions inside node_idx
                # easiest: build a mask relative to node_idx
                rel_mask = torch.isin(node_idx, valid_f_idx)
                y_f = data.y_flood[node_idx][rel_mask]
                pred_f = flood_logits[rel_mask]
                lf = loss_fn_flood(pred_f, y_f)

        # ----- landslide loss -----
        ll = 0.0

        # if both missing, skip step
        if isinstance(lf, float) and isinstance(ll, float) and lf == 0.0 and ll == 0.0:
            continue

        loss = lf + ll

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_steps += 1

    return total_loss / max(n_steps, 1)


@torch.no_grad()
def eval_epoch_node_batches(
    model,
    dataset,
    eval_idx,
    loss_fn_flood,
    loss_fn_landslide,
    device,
    node_batch_size=512,
):
    model.eval()
    data = dataset.data.to(device)

    flood_mask = getattr(data, "y_flood_mask", None)
    ls_mask    = getattr(data, "y_landslide_mask", None)

    total_loss = 0.0
    n_steps = 0

    flood_correct = 0
    flood_total = 0

    ls_correct = 0
    ls_total = 0

    # chunk eval_idx
    for start in tqdm(range(0, eval_idx.numel(), node_batch_size), desc="eval batches"):
        node_idx = eval_idx[start:start + node_batch_size].to(device)
        if node_idx.numel() == 0:
            continue

        patches = dataset.get_patches_for_node_idx(node_idx.cpu()).to(device)
        flood_logits = model(data, patches, node_idx)

        lf = 0.0
        ll = 0.0

        # flood metrics
        if hasattr(data, "y_flood"):
            valid_f_idx = _safe_masked_idx(node_idx, flood_mask)
            if valid_f_idx.numel() > 0:
                rel_mask = torch.isin(node_idx, valid_f_idx)
                y_f = data.y_flood[node_idx][rel_mask]
                pred_f = flood_logits[rel_mask]
                lf = loss_fn_flood(pred_f, y_f)

                flood_preds = pred_f.argmax(dim=1)
                flood_correct += (flood_preds == y_f).sum().item()
                flood_total += y_f.numel()

        # landslide metrics

        if isinstance(lf, float) and isinstance(ll, float) and lf == 0.0 and ll == 0.0:
            continue

        loss = lf + ll
        total_loss += loss.item()
        n_steps += 1

    flood_acc = flood_correct / flood_total if flood_total > 0 else 0.0
    ls_acc    = ls_correct / ls_total if ls_total > 0 else 0.0

    return (total_loss / max(n_steps, 1)), flood_acc, ls_acc


import torch
from torch.optim import Adam


ds = HouseGraphDataset(
    region="sogn",
    base_dir="../",
    graph_path="../processed/graphs/knn_graph_sogn_k8.pt",
    patches_m=100,
)


import torch
from torch.optim import Adam


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = ds.data
graph = torch.load("../processed/graphs/knn_graph_sogn_k8.pt", map_location="cpu")
train_idx = graph.get("train_idx", None)
val_idx   = graph.get("val_idx", None)
test_idx  = graph.get("test_idx", None)

# Fallback if you didn't save splits yet:
if train_idx is None:
    # simplest fallback using masks if they exist
    if hasattr(data, "train_mask"):
        train_idx = _index_from_mask(data.train_mask)
        val_idx   = _index_from_mask(data.val_mask)
        test_idx  = _index_from_mask(data.test_mask)
    else:
        # last-resort quick split over all nodes
        N = data.num_nodes
        perm = torch.randperm(N)
        n_test = int(0.1 * N)
        n_val  = int(0.1 * N)
        test_idx = perm[:n_test]
        val_idx  = perm[n_test:n_test+n_val]
        train_idx = perm[n_test+n_val:]

train_idx = train_idx.cpu()
val_idx   = val_idx.cpu()
test_idx  = test_idx.cpu()

# Model
tabular_in_dim = data.x.size(1)

model = HazardGNNModel(
    terrain_in_channels=6,
    terrain_dim=128,
    tabular_in_dim=tabular_in_dim,
    tabular_dim=64,
    gnn_hidden_dim=192,
    gnn_layers=3,
    num_flood_classes=3,
).to(device)

# Losses (keep separate so you can weight later if you want)
loss_f = nn.CrossEntropyLoss()
loss_l = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Train
for epoch in range(1, 21):
    train_loss = train_epoch_node_batches(
        model=model,
        dataset=ds,
        train_idx=train_idx,
        optimizer=optimizer,
        loss_fn_flood=loss_f,
        loss_fn_landslide=loss_l,
        device=device,
        node_batch_size=256,
        steps_per_epoch=200,
    )

    val_loss, val_f_acc, val_l_acc = eval_epoch_node_batches(
        model=model,
        dataset=ds,
        eval_idx=val_idx,
        loss_fn_flood=loss_f,
        loss_fn_landslide=loss_l,
        device=device,
        node_batch_size=512,
    )

    print(
        f"Epoch {epoch:02d} | "
        f"train_loss={train_loss:.4f} | "
        f"val_loss={val_loss:.4f} | "
        f"val_f_acc={val_f_acc:.3f} | "
        f"val_l_acc={val_l_acc:.3f}"
    )
