In [8]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [10]:
from flax import nn
from flax import optim

import jax
import jax.numpy as jnp

In [77]:
class ARF(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history=None):

        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        if self.is_initializing() and history is not None:
            x = history
            
        self.history = self.state("history", shape=(window_size, input_features), initializer=nn.initializers.zeros)
        self.history.value = np.vstack((self.history.value, x))[x.shape[0]:]
        
        y = nn.DenseGeneral(inputs=self.history.value,
                            features=output_features,
                            axis=(0, 1),
                            batch_dims=(),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [78]:
class Take(nn.Module):
    def apply(self, x, i):
        return x[i]

In [79]:
class Identity(nn.Module):
    def apply(self, x):
        return x

In [80]:
class Plus(nn.Module):
    def apply(self, x, z):
        return x + z

In [81]:
class Ensemble(nn.Module):
    def apply(self, x, modules, args):
        return [module(x, **arg) for (module, arg) in zip(modules, args)]

In [82]:
class Sequential(nn.Module):
    def apply(self, x, modules, args):
        results = x
        for module, arg in zip(modules, args):
            results = module(results, **arg)
        return results

In [83]:
class Residual(nn.Module):
    def apply(self, x, input_features, output_features, window_size, history):
        y_arf = ARF(x=x[1],
                    input_features=input_features,
                    output_features=output_features,
                    window_size=window_size,
                    history=history
                   )
        return (x[0], y_arf)

In [84]:
def run(X, Y, optimizer, state, objective, loss_fn):
    def _run(optstate, xy):
        x, y = xy
        optimizer, state = optstate
        with nn.stateful(state) as state:
            loss, y_hat, grad = optimizer.compute_gradients(partial(objective, x, y, loss_fn))
            return (optimizer.apply_gradient(grad), state), y_hat
    _, pred = jax.lax.scan(_run, (optimizer, state), (X, Y))
    return pred

In [85]:
from functools import partial

def residual(x, y, loss_fn, model):
    y_hats = model(x)
    target, y_hat, loss = y, y_hats[0], 0
    for i in range(len(y_hats) - 1):
        loss += loss_fn(target - y_hats[i], y_hats[i + 1])
        target -= y_hats[i]
        y_hat += y_hats[i + 1]
    return loss, y_hat

In [86]:
def xboost(x, y, loss_fn, model, reg = 1.0):
    y_hats = model(x)
    g = jax.grad(loss_fn)
    u, loss = 0, 0
    for i in range(len(y_hats)):
        eta = (2 / (i + 2))
        loss += g(y, u) * y_hats[i] + (reg / 2) * y_hats[i] * y_hats[i]
        u = (1 - eta) * u + eta * y_hats[i]
    return loss.reshape(()), u

In [87]:
results = {}
mses = []
nses = []

for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
#         model_def = Residual.partial(input_features=32,
#                                      output_features=1,
#                                      window_size=270,
#                                      history=X[:flood_data.cfg["seq_length"]-1]
#                                     )
        lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
        
        arf = Sequential.partial(modules=[Take, ARF], args=[{"i": 1}, {"input_features": 32, "output_features": 1, "window_size": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
        
        model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = optim.GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = np.array(ea_data[basin].qsim)
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_hat = run((Y_lstm, X_t), Y, optimizer, state, residual, lambda x, y: jnp.square(x-y).mean())
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

01022500 0.7921799 0.8391017 0.7921799 0.8391017
01031500 1.034634 0.88986236 0.91340697 0.86448205
01047000 1.2811565 0.87995404 1.0359901 0.8696394
01052500 1.9988087 0.847376 1.2766948 0.8640735
01054200 7.8198547 0.7346721 2.5853267 0.83819324
01055000 4.73436 0.73252463 2.9434988 0.82058173
01057000 0.91459924 0.87551165 2.653656 0.82842886
01073000 0.9153604 0.86096424 2.436369 0.8324958
01078000 0.6731362 0.8878817 2.2404542 0.8386498
01123000 0.8993281 0.8173188 2.1063416 0.8365167
01134500 1.715929 0.79124093 2.0708494 0.8324007
01137500 2.376047 0.766341 2.0962827 0.8268957
01139000 0.80212957 0.7747908 1.9967325 0.82288766
01139800 0.84583026 0.76805705 1.9145252 0.81897116
01142500 1.12383 0.7723832 1.8618122 0.81586534
01144000 0.729486 0.85160327 1.7910419 0.81809896
01162500 0.7964086 0.84338397 1.7325339 0.8195863
01169000 2.2953029 0.7448952 1.763799 0.8154368
01170100 1.8372197 0.7727082 1.7676632 0.8131879
01181000 1.9368113 0.78878325 1.7761205 0.8119677
01187300 2.

02479155 6.0705748 0.6767037 1.9745606 0.7358027
02479300 1.6803375 0.8036434 1.9728094 0.73620653
02479560 1.3658031 0.84836113 1.9692177 0.7368702
02481000 6.3722224 0.74335706 1.9951175 0.7369083
02481510 2.4130614 0.80391026 1.9975618 0.73730016
04015330 2.252899 0.66174245 1.9990463 0.7368609
04024430 0.4891839 0.7830616 1.9903187 0.73712784
04027000 0.27036124 0.887639 1.9804341 0.7379929
04040500 0.33802435 0.8444825 1.9710487 0.73860145
04043050 0.81265485 0.81663686 1.9644668 0.7390448
04045500 0.024126073 0.97218746 1.9535044 0.74036205
04057510 0.03834418 0.9494198 1.9427452 0.7415365
04057800 0.47224835 0.7852626 1.9345303 0.7417807
04059500 0.11727201 0.8931821 1.9244344 0.74262184
04063700 0.05912218 0.9114649 1.9141285 0.7435547
04074950 0.03584394 0.86431634 1.9038085 0.7442183
04105700 0.02420245 0.85471606 1.8935375 0.744822
04115265 0.113158405 0.5945382 1.8838615 0.74400526
04122200 0.02616846 0.86854124 1.8738197 0.7446785
04122500 0.020689014 0.88301164 1.8638568 

06910800 1.7256471 0.8276372 1.9332426 0.7316581
06911900 2.003322 0.82916117 1.9334532 0.7319508
06917000 3.2621305 0.7207122 1.9374313 0.7319172
06918460 1.1655071 0.8108257 1.9351269 0.7321527
06919500 1.2852887 0.842589 1.9331928 0.7324814
06921070 1.9437113 0.8230838 1.9332243 0.7327503
06921200 5.341466 0.67031974 1.9433078 0.7325656
07057500 0.98629683 0.6631954 1.9404846 0.73236096
07060710 1.6115025 0.65128326 1.939517 0.7321225
07066000 1.9083974 0.61063623 1.939426 0.7317662
07083000 0.14945406 0.9673984 1.934192 0.73245513
07142300 0.015930453 0.32153428 1.9285994 0.73125714
07145700 1.4042596 0.74024403 1.9270753 0.73128325
07167500 10.403313 0.41187114 1.9516438 0.7303574
07180500 2.0487287 0.6984247 1.9519244 0.73026514
07184000 8.064598 0.7379743 1.9695402 0.7302874
07195800 2.09204 0.67005473 1.9698925 0.7301143
07196900 8.081579 0.5772962 1.9874043 0.7296765
07197000 1.9835986 0.76535374 1.9873935 0.7297784
07208500 0.044143576 0.7479501 1.9818571 0.72983015
07261000 

14305500 14.770721 0.8732027 2.5913358 0.7347387
14306340 9.119397 0.85993594 2.6044445 0.73499006
14306500 3.7769217 0.9062961 2.6067939 0.7353334
14308990 1.7836583 0.81164193 2.6051476 0.73548603
14309500 3.3706017 0.8773304 2.6066756 0.73576903
14316700 3.410111 0.8767089 2.6082761 0.7360498
14325000 8.259205 0.8635902 2.6195104 0.7363034
14362250 0.16910686 0.8536346 2.6146488 0.7365362
14400000 14.843471 0.91909236 2.638864 0.73689777
10259000 0.40964818 0.3775329 2.6344585 0.7361875
11124500 1.385836 0.45980328 2.6319957 0.7356424
11141280 0.6180134 0.61667645 2.628031 0.73540807
11143000 2.6242275 0.90091413 2.6280236 0.7357334
11148900 5.8810487 0.77296036 2.6344023 0.73580635
11151300 0.14276774 0.59954906 2.6295261 0.7355397
11176400 0.97019565 0.60199463 2.626285 0.73527884
11230500 0.31112644 0.96245605 2.621772 0.73572165
11237500 2.0008547 0.9084885 2.6205642 0.73605776
11264500 0.7651049 0.9316566 2.6169615 0.73643756
11266500 0.9652136 0.91575474 2.6137602 0.73678505
1

In [None]:
np.array(ea_data[basin].qsim)[0]

In [1]:
class A:
    pass
class B(A):
    pass
class C(B):
    pass

In [2]:
c = C()

In [5]:
issubclass(type(c), A)

True

In [96]:
issubclass(np.array(1), onp.ndarray)

TypeError: issubclass() arg 1 must be a class