In [None]:
# Sets how much GPU memory JAX preallocate
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.5

In [None]:
from typing import Iterator, NamedTuple, Callable

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import lovely_jax as lj
lj.monkey_patch()

### Loading Data

In [None]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from pathlib import Path

from datasets import load_dataset
from torch.utils.data import DataLoader,default_collate
import torchvision.transforms.functional as TF
from operator import itemgetter
import fastcore.all as fc

In [None]:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)
dsd

In [None]:
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [None]:
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        return cls(*[DataLoader(ds, batch_size, collate_fn=collate_dict(ds), **kwargs) for ds in dd.values()])

In [None]:
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [None]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [None]:
tds = dsd.with_transform(transformi)
bs = 1024*8
dls = DataLoaders.from_dd(tds, bs, num_workers=6)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]

In [None]:
_, batch = list(enumerate(dt))[0]
batch

In [None]:
class Batch(NamedTuple):
  image: np.ndarray  # [B, H, W, 1]
  label: np.ndarray  # [B]

In [None]:
xb,yb = map(jnp.array, batch)
batch = Batch(xb,yb)
batch

### Model

In [None]:
def forward(x:jnp.array) ->jnp.ndarray: return hk.nets.MLP(output_sizes=[50,10])(x)
model = hk.without_apply_rng(hk.transform(forward))

In [None]:
key = jax.random.PRNGKey(42)
initial_params = model.init(key, xb)

In [None]:
logits = model.apply(initial_params, batch.image)
logits

In [None]:
@fc.typedispatch
@jax.jit
def evaluate(params:hk.Params, batch:Batch) -> jnp.ndarray:
    logits = model.apply(params, batch.image)
    preds = jnp.argmax (logits, axis=-1)
    return jnp.mean(preds == batch.label)

evaluate(initial_params, batch)

### Loss

In [None]:
l = optax.softmax_cross_entropy_with_integer_labels(model.apply(initial_params, batch.image), batch.label)
jnp.sum(l)

In [None]:
# Loss
def loss(params:hk.Params, batch: Batch)-> jnp.ndarray:
    bs, *_ = batch.image.shape
    preds = model.apply(params, batch.image)
    return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(preds, batch.label)/bs)

loss(initial_params, batch), batch

### Learning

In [None]:
class TrainingState(NamedTuple):
  params: hk.Params
  opt_state: optax.OptState
  
# Optimiser
lr = 1e-3
opt = optax.adam(lr)
initial_opt_state = opt.init(initial_params)
state = TrainingState(initial_params, initial_opt_state)
opt


