<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Setup

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


#### Data

As in `miniai`, we wil be using the `FashionMnist` Dataset for demonstration.   `Reax` is not intended to be a complete library, the `data` module is just a copy from [miniai]() to make it work.

In [5]:
import torchvision
import torchvision.transforms as transforms
from reax.data import DataLoaders, Batch, Tensor

In [6]:
XMEAN,XSTD, BATCH_SIZE, NUM_CLASSES = 0.28,0.35, 500, 10

tfm = transforms.Compose([transforms.PILToTensor(), transforms.Lambda(lambda x: x/255), transforms.Normalize(XMEAN, XSTD), transforms.Lambda(lambda x: torch.flatten(x))])
ds = partial(torchvision.datasets.FashionMNIST,root="data",download=True, transform = tfm)
train_ds, valid_ds = ds(train=True), ds(train=False)
tdl = DataLoader(train_ds, batch_size=BATCH_SIZE)
vdl = DataLoader(valid_ds, batch_size=BATCH_SIZE)
dls = DataLoaders(tdl, vdl)
batch = Batch(*map(jnp.array, next(iter(dls.train))))
batch

Batch(input=Array[500, 784] n=392000 x∈[-0.800, 2.057] μ=0.011 σ=1.006 gpu:0, target=Array[500] i32 x∈[0, 9] μ=4.402 σ=2.838 gpu:0)

## Model

