# Evaluation

In [4]:
from torch.utils.data import DataLoader # type: ignore
import torch.optim as optim # type: ignore
from torch.optim.lr_scheduler import ReduceLROnPlateau # type: ignore
from torch import nn # type: ignore

from hython.models.cudnnLSTM import CuDNNLSTM
from hython.datasets.datasets import get_dataset
from hython.sampler import *
from hython.normalizer import Normalizer
from hython.trainer import *
from hython.utils import read_from_zarr, missing_location_idx, set_seed, prepare_for_plotting
from hython.evaluator import predict
from hython.trainer import train_val

import matplotlib.pyplot as plt
from hython.viz import map_kge, map_bias, map_pbias, map_pearson, map_rmse, map_at_timesteps, ts_compare, plot_sampler, compute_kge_parallel, ts_plot, map_points

  from pandas.core.computation.check import NUMEXPR_INSTALLED
ERROR 1: PROJ: proj_create_from_database: Open of /home/iferrario/.local/miniforge/envs/emulator/share/proj failed


# Settings

In [5]:
## inputs

# wflow model name, i.e. surrogate input file produced by the preprocessing application

file_surr_input = "https://eurac-eo.s3.amazonaws.com/INTERTWIN/SURROGATE_INPUT/adg1km_eobs_preprocessed.zarr/"

surr_model =  "test.pt"
experiment =  "test"  


dir_wflow_model = "adg1km_eobs"
file_target = "run_default/output.nc"

# input directory 
dir_surr_input = "/mnt/CEPH_PROJECTS/InterTwin/hydrologic_data/surrogate_input"
dir_surr_model = "/home/iferrario/dev/itwinai/use-cases/eurac/tmp"
dir_wflow_input = "/mnt/CEPH_PROJECTS/InterTwin/Wflow/models"

## outputs

# directory to save the statistics computed during the normalization
dir_stats_output = "/home/iferrario/dev/itwinai/use-cases/eurac/tmp"

# === FILTER ==============================================================

# select temporal range
train_temporal_range = ["2016-01-01", "2018-12-31"] 
valid_temporal_range = ["2019-01-01", "2020-12-31"]

# select variable names
dynamic_names = ["precip", "pet", "temp"]
static_names = [ "wflow_dem", "Slope", "wflow_uparea"]  #[ 'thetaS', 'thetaR', 'RootingDepth', 'Swood','KsatVer', "Sl"] 
target_names = [ "runoff_river"]

# === MASK ========================================================================================

mask_names = ["mask_missing", "mask_lake"] # names depends on preprocessing application

# == MODEL  ========================================================================================

# DL model hyper parameters
HIDDEN_SIZE = 24
DYNAMIC_INPUT_SIZE = len(dynamic_names)
STATIC_INPUT_SIZE = len(static_names)
OUTPUT_SIZE = len(target_names)
TARGET_WEIGHTS = {t:0.5 for t in target_names}

BATCH = 256

# === METRICS =====================================================================================

metrics = { 
  "vwc": ["rmse", "kge",  "pbias"],
  "actevap": ["rmse", "kge",  "pbias"]
}

In [6]:

file_surr_model = f"{dir_surr_model}/{surr_model}"

file_wflow_target = f"{dir_wflow_input}/{dir_wflow_model}/{file_target}"


In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

train_temporal_range = slice(*train_temporal_range)
valid_temporal_range = slice(*valid_temporal_range)

In [8]:
# === READ TRAIN ============================================================= 

Xs = read_from_zarr(url=file_surr_input, group="xs", multi_index="gridcell").xs.sel(
     feat=static_names
 )


# === READ TEST ============================================================= 

Xd_test = (
    read_from_zarr(url=file_surr_input, group="xd", multi_index="gridcell")
    .sel(time=valid_temporal_range)
    .xd.sel(feat=dynamic_names)
)
Y_test = (
    read_from_zarr(url=file_surr_input, group="y", multi_index="gridcell")
    .sel(time=valid_temporal_range)
    .y.sel(feat=target_names)
)

