In [2]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [19]:
def runner(basin, lr=1e-5):
    import pickle

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Take, AR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))
    
    history_len = 270
    
    lstm = Take.partial(index=0)
    pcr = AR.partial(history_len=history_len, history=X[:history_len - 1])
    pcr = Sequential.partial(learners=[Take.partial(index=1), pcr])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, pcr])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)

    Y_hat, optimizer, state = smap((Y_lstm, X_t), Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "mse": jnp.square(Y - Y_hat).mean()}

In [None]:
(3652, 270, 32) (3652, 2)
(3652,)
(3652, 270, 32) (3652,)
[[-0.45735672 -1.19245875 -2.47350645 ... -0.89466536 -0.46939585
  -0.22357242]
 [-0.45735672 -1.11444044 -2.58669424 ... -0.89466536 -0.46939585
  -0.22357242]
 [-0.45735672 -0.90697372 -2.22412539 ... -0.89466536 -0.46939585
  -0.22357242]
 ...
 [-0.42567217 -0.34831032 -0.10944621 ... -0.89466536 -0.46939585
  -0.22357242]
 [-0.45015568  0.08645047 -0.02294497 ... -0.89466536 -0.46939585
  -0.22357242]
 [-0.45735672  0.14874364 -0.17110133 ... -0.89466536 -0.46939585
  -0.22357242]] 0.6203073859214783

In [64]:
class BlackBox(flax.nn.Module):
    def apply(self, x, arr):
        self.index = self.state("index", shape=(), initializer=flax.nn.initializers.zeros)
        val = arr[self.index.value.astype(int)]
        if not self.is_initializing():
            self.index.value += 1
        return val

In [53]:
arr = onp.random.rand(10, 10)

In [54]:
arr

array([[0.22834598, 0.73783427, 0.21340127, 0.14594622, 0.93729401,
        0.21077966, 0.00217258, 0.0783711 , 0.10410216, 0.00548621],
       [0.28369015, 0.26600418, 0.88828835, 0.156351  , 0.30090768,
        0.46945827, 0.95541964, 0.87032048, 0.07816289, 0.28699262],
       [0.21109687, 0.35374314, 0.37884342, 0.0028461 , 0.65632721,
        0.98917952, 0.53040956, 0.5306344 , 0.28700854, 0.30349276],
       [0.37714655, 0.99690132, 0.3412204 , 0.43779144, 0.16694666,
        0.40623382, 0.07328402, 0.76408005, 0.03327027, 0.49728318],
       [0.00166951, 0.73465882, 0.39870041, 0.611089  , 0.51767051,
        0.98409228, 0.83809796, 0.93466848, 0.83634388, 0.93795758],
       [0.71520936, 0.95066642, 0.76092544, 0.29650691, 0.04293031,
        0.6451522 , 0.58337907, 0.2337206 , 0.5141206 , 0.86950689],
       [0.11889408, 0.51946055, 0.34231914, 0.16377098, 0.80210626,
        0.74457677, 0.18377496, 0.52103961, 0.86024615, 0.4611516 ],
       [0.37840992, 0.03470067, 0.7521452

In [55]:
with flax.nn.stateful() as state:
    model_def = BlackBox.partial(arr=arr)
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(10, 10)])
    model = flax.nn.Model(model_def, params)

In [61]:
with flax.nn.stateful(state) as state:
    print(model(1))

[0.71520936 0.95066642 0.76092544 0.29650691 0.04293031 0.6451522
 0.58337907 0.2337206  0.5141206  0.86950689]


In [67]:
def runner2(basin, lr=1e-5):
    import pickle

    import jax.numpy as jnp

    from timecast.learners import Sequential, Parallel, Take, AR
    from timecast import smap
    from timecast.objectives import residual
    from timecast.optim import GradientDescent

    import tqdm.notebook as tqdm
    
    X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
    Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))
    
    history_len = 270
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[history_len - 1:]
    Y_lstm = jnp.array(Y.qsim)
    Y = jnp.array(Y.qobs).reshape(-1, 1)
    
    lstm = BlackBox.partial(arr=Y_lstm)
    ar = AR.partial(history_len=history_len, history=X[:history_len - 1])
    model, state = Parallel.new(shape=(1, 32), learners=[lstm, ar])
    
    optim_def = GradientDescent(learning_rate=lr)
    optimizer = optim_def.create(model)

    Y_hat, optimizer, state = smap(X_t, Y, optimizer, state=state, objective=residual)

    return {"basin": basin, "mse": jnp.square(Y - Y_hat).mean()}

