In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [3]:
from flax import nn
from flax import optim

import jax
import jax.numpy as jnp

In [4]:
class ARF(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history=None):
        
        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        self.history = self.state("history", shape=(window_size, input_features), initializer=nn.initializers.zeros)
        
        if self.is_initializing() and history is not None:
            self.history.value = np.vstack((self.history.value, history))[history.shape[0]:]
        else:
            self.history.value = np.vstack((self.history.value, x))[x.shape[0]:]
        
        y = nn.DenseGeneral(inputs=self.history.value,
                            features=output_features,
                            axis=(0, 1),
                            batch_dims=(),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [5]:
class Take(nn.Module):
    def apply(self, x, i):
        return x[i]

In [6]:
class Identity(nn.Module):
    def apply(self, x):
        return x

In [7]:
class Plus(nn.Module):
    def apply(self, x, z):
        return x + z

In [8]:
class Ensemble(nn.Module):
    def apply(self, x, modules, args):
        return [module(x, **arg) for (module, arg) in zip(modules, args)]

In [9]:
class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

In [10]:
class Residual(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history):
        y_arf = ARF(x=x[1],
                    input_features=input_features,
                    output_features=output_features,
                    window_size=window_size,
                    history=history
                   )
        
        return (x[0], y_arf)

In [17]:
from functools import partial, reduce

def residual(x, y, loss_fn, model):
    y_hats = model(x)
    loss, target, y_hat = 0, y, y_hats[0]
    for i in range(len(y_hats) - 1):
        loss += loss_fn(target - y_hats[i], y_hats[i + 1])
        target -= y_hats[i]
        y_hat += y_hats[i + 1]
    return loss, y_hat

def run(X, Y, optimizer, state, objective, loss_fn):
    def _run(optstate, xy):
        x, y = xy
        optimizer, state = optstate
        with nn.stateful(state) as state:
            loss, y_hat, grad = optimizer.compute_gradients(partial(objective, x, y, loss_fn))
            return (optimizer.apply_gradient(grad), state), y_hat
    _, pred = jax.lax.scan(_run, (optimizer, state), (X, Y))
    return pred

In [None]:
def xboost(x, y, loss_fn, model):
    y_hats = model(x)
    us, u = [], 0
    for i in range(len(y_hats)):
        eta = (2 / (i + 2))
        u = (1 - eta) * u + eta * y_hats[i]
        us.append(u)
    return loss, y_hat

In [16]:
results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
        model_def = Residual.partial(input_features=32,
                                     output_features=1,
                                     window_size=270,
                                     history=X[:flood_data.cfg["seq_length"]-1]
                                    )
        lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
        arf = Sequential.partial(modules=[Take, ARF], args=[{"i": 1}, {"input_features": 32, "output_features": 1, "window_size": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
        model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = optim.GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    Y_lstm = np.array(ea_data[basin].qsim)
    
    Y_hat = run((Y_lstm, X_t), Y, optimizer, state, residual, lambda x, y: jnp.square(x-y).mean())
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.7921799 0.8391017 0.7921799 0.8391017
01031500 1.034634 0.88986236 0.91340697 0.86448205
01047000 1.2811565 0.87995404 1.0359901 0.8696394
01052500 1.9988087 0.847376 1.2766948 0.8640735
01054200 7.8198547 0.7346721 2.5853267 0.83819324
01055000 4.73436 0.73252463 2.9434988 0.82058173
01057000 0.91459924 0.87551165 2.653656 0.82842886
01073000 0.9153604 0.86096424 2.436369 0.8324958
01078000 0.6731362 0.8878817 2.2404542 0.8386498
01123000 0.8993281 0.8173188 2.1063416 0.8365167
01134500 1.715929 0.79124093 2.0708494 0.8324007
01137500 2.376047 0.766341 2.0962827 0.8268957
01139000 0.80212957 0.7747908 1.9967325 0.82288766
01139800 0.84583026 0.76805705 1.9145252 0.81897116
01142500 1.12383 0.7723832 1.8618122 0.81586534


KeyboardInterrupt: 

In [78]:
np.array(ea_data[basin].qsim)[0]

DeviceArray(0.24191774, dtype=float32)