In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as onp
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
basins = pickle.load(open("../data/basins.p", "rb"))

In [10]:
@tc.experiment("basin", basins)
@tc.experiment("lr", jnp.linspace(-8, -1, 29))
def runner(basin, lr):
    import pickle
    from ealstm.gaip.utils import MSE

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Index, PredictLast
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/test/{}.p".format(basin), "rb"))
    Y = pickle.load(open("../data/ealstm/{}.p".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Index.partial(index=0)
    pl = PredictLast.partial()
    pl = Sequential.partial(learners=[Index.partial(index=1), pl])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pl])
    
    optim_def = GradientDescent(learning_rate=(10 ** lr))
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "lr": lr, "mse": MSE(Y, Y_hat)}

In [11]:
results = runner.run(processes=10, tqdm=tqdm)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [12]:
for lr in tqdm.tqdm(jnp.linspace(-8, -1, 29)):
    print("Average MSE (lr={0:.10f}): {1:.2f}".format(lr, jnp.average(jnp.array([result["mse"] for result in results if result["lr"] == lr]))))

HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))

Average MSE (lr=-8.0000000000): 4.14
Average MSE (lr=-7.7500000000): 4.14
Average MSE (lr=-7.5000000000): 4.14
Average MSE (lr=-7.2500000000): 4.14
Average MSE (lr=-7.0000000000): 4.14
Average MSE (lr=-6.7500000000): 4.14
Average MSE (lr=-6.5000000000): 4.14
Average MSE (lr=-6.2500000000): 4.14
Average MSE (lr=-6.0000000000): 4.14
Average MSE (lr=-5.7500000000): 4.14
Average MSE (lr=-5.5000000000): 4.14
Average MSE (lr=-5.2500000000): 4.14
Average MSE (lr=-5.0000000000): 4.14
Average MSE (lr=-4.7500000000): 4.14
Average MSE (lr=-4.5000000000): 4.14
Average MSE (lr=-4.2500000000): 4.14
Average MSE (lr=-4.0000000000): 4.14
Average MSE (lr=-3.7500000000): 4.14
Average MSE (lr=-3.5000000000): 4.14
Average MSE (lr=-3.2500000000): 4.14
Average MSE (lr=-3.0000000000): 4.14
Average MSE (lr=-2.7500000000): 4.14
Average MSE (lr=-2.5000000000): 4.14
Average MSE (lr=-2.2500000000): 4.14
Average MSE (lr=-2.0000000000): 4.14
Average MSE (lr=-1.7500000000): 4.14
Average MSE (lr=-1.5000000000): 4.14
A