In [None]:
@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
    grads = jax.grad(loss)(state.params, batch)
    updates, opt_state = opt.update(grads, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    return TrainingState(params, opt_state)

In [None]:
state.params, batch

In [None]:
s = update(state, batch)
s


In [None]:
class TrainingStats(NamedTuple):
  accuracy: list
  losses: list
  ns: list

stats = TrainingStats([],[],[])

In [None]:
@jax.jit
def update(state: TrainingState, batch: Batch) -> TrainingState:
    l, grad = jax.value_and_grad(loss)(state.params, batch)
    updates, opt_state = opt.update(grad, state.opt_state)
    params = optax.apply_updates(state.params, updates)
    return TrainingState(params, opt_state), l

In [None]:
@fc.typedispatch
@jax.jit
def evaluate(logits: jnp.ndarray, batch:Batch) -> jnp.ndarray:
    preds = jnp.argmax (logits, axis=-1)
    return jnp.mean(preds == batch.label)

@fc.typedispatch
@jax.jit
def evaluate(params:hk.Params, batch:Batch) -> jnp.ndarray:
    logits = model.apply(params, batch.image)
    preds = jnp.argmax (logits, axis=-1)
    return jnp.mean(preds == batch.label)

In [None]:
xb, yb = next(iter(dls.valid))
b = Batch(jnp.array(xb), jnp.array(yb))
b

In [None]:
evaluate(state.params, b)

In [None]:
state = TrainingState(initial_params, initial_opt_state)
n_epochs = 2
for epoch in range(n_epochs):
    for _, batch in enumerate(dls.train): 
        xb, yb = batch
        state, _ = update(state, Batch(jnp.array(xb), jnp.array(yb)))
    xb, yb = next(iter(dls.valid))
    accuracy = evaluate(state.params, Batch(jnp.array(xb), jnp.array(yb)))
    print({"epoch": epoch, "accuracy": f"{accuracy:.3f}"})

### Learner

In [None]:
subkey = jax.random.PRNGKey(42)
xb, yb = next(iter(dls.train))
initial_params = model.init(subkey, xb)

In [None]:
state = TrainingState(initial_params, initial_opt_state)
stats = TrainingStats([],[],[])
class Learner:
    key = jax.random.PRNGKey(42)
    def __init__(self, 
            model:hk.transform, 
            dls:DataLoaders, 
            loss_func:Callable[..., jnp.ndarray], 
            lr: float, 
            opt_func:Callable=optax.sgd, 
            state=None, 
            stats=None ): 
        fc.store_attr()        
        if state is None:
            self.key, subkey = jax.random.split(self.key)
            xb, _ = next(iter(dls.train))
            initial_params = model.init(subkey, xb)
            initial_opt_state = opt.init(initial_params)
            # self.state = TrainingState(initial_params, initial_opt_state)
        if stats is None:
            stats = TrainingStats([],[],[])

# learn = Learner(model, dls, loss, 1e-2, opt, state, stats)
learn = Learner(model, dls, loss, 1e-2, opt, None, None)

In [None]:
def one_batch(batch):
    if state.is_training:
        state, loss = update(state, batch)
    with jax.default_device(jax.devices("cpu")[0]): calc_stats(batch, loss)

def calc_stats(batch, loss):
    logits = model.apply(state.params, batch.image)
    acc = evaluate(initial_params, batch)
    stats.accuracy.append(acc)
    n = len(batch.label)
    stats.losses.append(loss*n)
    stats.ns.append(n)

def one_epoch():
    dl = dls.train if state.is_training else dls.valid
    for num, batch in enumerate(dl): 
        one_batch(map(jnp.array, batch))
    n = sum(stats.ns)
    print(state.epoch, state.is_training, sum(stats.losses)/n, sum(stats.accs)/n)

def fit(n_epochs):
    for epoch in range(n_epochs):
        state.epoch = epoch
        state.is_training = True
        one_epoch()
        state.is_training = False
        one_epoch()

In [None]:
state = TrainingState(initial_params, initial_opt_state, False, 0)
stats = TrainingStats([],[],[])
fit(2)

In [None]:
class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.xb,self.yb = to_device(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad(): self.calc_stats()

    def calc_stats(self):
        acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num,self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    
    def fit(self, n_epochs):
        self.accs,self.losses,self.ns = [],[],[]
        self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)

In [None]:
dl = dls.train
n_epochs = 2
for epoch in n_epochs:
    for num, batch in enumerate(dl):
        xb, yb = map(jnp.array, batch)
        loss, grads = value_and_grad(lambda p,x,y: loss_func(model(loss_func(model.apply))(params, xb, yb)
    p   arams = jax.tree_map(UpdateWeights, params, param_grads)

In [None]:
@fc.typedispatch
def fn(x:int): return x+1

@fc.typedispatch
def fn(x:float): return x+2

fn(1), fn(1.0)


In [None]:
@fc.typedispatch
def fn(x:int): return 1, x+5

g,h = fn(1)
g,h

In [None]:
i = fn(1)
i

In [None]:

# Model
def forward(x:jnp.array) -> jnp.ndarray: return hk.nets.MLP(output_sizes=[50,10])(x)
model = hk.without_apply_rng(hk.transform(forward))

# Optimiser
lr = 1e-3
opt = optax.adam(lr)

# Loss
def loss(params:hk.Params, batch)-> jnp.ndarray:
    xb,yb = batch
    bs, *_ = xb.shape
    return optax.softmax_cross_entropy_with_integer_labels(model.apply(params, xb), yb)

@jax.jit
def evaluate(params, batch) -> jnp.ndarray:

In [None]:
key = jrnd.PRNGKey(42)
params = model.init(key, xb)
params

In [None]:
preds = model.apply(params, None, xb)
preds

### Loss

In [None]:
import optax

In [None]:
yb # integers, not one_hot_encodings

In [None]:
loss_func = optax.softmax_cross_entropy_with_integer_labels
# if yb was one_hot_encoded, could use `optax.softmax_cross_entropy``
loss_func(preds, yb)

In [None]:
def loss(params, x,y):
    preds = model.apply(params, None, x)
    return loss_func(preds, y)

In [None]:
loss(params, xb,yb)

In [None]:
@hk.transform
def loss_fn(batch) ->jnp.ndarray:

### Backprop

In [None]:
jnp.DeviceArray??

In [None]:
#loss, lr
lr = 0.02
@jax.jit
def update(params,x,y):
    v, g = value_and_grad(loss)(params, x,y)
    # return jax.tree_map((lambda w,g,r: w-g*r), params, g, lr)
    return v,g

In [None]:
update(params, xb,yb)

In [None]:
v,g = update(params, xb[0],yb[0])
v,g

In [None]:
def loss(params, x,y):
    preds = vmap(model.apply, in_axes=(None, None, 0))(params, None, x)
    return loss_func(preds, y)

In [None]:
v,g = update(params, xb,yb)

In [None]:
p = update(params, xb,yb)
p

### Training

In [None]:
type(params)

In [None]:
model.apply

In [None]:
dl = dls.train
n_epochs = 2
for epoch in n_epochs:
    for num, batch in enumerate(dl):
        xb, yb = map(jnp.array, batch)
        loss, grads = value_and_grad(lambda p,x,y: loss_func(model(loss_func(model.apply))(params, xb, yb)
    p   arams = jax.tree_map(UpdateWeights, params, param_grads)

In [None]:
@jit
def update(model, params, loss_func, x, y):
    def loss(params, x,y): return loss_func(model(params))
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [None]:
(lambda a,b: a+b)(1,2)

In [None]:
def fit(model, params, dl, n_epochs: int):
    updated_params = params
    for epoch in n_epochs:
        for num, batch in enumerate(dl):
            #one_batch
            
    return updated_params

In [None]:
# def fit(model, params, dls, n_epochs: int):
    
#     def one_batch(batch:jnp.array, is_training:bool):
#         # convert torch tensors to jnp.arrays
#         xb,yb = batch
#         preds = model.apply(params, None, xb)
#         if is_training:
#             loss, grads = value_and_grad(loss_func)(preds,yb)
            
#         else:
#             loss = loss_func(preds, yb)
        

#         loss = loss_func(preds, yb)
#     def one_epoch(is_training: bool):
#         dl = dls.train if is_training else dls.valid
#         for num, batch in enumerate(dl): one_batch(map(jnp.array, batch), is_training)
        
#     for epoch in range(n_epochs):
#         one_epoch(is_training=True)
#         one_epoch(is_training=False)

### Creating a Learner

In [None]:
import fastcore.all as fc

In [None]:
class Learner:
    def __init__(self,dls): fc.store_attr()
    def one_batch(self):
        self.xb, self.yb = self.batch
    def one_epoch(self, is_training):
        dl = self.dls.train if is_training else self.dls.valid
        for self.num, self.batch in enumerate(dl): self.one_batch()

    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            self.one_epoch(False)