In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [3]:
from flax import nn
from flax import optim

import jax
import jax.numpy as jnp

In [4]:
class ARF(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history=None):
        
        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        self.history = self.state("history", shape=(window_size, input_features), initializer=nn.initializers.zeros)
        
        if self.is_initializing() and history is not None:
            self.history.value = np.vstack((self.history.value, history))[history.shape[0]:]
        else:
            self.history.value = np.vstack((self.history.value, x))[x.shape[0]:]
        
        y = nn.DenseGeneral(inputs=self.history.value,
                            features=output_features,
                            axis=(0, 1),
                            batch_dims=(),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [5]:
class Take(nn.Module):
    def apply(self, x, i):
        return x[i]

In [6]:
class Identity(nn.Module):
    def apply(self, x):
        return x

In [7]:
class Plus(nn.Module):
    def apply(self, x, z):
        return x + z

In [8]:
class Ensemble(nn.Module):
    def apply(self, x, modules, args):
        return [module(x, **arg) for (module, arg) in zip(modules, args)]

In [9]:
class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

In [10]:
class Residual(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history):
        y_arf = ARF(x=x[1],
                    input_features=input_features,
                    output_features=output_features,
                    window_size=window_size,
                    history=history
                   )
        return (x[0], y_arf)

In [60]:
def run(X, Y, optimizer, state, objective, loss_fn):
    def _run(optstate, xy):
        x, y = xy
        optimizer, state = optstate
        with nn.stateful(state) as state:
            loss, y_hat, grad = optimizer.compute_gradients(partial(objective, x, y, loss_fn))
            return (optimizer.apply_gradient(grad), state), y_hat
    _, pred = jax.lax.scan(_run, (optimizer, state), (X, Y))
    return pred

In [61]:
from functools import partial

def residual(x, y, loss_fn, model):
    y_hats = model(x)
    target, y_hat, loss = y, y_hats[0], 0
    for i in range(len(y_hats) - 1):
        loss += loss_fn(target - y_hats[i], y_hats[i + 1])
        target -= y_hats[i]
        y_hat += y_hats[i + 1]
    return loss, y_hat

In [62]:
def xboost(x, y, loss_fn, model, reg = 1.0):
    y_hats = model(x)
    g = jax.grad(loss_fn)
    u, loss = 0, 0
    for i in range(len(y_hats)):
        eta = (2 / (i + 2))
        loss += g(y, u) * y_hats[i] + (reg / 2) * y_hats[i] * y_hats[i]
        u = (1 - eta) * u + eta * y_hats[i]
    return loss.reshape(()), u

In [59]:
results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
        model_def = Residual.partial(input_features=32,
                                     output_features=1,
                                     window_size=270,
                                     history=X[:flood_data.cfg["seq_length"]-1]
                                    )
        lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
        arf = Sequential.partial(modules=[Take, ARF], args=[{"i": 1}, {"input_features": 32, "output_features": 1, "window_size": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
        model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = optim.GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = np.array(ea_data[basin].qsim)
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_hat = run((Y_lstm, X_t), Y, optimizer, state, xboost, lambda x, y: jnp.square(x-y).mean())
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 5.523398 -0.12184799 5.523398 -0.12184799
01031500 7.231548 0.23019576 6.377473 0.054173887
01047000 10.041595 0.059090078 7.598847 0.055812616
01052500 10.929575 0.16544527 8.431529 0.08322078
01054200 26.1279 0.11347973 11.970803 0.08927257
01055000 14.437254 0.184344 12.381878 0.105117805
01057000 6.3041797 0.14192247 11.513635 0.11037562
01073000 6.3879013 0.029729962 10.872918 0.10029491
01078000 6.72403 -0.11996174 10.41193 0.07582195
01123000 4.9362335 -0.0027009249 9.864361 0.067969665
01134500 8.194088 0.0031113029 9.712518 0.062073447
01137500 8.231464 0.19052285 9.589096 0.07277756
01139000 4.437037 -0.24576068 9.192783 0.04827462
01139800 5.254281 -0.4408251 8.911462 0.013338928
01142500 5.561816 -0.12647164 8.688152 0.0040182234
01144000 4.6778274 0.048406124 8.437507 0.006792467
01162500 5.636065 -0.10834825 8.272716 1.9483707e-05
01169000 6.119 0.3199215 8.153065 0.017791817
01170100 6.779939 0.16121906 8.080795 0.02534062
01181000 7.52582 0.1792804 8.053046 0.0

KeyboardInterrupt: 

In [None]:
np.array(ea_data[basin].qsim)[0]