In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

In [2]:
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [3]:
from flax import nn
from flax import optim

import jax
import jax.numpy as jnp

from timecast.learners import Sequential, Ensemble, AR
from timecast import smap
from timecast.objectives import residual

In [4]:
class Identity(nn.Module):
    def apply(self, x):
        return x
class Take(nn.Module):
    def apply(self, x, i):
        return x[i]

In [5]:
results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
        lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
        
        arf = Sequential.partial(modules=[Take, AR], args=[{"i": 1}, {"output_features": 1, "history_len": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
        
        model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = optim.GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = np.array(ea_data[basin].qsim)
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_hat = smap((Y_lstm, X_t), Y, optimizer, state, residual, lambda x, y: jnp.square(x-y).mean())
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))



01022500 0.7921799 0.8391017 0.7921799 0.8391017
01031500 1.034634 0.88986236 0.91340697 0.86448205
01047000 1.2811565 0.87995404 1.0359901 0.8696394
01052500 1.9988087 0.847376 1.2766948 0.8640735
01054200 7.8198547 0.7346721 2.5853267 0.83819324
01055000 4.73436 0.73252463 2.9434988 0.82058173


KeyboardInterrupt: 

In [None]:
issubclass(np.array(1), onp.ndarray)