# core

> This is the core of the `Reax` lib. 

Here we define the major abstractions.

## Setup

In [1]:
#|hide
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2
%load_ext autoreload
%autoreload 2

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


In [2]:
#| default_exp core

In [3]:
#|hide
#|export
from functools import partial
from typing import (Callable, Dict, Hashable, List, Mapping, NamedTuple,
                    Optional, Sequence, Tuple, Union)

import haiku as hk
import jax
import jax.numpy as jnp
import lovely_jax as lj
import numpy as np
import optax
import torch
from torch.utils.data import DataLoader

In [4]:
#|hide
lj.monkey_patch()
jax.default_backend()

'gpu'

#### Data

As in `miniai`, we wil be using the `FashionMnist` Dataset for demonstration.   `Reax` is not intended to be a complete library, the `data` module is just a copy from [miniai]() to make it work.

In [5]:

import torchvision
import torchvision.transforms as transforms
from reax.data import DataLoaders, Batch, Tensor

In [6]:

XMEAN,XSTD, BATCH_SIZE, NUM_CLASSES = 0.28,0.35, 500, 10

tfm = transforms.Compose([transforms.PILToTensor(), 
                          transforms.Lambda(lambda x: x/255), transforms.Normalize(XMEAN, XSTD), 
                          transforms.Lambda(lambda x: torch.flatten(x))])
ds = partial(torchvision.datasets.FashionMNIST,root="data",download=True, transform = tfm)
train_ds, valid_ds = ds(train=True), ds(train=False)
tdl = DataLoader(train_ds, batch_size=BATCH_SIZE)
vdl = DataLoader(valid_ds, batch_size=BATCH_SIZE)
dls = DataLoaders(tdl, vdl)
batch = Batch(*map(jnp.array, next(iter(dls.train))))
batch

Batch(input=Array[500, 784] n=392000 x∈[-0.800, 2.057] μ=0.011 σ=1.006 gpu:0, target=Array[500] i32 x∈[0, 9] μ=4.402 σ=2.838 gpu:0)

