### Selected samples SiNET validation

**Author:** Jakub Walczak, PhD

This notebook contains validation of the proposed SiNET method against 
kriging and IDW for a set of validation samples.

The notebook was used to upgrade the design of SiNET.

In [1]:
import csv
import shutil
from functools import partial
from pathlib import Path
from typing import Any, Callable

import xarray as xr
from rich.console import Console

import climatrix as cm

%load_ext rich

In [2]:
console = Console()

NAN_POLICY = "resample"
console.print("[bold green]Using NaN policy: [/bold green]", NAN_POLICY)

SEED = 1
console.print("[bold green]Using seed: [/bold green]", SEED)

DSET_PATH = Path(__session__).parent.parent.joinpath("data")
console.print("[bold green]Using dataset path: [/bold green]", DSET_PATH)

EUROPE_BOUNDS = {"north": 71, "south": 36, "west": -24, "east": 35}
EUROPE_DOMAIN = cm.Domain.from_lat_lon(
    lat=slice(EUROPE_BOUNDS["south"], EUROPE_BOUNDS["north"], 0.1),
    lon=slice(EUROPE_BOUNDS["west"], EUROPE_BOUNDS["east"], 0.1),
    kind="dense",
)
cm.seed_all(SEED)

In [3]:
def get_all_dataset_idx() -> list[str]:
    return sorted(
        list({path.stem.split("_")[-1] for path in DSET_PATH.glob("*.nc")})
    )

In [4]:
def run_single_method(
    d: str, i: int, method: str, reconstruct_dense: bool = True, **params
):
    cm.seed_all(SEED)
    train_dset = xr.open_dataset(
        DSET_PATH / f"ecad_obs_europe_train_{d}.nc"
    ).cm
    val_dset = xr.open_dataset(DSET_PATH / f"ecad_obs_europe_val_{d}.nc").cm
    reconstructed_dset = train_dset.reconstruct(
        val_dset.domain,
        method=method,
        checkpoint="./checkpoint",
        overwrite_checkpoint=True,
        validation=val_dset,
        **params,
    )
    if reconstruct_dense:
        reconstructed_dense = train_dset.reconstruct(
            EUROPE_DOMAIN, method=method, checkpoint="./checkpoint", **params
        )
    return val_dset, reconstructed_dset, reconstructed_dense

In [5]:
dset_idx = get_all_dataset_idx()
console.print(
    f"[bold green]There is [bold yellow]{len(dset_idx)}[/bold yellow] samples available [/bold green]"
)

In [6]:
IDX = 0

In [7]:
sinet_val_dset, sinet_reconstructed_dset, sinet_reconstructed_dense = (
    run_single_method(
        dset_idx[IDX],
        IDX,
        "sinet",
        lr=1e-3,
        weight_decay=0,
        hidden_dim=32,
        sorting_group_size=16,
        layers=2,
        num_epochs=20,
        batch_size=20,
        num_workers=0,
        device="cuda",
        gradient_clipping_value=1e2,
        mse_loss_weight=2.0,
        eikonal_loss_weight=1e-4,
        laplace_loss_weight=1e-4,
        patience=None,
    )
)

10-09-2025 15:23:04 INFO | climatrix.reconstruct.nn.base_nn | Using checkpoint path: /home/jakub/projects/climatrix/experiments/jwalczak/01_Apr_02_compare_recon_method/notebooks/checkpoint
10-09-2025 15:23:04 INFO | climatrix.reconstruct.sinet.sinet | Initializing SiNET model...
10-09-2025 15:23:04 INFO | climatrix.reconstruct.sinet.sinet | Configuring Adam optimizer with learning rate: 0.001000
10-09-2025 15:23:09 INFO | climatrix.reconstruct.nn.base_nn | Training SiNET model...
10-09-2025 15:23:10 INFO | climatrix.reconstruct.nn.base_nn | Validation dataset is available. Using it for validation.


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Reconstructing target domain...
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Creating mini-batches for surface reconstruction...
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 1/1...
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Surface finding complete. Concatenating results.
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Using checkpoint path: /home/jakub/projects/climatrix/experiments/jwalczak/01_Apr_02_compare_recon_method/notebooks/checkpoint
10-09-2025 15:23:21 INFO | climatrix.reconstruct.sinet.sinet | Initializing SiNET model...
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Reconstructing target domain...
10-09-2025 15:23:21 INFO | climatrix.reconstruct.nn.base_nn | Creating mini-batches for surface reconstruction...


  torch.load(checkpoint, map_location=self.device)


