# Analysis of Inverse Distance Weightining reconstruction method

**Author:** Jakub Walczak, PhD

In [9]:
import os
import sys
from pathlib import Path
import importlib
from functools import partial

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import xarray as xr
from bayes_opt import BayesianOptimization
import climatrix as cm
from climatrix.dataset.axis import Axis
from rich.console import Console

# NOTE: we add main experiment directory to PATH to enable to import easily
# adjust if needed
sys.path.insert(0, str(Path(os.path.abspath("")).parent.parent))
from src.hyperparam import find_hyperparameters

In [10]:
console = Console()

In [54]:
NAN_POLICY = "resample"

SEED = 1

# NOTE: paths might need adjustement depending on the directory where Jupyter is run
DSET_PATH = Path(os.path.abspath("")).parent.parent.joinpath("data")

OPTIM_INIT_POINTS: int = 50
OPTIM_N_ITERS: int = 10

SAMPLE_VERTICAL: int = 0

PLOT_DIR: str = (
    Path(os.path.abspath("")).parent.parent / "results" / "idw" / "plots"
)
PLOT_DIR.mkdir(parents=True, exist_ok=True)

METRICS_PATH: str = (
    Path(os.path.abspath("")).parent.parent / "results" / "idw" / "metrics.csv"
)

In [45]:
cm.seed_all(SEED)

In [46]:
def compute_criterion(train_dset, val_dset, **hparams) -> float:
    k_min = int(hparams["k_min"])
    k = int(hparams["k"])
    power = float(hparams["power"])
    if k_min > k:
        return -100
    recon_dset = train_dset.reconstruct(
        val_dset.domain,
        method="idw",
        k=k,
        power=power,
        k_min=k_min,
    )
    metrics = cm.Comparison(
        recon_dset, val_dset, map_nan_from_source=False
    ).compute_report()
    # NOTE: minus to force maximizing
    return -metrics["Max Abs Error"]

In [47]:
bounds = {
    "k": (1, 50),
    "power": (1e-7, 5.0),
    "k_min": (1, 40),
}

In [48]:
def optimize_hyperparameters(train_dset, val_dset):
    result = find_hyperparameters(
        train_dset,
        val_dset,
        compute_criterion,
        bounds=bounds,
        n_init_points=OPTIM_INIT_POINTS,
        n_iter=OPTIM_N_ITERS,
        seed=SEED,
        verbose=0,
    )

    optim_power = result["params"]["power"]
    optim_k = result["params"]["k"]
    optim_k_min = result["params"]["k_min"]

    return optim_power, optim_k, optim_k_min

In [49]:
dset_dates = [path.stem.split("_")[-1] for path in DSET_PATH.glob("*.nc")]

In [None]:
with console.status("[magenta]Preparing experiment...") as status:
    all_metrics = {}
    for d in dset_dates:
        status.update(
            f"[magenta]Processing date: {d}...", spinner="bouncingBall"
        )
        train_dset = xr.open_dataset(
            DSET_PATH / f"ecad_obs_europe_train_{d}.nc"
        ).cm
        val_dset = xr.open_dataset(
            DSET_PATH / f"ecad_obs_europe_val_{d}.nc"
        ).cm
        test_dset = xr.open_dataset(
            DSET_PATH / f"ecad_obs_europe_test_{d}.nc"
        ).cm
        status.update(
            f"[magenta]Optimizing hyper-parameters for date: {d}...",
            spinner="bouncingBall",
        )
        power, k, k_min = optimize_hyperparameters(train_dset, val_dset)
        status.update(
            "[magenta]Reconstructing with optimised parameters...",
            spinner="bouncingBall",
        )
        reconstructed_dset = train_dset.reconstruct(
            test_dset.domain,
            method="idw",
            k=k,
            power=power,
            k_min=k_min,
        )
        status.update(
            f"[magenta]Saving reconstructed dset to {PLOT_DIR / f"{d}_reconstructed.png"}...",
            spinner="bouncingBall",
        )
        reconstructed_dset.plot(False).get_figure().savefig(
            PLOT_DIR / f"{d}_reconstructed.png"
        )
        status.update(
            f"[magenta]Saving test dset to {PLOT_DIR / f"{d}_test.png"}...",
            spinner="bouncingBall",
        )
        test_dset.plot(False).get_figure().savefig(PLOT_DIR / f"{d}_test.png")
        status.update("[magenta]Evaluating...", spinner="bouncingBall")
        cmp = cm.Comparison(reconstructed_dset, test_dset)
        cmp.plot_diff().get_figure().savefig(PLOT_DIR / f"{d}_diffs.png")
        cmp.plot_signed_diff_hist().get_figure().savefig(
            PLOT_DIR / f"{d}_hist.png"
        )
        metrics = cmp.compute_report()
        all_metrics[d] = metrics
    status.update("[magenta]Saving quality metrics...", spinner="bouncingBall")
    pd.DataFrame(all_metrics).transpose().to_csv(METRICS_PATH)