:::{.callout-note}
Have you noticed how tensors are printed? This is [lovely-jax](https://xl0.github.io/lovely-jax/), the wonderful library that makes the JAX array representation more friendly. 
:::

## Model

The basic [Haiku](https://dm-haiku.readthedocs.io/) object to represent a model is a [TransformedWithState](https://dm-haiku.readthedocs.io/en/latest/api.html#transformedwithstate).  It represents a `function` or `module` that has been transformed by a `hk.transform` function.  Here we are using `hk.transform_with_state` which is the superset of the transform functions.  

State in the `Haiku` lingo means everything that make your original `Callable` not a pure function.  It is the context or state.  Somoe common `DNN` modules like `batch_norm`can keep some `state` to perform its work.  `State`, `Buffers` and `Context` are common names for this.

In [7]:
def forward(x:jnp.array) ->jnp.ndarray:
  return hk.nets.MLP(output_sizes=[50,NUM_CLASSES])(x) # todo: remove NUM_CLASSES dependency
network = hk.transform_with_state(forward)
type(network)

haiku._src.transform.TransformedWithState

#### Model class

In `Reax`, a `Model` is an immutable object. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) are JAX datastructures.

In [8]:
#|export

PyTree = Union[
    Tensor, Tuple["PyTree", ...], List["PyTree"], Dict[Hashable, "PyTree"], hk.Params, hk.State, optax.OptState, None
]  # I hope that with this definition it will work in  Haiku and Flax

ApplyFn = Callable[..., Tuple[Tensor, PyTree]] # returns result and state (aka buffers)

In [9]:
#|exports
class Model(NamedTuple):
    params: PyTree # the models parameters, weights and biases
    state: PyTree  # the model auxiliary state, e.g. batchnorm buffers
    apply: ApplyFn # the model forward pass function
    input_shape: Tuple[int, ...] # the shape of the input, used to infer the model output shape

    rng = hk.PRNGSequence(42) # random number generator

    @staticmethod
    def from_haiku(
        transformed: hk.TransformedWithState,       # transformed haiku model
        x: Tensor                                   # example input (e.g. batch.input)
    ):
        ''' Create a Model from a Haiku Transformed object and an example input.'''
        init, apply = transformed
        params, state = jax.jit(init)(next(Model.rng), x)
        return Model(params=params, state=state, apply=apply, input_shape=x.shape)

In [10]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model(params={'mlp/~/linear_0': {'b': Array[50] [38;2;127;127;127mall_zeros[0m gpu:0, 'w': Array[784, 50] n=39200 x∈[-0.071, 0.071] μ=0.000 σ=0.032 gpu:0}, 'mlp/~/linear_1': {'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] n=500 x∈[-0.276, 0.270] μ=-0.001 σ=0.121 gpu:0}}, state={}, apply=<function transform_with_state.<locals>.apply_fn at 0x7fb8f8324280>, input_shape=(500, 784))

Let's keep us sane and improve the model representation.

In [11]:
#|export
import fastcore.all as fc
from tabulate import tabulate
from reax.utils import str_tree
# the tabulate package is also used by Haiku in its hk.experimental methods

In [12]:
#|export
@fc.patch
def __repr__(self:Model)->str:
    table = [["Params", "State"],[str_tree(self.params), str_tree(self.state)]]
    return f"{self.__class__.__name__}:\n{tabulate(table, headers='firstrow', tablefmt='grid')}"


In [13]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model:
+---------------------------------------------+---------+
| Params                                      | State   |
| mlp/~/linear_0:                             | {}      |
|   b: all_zeros                              |         |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |
| mlp/~/linear_1:                             |         |
|   b: all_zeros                              |         |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |
+---------------------------------------------+---------+

In [14]:
#|export
@fc.patch
def __str__(self:Model) -> str:
    s1 = hk.experimental.tabulate(self.apply,
            columns=["input", "module", "owned_params", "output", "params_size"])(jnp.ones(self.input_shape))
    s2 = '\n'.join(self.__repr__().split('\n')[1:])
    return f"{s1}\n{s2}"

In [15]:
print(m)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+
| Params                        

#### Model Reactivity (Model Store)

Ok, now we will start to play with reactivity.  In `fastai` (also in Keras, vanilla PyTorch, etc) there is the concept of `Callbacks`.  It is the way to be notified when something of interest happens. 

> Don't nudge me, let me __call you back__ when I have something for you!

In general, you will need a callback only during training, after all, it is when your `things` change.  The model, the hyperparameters, the metrics, etc.

The __fastai/miniai__ `Learner` is an `Observable` and you can hold multiple callbacks. Every callback keep its state in the Learner object. You can have callbacks for metrics, for logging and saving the training process... callbacks that depend on other callbacks! That is why there is that ... shall I say... __ugly__ `order` property in the `Callback`class.

`Reax` is just an experiment on how to handle this reactivity in another way.  Maybe it will prove itself too bloated... or not. I decided to do it in `JAX/Haiku` to force a `functional programming` perspective.

The basic abstraction in  `Reax` are `stores`, observables that hold any value. We could have used [RxPy] which is an incredible package. But its superpowers may be too much for what we need. That is why I took inspiration from the `Svelte` JS framework to create `stores` (it became its own package, [Sveltish](https://fredguth.github.io/sveltish)).


In [16]:
#|export
from reax.stores import Writable, Notifier

A `ModelStore` is just a `Writable` store that holds values of type `Model`. 

In [17]:
#|export

class ModelStore(Writable[Model]):
    ''' A Model store. Custom Writable store'''
    def __init__(self,
                initial_value: Model, # Initial value of the store
            ) -> None:
        start: Notifier = lambda x: None # we won't need a Start/Stop Notifier
        super().__init__(initial_value, start)

#### Improving the ModelStore representation

We also may improve its representation.

In [18]:
#|export
import yaml


In [19]:
#|export
@fc.patch
def __repr__(self:ModelStore) -> str:
    params, state, apply, shape = self.value
    table = [["Params", "State", "Callbacks"],[str_tree(params), str_tree(state), [f"{s}\n" for s in self.subscribers]]]
    table = [["Params", "State", "Callbacks"],
                [str_tree(m.params), str_tree(m.state), yaml.dump([{i:str(f)} for i,f in enumerate(self.subscribers)])]]
    return f"{self.__class__.__name__}:\n{tabulate(table, headers='firstrow', tablefmt='grid')}"

In [20]:
ms = ModelStore(m)
ms

ModelStore:
+---------------------------------------------+---------+-------------+
| Params                                      | State   | Callbacks   |
| mlp/~/linear_0:                             | {}      | []          |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |             |
| mlp/~/linear_1:                             |         |             |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |             |
+---------------------------------------------+---------+-------------+

In [21]:
#|export
@fc.patch
def __str__(self:ModelStore) -> str:
    columns=["input", "module", "owned_params", "output", "params_size"]
    s = hk.experimental.tabulate(self.value.apply, columns=columns)(jnp.ones(self.value.input_shape))
    r=  f"{s}\n"
    s = '\n'.join(self.__repr__().split('\n')[1:])
    r+= f"{s}"
    return r

In [22]:
print(ms)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+-------------+
| Params          

A `callback` is any `Callable` that you pass on `subscribe`.

In [23]:
u1 = ms.subscribe(lambda x: print("1: callback 1"))

1: callback 1


A change in the store value, triggers all callbacks subscribed to it.

In [24]:
m = ms.get()
ms.set(Model(**(m._asdict()|{"state": {'a': 1, 'b': 2}})))

1: callback 1


#### Optimizer

You can have different `stores` for different things.  For example, this is a simpler one to deal with the optimizer.

In [25]:
#|export
class Optimizer(NamedTuple):
    state: optax.OptState
    apply: Callable

OptimizerStore = Writable[Optimizer]

By the way, we will use [Optax](https://optax.readthedocs.io/), which is a good companion for `Haiku`.

In [26]:
grad_tfm = optax.sgd(1e-3)
apply = grad_tfm.update
optState = grad_tfm.init(m.params) # you initialize the optimizer with the model params
optimizer = Optimizer(state=optState, apply=apply)
optimizer

Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7fb8f80ec940>)

In [27]:
os= OptimizerStore(optimizer)
u2 = os.subscribe(lambda x: print(f"callback 2: {x}"))

callback 2: Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7fb8f80ec940>)


In [28]:
grad_tf2 = optax.adam(1e-4)
optState2 = grad_tf2.init(m.params)
os.set(Optimizer(state=optState2, apply=grad_tf2.update))

callback 2: Optimizer(state=(ScaleByAdamState(count=Array i32 gpu:0 0, mu={'mlp/~/linear_0': {'b': Array[50] [38;2;127;127;127mall_zeros[0m gpu:0, 'w': Array[784, 50] [38;2;127;127;127mall_zeros[0m gpu:0}, 'mlp/~/linear_1': {'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] [38;2;127;127;127mall_zeros[0m gpu:0}}, nu={'mlp/~/linear_0': {'b': Array[50] [38;2;127;127;127mall_zeros[0m gpu:0, 'w': Array[784, 50] [38;2;127;127;127mall_zeros[0m gpu:0}, 'mlp/~/linear_1': {'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] [38;2;127;127;127mall_zeros[0m gpu:0}}), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7fb8f80ecd30>)


Cleaning up... you should remember to unsubscribe when you are done with a store.

In [29]:
u1(), u2()

(None, None)

In [30]:
m = Model.from_haiku(transformed=network, x=batch.input)
ms = ModelStore(m)
u1 = ms.subscribe(lambda x: print(f"cb 1:\n{x}"))

cb 1:
+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+-----------------------------------------+---------+
| Params                      

## Training

Finally we arrived at the Training, the  `core` of the `core`  ```¯\_(ツ)_/¯```

Here is where we will most need callbacks.

#### Learner

Like in `fastai`, we create a `Learner` class that will deal with the training. 

In [31]:
#|export
LossFn = Callable[[Tensor, Tensor], Tensor] # per example loss function
class Learner:
    '''Basic class for handling the training loop.'''
    def __init__(self, model:ModelStore, dls: DataLoaders, loss_func: LossFn, optimizer: OptimizerStore) -> None:
        # keeping fastai orderhere. I would prefer: dls, model, optimizer, loss_func, which seems more natural to me.
        fc.store_attr()
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"
    def __str__(self) -> str:
        table = [["Model", "DataLoaders", "LossFn", "Optimizer"],[id(self.model), id(self.dls), id(self.loss_func), id(self.optimizer)]]
        return tabulate(table, headers='firstrow', tablefmt='grid')

learner = Learner(model=ms, dls=dls, loss_func=optax.softmax_cross_entropy_with_integer_labels, optimizer=os)
learner

Learner:
+-----------------+-----------------+-----------------+-----------------+
|           Model |     DataLoaders |          LossFn |       Optimizer |
| 140432412740432 | 140432415838800 | 140436291828464 | 140432412267040 |
+-----------------+-----------------+-----------------+-----------------+

Learner itself, is not a store, but holds different stores for different aspects of the training.

We have a `ModelStore`, an `OptimizerStore`... it is only missing the most important thing we want to __observe__... the training loop itself. We need a `TrainingStore`.

But for that... let's first examine what we need.  Let's take a look in the __training loop__:

#### The anatomy of a training loop

```python
# pseudo-code

def fit(epochs: int)->None:
    '''Train the model for a number of epochs.'''
    # before fit
    for epoch in range(epochs):
        # is_training
        one_epoch(dls.train) # train for one epoch
        # is_validating
        one_epoch(dls.valid) # validate for one epoch
        # should halt epochs?
    # after fit

def one_epoch(dl)->None:
    '''Train or validate for one epoch.'''
    # before epoch
    for batch_n, batch in enumerate(dl): 
        one_batch(batch_n, batch)
        # should halt batches?
    # after epoch

def one_batch(batch_n: int, batch: Batch)->None:
    '''Train or validate for one batch.'''
    # before batch
    predict(...) # preds
    evaluate(...)# loss
    update model(...) if is_training
    # after batch
````

Our `TrainingStore` shall tell us where we are in the training loop and some information relevant at this point.

>  I am `training`, `epoch` 5, `iteration` 345, after `evaluate` with certain `current loss`.


Another aspect is that it seems it should be a `Readable` store, afterall, we don't want any callback being able to change information like:
`in which batch of which epoch am I?`

Exceptionally, we want to tell the `TrainingStore` to halt.

Let's start with:

In [66]:
#|export
class TrainingState(NamedTuple):

    last: Dict                      # last event that happened {'event': {payload}),
                                    # eg. {'before_batch': {'iter': 345}}
    epochs: int                     # number of epochs to fit
    epoch: int                      # current epoch
    step: int                       # current step, since the beginning of the training
    iter: int                       # current batch number, since beggining of the epoch
    batch: Optional[Batch]          # current batch instance or None (if training hasn't started yet)
    is_running: bool=False          # True if running (training/validation), False if stopped
    is_training: bool=False         # True if training is in progress
    is_validating: bool=False       # True if evaluation is in progress
    should_halt: bool=False         # True if should stop, False otherwise

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"
    def __str__(self) -> str:
        s = ""
        for (start,end) in [(0,6),(6,10)]:
            keys = list(t._asdict().keys())[start:end]
            values = [t._asdict()[key] for key in keys]
            s+= f"{tabulate([values], headers=keys, tablefmt='grid')}\n"
        return s

In [67]:
t = TrainingState(last={'created':None}, epochs=0, epoch=0, step=0, iter=0, batch=None)
t

TrainingState:
+-------------------+----------+---------+--------+--------+---------+
| last              |   epochs |   epoch |   step |   iter | batch   |
| {'created': None} |        0 |       0 |      0 |      0 |         |
+-------------------+----------+---------+--------+--------+---------+
+--------------+---------------+-----------------+---------------+
| is_running   | is_training   | is_validating   | should_halt   |
| False        | False         | False           | False         |
+--------------+---------------+-----------------+---------------+

### Training Store

In [None]:
#|export
class TrainingStore(Writable[TrainingState]):
    ''' A store that keeps tracking of the training loop state'''
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"

    def __str__(self) -> str:
        state = list(self.value._asdict().items())
        state_t = list(zip(*state))
        cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
        cbs_t = list(zip(*cbs))
        table = list(itertools.zip_longest(*state_t,*cbs_t))
        return tabulate(table, headers=['','State','', 'Calbacks'], tablefmt='grid')

#### TrainingStore representation

In [None]:
# a = [("A", "B", "C"), (1,2,3)]
# b = [("D", "E", None), (4,5,None)]
a = [("A", 1), ["B", 2], ["C", 3]]
b = [["D", 4], ["E", 5]]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

In [None]:
a = [["A", 1], ["B", 2], ["C", 3]]
b = []
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

In [None]:
a = [('epoch', 0), ('step', 0), ('batch_n', 0), ('batch', None), ('metrics', None), ('last_event', None), ('is_training', False), ('should_halt', False)]
b = [('0:', lambda:None)]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

In [None]:
#|export

@fc.patch
def __repr__(self: TrainingStore) -> str:
        return f"{self.__class__.__name__}:\n{self}"
@fc.patch
def __str__(self: TrainingStore) -> str:
    state = list(self.value._asdict().items())
    state_t = list(zip(*state))
    cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
    cbs_t = list(zip(*cbs))
    table = list(itertools.zip_longest(*state_t,*cbs_t))
    return tabulate(table, headers=['','State','', 'Calbacks'], tablefmt='grid')

In [None]:
ts = TrainingStore(t, lambda x:None)
u4 = ts.subscribe(lambda x: print(f"callback 4:\n {x}"))

In [None]:
print(ts)

In [None]:
unsubs = []
for i in range(12):
    u = ts.subscribe(lambda x: print(f"callback: {i}"))
    unsubs.append(u)
ts

In [None]:
for u in unsubs: u()

In [None]:
list(ts.value._asdict().keys())

In [None]:
#|export
LossFn = Callable[[Tensor, Tensor], Tensor] # per example loss function
class Learner:
    '''Basic class for handling the training loop.'''
    def __init__(self, model:ModelStore, dls: DataLoaders, loss_func: LossFn, optimizer: OptimizerStore) -> None:
        # keeping fastai orderhere. I would prefer: dls, model, optimizer, loss_func, which seems more natural to me.
        fc.store_attr()
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"
    def __str__(self) -> str:
        table = [["Model", "DataLoaders", "LossFn", "Optimizer"],[id(self.model), id(self.dls), id(self.loss_func), id(self.optimizer)]]
        return tabulate(table, headers='firstrow', tablefmt='grid')



In [None]:
class TrainingState:
    def __init__(self, *,
        epochs: int,                     # number of epochs to fit
        epoch: int,                      # current epoch
        step: int,                       # current step, since the beginning of the training
        iter: int,                       # current batch number, since beggining of the epoch
        batch: Optional[Batch],          # current batch instance or None (if training hasn't started yet)
        last: Dict=None,                 # last event that happened {'event': {payload}),
                                         # eg. {'before_batch': {'iter': 345}}
        is_running: bool=False,          # True if running (training/validation), False, if stopped
        is_training: bool=False,         # True if training is in progress
        is_validating: bool=False,       # True if evaluation is in progress
        should_halt: bool=False,         # True if should stop, False otherwise
    ) -> None: fc.store_attr()
    def as_dict(self):
        return self.__stored_args__
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"
    def __str__(self) -> str:
        return tabulate(list(self.__dict__.items())[1:])

In [None]:
class T4:
    def __init__(self, *, a: int, b: float = 1):
        fc.store_attr()
        
t = T4(a=3)
assert t.a==3 and t.b==1

In [None]:
class TrainingStore(Writable[TrainingState]):
    ''' A store that keeps tracking of the training loop state'''

In [None]:
@fc.patch
def __str__(self: TrainingStore) -> str:
    return str(self.value)
    # state = list(self.value.__dict__.items())
    # state_t = list(zip(*state))
    # cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
    # cbs_t = list(zip(*cbs))
    # table = list(itertools.zip_longest(*state_t,*cbs_t))
    # return tabulate(table, headers=['','State','', 'Calbacks'], tablefmt='grid')

In [None]:
@fc.patch
def __init__(self:Learner, model:ModelStore, dls: DataLoaders, loss_func: LossFn, optimizer: OptimizerStore) -> None:
    # keeping fastai orderhere. I would prefer: dls, model, optimizer, loss_func, which seems more natural to me.
    fc.store_attr()
    self.store = TrainingStore(TrainingState(epochs=0, epoch=0, step=0, iter=0, batch=None))
    self.watch = list(learner.store.value.__dict__.keys())[1:] # attributes to watch for changes

In [None]:
@fc.patch
def __settr__(self:Learner, name, value) -> None:
    print ('name:', name)
    if name in self.watch:
        print('watcha!')
    self.super().__setattr__(name, value)

@fc.patch
def __str__(self:Learner) -> str:
    return f"{self.__dict__}"
    # table = [["Model", "DataLoaders", "LossFn", "Optimizer"],[id(self.model), id(self.dls), id(self.loss_func), id(self.optimizer)]]
    # table =  tabulate(table, headers='firstrow', tablefmt='grid')
    # return f"{table}\n ... \n {self.__dict__}"


In [None]:
LossFn = Callable[[Tensor, Tensor], Tensor] # per example loss function
class Learner:
    '''Basic class for handling the training loop.'''
    def __init__(self, model:ModelStore, dls: DataLoaders, loss_func: LossFn, optimizer: OptimizerStore) -> None:
        # keeping fastai orderhere. I would prefer: dls, model, optimizer, loss_func, which seems more natural to me.
        fc.store_attr()
        self.store = TrainingStore(TrainingState(epochs=0, epoch=0, step=0, iter=0, batch=None))
        self.watch = list(learner.store.value.__dict__.keys())[1:] # attributes to watch for changes
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self}"
    # def __str__(self) -> str:
    #     r = f"{list(self.__dict__.keys())}\n"
    #     table = [["Model", "DataLoaders", "LossFn", "Optimizer"],[id(self.model), id(self.dls), id(self.loss_func), id(self.optimizer)]]
    #     return r+tabulate(table, headers='firstrow', tablefmt='grid')
    def __setattr__(self, k,v) -> None:
        if hasattr(self, 'watch') and k in self.watch:
            new_value = self.store.value.as_dict()|{k:v}
            self.store.set(TrainingState(**new_value))
        super().__setattr__(k,v)

In [None]:
learner = Learner(model=ms, dls=dls, loss_func=optax.softmax_cross_entropy_with_integer_labels, optimizer=os)
learner.x = 1
learner.epoch = 1
learner.iter = 45

In [None]:
a = NamedTuple("a", [("n", int)])
a.n = 1

In [None]:
a.x = 3
print(a)

In [None]:
class Bunch:
    __init__ = lambda self, **kw: setattr(self, '__dict__', kw)
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.__dict__!r})"

In [None]:
a = Bunch(n=2,x=4)

In [None]:
a.y=10
a.iter = 20

In [None]:
a

In [None]:
a.__dict__|{'y':12}

In [None]:
setattr(a, 'z', 30)
a

In [None]:
# class with_interceptor:
#     def __init__(self, nm): self.nm = nm
#     def __call__(self, f):
#         print('intercepting....')
#         print(f)
#         # print(f'args: {args}')
#         # print(f'kwargs: {kwargs}')
#         return f
#     def __setattr__(self, k,v) -> None:
#         print('i.__setattr__:', (k,v))
#         super().__setattr__(k,v)

# class DummyCls:

#     @with_interceptor
#     def dummy_fn(self, x):
#         x = 1
#         print('dummy')

# a = DummyCls()
# a.dummy_fn()

In [None]:
# def make_pretty(func):
#     def inner(*args, **kwargs):
#         print("I got decorated")
#         func(*args, **kwargs)
#     return inner


# @make_pretty
# def ordinary(x):
#     print("I am ordinary")


# ordinary(1)

In [None]:
class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

In [None]:
class with_interceptor:
    def __init__(self, name='no name'):
        self.name = name
        self.stack = []
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            self.stack = [*self.stack, {f.__name__: {'obj': o, 'f': f, 'args': args, 'kwargs': kwargs}}]
            f(o, *args, **kwargs)
            self.stack = self.stack[:-1]
        return _f

    def __setattr__(self, k,v) -> None:
        print('make pretty intercepting:', (k,v))
        super().__setattr__(k,v)


In [None]:
from functools import wraps


def stored(f):
    @wraps(f)
    def wrapper(o, *args, **kwds):
        store = Writable(o)
        u = store.subscribe(lambda x: print(f"stored_callback: {x}"))
        f(o, *args, **kwds)
        u()
    return wrapper

class Child:

    @stored
    def do_something(self, x, y, *args, **kwargs):
        print('2')
        self.x = x
        self.y = y
        z = x + y
        def g(t):
            z = 3
            u = t
            def h(z):
                a = z
            h(u)
        g(z)

c = Child()
c.do_something(1,2,3, a=1, b=2)
c.do_something(1,2,3, a=1, b=2)

In [None]:
def stored(f):
    def _f(o, *args, **kwargs):
        o.__dict__[f.__name__] = f(o, *args, **kwargs)
    return _f

In [None]:
ordinary(1)

In [None]:
class dummyInterceptor:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(*args, **kwargs):
            f(*args, **kwargs)
        return _f

    def __setattr__(self, k,v) -> None:
        print('intercepting:', (k,v))
        super().__setattr__(k,v)

# def __call__(self, f):
#         def _f(o, *args, **kwargs):
#             try:
#                 o.callback(f'before_{self.nm}')
#                 f(o, *args, **kwargs)
#                 o.callback(f'after_{self.nm}')
#             except globals()[f'Cancel{self.nm.title()}Exception']: pass
#             finally: o.callback(f'cleanup_{self.nm}')
#         return _f


@dummyInterceptor
def dummy(x):
    x = 1
    def g():
        y = 2

In [None]:
dummy()

In [None]:
import inspect

In [None]:
f.__code__.co_varnames

In [None]:
f.__code__.co_varnames

In [None]:
learner.__dict__

In [None]:
@fc.patch
def __getattr__(self:Learner, name):
        if name in ('epochs','epoch','iter','one_batch'): 
            return partial(self.callback, name)
        raise AttributeError(name)

In [None]:
class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

In [None]:
def callback(self, method_nm): 
    run_cbs(self.cbs, method_nm, self)

In [None]:

#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
#|hide

# print(state)
# class TrainingStore(Writable[TrainingState]):

# @fc.patch
# def fit(self:Learner, n_epochs, trnState: TrainingState):
#     "Fit the model for `n_epochs` using batches from `dls`"
#     trnState.emit(Event(id="before_fit", payload=None))
#     for epoch in range(n_epochs):
#         self.one_epoch(is_training=True, trnState=trnState)
#         self.one_epoch(is_training=False, trnState=trnState)
#         if (trnState.get().should_halt): break

# training = TrainingStore(TrainingState(epoch=0, step=0, batch_n=0, batch=None, metrics=None, last_event=None))
# u3 = training.subscribe(lambda x: print(f"3:\n {x}"))
# @fc.patch
# def fit(self:Learner, n_epochs, trnState: TrainingState):
#     "Fit the model for `n_epochs` using batches from `dls`"
#     trnState.emit(Event(id="before_fit", payload=None))
#     for epoch in range(n_epochs):
#         self.one_epoch(is_training=True, trnState=trnState)
#         self.one_epoch(is_training=False, trnState=trnState)
#         if (trnState.get().should_halt): break

# @fc.patch
# def one_epoch(self:Learner, is_training: bool, trnState: TrainingState):
#     a = 1
#     # print(f"one_epoch: is_training={is_training}")
#     # print(trnState)
#     # trnState._s_is_training = is_training
#     # self.dl = self.dls.train if is_training else self.dls.valid
#     # trnState.emit(Event(id=f"before_epoch", payload=trnState._s_epoch))
#     # for batch_n, batch in enumerate(self.dl):
#     #     trnState._s_batch_n, trnState._s_batch  = batch_n, batch
#     #     # self.one_batch(trnState=trnState)
#     #     if (trnState._s_should_halt): break
#     # trnState.emit(Event(id=f"after_epoch", payload=trnState._s_epoch))



# params, state, apply, _ = ms.get()
# rng = hk.PRNGSequence(42) # random number generator
# @jax.jit
# def _predict(params, state, key, batch) -> Tensor:
#     logits, new_state = apply(params, state, key, batch.input)
#     return jnp.argmax (logits, axis=-1), new_state
# key = next(rng)
# _predict(params, state, key, batch)
# @jax.jit
# def _evaluate(params, state, key, batch) -> Tensor:
#     preds, _ = _predict(params, state, key, batch)
#     return jnp.mean(preds == batch.target)
# from torch.utils.benchmark import Timer
# evTimer = Timer(stmt="_evaluate(params, state, key, batch)", globals=globals())
# evTimer.timeit(1000)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def evaluate(model: ModelStore, batch: Batch) -> Tensor:
#     params, state, apply, _ = model.get()
#     key = next(rng)
#     return _evaluate(params, state, key, batch)
# evaluate(ms, batch)
# @jax.jit
# def _loss_fn(params, state, key, batch)-> jnp.ndarray:
#     targs = batch.target
#     preds, new_state = apply(params, state, key, batch.input)
#     # return the expectation of the loss wrt the distribution of the targets
#     return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(preds, targs)/targs.shape[0]), new_state
# key = next(rng)
# loss, new_state = _loss_fn(params, state, key, batch)
# lfTimer = Timer(stmt="_loss_fn(params, state, key, batch)", globals=globals())
# lfTimer.timeit(1000)

# a = NamedTuple('A', [('a', int), ('b', int)])(1,2)
# b = NamedTuple('A', [('a', int), ('b', int)])(3,3)
# s1 = set(a._asdict().items())
# s2 = set(b._asdict().items())
# s1 ^ s2
# trnState = TrainingStore(TrainingState(epoch=0, step=0, batch_n=0, batch=None, metrics=None, last_event=None))
# logs = []
# def logger(x):
#     logs.append(x)
#     last = set((logs[-1])._asdict().items())
#     curr = set((x)._asdict().items())
#     print (last ^ curr)

# u4 = trnState.subscribe(lambda x: logger(x))
# def one_batch(self):
#     self.preds = self.model(self.batch[0])
#     self.loss = self.loss_func(self.preds, self.batch[1])
#     if self.model.training:
#         self.loss.backward()
#         self.opt.step()
#         self.opt.zero_grad()
# class Learner():
#     def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

#     def one_batch(self):
#         self.preds = self.model(self.batch[0])
#         self.loss = self.loss_func(self.preds, self.batch[1])
#         if self.model.training:
#             self.loss.backward()
#             self.opt.step()
#             self.opt.zero_grad()

#     def one_epoch(self, train):
#         self.model.train(train)
#         self.dl = self.dls.train if train else self.dls.valid
#         try:
#             self.callback('before_epoch')
#             for self.iter,self.batch in enumerate(self.dl):
#                 try:
#                     self.callback('before_batch')
#                     self.one_batch()
#                     self.callback('after_batch')
#                 except CancelBatchException: pass
#             self.callback('after_epoch')
#         except CancelEpochException: pass
    
#     def fit(self, n_epochs):
#         self.n_epochs = n_epochs
#         self.epochs = range(n_epochs)
#         self.opt = self.opt_func(self.model.parameters(), self.lr)
#         try:
#             self.callback('before_fit')
#             for self.epoch in self.epochs:
#                 self.one_epoch(True)
#                 self.one_epoch(False)
#             self.callback('after_fit')
#         except CancelFitException: pass

#     def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
# #|export
# class with_cbs:
#     def __init__(self, nm): self.nm = nm
#     def __call__(self, f):
#         def _f(o, *args, **kwargs):
#             try:
#                 o.callback(f'before_{self.nm}')
#                 f(o, *args, **kwargs)
#                 o.callback(f'after_{self.nm}')
#             except globals()[f'Cancel{self.nm.title()}Exception']: pass
#             finally: o.callback(f'cleanup_{self.nm}')
#         return _f





# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
# params, state, apply, _ = ms.get()
# @jax.jit
# def _loss_fn(params, state, batch)-> Tuple[jnp.ndarray, PyTree]:
#     bs, *_ = batch.target.shape
#     logits, state = apply(params, state, next(rng), batch.input)
#     state = {'a':1, 'b':2}
#     return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(logits, batch.target)/bs)

# def loss_fn(model: ModelStore, batch: Batch) -> float:
#     params, state, apply, _ = model.get()
#     loss_value =  _loss_fn(params, state, batch)
#     new_model = Model(**(m._asdict()|{'state': new_state}))
#     model.set(new_model)
#     return float(loss_value)

# loss_fn(ms, batch)
# ms
# from functools import partial
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def update(model: ModelStore, optimizer: OptimizerStore, batch: Batch)->None:
#     m = model.get()
#     o = optimizer.get()
#     f = partial(loss_fn)(model=model)
#     grads = jax.grad(loss_fn)(batch)
#     @jax.jit
#     def _update():
#         updates, new_optState = o.apply(grads, o.state)
#         new_model_params = optax.apply_updates(m.params, updates)
#         return new_model_params, new_optState
#     new_model_params, new_optState = _update()
#     new_model = Model(**(m._asdict()|{'params': new_model_params}))
#     new_optimizer = Optimizer(**(o._asdict()|{'state': new_optState}))
#     model.set(new_model)
#     optimizer.set(new_optimizer)
#     return None
# todo: tentar jax.tree_util.Partial
# m = ms.get()
# o = os.get()
# f = partial(loss_fn, model=ms)
# grads = jax.grad(f)(batch)
# rng = hk.PRNGSequence(jax.random.PRNGKey(42))
# params, state, apply, _ = ms.get()
# def loss_fn():
#     loss_value, new_state =  _loss_fn(params, state, batch)
    
# grads = jax.grad(_loss_fn)(params, state, batch)
# grads
# update(ms, os, batch)

# rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# def loss_fn(model: ModelStore, batch: Batch) -> float:
#     params, state, apply, _ = model.get()
#     @jax.jit
#     def _loss(params, state, batch)-> jnp.ndarray:
#         bs, *_ = batch.target.shape
#         logits, state = (apply)(params, state, next(rng), batch.input)
#         return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(logits, batch.target)/bs), state
#     loss_value, new_state =  _loss(params, state, batch)
#     new_model = Model(**(m._asdict()|{'state': new_state}))
#     model.set(new_model)
#     return float(loss_value)

# loss_fn(ms, batch)
# def get_loss(loss_func, *args): return jax.jit(lambda params: loss_func(get_model(params), *args))
# mse_loss = get_loss(mse, xb,tb) 
# mse_loss, mse_loss(W)
# from torch.utils.benchmark import Timer
# jax_grad = Timer( stmt="jax.grad(mse_loss)", globals=globals())
# jax_grad.timeit(1000)
# class TrainingStore(Writable[TrainingState]):

#     def emit(self, event: Event):
#         self.set(self.value._replace(last_event=event))
#     # def __getattr__(self, name): # there  is a bug, I can't fi
#     #     if name[:3]=='_s_' : return getattr(self.value, name[3:])
#     #     else: return super().__getattr__(name)
#     # def __setattr__(self, name, value):
#     #     if name[:3]=='_s_' and hasattr(self.value, name[3:]):
#     #         self.set(self.value._replace(**{name[3:]: value}))
#     #     else: super().__setattr__(name, value)
#     def __repr__(self) -> str:
#         return f"{self.__class__.__name__}:\n{self}"
#     def __str__(self) -> str:
#         state = list(self.value._asdict().items())
#         cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
#         table = list(itertools.zip_longest(list(zip(*state)),list(zip(*cbs))))
#         return tabulate(table, headers=['State', 'Calbacks'], tablefmt='grid')
# def __repr__(self) -> str:
#         return f"{self.__class__.__name__}:\n{self}"
#     def __str__(self) -> str:
#         state = list(self.value._asdict().items())
#         state_t = list(zip(*state))
#         cbs = [(f"{i}:", v) for i, v in enumerate(self.subscribers)]
#         cbs_t = list(zip(*cbs))
#         table = list(itertools.zip_longest(*state_t,*cbs_t))
#         return tabulate(table, headers=['','State','', 'Calbacks'], tablefmt='grid')
#     # @property
#     # def _(self):
#     #     """The store value."""
#     #     return self.value
#     # @_.setter
#     # def _(self, value: TrainingState):
#     #     self.set(value)