10-09-2025 15:23:22 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 1/5...
10-09-2025 15:23:22 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 2/5...
10-09-2025 15:23:22 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 3/5...
10-09-2025 15:23:23 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 4/5...
10-09-2025 15:23:23 INFO | climatrix.reconstruct.nn.base_nn | Processing mini-batch 5/5...
10-09-2025 15:23:23 INFO | climatrix.reconstruct.nn.base_nn | Surface finding complete. Concatenating results.


In [8]:
cm.Comparison(sinet_val_dset, sinet_reconstructed_dset).compute_report()


[1m{[0m
    [32m'RMSE'[0m: [1;36m1.308415174484253[0m,
    [32m'MAE'[0m: [1;36m0.8040948510169983[0m,
    [32m'Max Abs Error'[0m: [1;36m5.594202041625977[0m,
    [32m'R^2'[0m: [1;36m0.8809424042701721[0m
[1m}[0m

### After optimising hyperpararmeters

In [None]:
BOUNDS = {
    "lr": (1e-5, 1e-2),
    "num_epochs": (400, 1000),
    "batch_size": (2, 2044),
    "mse_loss_weight": (1e-8, 1e3),
    "eikonal_loss_weight": (0, 1e-2),
    "laplace_loss_weight": (0, 1e-2),
    "patience": (10, 200),
    "scale": (0.01, 10.0),
    "hidden_dim": [16, 64, 128, 256],
    "weight_decay": (0.0, 1e-2)
}
console.print("[bold green]Hyperparameter bounds: [/bold green]", BOUNDS)

OPTIM_INIT_POINTS: int = 50
console.print(
    "[bold green]Using nbr initial points for optimization: [/bold green]",
    OPTIM_INIT_POINTS,
)

OPTIM_N_ITERS: int = 100
console.print(
    "[bold green]Using iterations for optimization[/bold green]", OPTIM_N_ITERS
)
console.print(
    "[bold green]Dataset: [/bold green]", dset_idx[IDX]
)

In [None]:
def find_hyperparameters(
    train_dset: cm.BaseClimatrixDataset,
    val_dset: cm.BaseClimatrixDataset,
    bounds: dict[str, tuple],
    n_init_points: int = 30,
    n_iter: int = 200,
    seed: int = 0,
    verbose: int = 2,
) -> tuple[float, dict[str, float]]:
    finder = cm.optim.HParamFinder(
        "sinet",
        train_dset,
        val_dset,
        metric="mae",
        n_iters=OPTIM_N_ITERS,
        bounds=BOUNDS,
        random_seed=SEED,
    )
    result = finder.optimize()
    return result


def run_single_experiment(d: str):
    train_dset = xr.open_dataset(
        DSET_PATH / f"ecad_obs_europe_train_{d}.nc"
    ).cm
    val_dset = xr.open_dataset(DSET_PATH / f"ecad_obs_europe_val_{d}.nc").cm
    result = find_hyperparameters(
        train_dset,
        val_dset,
        BOUNDS,
        n_init_points=OPTIM_INIT_POINTS,
        n_iter=OPTIM_N_ITERS,
        seed=SEED,
        verbose=2,
    )
    console.print("[bold yellow]Optimized parameters:[/bold yellow]")
    console.print(
        "[yellow]Learning rate (lr):[/yellow]", result["best_params"]["lr"]
    )
    console.print(
        "[yellow]Number of epochs:[/yellow]",
        result["best_params"]["num_epochs"],
    )
    console.print(
        "[yellow]Scale:[/yellow]",
        result["best_params"]["scale"],
    )
    console.print(
        "[yellow]Batch size:[/yellow]", result["best_params"]["batch_size"]
    )
    console.print(
        "[yellow]MSE loss weight:[/yellow]",
        result["best_params"]["mse_loss_weight"],
    )
    console.print(
        "[yellow]Eikonal loss weight:[/yellow]",
        result["best_params"]["eikonal_loss_weight"],
    )
    console.print(
        "[yellow]Laplace loss weight:[/yellow]",
        result["best_params"]["laplace_loss_weight"],
    )
    console.print(
        "[yellow]Early stopping patience:[/yellow]",
        result["best_params"]["patience"],
    )
    console.print(
        "[yellow]Hidden dimension:[/yellow]",
        result["best_params"]["hidden_dim"],
    )
    console.print(
        "[yellow]Weight decay:[/yellow]",
        result["best_params"]["weight_decay"],
    )    
    console.print("[yellow]Best loss:[/yellow]", result["best_score"])
    reconstructed_dset = train_dset.reconstruct(
        val_dset.domain,
        method="sinet",
        device="cuda",
        lr=result["best_params"]["lr"],
        weight_decay=result["best_params"]["weight_decay"],
        num_epochs=result["best_params"]["num_epochs"],
        batch_size=result["best_params"]["batch_size"],
        num_workers=0,
        scale=result["best_params"]["scale"],
        mse_loss_weight=result["best_params"]["mse_loss_weight"],
        eikonal_loss_weight=result["best_params"]["eikonal_loss_weight"],
        laplace_loss_weight=result["best_params"]["laplace_loss_weight"],
        patience=result["best_params"]["patience"],
        hidden_dim=result["best_params"]["hidden_dim"],
        checkpoint="./sinet_checkpoint.pth",
        overwrite_checkpoint=True,
    )
    cmp = cm.Comparison(reconstructed_dset, val_dset)
    metrics = cmp.compute_report()
    metrics["dataset_id"] = d
    hyperparams = {
        "dataset_id": d,
        "lr": result["best_params"]["lr"],
        "num_epochs": result["best_params"]["num_epochs"],
        "scale": result["best_params"]["scale"],
        "batch_size": result["best_params"]["batch_size"],
        "mse_loss_weight": result["best_params"]["mse_loss_weight"],
        "eikonal_loss_weight": result["best_params"]["eikonal_loss_weight"],
        "laplace_loss_weight": result["best_params"]["laplace_loss_weight"],
        "patience": result["best_params"]["patience"],
        "hidden_dim": result["best_params"]["hidden_dim"],
        "weight_decay":result["best_params"]["weight_decay"],
        "opt_loss": result["best_score"],
    }
    return (metrics, hyperparams)

In [None]:
metrics, hyperparams = run_single_experiment(dset_idx[IDX])

In [None]:
sinet_val_dset, sinet_reconstructed_dset, sinet_reconstructed_dense = (
    run_single_method(
        dset_idx[IDX],
        IDX,
        "sinet",
        lr=hyperparams["lr"],
        weight_decay=hyperparams["weight_decay"],
        num_epochs=hyperparams["num_epochs"],
        batch_size=hyperparams["batch_size"],
        num_workers=0,
        device="cuda",
        mse_loss_weight=hyperparams["mse_loss_weight"],
        hidden_dim=hyperparams["hidden_dim"],
    )
)

In [None]:
sinet_reconstructed_dense.plot()

In [None]:
train_dset = xr.open_dataset(
    DSET_PATH / f"ecad_obs_europe_train_{dset_idx[IDX]}.nc"
).cm.plot()

In [None]:
sinet_reconstructed_dset.plot()

In [None]:
cm.Comparison(sinet_val_dset, sinet_reconstructed_dset).compute_report()