In [None]:
#| default_exp stores
from functools import partial
from itertools import zip_longest
from operator import itemgetter
from typing import Callable, NamedTuple, Tuple, Any, Optional, Union, TypeVar, Sequence, Mapping, List, Tuple, Dict, Hashable, Iterable, Type, cast, overload

import haiku as hk
import jax
import jax.numpy as jnp
import lovely_jax as lj
import lovely_tensors as lt
import numpy as np
import optax
import torchvision
import torchvision.transforms as transforms
from lovely_numpy import lo as ln
# from ./stores import Writable
from torch.utils.data import DataLoader, default_collate
import torch

lt.monkey_patch()
lj.monkey_patch()
jax.default_backend()

'gpu'

In [None]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))

## Types

In [None]:
# Initializer = Callable[[Sequence[int], Any], jax.Array]
# Params = Mapping[str, Mapping[str, jax.Array]]
# State = Mapping[str, Mapping[str, jax.Array]]

# # Missing JAX types.
# PRNGKey = jnp.ndarray  # pylint: disable=invalid-name

In [None]:
PRNGKey = jax.random.PRNGKey
Tensor = Union[jax.Array, jnp.ndarray] # should include np.ndarray, torch.Tensor?
PyTree = Union[Tensor,
               Tuple['PyTree', ...],
               List['PyTree'],
               Dict[Hashable, 'PyTree'],
               hk.Params, hk.State, optax.OptState,
               None] #hope that it works with Haiku and Flax

LossFn = Callable[[Tensor, Tensor], Tensor]
ApplyFn = Callable[..., Tuple[Tensor, PyTree]] 


class Optimizer(NamedTuple):
    state: optax.OptState # optax optimizer state
    gradTransformer: optax.GradientTransformation # optax optimizer (e.g. Adam)

class Learner(NamedTuple):
    data: DataLoaders
    model: Model
    optimizer: Optimizer
    loss_fn: LossFn
    state: PyTree

class Batch(NamedTuple):
  input: np.ndarray   # [B, H, W, C]
  target: np.ndarray  # [B]


#### Data

In [None]:
XMEAN,XSTD, BATCH_SIZE, NUM_CLASSES = 0.28,0.35, 500, 10

tfm = transforms.Compose([transforms.PILToTensor(), transforms.Lambda(lambda x: x/255), transforms.Normalize(XMEAN, XSTD), transforms.Lambda(lambda x: torch.flatten(x))])
ds = partial(torchvision.datasets.FashionMNIST,root="data",download=True, transform = tfm)
train_ds, valid_ds = ds(train=True), ds(train=False)
tdl = DataLoader(train_ds, batch_size=BATCH_SIZE)
vdl = DataLoader(valid_ds, batch_size=BATCH_SIZE)
dls = DataLoaders(tdl, vdl)
batch = Batch(*map(jnp.array, next(iter(dls.train))))
batch

Batch(input=Array[500, 784] n=392000 x∈[-0.800, 2.057] μ=0.011 σ=1.006 gpu:0, target=Array[500] i32 x∈[0, 9] μ=4.402 σ=2.838 gpu:0)

#### Model

In [None]:
def forward(x:jnp.array) ->jnp.ndarray:
  return hk.nets.MLP(output_sizes=[50,NUM_CLASSES])(x)

In [None]:
network = hk.transform_with_state(forward)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params, state = jax.jit(network.init)(next(rng), batch.input)

In [None]:
type(network)
hk.TransformedWithState

haiku._src.transform.TransformedWithState

In [None]:
class Model(NamedTuple):
    rng = hk.PRNGSequence(jax.random.PRNGKey(42))

    @staticmethod
    def from_haiku(
        transformed: hk.TransformedWithState, # transformed haiku model
        x: Tensor # example input (e.g. batch.input)
        ):
        init, apply = transformed
        params, state = jax.jit(init)(next(Model.rng), x)
        return Model(params=params, state=state, apply=apply)

    params: PyTree  # model weights and biases
    state: PyTree  # buffers (aka context) of the model (e.g. batch norm running mean)
    apply: ApplyFn  # model pure inference function

In [None]:
model = Model.from_haiku(transformed=network, x=batch.input)

function createCount() {
	const { subscribe, set, update } = writable(0);

	return {
		subscribe,
		increment: () => update(n => n + 1),
		decrement: () => update(n => n - 1),
		reset: () => set(0)
	};
}

export const count = createCount();

In [None]:
def evaluate(model: Model, batch: Batch) -> Tensor:
    (params, state, apply) = model
    @jax.jit
    def _evaluate(params, state, batch) -> Tensor:
        logits, state = jax.jit(apply)(params, state, next(Model.rng), batch.input)
        preds = jnp.argmax (logits, axis=-1)
        return jnp.mean(preds == batch.target), state
    result, state = _evaluate(params, state, batch)
    
evaluate(network.apply, params, state, next(rng), batch)

(Array gpu:0 0.146, {})

In [None]:
def loss_fn(apply: ApplyFn, params:hk.Params, state: hk.State, key:PRNGKey, batch: Batch) -> Tuple[float, hk.State]:
    @jax.jit
    def _loss(params, state,  key, batch)-> jnp.ndarray:
        bs, *_ = batch.target.shape
        logits, state = (apply)(params, state, key, batch.input)
        return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(logits, batch.target)/bs), state
    loss_value, state =  _loss(params, state, key, batch)
    return float(loss_value), state

e, state = loss(network.apply, params, state, next(rng), batch)
e, state, type(e)

(2.541998863220215, {}, float)

In [None]:
class Optimizer(NamedTuple):
    opt: optax.GradientTransformation # optax optimizer (e.g. Adam)
    state: optax.OptState # optax optimizer state

In [None]:
optimizer = Optimizer(optax.adam(1e-3), optax.adam(1e-3).init(params))

In [None]:
(opt,opt_state) =  optimizer
opt

GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>)

In [None]:

def update(opt: optimizer, batch: Batch):
    grads = jax.grad(loss_fn)(network.apply, params, state, next(rng), batch)
    @jax.jit
    def _update():
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
    return (params, opt_state)

In [None]:
_update(params, batch)

TypeError: Argument '<function transform_with_state.<locals>.apply_fn at 0x7f1f989e1e50>' of type <class 'function'> is not a valid JAX type.

In [None]:
K = TypeVar("K")
V = TypeVar("V")
T = TypeVar("T")
U = TypeVar("U")
PyTreeDef = type(jax.tree_util.tree_structure(None))

In [None]:
pFlat, pDef = jax.tree_util.tree_flatten(params)
a: jax.Array = pFlat[0]
a

In [None]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
network.init(next(rng), jnp.ones([BATCH_SIZE, 1]))

In [None]:
type(network)

In [None]:
BATCH_SIZE = 500
params, buffers = network.init(next(rng), jnp.ones([BATCH_SIZE, 1]))

In [None]:
type(buffers)

In [None]:
class Model(NamedTuple):
    apply: Callable # model inference function
    params: hk.Params