In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

%matplotlib notebook

In [7]:
import pickle
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE
from flax import nn

import jax
import jax.numpy as jnp

from timecast.learners import Sequential, Parallel, AR, Index
from timecast import smap
from timecast.objectives import residual
from timecast.optim import RMSProp

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

results = {}
mses = []
nses = []

lr = 1e-5
beta = 0.999

for X, y, basin in flood_data.generator():
    with nn.stateful() as state:
        lstm = Index.partial(index=0)
        take1 = Index.partial(index=1)
        ar = AR.partial(output_features=1, history_len=270, history=X[:flood_data.cfg["seq_length"]-1])
        arf = Sequential.partial(learners=[take1, ar])
        model_def = Parallel.partial(learners=[lstm, arf])
        ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1, 32)])
        model = nn.Model(model_def, params)
    optim_def = RMSProp(learning_rate=lr, beta2=beta)
    optimizer = optim_def.create(model)
    break
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = jnp.array(ea_data[basin].qsim)
    Y = jnp.array(ea_data[basin].qobs).reshape(-1, 1)

#     Y_hat = smap(X_t, Y, optimizer, lambda x, y: jnp.square(x-y).mean(), state, residual, )
    optimizer, state, Y_hat = smap((Y_lstm, X_t), Y, optimizer, lambda x, y: jnp.square(x-y).mean(), state, residual, )

    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    mses.append(mse)
    nses.append(nse)

    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": jnp.mean(jnp.array(mses)),
        "avg_nse": jnp.mean(jnp.array(nses))
    }
    break

In [4]:
results

{'01022500': {'mse': DeviceArray(0.82482624, dtype=float32),
  'nse': DeviceArray(0.83247095, dtype=float32),
  'count': 3921,
  'avg_mse': DeviceArray(0.82482624, dtype=float32),
  'avg_nse': DeviceArray(0.83247095, dtype=float32)}}

In [8]:
state.state

{'/Sequential_1/AR_1': {'history': DeviceArray([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
                 0.        ,  0.        ],
               [-0.45735672, -1.1924587 , -2.4735065 , ..., -0.8065217 ,
                -0.6008064 ,  0.20375867],
               [-0.45735672, -1.1144404 , -2.5866942 , ..., -0.8065217 ,
                -0.6008064 ,  0.20375867],
               ...,
               [-0.44727528, -0.1894627 , -0.5125051 , ..., -0.8065217 ,
                -0.6008064 ,  0.20375867],
               [-0.42567217, -0.34831032, -0.10944621, ..., -0.8065217 ,
                -0.6008064 ,  0.20375867],
               [-0.45015568,  0.08645047, -0.02294497, ..., -0.8065217 ,
                -0.6008064 ,  0.20375867]], dtype=float32)}}

In [5]:
model

Model(module=<class 'flax.nn.base.Parallel'>, params={'Index_0': {}, 'Sequential_1': {'AR_1': {'linear': {'bias': DeviceArray([0.], dtype=float32), 'kernel': DeviceArray([[[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],
              [0.]],

             [[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],
              [0.]],

             [[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],
              [0.]],

             ...,

             [[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],
              [0.]],

             [[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],
              [0.]],

             [[0.],
              [0.],
              [0.],
              ...,
              [0.],
              [0.],

In [None]:
from timecast.series import sp500
import flax

In [None]:
X, Y = sp500.generate()

In [None]:
with flax.nn.stateful(None) as state:
    a = 2 + 1

In [None]:
from timecast.learners import PredictConstant, PredictLast, AR
from timecast.learners import Sequential
from timecast.api import smap

In [None]:
pc4 = PredictConstant.make((), c=4)

In [None]:
X[:flood_data.cfg["seq_length"]-1]

In [None]:
pc4 = PredictConstant.make(c=4)
pc5 = PredictConstant.make(c=5)
pl = PredictLast.make()

In [None]:
s = Sequential.make((), learners=[pc4, pc5])

In [None]:
ar = AR.make((1,), output_features=1, history_len=10)

In [None]:
from timecast.optim import Adagrad

In [None]:
opt = Adagrad.make(pc4)

In [None]:
X_t

In [None]:
opt.optimizer.target

In [None]:
smap(X_t, Y, opt)

In [None]:
opt.optimizer.target

In [None]:
hyper_params = self.optimizer_def.update_hyper_params()

In [None]:
self.optimizer_def.apply_gradient(hyper_params, self.target, self.state, opt.optimizer.target)

In [None]:
params = self.target
state = self.state
grads = opt.optimizer.target

In [None]:
step = state.step

In [None]:
params_flat, treedef = jax.tree_flatten(params)
states_flat = treedef.flatten_up_to(state.param_states)
grads_flat = treedef.flatten_up_to(grads)
out = [self.apply_param_gradient(step, hyper_params, param, state, grad)
        for param, state, grad in zip(params_flat, states_flat, grads_flat)]

In [None]:
out

In [None]:
import flax
import jax
import jax.numpy as jnp

class Identity(flax.nn.Module):
    def apply(self, x):
        self.param("x", (1,) if not hasattr(x, "shape") else x.shape, flax.nn.initializers.zeros)
        return x

model_def = Identity.partial()
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(1,)])
model = flax.nn.Model(model_def, params)

def loss_fn(model, x, y):
    y_hat = model(x)
    return jnp.square(y - y_hat).mean(), y_hat

optim_def = flax.optim.Adam(learning_rate=lr)
optimizer = optim_def.create(model)

(loss, y_hat), grad = jax.value_and_grad(loss_fn, has_aux=True)(optimizer.target, 1.0, 2.0)
optimizer.apply_gradient(grad)

In [None]:
loss, y_hat, grad

In [None]:
optimizer.target

In [None]:
model_def

In [None]:
model_def = AR.partial(output_dim=1, history_len=270)

In [None]:
model_def

In [None]:
import numpy as onp
onp.random.rand(*())