In [5]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [100]:
from ealstm.gaip import FloodLSTM
from ealstm.gaip import FloodData
from ealstm.gaip.utils import MSE, NSE

from timecast.learners import AR
from timecast.optim import SGD

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [16]:
results = {}
mses = []
nses = []
for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    sgd = SGD(learning_rate=LR_AR, online=False)
    ar = AR(input_dim=AR_INPUT_DIM,
            output_dim=AR_OUTPUT_DIM,
            window_size=flood_data.cfg["seq_length"],
            optimizer=sgd,
            history=X[:flood_data.cfg["seq_length"]],
            fit_intercept=True,
            constrain=False
           )
    # NOTE: difference in indexing convention, so need to pad one row
    X_trimmed = np.vstack((X[flood_data.cfg["seq_length"]:], np.ones((1, X.shape[1]))))
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(ea_data[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = ar.predict_and_update(X_trimmed, Y_target, batch_size=1)
    
    Y_hat = Y_lstm + Y_ar
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    break
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.59074104 0.8800156 0.59074104 0.8800156


In [128]:
from flax import nn
from flax import optim

import jax
import jax.numpy as jnp

In [139]:
def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size + 1
    return np.swapaxes(np.stack([np.roll(X, shift=-(i + offset), axis=0) for i in range(window_size)]), 0, 1)[:num_windows]

class ARF(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history=None):
        
        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        self.history = self.state("history", (window_size, input_features), nn.initializers.zeros)
        
        if self.is_initializing():
            self.history.value = np.vstack((self.history.value, history))[history.shape[0]:]
        else:
            self.history.value = np.vstack((self.history.value, x))[x.shape[0]:]
        
        y = nn.DenseGeneral(inputs=self.history.value,
                            features=output_features,
                            axis=(0, 1),
                            batch_dims=(),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [140]:
with nn.stateful() as init_state:
    model_def = ARF.partial(input_features=32, output_features=1, window_size=270, history=X[:flood_data.cfg["seq_length"]-1])
    ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
    model = nn.Model(model_def, params)

In [121]:
optim_def = optim.GradientDescent(learning_rate=1e-5)
optimizer = optim_def.create(model)

In [27]:
X_t = X[flood_data.cfg["seq_length"]-1:]

In [180]:
result = []
for x, y in tqdm.tqdm(zip(X_t, Y_target), total=X_t.shape[0]):
    def loss_fn(model):
        y_hat = model(x[None, ...])
        return jnp.square(y - y_hat).mean(), y_hat
    with nn.stateful(init_state) as init_state:
        loss, y_hat, grad = optimizer.compute_gradients(loss_fn)
        optimizer = optimizer.apply_gradient(grad)
        result.append(y_hat)

HBox(children=(FloatProgress(value=0.0, max=3652.0), HTML(value='')))




In [182]:
Y_hat = Y_lstm + np.array(result)
MSE(Y_hat, Y)

DeviceArray(0.59074104, dtype=float32)

In [115]:
def test(optstate, xy):
    x, y = xy
    optimizer, state = optstate
    def loss_fn(model):
        y_hat = model(x)
        return jnp.square(y - y_hat).mean(), y_hat
    with nn.stateful(state) as state:
        loss, y_hat, grad = optimizer.compute_gradients(loss_fn)
        return (optimizer.apply_gradient(grad), state), y_hat

In [116]:
%%timeit
jax.lax.scan(test, (optimizer, init_state), (X_t, Y_target))

317 ms ± 34.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [117]:
optstate, y_arf = jax.lax.scan(test, (optimizer, init_state), (X_t, Y_target))

In [118]:
Y_hat = Y_lstm + y_arf
MSE(Y_hat, Y)

DeviceArray(4.7358737, dtype=float32)

In [95]:
def run(X, Y, optimizer, state, loss_fn):
    def _run(optstate, xy):
        x, y = xy
        optimizer, state = optstate
        def objective(model):
            y_hat = model(x)
            return loss_fn(y, y_hat), y_hat
        with nn.stateful(state) as state:
            loss, y_hat, grad = optimizer.compute_gradients(objective)
            return (optimizer.apply_gradient(grad), state), y_hat
    _, pred = jax.lax.scan(_run, (optimizer, state), (X, Y))
    return pred

In [97]:
y_arf = run(X_t, Y_target, optimizer, init_state, lambda x, y: jnp.square(x - y).mean())

In [98]:
Y_hat = Y_lstm + y_arf
MSE(Y_hat, Y)

DeviceArray(0.59074104, dtype=float32)

In [138]:
results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
        model_def = ARF.partial(input_features=32, output_features=1, window_size=270, history=X[:flood_data.cfg["seq_length"]-1])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = optim.GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_lstm = np.array(ea_data[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    Y_ar = run(X_t, Y_target, optimizer, state, lambda x, y: jnp.square(x-y).mean())
    
    Y_hat = Y_lstm + Y_ar
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.7921799 0.8391017 0.7921799 0.8391017
01031500 1.034634 0.88986236 0.91340697 0.86448205
01047000 1.2811567 0.87995404 1.0359902 0.8696394
01052500 1.9988087 0.847376 1.2766949 0.8640735
01054200 7.819855 0.73467207 2.585327 0.83819324
01055000 4.7343593 0.7325247 2.9434988 0.82058185
01057000 0.91459924 0.87551165 2.653656 0.8284289
01073000 0.9153604 0.86096424 2.436369 0.83249587
01078000 0.6731362 0.8878817 2.2404542 0.83864987
01123000 0.8993281 0.8173188 2.1063416 0.83651674
01134500 1.715929 0.79124093 2.0708494 0.83240074
01137500 2.3760467 0.76634103 2.0962827 0.8268958
01139000 0.80212957 0.7747908 1.9967325 0.8228877
01139800 0.84583026 0.76805705 1.9145252 0.8189712
01142500 1.1238301 0.7723832 1.8618122 0.81586534
01144000 0.72948605 0.85160327 1.7910419 0.81809896
01162500 0.7964086 0.84338397 1.7325339 0.8195863
01169000 2.2953029 0.7448952 1.763799 0.8154368
01170100 1.8372196 0.7727082 1.7676632 0.8131879
01181000 1.9368114 0.78878325 1.7761208 0.8119677
011

KeyboardInterrupt: 