In [68]:
mses = []
for basin in tqdm.tqdm(basins):
    result = runner2(basin)
    mses.append(result["mse"])
    print(jnp.mean(jnp.array(mses)), result["mse"])

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

0.7921799 0.7921799
0.91340697 1.034634
1.0359901 1.2811565
1.2766948 1.9988087
2.5853267 7.8198547
2.9434988 4.73436
2.653656 0.91459924
2.436369 0.9153604
2.2404542 0.6731362
2.1063416 0.8993281



KeyboardInterrupt: 

In [69]:
from timecast.learners import BlackBox

In [70]:
arr = onp.random.rand(10, 10)

In [71]:
arr

array([[0.90518325, 0.31685522, 0.39291534, 0.99181748, 0.29065332,
        0.24163112, 0.95548252, 0.99844331, 0.09140171, 0.65865713],
       [0.52713931, 0.88044497, 0.35095084, 0.71680116, 0.26766067,
        0.77530987, 0.31168818, 0.19164685, 0.95863018, 0.04147895],
       [0.20860077, 0.46211544, 0.29739237, 0.85566494, 0.12643679,
        0.1112244 , 0.47639521, 0.48575916, 0.7120798 , 0.81211941],
       [0.4754026 , 0.11204976, 0.47827696, 0.99039971, 0.80149364,
        0.24717626, 0.03159279, 0.95914768, 0.08948704, 0.7104717 ],
       [0.02233964, 0.50984689, 0.57352503, 0.39096701, 0.61187166,
        0.27400898, 0.94770433, 0.05261495, 0.1486659 , 0.88145313],
       [0.75581062, 0.93018558, 0.44643126, 0.48888059, 0.61728009,
        0.52795606, 0.60196177, 0.40297298, 0.84111828, 0.60850531],
       [0.75595573, 0.63371538, 0.41360125, 0.23757455, 0.09884147,
        0.89886265, 0.57831166, 0.05443015, 0.23131649, 0.96927147],
       [0.04700768, 0.39412255, 0.2333068

In [84]:
with flax.nn.stateful() as state:
    model_def = BlackBox.partial(arr=arr, name="BlackBox")
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(10, 10)])
    model = flax.nn.Model(model_def, params)

In [85]:
print(model)

Model(module=<class 'flax.nn.base.BlackBox'>, params={})


In [86]:
with flax.nn.stateful(state) as state:
    print(state.state)
    for i in range(arr.shape[0]):
        print(arr[i], model(1))

{'/BlackBox': {'index': DeviceArray(0., dtype=float32)}}
[0.90518325 0.31685522 0.39291534 0.99181748 0.29065332 0.24163112
 0.95548252 0.99844331 0.09140171 0.65865713] [0.90518325 0.31685522 0.39291534 0.99181748 0.29065332 0.24163112
 0.95548252 0.99844331 0.09140171 0.65865713]
[0.52713931 0.88044497 0.35095084 0.71680116 0.26766067 0.77530987
 0.31168818 0.19164685 0.95863018 0.04147895] [0.52713931 0.88044497 0.35095084 0.71680116 0.26766067 0.77530987
 0.31168818 0.19164685 0.95863018 0.04147895]
[0.20860077 0.46211544 0.29739237 0.85566494 0.12643679 0.1112244
 0.47639521 0.48575916 0.7120798  0.81211941] [0.20860077 0.46211544 0.29739237 0.85566494 0.12643679 0.1112244
 0.47639521 0.48575916 0.7120798  0.81211941]
[0.4754026  0.11204976 0.47827696 0.99039971 0.80149364 0.24717626
 0.03159279 0.95914768 0.08948704 0.7104717 ] [0.4754026  0.11204976 0.47827696 0.99039971 0.80149364 0.24717626
 0.03159279 0.95914768 0.08948704 0.7104717 ]
[0.02233964 0.50984689 0.57352503 0.39096

In [89]:
with flax.nn.stateful() as state:
    model(1)

ValueError: No state variable named `index` exists.