# Analysis of ODE-LSTM Results

This notebook creates the comparison between ODE-LSTM and MTS-LSTM from the paper. 
To reproduce the contents of this notebook, you need to download the models' predictions (or create them yourself) into the folder `BASE_DIR`

`README.md` contains information on where to obtain the required data.

In [1]:
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
from scipy.stats import wilcoxon
from tqdm.notebook import tqdm

from neuralhydrology.evaluation.metrics import calculate_metrics

BASE_DIR = Path('/home/mgauch/mts-lstm/results/odelstm')

basins = ['01022500', '02064000', '02374500', '05593575', '06404000', '06889500', '08190000', '09352900', '11481200', '12189500']
metric = 'NSE'

## Preparation
### Load predictions and metrics for each model ensemble

In [2]:
# (a) train on 1D+12H, evaluate on 1H (dividing 12H-predictions by 12)
# (b) train on 1H+3H, evaluate on 1D (aggregating every 8 3H-predictions)
# (c) train on 1H+1D, evaluate on 1H+1D
a_mtslstm, b_mtslstm = {}, {}
a_odelstm, b_odelstm = {}, {}
for b in basins:
    # MTS-LSTM predictions (single-basin)
    a_mtslstm[b] = pickle.load(open(BASE_DIR / f'ensemble_mtslstm_a_{b}.p', 'rb'))[b]
    b_mtslstm[b] = pickle.load(open(BASE_DIR / f'ensemble_mtslstm_b_{b}.p', 'rb'))[b]

    # ODE-LSTM (single-basin)
    a_odelstm[b] = pickle.load(open(BASE_DIR / f'ensemble_odelstm_a_{b}.p', 'rb'))[b]
    b_odelstm[b] = pickle.load(open(BASE_DIR / f'ensemble_odelstm_b_{b}.p', 'rb'))[b]

### (Dis-)aggregate MTS-LSTM predictions to missing timescales and calculate metrics

In [3]:
for basin in tqdm(basins):
    a_mtslstm[basin]['1H']['xr'] = a_mtslstm[basin]['12H']['xr'].resample({'datetime': '1H'}).ffill()
    b_mtslstm[basin]['1D']['xr'] = b_mtslstm[basin]['3H']['xr'].resample({'datetime': '1D'}).mean()
    a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs'] = b_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs']
    b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs'] = a_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs']

    a_mtslstm[basin]['1H'][f'{metric}_1H'] = calculate_metrics(a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_obs'],
                                                               a_mtslstm[basin]['1H']['xr']['qobs_mm_per_hour_sim'],
                                                               [metric], resolution='1H')[metric]
    b_mtslstm[basin]['1D'][f'{metric}_1D'] = calculate_metrics(b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_obs'],
                                                               b_mtslstm[basin]['1D']['xr']['qobs_mm_per_hour_sim'],
                                                               [metric], resolution='1D')[metric]

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

  return np.nanmean(a, axis=axis, dtype=dtype)





In [4]:
def to_df(dict_mtslstm, dict_odelstm):
    df = pd.DataFrame.from_dict({**{('MTS-LSTM', basin): {f'{metric}_{f}': dict_mtslstm[basin][f][f'{metric}_{f}']
                                                          for f in dict_mtslstm[basin]} for basin in basins}, 
                                 **{('ODE-LSTM', basin): {f'{metric}_{f}': dict_odelstm[basin][f][f'{metric}_{f}']
                                                          for f in dict_odelstm[basin]} for basin in basins}},
                               orient='index')
    return df

a_df = to_df(a_mtslstm, a_odelstm)
b_df = to_df(b_mtslstm, b_odelstm)

In [5]:
print('(A) Train on 1D, 12H. Evaluate on 1H. Medians:')
display(a_df.median(axis=0, level=0))
print('   Means')
display(a_df.mean(axis=0, level=0))

print('(B) Train on 1H, 3H. Evaluate on 1D. Medians:')
display(b_df.median(axis=0, level=0))
print('   Means')
display(b_df.mean(axis=0, level=0))

(A) Train on 1D, 12H. Evaluate on 1H. Medians:


Unnamed: 0,NSE_1D,NSE_12H,NSE_1H
MTS-LSTM,0.726355,0.734082,0.705742
ODE-LSTM,0.719632,0.705877,0.63906


   Means


Unnamed: 0,NSE_1D,NSE_12H,NSE_1H
MTS-LSTM,0.664499,0.67229,0.633878
ODE-LSTM,0.65077,0.638408,0.591831


(B) Train on 1H, 3H. Evaluate on 1D. Medians:


Unnamed: 0,NSE_1H,NSE_3H,NSE_1D
MTS-LSTM,0.700421,0.727794,0.745852
ODE-LSTM,0.677459,0.674996,0.586775


   Means


Unnamed: 0,NSE_1H,NSE_3H,NSE_1D
MTS-LSTM,0.633374,0.672315,0.718217
ODE-LSTM,0.585862,0.592676,0.54601
