In [7]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

%matplotlib notebook

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
@tc.experiment("beta", [0.999, 0.994])
@tc.experiment("lr", [1e-6, 1e-5, 1e-4])
def beta_lr(beta, lr):
    
    import pickle
    from ealstm.gaip.flood_data import FloodData
    from ealstm.gaip.utils import MSE, NSE
    from flax import nn

    import jax
    import jax.numpy as jnp

    from timecast.learners import Sequential, Ensemble, AR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import RMSProp
    
    class Identity(nn.Module):
        def apply(self, x):
            return x
    class Take(nn.Module):
        def apply(self, x, i):
            return x[i]
    
    cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
    ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
    flood_data = FloodData(cfg_path)
    
    results = {}
    mses = []
    nses = []

    for X, y, basin in flood_data.generator():
        with nn.stateful() as state:
            lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
            arf = Sequential.partial(modules=[Take, AR], args=[{"i": 1}, {"output_features": 1, "history_len": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
            model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
            ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
            model = nn.Model(model_def, params)
        optim_def = RMSProp(learning_rate=lr, beta2=beta)
        optimizer = optim_def.create(model)

        # NOTE: difference in indexing convention, so need to pad one row
        X_t = X[flood_data.cfg["seq_length"]-1:]
        Y_lstm = jnp.array(ea_data[basin].qsim)
        Y = jnp.array(ea_data[basin].qobs).reshape(-1, 1)

        Y_hat = smap((Y_lstm, X_t), Y, optimizer, state, residual, lambda x, y: jnp.square(x-y).mean())

        mse = MSE(Y, Y_hat)
        nse = NSE(Y, Y_hat)
        mses.append(mse)
        nses.append(nse)
        
        results[basin] = {
            "mse": mse,
            "nse": nse,
            "count": X.shape[0],
            "avg_mse": jnp.mean(jnp.array(mses)),
            "avg_nse": jnp.mean(jnp.array(nses))
        }
        break
    return results

In [18]:
a = beta_lr.run(processes=6, tqdm=tqdm)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




In [19]:
a

[{'01022500': {'mse': array(0.99373645, dtype=float32),
   'nse': array(0.7981639, dtype=float32),
   'count': 3921,
   'avg_mse': array(0.99373645, dtype=float32),
   'avg_nse': array(0.7981639, dtype=float32)}},
 {'01022500': {'mse': array(0.82482624, dtype=float32),
   'nse': array(0.83247095, dtype=float32),
   'count': 3921,
   'avg_mse': array(0.82482624, dtype=float32),
   'avg_nse': array(0.83247095, dtype=float32)}},
 {'01022500': {'mse': array(0.5105229, dtype=float32),
   'nse': array(0.89630854, dtype=float32),
   'count': 3921,
   'avg_mse': array(0.5105229, dtype=float32),
   'avg_nse': array(0.89630854, dtype=float32)}},
 {'01022500': {'mse': array(0.9982276, dtype=float32),
   'nse': array(0.7972517, dtype=float32),
   'count': 3921,
   'avg_mse': array(0.9982276, dtype=float32),
   'avg_nse': array(0.7972517, dtype=float32)}},
 {'01022500': {'mse': array(0.8596345, dtype=float32),
   'nse': array(0.82540107, dtype=float32),
   'count': 3921,
   'avg_mse': array(0.85963