In [2]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

In [3]:
from timecast.optim import Adagrad, RMSProp

In [4]:
from ealstm.gaip.flood_data import FloodData
from ealstm.gaip.utils import MSE, NSE

cfg_path = "../data/models/runs/run_2006_0032_seed444/cfg.json"
ea_data = pickle.load(open("../data/models/runs/run_2006_0032_seed444/lstm_seed444.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-5
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [5]:
from flax import nn

import jax
import jax.numpy as jnp

from timecast.learners import Sequential, Ensemble, AR
from timecast import map_grad
from timecast.optim import GradientDescent, Adagrad, RMSProp
from timecast.objectives import residual

In [20]:
class Identity(nn.Module):
    def apply(self, x):
        return x
class Take(nn.Module):
    def apply(self, x, i):
        return x[i]

def batch_window(X, window_size, offset=0):
    num_windows = X.shape[0] - window_size
    return np.stack([np.roll(X, shift=-(i + offset), axis=0)[:window_size] for i in range(num_windows)])
# [:num_windows]

class ARF(nn.Module):
    def apply(self, x, output_features, history_len, history=None):
        if x.ndim == 1:
            x = x[np.newaxis, :]
            
        self.history = self.state("history", shape=(history_len, x.shape[1]), initializer=nn.initializers.zeros)
            
        batch_size = x.shape[0]
        if self.is_initializing() and history is not None:
            self.history.value = np.vstack((self.history.value, history))[:history_len]
            
        stacked = np.vstack((self.history.value, x))
    
#         if batch_size > 1:
        batches = batch_window(stacked, history_len)
        self.history.value = stacked[x.shape[0]:]
#         else:
#             batches = stacked[x.shape[0]:]
#             self.history.value = batches
        
        y = nn.DenseGeneral(inputs=batches,
                            features=output_features,
                            axis=(1, 2),
                            batch_dims=(0),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [18]:
class ARF(nn.Module):
    def apply(self, x, output_features, history_len, history=None):

        if x.ndim == 1:
            x = x.reshape(1, -1)
            
        if self.is_initializing() and history is not None:
            x = history
            
        self.history = self.state("history", shape=(history_len, x.shape[1]), initializer=nn.initializers.zeros)
        self.history.value = np.vstack((self.history.value, x))[x.shape[0]:]
        
        y = nn.DenseGeneral(inputs=self.history.value,
                            features=output_features,
                            axis=(0, 1),
                            batch_dims=(),
                            bias=True,
                            dtype=jnp.float32,
                            kernel_init=nn.initializers.zeros,
                            bias_init=nn.initializers.zeros,
                            precision=None,
                            name="linear"
                           )
        return y

In [26]:
results = {}
mses = []
nses = []
batch_size = 3
for X, y, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
    with nn.stateful() as state:
        lstm = Sequential.partial(modules=[Take, Identity], args=[{"i": 0}, {}])
        
        arf = Sequential.partial(modules=[Take, ARF], args=[{"i": 1}, {"output_features": 1, "history_len": 270, "history": X[:flood_data.cfg["seq_length"]-1]}])
        
        model_def = Ensemble.partial(modules=[lstm, arf], args=[{}, {}])
        ys, params = model_def.init(jax.random.PRNGKey(0), (np.ones((batch_size, 1)), np.ones((batch_size, 32))))
#         ys, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(batch_size, 32)])
        model = nn.Model(model_def, params)
    optim_def = GradientDescent(learning_rate=1e-5)
    optimizer = optim_def.create(model)
    
    # NOTE: difference in indexing convention, so need to pad one row
    X_t = X[flood_data.cfg["seq_length"]-1:]
    Y_lstm = np.array(ea_data[basin].qsim)
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)
    
    Y_hat = map_grad((Y_lstm, X_t), Y, optimizer, state, residual, lambda x, y: jnp.square(x-y).mean(), batch_size=batch_size)
    
    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))
    

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

3
3
Model(module=<class 'flax.nn.base.Module.partial.<locals>.PartialModule'>, params={'0': {'0': {}, '1': {}}, '1': {'0': {}, '1': {'linear': {'bias': Traced<ShapedArray(float32[3,1]):JaxprTrace(level=0/0)>, 'kernel': Traced<ShapedArray(float32[3,270,32,1]):JaxprTrace(level=0/0)>}}}})


TypeError: sub got incompatible shapes for broadcasting: (3652, 1), (3651, 3).

In [648]:
jax.vmap??

In [216]:
a.reshape((0,) * 2).shape

(0, 0)

In [203]:
trim_size

3652

In [210]:
Y_lstm[:trim_size].reshape((-1, 2) + Y_lstm.shape[1:]).shape

(1826, 2)

In [232]:
x, y = (2, 4), 3

In [233]:
def d(x):
    return x, x + 1

In [234]:
jax.tree_util.tree_map(d, (x, y))

(((2, 3), (4, 5)), (3, 4))

In [185]:
(a, c), b = jax.tree_util.tree_map(d, (x, y))

In [186]:
a

(4, 5)

In [183]:
b

(2, 3)

In [197]:
jax.tree_util.tree_transpose(out_def, in_def, jax.tree_util.tree_map(d, (x, y)))

(((4, 3), 2), ((5, 4), 3))

In [194]:
out_def = jax.tree_util.tree_structure((x, y))

In [195]:
in_def = jax.tree_util.tree_structure((1, 2))

In [196]:
in_def

PyTreeDef(tuple, [*,*])

In [237]:
jax.tree_util.tree_flatten((x, y))[0][0]

2

In [241]:
np.concatenate((np.ones(4), np.array([[1]])))

TypeError: Cannot concatenate arrays with different ranks, got 1, 2.

In [27]:
nn.DenseGeneral.partial(features=1,
                        axis=(0),
                        batch_dims=(),
                        bias=True,
                        dtype=jnp.float32,
                        kernel_init=nn.initializers.zeros,
                        bias_init=nn.initializers.zeros,
                        precision=None,
                        name="linear"
                       )

flax.nn.base.Module.partial.<locals>.PartialModule