In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import math
import json
import numpy as np
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
from timecast.learners import ARX
from timecast.utils.ar import historify

In [3]:
basins = pickle.load(open("../data/flood/meta.pkl", "rb"))["basins"]

In [4]:
basin = basins[0]

X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))

In [24]:
from jax.config import config
config.update('jax_disable_jit', False)

In [56]:
import pickle
import flax

from timecast.learners import ARX, Parallel, Precomputed
from timecast.learners import NewMixin

from timecast.optim import GradientDescent
from timecast.objectives import residual

from timecast import tscan

basin = basins[0]

X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))
history_len = 270

Y_lstm = jnp.pad(jnp.array(Y.qsim), (history_len - 1, 0))
Y = jnp.pad(jnp.array(Y.qobs), (history_len - 1, 0)).reshape(-1, 1)

class FloodPredictor(NewMixin, flax.nn.Module):
    def apply(self, features, arr, history_len):
        lstm = Precomputed(x=features, arr=arr)
        arx = ARX(features=features, history_len=history_len, constrain=False)
        
        return [lstm, arx]

model, state = FloodPredictor.new(shapes=[(1, 32)], arr=Y_lstm, history_len=history_len)
optim_def = GradientDescent(learning_rate=10 ** -5)
optimizer = optim_def.create(model)

Y_hat, optimizer, state = tscan(X, Y, optimizer, state=state, objective=residual)

In [None]:
jnp.square(Y_hat.squeeze()[history_len - 1:] - Y.squeeze()[history_len - 1:]).mean()

In [38]:
Y_hat.squeeze()

DeviceArray([0.       , 0.       , 0.       , ..., 1.3323907, 1.0760297,
             1.052646 ], dtype=float32)

In [40]:
Y.squeeze()

DeviceArray([0.        , 0.        , 0.        , ..., 1.4529347 ,
             1.1823308 , 0.99915284], dtype=float32)

In [54]:
import pickle
import flax

from timecast.learners import AR, Parallel, Precomputed
from timecast.learners import NewMixin

from timecast.optim import GradientDescent
from timecast.objectives import residual

from timecast import tscan

basin = basins[0]

X = pickle.load(open("../data/flood/test/{}.pkl".format(basin), "rb"))
Y = pickle.load(open("../data/flood/base/{}.pkl".format(basin), "rb"))
history_len = 270

X = X[history_len - 1:]
Y_lstm = jnp.array(Y.qsim)
Y = jnp.array(Y.qobs).reshape(-1, 1)

class FloodPredictor(NewMixin, flax.nn.Module):
    def apply(self, features, arr, history_len):
        lstm = Precomputed(x=features, arr=arr)
        ar = AR(x=features, history_len=history_len, history=X[:history_len - 1])
        
        return [lstm, ar]

model, state = FloodPredictor.new(shapes=[(1, 32)], arr=Y_lstm, history_len=history_len)
optim_def = GradientDescent(learning_rate=10 ** -5)
optimizer = optim_def.create(model)

Y_hat, optimizer, state = tscan(X, Y, optimizer, state=state, objective=residual)

In [55]:
jnp.square(Y_hat - Y).mean()

DeviceArray(0.7916824, dtype=float32)

In [None]:
np.pad(a, (10, 0))

In [None]:
from timecast.optim import GradientDescent

In [None]:
optim_def = GradientDescent(learning_rate=10 ** -5)

In [None]:
optimizer = optim_def.create(model)

In [None]:
model.params

In [None]:
def func(model, truth, targets=None, features=None):
    y_hat = model(targets, features)
    return jnp.square(truth - y_hat).mean(), y_hat

with flax.nn.stateful(state) as state:
    (loss, y_hat), gradient = jax.value_and_grad(func, has_aux=True)(optimizer.target,
                                                                     jnp.ones((1, 2)),
                                                                     jax.random.uniform(tc.utils.random.generate_key(), (1, 2)),
                                                                     jax.random.uniform(tc.utils.random.generate_key(), (1, 1)))
    optimizer = optimizer.apply_gradient(gradient)

In [None]:
gradient

In [None]:
y_hat

In [None]:
optimizer.target.params

In [None]:
inputs = jax.random.uniform(jax.random.PRNGKey(2), (1, 2))
with flax.nn.stateful(state) as state:
    result = optimizer.target(inputs)

In [None]:
result

(batch, history, input1, input2) x (history, input1, input2, output1, output2) = (batch, input1, input2, output1, output2)

bias: (batch, output1, input1, input2, output2)

In [None]:
sorted([4,1,2])

In [None]:
jnp.tensordot(jnp.ones((3, 2, 4)), jnp.ones((2, 1, 5)), axes=[(1,), (0,)]).shape

In [None]:
a = jnp.ones((20, 2, 4))

In [None]:
historify(a, history_len = 4).shape

In [None]:
[1][:-1]

In [None]:
a.reshape((1, -1))

Loop 0
- targets: 0 to H
- features: 0 to H

next
- targets: 0 to H
- features: 1 to H + 1

finally
- targets: 1 to H + 1
- features: 1 to H + 1

In [None]:
def my_func(x):
    return x

c = jax.xla_computation(my_func)(1.0)
print_opts = jax.lib.xla_client._xla.HloPrintOptions.short_parsable()
local_backend = jax.lib.xla_client.get_local_backend('cpu')
out = local_backend.compile(c)
print(out.hlo_modules()[0].to_string(print_opts))