In [63]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import json
import numpy as np
import jax
import jax.numpy as jnp
import flax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
import timecast as tc

from mpl_toolkits import mplot3d


plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
"""timecast top-level API"""
from functools import partial
from typing import Callable
from typing import Tuple
from typing import Union

import flax
import jax
import jax.numpy as jnp
import numpy as np


def _objective(x, y, loss_fn, model):
    """Default objective function"""
    y_hat = model(x)
    return loss_fn(y, y_hat), y_hat


def tmap(
    X: Union[np.ndarray, Tuple[np.ndarray, ...]],
    Y: Union[np.ndarray, Tuple[np.ndarray, ...]],
    optimizer: flax.optim.base.Optimizer,
    loss_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] = lambda true, pred: jnp.square(
        true - pred
    ).mean(),
    state: flax.nn.base.Collection = None,
    objective: Callable[
        [
            np.ndarray,
            np.ndarray,
            Callable[[np.ndarray, np.ndarray], np.ndarray],
            flax.nn.base.Model,
        ],
        Tuple[np.ndarray, np.ndarray],
    ] = None,
    batch_size: int = 1
):
    """Take gradients steps performantly on one data item at a time
    Args:
        X: np.ndarray or tuple of np.ndarray of inputs
        Y: np.ndarray or tuple of np.ndarray of outputs
        optimizer: initialized optimizer
        loss_fn: loss function to compose where first arg is true value and
        second is pred
        state: state required by flax
        objective: function composing loss functions
    Returns:
        np.ndarray: result
    """
    state = state or flax.nn.Collection()
    objective = objective or _objective

    def _tmap(optstate, xy):
        """Helper function"""
        x, y = xy
        optimizer, state = optstate
        func = partial(objective, x, y, loss_fn)
        with flax.nn.stateful(state) as state:
            (loss, y_hat), grad = jax.value_and_grad(func, has_aux=True)(optimizer.target)
        return (optimizer.apply_gradient(grad), state), y_hat

    (optimizer, state), pred = jax.lax.scan(_tmap, (optimizer, state), (X, Y))
    return pred, optimizer, state

In [4]:
from timecast.learners import AR

In [112]:
model_def = flax.nn.DenseGeneral.partial(
    features=1,
    axis=(1, 2),
    batch_dims=0,
    kernel_init=flax.nn.initializers.kaiming_normal()
)

In [113]:
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(2, 1, 57)])

In [114]:
print(params["bias"].shape, params["kernel"].shape)

(1,) (1, 57, 1)


In [115]:
model = flax.nn.Model(model_def, params)

In [116]:
X = np.random.rand(6, 1, 57)

In [117]:
model(X)

DeviceArray([[1.4990948 ],
             [0.75036776],
             [1.395464  ],
             [0.9861583 ],
             [0.9448118 ],
             [1.0112364 ]], dtype=float32)

In [22]:
np.tensordot(params["kernel"], X, axes=[(0, 1), (0, 1)])

array([0.9262869])

In [72]:
model_def = flax.nn.Dense.partial(
    features=1,
    kernel_init=flax.nn.initializers.kaiming_normal()
)

In [73]:
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(57)])

In [74]:
print(params["bias"].shape, params["kernel"].shape)

(1,) (57, 1)


In [69]:
model = flax.nn.Model(model_def, params)

In [70]:
X = onp.random.rand(3, 57)

In [71]:
model(X)

DeviceArray([[1.7362856],
             [0.7493028],
             [1.4401443]], dtype=float32)