In [22]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Training

In [26]:
model_paths = !find ../data/models/runs | grep lstm

In [27]:
model_paths

['../data/models/runs/run_1906_1007_seed888/ealstm_seed888.p',
 '../data/models/runs/run_1906_2336_seed333/ealstm_seed333.p',
 '../data/models/runs/run_1906_1010_seed555/lstm_seed555.p',
 '../data/models/runs/run_1606_0922_seed222/lstm_no_static_seed222.p',
 '../data/models/runs/run_1906_1009_seed222/lstm_seed222.p',
 '../data/models/runs/run_1606_0923_seed666/lstm_no_static_seed666.p',
 '../data/models/runs/run_1606_0923_seed777/lstm_no_static_seed777.p',
 '../data/models/runs/run_1906_1005_seed333/ealstm_seed333.p',
 '../data/models/runs/run_2006_0033_seed777/lstm_seed777.p',
 '../data/models/runs/run_1906_1009_seed111/lstm_seed111.p',
 '../data/models/runs/run_1606_0923_seed888/lstm_no_static_seed888.p',
 '../data/models/runs/run_1606_0922_seed111/lstm_no_static_seed111.p',
 '../data/models/runs/run_1906_2337_seed777/ealstm_seed777.p',
 '../data/models/runs/run_2006_0032_seed666/lstm_seed666.p',
 '../data/models/runs/run_1906_1010_seed444/lstm_seed444.p',
 '../data/models/runs/run_1

# Evaluation

In [48]:
from ealstm.gaip.utils import MSE, NSE

In [50]:
def total_se(results, setype):
    se = 0
    count = 0
    for key in results[setype].keys():
        se += results[setype][key] * results["count"][key]
        count += results["count"][key]
    return se / float(count)

In [64]:
models = {}
for model_path in tqdm.tqdm(model_paths):
    name = model_path.split("/")[-2]
    models[name] = {"mse": {}, "nse": {}, "count": {}, "path": model_path}
    
    data = pickle.load(open(model_path, "rb"))
    for key, site in data.items():
        models[name]["mse"][key] = MSE(site.qobs, site.qsim)
        models[name]["nse"][key] = NSE(site.qobs, site.qsim)
        models[name]["count"][key] = site.shape[0]
    models[name]["total_mse"] = total_se(models[name], "mse")
    models[name]["total_nse"] = total_se(models[name], "nse")

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))




In [65]:
{key: model["total_mse"] for key, model in models.items()}

{'run_1906_1007_seed888': 3.3638509750782526,
 'run_1906_2336_seed333': 3.337346505673905,
 'run_1906_1010_seed555': 3.237445282285454,
 'run_1606_0922_seed222': 4.476296932159444,
 'run_1906_1009_seed222': 3.2220703187962276,
 'run_1606_0923_seed666': 4.499379957060465,
 'run_1606_0923_seed777': 4.624725695014974,
 'run_1906_1005_seed333': 3.5001009681127004,
 'run_2006_0033_seed777': 3.2339230156853978,
 'run_1906_1009_seed111': 3.2352868484148094,
 'run_1606_0923_seed888': 4.473549428796491,
 'run_1606_0922_seed111': 4.433866150819652,
 'run_1906_2337_seed777': 3.3147425884061694,
 'run_2006_0032_seed666': 3.1949645419398265,
 'run_1906_1010_seed444': 3.2289923281497943,
 'run_1906_2337_seed555': 3.282823277819134,
 'run_1606_0923_seed555': 4.652720299467179,
 'run_2006_0031_seed111': 3.2397084946905363,
 'run_1906_1007_seed777': 3.473512175578632,
 'run_2006_0032_seed333': 3.1932420165053017,
 'run_1906_1005_seed222': 3.3914191445835082,
 'run_1906_1010_seed666': 3.2351552555689036

In [66]:
{key: model["total_nse"] for key, model in models.items()}

{'run_1906_1007_seed888': 0.6789186889733216,
 'run_1906_2336_seed333': 0.6451087885394113,
 'run_1906_1010_seed555': 0.6834140732742268,
 'run_1606_0922_seed222': 0.29045259863318473,
 'run_1906_1009_seed222': 0.6995414040542303,
 'run_1606_0923_seed666': 0.26056844004656776,
 'run_1606_0923_seed777': 0.11695495061591926,
 'run_1906_1005_seed333': 0.6732774716033985,
 'run_2006_0033_seed777': 0.6410610093148134,
 'run_1906_1009_seed111': 0.6922506589353086,
 'run_1606_0923_seed888': 0.23367518333224074,
 'run_1606_0922_seed111': 0.26169006489219654,
 'run_1906_2337_seed777': 0.6450881751087002,
 'run_2006_0032_seed666': 0.6678921865006313,
 'run_1906_1010_seed444': 0.6854652297749018,
 'run_1906_2337_seed555': 0.6471139040089225,
 'run_1606_0923_seed555': 0.22364850467499747,
 'run_2006_0031_seed111': 0.65244279711988,
 'run_1906_1007_seed777': 0.671839267064583,
 'run_2006_0032_seed333': 0.6796908169889062,
 'run_1906_1005_seed222': 0.6793343678122624,
 'run_1906_1010_seed666': 0.694

In [67]:
models["run_2006_0032_seed444"]["path"]

'../data/models/runs/run_2006_0032_seed444/lstm_seed444.p'