The basic [Haiku](https://dm-haiku.readthedocs.io/) object to represent a model is a [TransformedWithState](https://dm-haiku.readthedocs.io/en/latest/api.html#transformedwithstate).  It represents a `function` or `module` that has been transformed by a `hk.transform` function.  Here we are using `hk.transform_with_state` which is the superset of the transform functions.  

State in the `Haiku` lingo means everything that make your original `Callable` not a pure function.  It is the context or state.  Somoe common `DNN` modules like `batch_norm`can keep some `state` to perform its work.  `State`, `Buffers` and `Context` are common names for this.

In [7]:
def forward(x:jnp.array) ->jnp.ndarray:
  return hk.nets.MLP(output_sizes=[50,NUM_CLASSES])(x) # todo: remove NUM_CLASSES dependency
network = hk.transform_with_state(forward)
type(network)

haiku._src.transform.TransformedWithState

In `Reax`, a [`Model`](https://fredguth.github.io/reax/core.html#model) is an immutable object. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) are JAX datastructures.

In [1]:
#| echo: false
#| output: asis
show_doc(Model)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L28){target="_blank" style="float:right; font-size:smaller"}

### Model

>      Model (params:Union[jax.Array,numpy.ndarray,Tuple[ForwardRef('PyTree'),..
>             .],List[ForwardRef('PyTree')],Dict[Hashable,ForwardRef('PyTree')],
>             Mapping[str,Mapping[str,jax.Array]],Iterable[ForwardRef('ArrayTree
>             ')],Mapping[Any,ForwardRef('ArrayTree')],NoneType], state:Union[ja
>             x.Array,numpy.ndarray,Tuple[ForwardRef('PyTree'),...],List[Forward
>             Ref('PyTree')],Dict[Hashable,ForwardRef('PyTree')],Mapping[str,Map
>             ping[str,jax.Array]],Iterable[ForwardRef('ArrayTree')],Mapping[Any
>             ,ForwardRef('ArrayTree')],NoneType], apply:Callable[...,Tuple[Unio
>             n[jax.Array,numpy.ndarray],Union[jax.Array,numpy.ndarray,Tuple[For
>             wardRef('PyTree'),...],List[ForwardRef('PyTree')],Dict[Hashable,Fo
>             rwardRef('PyTree')],Mapping[str,Mapping[str,jax.Array]],Iterable[F
>             orwardRef('ArrayTree')],Mapping[Any,ForwardRef('ArrayTree')],NoneT
>             ype]]], input_shape:Tuple[int,...])

In [9]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model(params={'mlp/~/linear_0': {'b': Array[50] [38;2;127;127;127mall_zeros[0m gpu:0, 'w': Array[784, 50] n=39200 x∈[-0.071, 0.071] μ=0.000 σ=0.032 gpu:0}, 'mlp/~/linear_1': {'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] n=500 x∈[-0.276, 0.270] μ=-0.001 σ=0.121 gpu:0}}, state={}, apply=<function transform_with_state.<locals>.apply_fn at 0x7f41486c1d30>, input_shape=(500, 784))

Let's keep us sane and improve the model representation.

In [12]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model:
+---------------------------------------------+---------+
| Params                                      | State   |
| mlp/~/linear_0:                             | {}      |
|   b: all_zeros                              |         |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |
| mlp/~/linear_1:                             |         |
|   b: all_zeros                              |         |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |
+---------------------------------------------+---------+

In [14]:
print(m)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+
| Params                        

#### Model Reactivity (Model Store)

Ok, now we will start to play with reactivity.  In `fastai` (also in Keras, vanilla PyTorch, etc) there is the concept of `Callbacks`.  It is the way to be notified when something of interest happens. 

> Don't nudge me, let me __call you back__ when I have something for you!

In general, you will need a callback only during training, after all, it is when your `things` change.  The model, the hyperparameters, the metrics, etc.

The __fastai/miniai__ [`Learner`](https://fredguth.github.io/reax/core.html#learner) is an `Observable` and you can hold multiple callbacks. Every callback keep its state in the Learner object. You can have callbacks for metrics, for logging and saving the training process... callbacks that depend on other callbacks! That is why there is that ... shall I say... __ugly__ `order` property in the `Callback`class.

`Reax` is just an experiment on how to handle this reactivity in another way.  Maybe it will prove itself too bloated... or not. I decided to do it in `JAX/Haiku` to force a `functional programming` perspective.

The basic abstraction in  `Reax` are `stores`, observables that hold any value. We could have used [RxPy] which is an incredible package. But its superpowers may be too much for what we need. That is why I took inspiration from the `Svelte` JS framework to create `stores` (it became its own package, [Sveltish](https://fredguth.github.io/sveltish)).

A [`ModelStore`](https://fredguth.github.io/reax/core.html#modelstore) is just a [`Writable`](https://fredguth.github.io/reax/stores.html#writable) store that holds values of type [`Model`](https://fredguth.github.io/reax/core.html#model). 

In [2]:
#| echo: false
#| output: asis
show_doc(ModelStore)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L71){target="_blank" style="float:right; font-size:smaller"}

### ModelStore

>      ModelStore (initial_value:__main__.Model)

A Model store. Custom Writable store

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| initial_value | Model | Initial value of the store |
| **Returns** | **None** |  |

#### Improving the ModelStore representation

We also may improve its representation.

In [19]:
ms = ModelStore(m)
ms

ModelStore:
+---------------------------------------------+---------+-------------+
| Params                                      | State   | Callbacks   |
| mlp/~/linear_0:                             | {}      | []          |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |             |
| mlp/~/linear_1:                             |         |             |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |             |
+---------------------------------------------+---------+-------------+

In [21]:
print(ms)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+-------------+
| Params          

A `callback` is any `Callable` that you pass on `subscribe`.

In [22]:
u1 = ms.subscribe(lambda x: print("1: callback 1"))

1: callback 1


A change in the store value, triggers all callbacks subscribed to it.

In [23]:
m = ms.get()
ms.set(Model(**(m._asdict()|{"state": {'a': 1, 'b': 2}})))

1: callback 1


#### Optimizer

You can have different `stores` for different things.  For example, this is a simpler one to deal with the optimizer.

In [3]:
#| echo: false
#| output: asis
show_doc(Optimizer)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L103){target="_blank" style="float:right; font-size:smaller"}

### Optimizer

>      Optimizer (state:Union[jax.Array,Iterable[ForwardRef('ArrayTree')],Mappin
>                 g[Any,ForwardRef('ArrayTree')]], apply:Callable)

By the way, we will use [Optax](https://optax.readthedocs.io/), which is a good companion for `Haiku`.

In [25]:
grad_tfm = optax.sgd(1e-3)
apply = grad_tfm.update
optState = grad_tfm.init(m.params) # you initialize the optimizer with the model params
optimizer = Optimizer(state=optState, apply=apply)
optimizer

Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f41484f4280>)

In [26]:
os= OptimizerStore(optimizer)
u2 = os.subscribe(lambda x: print(f"callback 2: {x}"))

callback 2: Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f41484f4280>)


In [27]:
grad_tf2 = optax.adam(1e-4)
optState2 = grad_tf2.init(m.params)
os.set(Optimizer(state=optState2, apply=grad_tf2.update))

callback 2: Optimizer(state=(ScaleByAdamState(count=Array i32 gpu:0 0, mu={'mlp/~/linear_0': {'b': Array[50] all_zeros gpu:0, 'w': Array[784, 50] all_zeros gpu:0}, 'mlp/~/linear_1': {'b': Array[10] all_zeros gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] all_zeros gpu:0}}, nu={'mlp/~/linear_0': {'b': Array[50] all_zeros gpu:0, 'w': Array[784, 50] all_zeros gpu:0}, 'mlp/~/linear_1': {'b': Array[10] all_zeros gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] all_zeros gpu:0}}), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f41484f4ee0>)


Cleaning up... you should remember to unsubscribe when you are done with a store.

In [28]:
u1(), u2()

(None, None)

m = Model.from_haiku(transformed=network, x=batch.input)
ms = ModelStore(m)
u1 = ms.subscribe(lambda x: print(f"cb 1:\n{x}"))

## Training

Finally we arrived at the Training, the  `core` of the `core`  ```¯\_(ツ)_/¯```

Here is where we will most need callbacks.

#### Learner

Like in `fastai`, we create a [`Learner`](https://fredguth.github.io/reax/core.html#learner) class that will deal with the training. 

In [4]:
#| echo: false
#| output: asis
show_doc(Learner)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L111){target="_blank" style="float:right; font-size:smaller"}

### Learner

>      Learner (model:__main__.ModelStore, dls:reax.data.DataLoaders, loss_func:
>               Callable[[Union[jax.Array,numpy.ndarray],Union[jax.Array,numpy.n
>               darray]],Union[jax.Array,numpy.ndarray]],
>               optimizer:reax.stores.Writable[__main__.Optimizer])

Basic class for handling the training loop.

In [30]:
learner = Learner(model=ms, dls=dls, loss_func=optax.softmax_cross_entropy_with_integer_labels, optimizer=os)
learner

Learner:
+-----------------+-----------------+-----------------+-----------------+
|           Model |     DataLoaders |          LossFn |       Optimizer |
| 139918362578752 | 139918365725360 | 139922468264256 | 139918362655424 |
+-----------------+-----------------+-----------------+-----------------+

Learner itself, is not a store, but holds different stores for different aspects of the training.

We have a [`ModelStore`](https://fredguth.github.io/reax/core.html#modelstore), an `OptimizerStore`... it is only missing the most important thing we want to __observe__... the training loop itself. We need a [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore).

But for that... let's first examine what we need.  Let's take a look in the __training loop__:

#### `Fit`: the training Loop

```python
# pseudo-code

def fit(epochs: int)->None:
    '''Train the model for a number of epochs.'''
    # before fit
    for epoch in range(epochs):
        # is_training
        one_epoch(dls.train) # train for one epoch
        # is_validating
        one_epoch(dls.valid) # validate for one epoch
        # should halt epochs?
    # after fit

def one_epoch(dl)->None:
    '''Train or validate for one epoch.'''
    # before epoch
    for batch_n, batch in enumerate(dl): 
        one_batch(batch_n, batch)
        # should halt batches?
    # after epoch

def one_batch(batch_n: int, batch: Batch)->None:
    '''Train or validate for one batch.'''
    # before batch
    predict(...) # preds
    evaluate(...)# loss
    update model(...) if is_training
    # after batch
````

Our [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore) shall tell us where we are in the training loop and some information relevant at this point.

>  I am `training`, `epoch` 5, `iteration` 345, after `evaluate` with certain `current loss`.


Another aspect is that it seems it should be a [`Readable`](https://fredguth.github.io/reax/stores.html#readable) store, afterall, we don't want any callback being able to change information like:
`in which batch of which epoch am I?`

Exceptionally, we want to tell the [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore) to halt.

Let's start with:

In [5]:
#| echo: false
#| output: asis
show_doc(TrainingState)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L123){target="_blank" style="float:right; font-size:smaller"}

### TrainingState

>      TrainingState (epochs:int, epoch:int, step:int, iter:int,
>                     batch:Optional[reax.data.Batch], last:Dict=None,
>                     is_running:bool=False, is_training:bool=False,
>                     is_validating:bool=False, should_halt:bool=False)

In [32]:
t = TrainingState(epochs=0, epoch=0, step=0, iter=0, batch=None)
t

TrainingState:
-------------  -----
epochs             0
epoch              0
step               0
iter               0
batch
last
is_running     False
is_training    False
is_validating  False
should_halt    False
-------------  -----

In [6]:
#| echo: false
#| output: asis
show_doc(TrainingStore)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L148){target="_blank" style="float:right; font-size:smaller"}

### TrainingStore

>      TrainingStore (initial_value:T, start:Notifier)

A store that keeps tracking of the training loop state

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| initial_value | T | initial value of the store |
| start | Notifier | function called when the first subscriber is added |
| **Returns** | **None** |  |

#### TrainingStore representation

In [35]:
# a = [("A", "B", "C"), (1,2,3)]
# b = [("D", "E", None), (4,5,None)]
a = [("A", 1), ["B", 2], ["C", 3]]
b = [["D", 4], ["E", 5]]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+----+------+----+------+
|    |   H1 |    |   H2 |
| A  |    1 | D  |    4 |
+----+------+----+------+
| B  |    2 | E  |    5 |
+----+------+----+------+
| C  |    3 |    |      |
+----+------+----+------+


In [36]:
a = [["A", 1], ["B", 2], ["C", 3]]
b = []
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+----+------+
|    |   H1 |
| A  |    1 |
+----+------+
| B  |    2 |
+----+------+
| C  |    3 |
+----+------+


In [37]:
a = [('epoch', 0), ('step', 0), ('batch_n', 0), ('batch', None), ('metrics', None), ('last_event', None), ('is_training', False), ('should_halt', False)]
b = [('0:', lambda:None)]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+-------------+-------+----+---------------------------------------+
|             |    H1 |    | H2                                    |
| epoch       |     0 | 0: | <function <lambda> at 0x7f414849e670> |
+-------------+-------+----+---------------------------------------+
| step        |     0 |    |                                       |
+-------------+-------+----+---------------------------------------+
| batch_n     |     0 |    |                                       |
+-------------+-------+----+---------------------------------------+
| batch       |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| metrics     |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| last_event  |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| is_training | False |    |      

In [39]:
ts = TrainingStore(t, lambda x:None)
u4 = ts.subscribe(lambda x: print(f"callback 4:\n {x}"))

callback 4:
 -------------  -----
epochs             0
epoch              0
step               0
iter               0
batch
last
is_running     False
is_training    False
is_validating  False
should_halt    False
-------------  -----


In [40]:
print(ts)

+---------------+---------+----+---------------------------------------+
|               |   State |    | Calbacks                              |
| epochs        |       0 | 0: | <function <lambda> at 0x7f41484a6700> |
+---------------+---------+----+---------------------------------------+
| epoch         |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| step          |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| iter          |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| batch         |         |    |                                       |
+---------------+---------+----+---------------------------------------+
| last          |         |    |                                       |
+---------------+---------+----+-------------------

In [41]:
unsubs = []
for i in range(12):
    u = ts.subscribe(lambda x: print(f"callback: {i}"))
    unsubs.append(u)
ts

callback: 0
callback: 1
callback: 2
callback: 3
callback: 4
callback: 5
callback: 6
callback: 7
callback: 8
callback: 9
callback: 10
callback: 11


TrainingStore:
+---------------+---------+-----+---------------------------------------+
|               |   State |     | Calbacks                              |
| epochs        |       0 | 0:  | <function <lambda> at 0x7f41484ad040> |
+---------------+---------+-----+---------------------------------------+
| epoch         |       0 | 1:  | <function <lambda> at 0x7f41484a6a60> |
+---------------+---------+-----+---------------------------------------+
| step          |       0 | 2:  | <function <lambda> at 0x7f41484ad280> |
+---------------+---------+-----+---------------------------------------+
| iter          |       0 | 3:  | <function <lambda> at 0x7f41484a6ca0> |
+---------------+---------+-----+---------------------------------------+
| batch         |         | 4:  | <function <lambda> at 0x7f41484a64c0> |
+---------------+---------+-----+---------------------------------------+
| last          |         | 5:  | <function <lambda> at 0x7f41484ad4c0> |
+---------------+------

In [42]:
for u in unsubs: u()

In [44]:
# print(state)

In [45]:
# class TrainingStore(Writable[TrainingState]):

In [46]:
# @fc.patch
# def fit(self:Learner, n_epochs, trnState: TrainingState):
#     "Fit the model for `n_epochs` using batches from `dls`"
#     trnState.emit(Event(id="before_fit", payload=None))
#     for epoch in range(n_epochs):
#         self.one_epoch(is_training=True, trnState=trnState)
#         self.one_epoch(is_training=False, trnState=trnState)
#         if (trnState.get().should_halt): break

In [47]:
# training = TrainingStore(TrainingState(epoch=0, step=0, batch_n=0, batch=None, metrics=None, last_event=None))
# u3 = training.subscribe(lambda x: print(f"3:\n {x}"))

In [48]:
# @fc.patch
# def fit(self:Learner, n_epochs, trnState: TrainingState):
#     "Fit the model for `n_epochs` using batches from `dls`"
#     trnState.emit(Event(id="before_fit", payload=None))
#     for epoch in range(n_epochs):
#         self.one_epoch(is_training=True, trnState=trnState)
#         self.one_epoch(is_training=False, trnState=trnState)
#         if (trnState.get().should_halt): break

# @fc.patch
# def one_epoch(self:Learner, is_training: bool, trnState: TrainingState):
#     a = 1
#     # print(f"one_epoch: is_training={is_training}")
#     # print(trnState)
#     # trnState._s_is_training = is_training
#     # self.dl = self.dls.train if is_training else self.dls.valid
#     # trnState.emit(Event(id=f"before_epoch", payload=trnState._s_epoch))
#     # for batch_n, batch in enumerate(self.dl):
#     #     trnState._s_batch_n, trnState._s_batch  = batch_n, batch
#     #     # self.one_batch(trnState=trnState)
#     #     if (trnState._s_should_halt): break
#     # trnState.emit(Event(id=f"after_epoch", payload=trnState._s_epoch))

In [49]:
# params, state, apply, _ = ms.get()
# rng = hk.PRNGSequence(42) # random number generator
# @jax.jit
# def _predict(params, state, key, batch) -> Tensor:
#     logits, new_state = apply(params, state, key, batch.input)
#     return jnp.argmax (logits, axis=-1), new_state
# key = next(rng)
# _predict(params, state, key, batch)
# @jax.jit
# def _evaluate(params, state, key, batch) -> Tensor:
#     preds, _ = _predict(params, state, key, batch)
#     return jnp.mean(preds == batch.target)
# from torch.utils.benchmark import Timer
# evTimer = Timer(stmt="_evaluate(params, state, key, batch)", globals=globals())
# evTimer.timeit(1000)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def evaluate(model: ModelStore, batch: Batch) -> Tensor:
#     params, state, apply, _ = model.get()
#     key = next(rng)
#     return _evaluate(params, state, key, batch)
# evaluate(ms, batch)
# @jax.jit
# def _loss_fn(params, state, key, batch)-> jnp.ndarray:
#     targs = batch.target
#     preds, new_state = apply(params, state, key, batch.input)
#     # return the expectation of the loss wrt the distribution of the targets
#     return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(preds, targs)/targs.shape[0]), new_state
# key = next(rng)
# loss, new_state = _loss_fn(params, state, key, batch)
# lfTimer = Timer(stmt="_loss_fn(params, state, key, batch)", globals=globals())
# lfTimer.timeit(1000)

In [50]:
# a = NamedTuple('A', [('a', int), ('b', int)])(1,2)
# b = NamedTuple('A', [('a', int), ('b', int)])(3,3)
# s1 = set(a._asdict().items())
# s2 = set(b._asdict().items())
# s1 ^ s2

In [51]:
# trnState = TrainingStore(TrainingState(epoch=0, step=0, batch_n=0, batch=None, metrics=None, last_event=None))
# logs = []
# def logger(x):
#     logs.append(x)
#     last = set((logs[-1])._asdict().items())
#     curr = set((x)._asdict().items())
#     print (last ^ curr)

# u4 = trnState.subscribe(lambda x: logger(x))

In [52]:
# def one_batch(self):
#     self.preds = self.model(self.batch[0])
#     self.loss = self.loss_func(self.preds, self.batch[1])
#     if self.model.training:
#         self.loss.backward()
#         self.opt.step()
#         self.opt.zero_grad()

In [53]:
# class Learner():
#     def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

#     def one_batch(self):
#         self.preds = self.model(self.batch[0])
#         self.loss = self.loss_func(self.preds, self.batch[1])
#         if self.model.training:
#             self.loss.backward()
#             self.opt.step()
#             self.opt.zero_grad()

#     def one_epoch(self, train):
#         self.model.train(train)
#         self.dl = self.dls.train if train else self.dls.valid
#         try:
#             self.callback('before_epoch')
#             for self.iter,self.batch in enumerate(self.dl):
#                 try:
#                     self.callback('before_batch')
#                     self.one_batch()
#                     self.callback('after_batch')
#                 except CancelBatchException: pass
#             self.callback('after_epoch')
#         except CancelEpochException: pass
    
#     def fit(self, n_epochs):
#         self.n_epochs = n_epochs
#         self.epochs = range(n_epochs)
#         self.opt = self.opt_func(self.model.parameters(), self.lr)
#         try:
#             self.callback('before_fit')
#             for self.epoch in self.epochs:
#                 self.one_epoch(True)
#                 self.one_epoch(False)
#             self.callback('after_fit')
#         except CancelFitException: pass

#     def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

In [54]:
# #|export
# class with_cbs:
#     def __init__(self, nm): self.nm = nm
#     def __call__(self, f):
#         def _f(o, *args, **kwargs):
#             try:
#                 o.callback(f'before_{self.nm}')
#                 f(o, *args, **kwargs)
#                 o.callback(f'after_{self.nm}')
#             except globals()[f'Cancel{self.nm.title()}Exception']: pass
#             finally: o.callback(f'cleanup_{self.nm}')
#         return _f

In [55]:
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
# params, state, apply, _ = ms.get()
# @jax.jit
# def _loss_fn(params, state, batch)-> Tuple[jnp.ndarray, PyTree]:
#     bs, *_ = batch.target.shape
#     logits, state = apply(params, state, next(rng), batch.input)
#     state = {'a':1, 'b':2}
#     return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(logits, batch.target)/bs)

# def loss_fn(model: ModelStore, batch: Batch) -> float:
#     params, state, apply, _ = model.get()
#     loss_value =  _loss_fn(params, state, batch)
#     new_model = Model(**(m._asdict()|{'state': new_state}))
#     model.set(new_model)
#     return float(loss_value)

# loss_fn(ms, batch)
# ms

In [56]:
# from functools import partial

In [57]:
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def update(model: ModelStore, optimizer: OptimizerStore, batch: Batch)->None:
#     m = model.get()
#     o = optimizer.get()
#     f = partial(loss_fn)(model=model)
#     grads = jax.grad(loss_fn)(batch)
#     @jax.jit
#     def _update():
#         updates, new_optState = o.apply(grads, o.state)
#         new_model_params = optax.apply_updates(m.params, updates)
#         return new_model_params, new_optState
#     new_model_params, new_optState = _update()
#     new_model = Model(**(m._asdict()|{'params': new_model_params}))
#     new_optimizer = Optimizer(**(o._asdict()|{'state': new_optState}))
#     model.set(new_model)
#     optimizer.set(new_optimizer)
#     return None

In [58]:
# todo: tentar jax.tree_util.Partial

In [59]:
# m = ms.get()
# o = os.get()
# f = partial(loss_fn, model=ms)
# grads = jax.grad(f)(batch)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
# params, state, apply, _ = ms.get()
# def loss_fn():
#     loss_value, new_state =  _loss_fn(params, state, batch)
    
# grads = jax.grad(_loss_fn)(params, state, batch)
# grads
# update(ms, os, batch)

In [60]:
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def loss_fn(model: ModelStore, batch: Batch) -> float:
#     params, state, apply, _ = model.get()
#     @jax.jit
#     def _loss(params, state, batch)-> jnp.ndarray:
#         bs, *_ = batch.target.shape
#         logits, state = (apply)(params, state, next(rng), batch.input)
#         return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(logits, batch.target)/bs), state
#     loss_value, new_state =  _loss(params, state, batch)
#     new_model = Model(**(m._asdict()|{'state': new_state}))
#     model.set(new_model)
#     return float(loss_value)

# loss_fn(ms, batch)

In [61]:
# def get_loss(loss_func, *args): return jax.jit(lambda params: loss_func(get_model(params), *args))
# mse_loss = get_loss(mse, xb,tb) 
# mse_loss, mse_loss(W)
# from torch.utils.benchmark import Timer
# jax_grad = Timer( stmt="jax.grad(mse_loss)", globals=globals())
# jax_grad.timeit(1000)

In [62]:
# class TrainingStore(Writable[TrainingState]):

#     def emit(self, event: Event):
#         self.set(self.value._replace(last_event=event))
#     # def __getattr__(self, name): # there  is a bug, I can't fi
#     #     if name[:3]=='_s_' : return getattr(self.value, name[3:])
#     #     else: return super().__getattr__(name)
#     # def __setattr__(self, name, value):
#     #     if name[:3]=='_s_' and hasattr(self.value, name[3:]):
#     #         self.set(self.value._replace(**{name[3:]: value}))
#     #     else: super().__setattr__(name, value)
#     def __repr__(self) -> str:
#         return f"{self.__class__.__name__}:\n{self}"
#     def __str__(self) -> str:
#         state = list(self.value._asdict().items())
#         cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
#         table = list(itertools.zip_longest(list(zip(*state)),list(zip(*cbs))))
#         return tabulate(table, headers=['State', 'Calbacks'], tablefmt='grid')

In [63]:
# def __repr__(self) -> str:
#         return f"{self.__class__.__name__}:\n{self}"
#     def __str__(self) -> str:
#         state = list(self.value._asdict().items())
#         state_t = list(zip(*state))
#         cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
#         cbs_t = list(zip(*cbs))
#         table = list(itertools.zip_longest(*state_t,*cbs_t))
#         return tabulate(table, headers=['','State','', 'Calbacks'], tablefmt='grid')
#     # @property
#     # def _(self):
#     #     """The store value."""
#     #     return self.value
#     # @_.setter
#     # def _(self, value: TrainingState):
#     #     self.set(value)