In [9]:
# === MASK ============================================================= 
masks = (
    read_from_zarr(url=file_surr_input, group="mask")
    .mask.sel(mask_layer=mask_names)
    .any(dim="mask_layer")
)


In [10]:
# === NORMALIZE ============================================================= 

normalizer_dynamic = Normalizer(method="standardize", type="spacetime", axis_order="NTC")

normalizer_static = Normalizer(method="standardize", type="space", axis_order="NTC")

normalizer_target = Normalizer(method="standardize", type="spacetime", axis_order="NTC")


normalizer_dynamic.read_stats(f"{dir_stats_output}/xd.npy")
normalizer_static.read_stats(f"{dir_stats_output}/xs.npy")
normalizer_target.read_stats(f"{dir_stats_output}/y.npy")


Xd_test = normalizer_dynamic.normalize(Xd_test)
Xs = normalizer_static.normalize(Xs)
Y_test = normalizer_target.normalize(Y_test)

read from /home/iferrario/dev/itwinai/use-cases/eurac/tmp/xd.npy
read from /home/iferrario/dev/itwinai/use-cases/eurac/tmp/xs.npy
read from /home/iferrario/dev/itwinai/use-cases/eurac/tmp/y.npy


ValueError: operands could not be broadcast together with shapes (40140,3) (1,6) 

In [None]:
# ==== MODEL ============================================================================

model = CuDNNLSTM(
                  hidden_size=HIDDEN_SIZE, 
                  dynamic_input_size=DYNAMIC_INPUT_SIZE,
                  static_input_size=STATIC_INPUT_SIZE, 
                  output_size=OUTPUT_SIZE
)

model.to(device)

# model load precomputed weights 
print(f"loading model {file_surr_model}")
model.load_state_dict(torch.load(file_surr_model))

In [None]:
# === PREDICT =============================================================================

ds_target = xr.open_dataset(file_wflow_target, chunks= {"time":200}).isel(lat=slice(None, None, -1)).sel(layer=1, drop=True)

lat, lon, time = len(masks.lat),len(masks.lon), Xd_test.shape[1]

y_pred = predict(Xd_test.values, Xs.values, model, BATCH, device)


y_pred = normalizer_target.denormalize(y_pred)

Y_test = normalizer_target.denormalize( Y_test)

In [None]:
# === EVALUATE ==============================================================================================

for iv, var in enumerate(target_names):
    print(var)
    metrics_var = metrics.copy().pop(var)

    y_target_plot, y_pred_plot = prepare_for_plotting(y_target= Y_test[:,:,[iv]].values,
                                                y_pred = y_pred[:,:,[iv]], 
                                                shape = (lat, lon, time), 
                                                coords  = ds_target.sel(time=valid_temporal_range).coords)

    y_target_plot= y_target_plot.where(~masks.values[...,None])
    y_pred_plot = y_pred_plot.where(~masks.values[...,None])
    
    ts_compare(y_target_plot, y_pred_plot, lat = [46.4], lon = [11.4])    
    
    for metric in metrics_var:
        print(metric)
        if "rmse" in metric:
            fig, ax, rmse = map_rmse(y_target_plot, y_pred_plot, unit = f"{var} (mm)", figsize = (8, 8), return_rmse=True, title=f"{var} {metric}")
        elif "kge" in metric:
            fig, ax, kge = map_kge(y_target_plot, y_pred_plot, figsize = (8, 8), return_kge =True, kwargs_imshow={"vmin":-0.5, "vmax":1},
            ticks = np.linspace(-0.5, 1, 16), title=f"{var} {metric}")
        elif "pbias" in metric:
            fig, ax, pbias = map_pbias(y_target_plot, y_pred_plot, figsize = (8, 8), return_pbias=True, kwargs_imshow={"vmin":-100, "vmax":100}, 
                ticks = [l*10 for l in range(-10,11, 1)], title=f"{var} {metric}")
        else:
            print(f"{metric} not found")
            continue