<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

## Setup

#### Data

As in `miniai`, we wil be using the `FashionMnist` Dataset for demonstration.   `Reax` is not intended to be a complete library, the `data` module is just a copy from [miniai]() to make it work.

In [5]:
import torchvision
import torchvision.transforms as transforms
from reax.data import DataLoaders, Batch, Tensor

In [6]:
XMEAN,XSTD, BATCH_SIZE, NUM_CLASSES = 0.28,0.35, 500, 10

tfm = transforms.Compose([transforms.PILToTensor(), 
                          transforms.Lambda(lambda x: x/255), transforms.Normalize(XMEAN, XSTD), 
                          transforms.Lambda(lambda x: torch.flatten(x))])
ds = partial(torchvision.datasets.FashionMNIST,root="data",download=True, transform = tfm)
train_ds, valid_ds = ds(train=True), ds(train=False)
tdl = DataLoader(train_ds, batch_size=BATCH_SIZE)
vdl = DataLoader(valid_ds, batch_size=BATCH_SIZE)
dls = DataLoaders(tdl, vdl)
batch = Batch(*map(jnp.array, next(iter(dls.train))))
batch

Batch(input=Array[500, 784] n=392000 x∈[-0.800, 2.057] μ=0.011 σ=1.006 gpu:0, target=Array[500] i32 x∈[0, 9] μ=4.402 σ=2.838 gpu:0)

:::{.callout-note}
Have you noticed how tensors are printed? This is [lovely-jax](https://xl0.github.io/lovely-jax/), the wonderful library that makes the JAX array representation more friendly. 
:::

## Model

The basic [Haiku](https://dm-haiku.readthedocs.io/) object to represent a model is a [TransformedWithState](https://dm-haiku.readthedocs.io/en/latest/api.html#transformedwithstate).  It represents a `function` or `module` that has been transformed by a `hk.transform` function.  Here we are using `hk.transform_with_state` which is the superset of the transform functions.  

State in the `Haiku` lingo means everything that make your original `Callable` not a pure function.  It is the context or state.  Somoe common `DNN` modules like `batch_norm`can keep some `state` to perform its work.  `State`, `Buffers` and `Context` are common names for this.

In [7]:
def forward(x:jnp.array) ->jnp.ndarray:
  return hk.nets.MLP(output_sizes=[50,NUM_CLASSES])(x) # todo: remove NUM_CLASSES dependency
network = hk.transform_with_state(forward)
type(network)

haiku._src.transform.TransformedWithState

#### Model class

In `Reax`, a [`Model`](https://fredguth.github.io/reax/core.html#model) is an immutable object. [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) are JAX datastructures.

In [9]:
class Model(NamedTuple):
    params: PyTree # the models parameters, weights and biases
    state: PyTree  # the model auxiliary state, e.g. batchnorm buffers
    apply: ApplyFn # the model forward pass function
    input_shape: Tuple[int, ...] # the shape of the input, used to infer the model output shape

    rng = hk.PRNGSequence(42) # random number generator

    @staticmethod
    def from_haiku(
        transformed: hk.TransformedWithState,       # transformed haiku model
        x: Tensor                                   # example input (e.g. batch.input)
    ):
        ''' Create a Model from a Haiku Transformed object and an example input.'''
        init, apply = transformed
        params, state = jax.jit(init)(next(Model.rng), x)
        return Model(params=params, state=state, apply=apply, input_shape=x.shape)

In [1]:
#| echo: false
#| output: asis
show_doc(Model)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L29){target="_blank" style="float:right; font-size:smaller"}

### Model

>      Model (params:Union[jax.Array,numpy.ndarray,Tuple[ForwardRef('PyTree'),..
>             .],List[ForwardRef('PyTree')],Dict[Hashable,ForwardRef('PyTree')],
>             Mapping[str,Mapping[str,jax.Array]],Iterable[ForwardRef('ArrayTree
>             ')],Mapping[Any,ForwardRef('ArrayTree')],NoneType], state:Union[ja
>             x.Array,numpy.ndarray,Tuple[ForwardRef('PyTree'),...],List[Forward
>             Ref('PyTree')],Dict[Hashable,ForwardRef('PyTree')],Mapping[str,Map
>             ping[str,jax.Array]],Iterable[ForwardRef('ArrayTree')],Mapping[Any
>             ,ForwardRef('ArrayTree')],NoneType], apply:Callable[...,Tuple[Unio
>             n[jax.Array,numpy.ndarray],Union[jax.Array,numpy.ndarray,Tuple[For
>             wardRef('PyTree'),...],List[ForwardRef('PyTree')],Dict[Hashable,Fo
>             rwardRef('PyTree')],Mapping[str,Mapping[str,jax.Array]],Iterable[F
>             orwardRef('ArrayTree')],Mapping[Any,ForwardRef('ArrayTree')],NoneT
>             ype]]], input_shape:Tuple[int,...])

In [10]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model(params={'mlp/~/linear_0': {'b': Array[50] [38;2;127;127;127mall_zeros[0m gpu:0, 'w': Array[784, 50] n=39200 x∈[-0.071, 0.071] μ=0.000 σ=0.032 gpu:0}, 'mlp/~/linear_1': {'b': Array[10] [38;2;127;127;127mall_zeros[0m gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] n=500 x∈[-0.276, 0.270] μ=-0.001 σ=0.121 gpu:0}}, state={}, apply=<function transform_with_state.<locals>.apply_fn at 0x7f514c0ed280>, input_shape=(500, 784))

Let's keep us sane and improve the model representation.

In [13]:
m = Model.from_haiku(transformed=network, x=batch.input)
m

Model:
+---------------------------------------------+---------+
| Params                                      | State   |
| mlp/~/linear_0:                             | {}      |
|   b: all_zeros                              |         |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |
| mlp/~/linear_1:                             |         |
|   b: all_zeros                              |         |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |
+---------------------------------------------+---------+

In [15]:
print(m)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+
| Params                        

#### Model Reactivity (Model Store)

Ok, now we will start to play with reactivity.  In `fastai` (also in Keras, vanilla PyTorch, etc) there is the concept of `Callbacks`.  It is the way to be notified when something of interest happens. 

> Don't nudge me, let me __call you back__ when I have something for you!

In general, you will need a callback only during training, after all, it is when your `things` change.  The model, the hyperparameters, the metrics, etc.

The __fastai/miniai__ [`Learner`](https://fredguth.github.io/reax/core.html#learner) is an `Observable` and you can hold multiple callbacks. Every callback keep its state in the Learner object. You can have callbacks for metrics, for logging and saving the training process... callbacks that depend on other callbacks! That is why there is that ... shall I say... __ugly__ `order` property in the `Callback`class.

`Reax` is just an experiment on how to handle this reactivity in another way.  Maybe it will prove itself too bloated... or not. I decided to do it in `JAX/Haiku` to force a `functional programming` perspective.

The basic abstraction in  `Reax` are `stores`, observables that hold any value. We could have used [RxPy] which is an incredible package. But its superpowers may be too much for what we need. That is why I took inspiration from the `Svelte` JS framework to create `stores` (it became its own package, [Sveltish](https://fredguth.github.io/sveltish)).

A [`ModelStore`](https://fredguth.github.io/reax/core.html#modelstore) is just a [`Writable`](https://fredguth.github.io/reax/stores.html#writable) store that holds values of type [`Model`](https://fredguth.github.io/reax/core.html#model). 

In [2]:
#| echo: false
#| output: asis
show_doc(ModelStore)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L72){target="_blank" style="float:right; font-size:smaller"}

### ModelStore

>      ModelStore (initial_value:__main__.Model)

A Model store. Custom Writable store

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| initial_value | Model | Initial value of the store |
| **Returns** | **None** |  |

#### Improving the ModelStore representation

We also may improve its representation.

In [20]:
ms = ModelStore(m)
ms

ModelStore:
+---------------------------------------------+---------+-------------+
| Params                                      | State   | Callbacks   |
| mlp/~/linear_0:                             | {}      | []          |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.071, 0.071] μ=-9.844e-05 σ=0.032 |         |             |
| mlp/~/linear_1:                             |         |             |
|   b: all_zeros                              |         |             |
|   w: x∈[-0.275, 0.279] μ=-0.002 σ=0.123     |         |             |
+---------------------------------------------+---------+-------------+

In [22]:
print(ms)

+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+---------------------------------------------+---------+-------------+
| Params          

A `callback` is any `Callable` that you pass on `subscribe`.

In [23]:
u1 = ms.subscribe(lambda x: print("1: callback 1"))

1: callback 1


A change in the store value, triggers all callbacks subscribed to it.

In [24]:
m = ms.get()
ms.set(Model(**(m._asdict()|{"state": {'a': 1, 'b': 2}})))

1: callback 1


#### Optimizer

You can have different `stores` for different things.  For example, this is a simpler one to deal with the optimizer.

In [3]:
#| echo: false
#| output: asis
show_doc(Optimizer)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L104){target="_blank" style="float:right; font-size:smaller"}

### Optimizer

>      Optimizer (state:Union[jax.Array,Iterable[ForwardRef('ArrayTree')],Mappin
>                 g[Any,ForwardRef('ArrayTree')]], apply:Callable)

By the way, we will use [Optax](https://optax.readthedocs.io/), which is a good companion for `Haiku`.

In [26]:
grad_tfm = optax.sgd(1e-3)
apply = grad_tfm.update
optState = grad_tfm.init(m.params) # you initialize the optimizer with the model params
optimizer = Optimizer(state=optState, apply=apply)
optimizer

Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f512c6ab280>)

In [27]:
os= OptimizerStore(optimizer)
u2 = os.subscribe(lambda x: print(f"callback 2: {x}"))

callback 2: Optimizer(state=(EmptyState(), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f512c6ab280>)


In [28]:
grad_tf2 = optax.adam(1e-4)
optState2 = grad_tf2.init(m.params)
os.set(Optimizer(state=optState2, apply=grad_tf2.update))

callback 2: Optimizer(state=(ScaleByAdamState(count=Array i32 gpu:0 0, mu={'mlp/~/linear_0': {'b': Array[50] all_zeros gpu:0, 'w': Array[784, 50] all_zeros gpu:0}, 'mlp/~/linear_1': {'b': Array[10] all_zeros gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] all_zeros gpu:0}}, nu={'mlp/~/linear_0': {'b': Array[50] all_zeros gpu:0, 'w': Array[784, 50] all_zeros gpu:0}, 'mlp/~/linear_1': {'b': Array[10] all_zeros gpu:0 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 'w': Array[50, 10] all_zeros gpu:0}}), EmptyState()), apply=<function chain.<locals>.update_fn at 0x7f512c6abca0>)


Cleaning up... you should remember to unsubscribe when you are done with a store.

In [29]:
u1(), u2()

(None, None)

In [30]:
m = Model.from_haiku(transformed=network, x=batch.input)
ms = ModelStore(m)
u1 = ms.subscribe(lambda x: print(f"cb 1:\n{x}"))

cb 1:
+--------------+-------------------------+-----------------+-------------+---------------+
| Input        | Module                  | Module params   | Output      |   Param count |
| f32[500,784] | mlp (MLP)               |                 | f32[500,10] |        39,760 |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,784] | mlp/~/linear_0 (Linear) | w: f32[784,50]  | f32[500,50] |        39,250 |
|              |  └ mlp (MLP)            | b: f32[50]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
| f32[500,50]  | mlp/~/linear_1 (Linear) | w: f32[50,10]   | f32[500,10] |           510 |
|              |  └ mlp (MLP)            | b: f32[10]      |             |               |
+--------------+-------------------------+-----------------+-------------+---------------+
+-----------------------------------------+---------+
| Params                      

## Training

Finally we arrived at the Training, the  `core` of the `core`  ```¯\_(ツ)_/¯```

Here is where we will most need callbacks.

#### Learner

Like in `fastai`, we create a [`Learner`](https://fredguth.github.io/reax/core.html#learner) class that will deal with the training. 

In [4]:
#| echo: false
#| output: asis
show_doc(Learner)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L112){target="_blank" style="float:right; font-size:smaller"}

### Learner

>      Learner (model:__main__.ModelStore, dls:reax.data.DataLoaders, loss_func:
>               Callable[[Union[jax.Array,numpy.ndarray],Union[jax.Array,numpy.n
>               darray]],Union[jax.Array,numpy.ndarray]],
>               optimizer:reax.stores.Writable[__main__.Optimizer])

Basic class for handling the training loop.

In [32]:
learner = Learner(model=ms, dls=dls, loss_func=optax.softmax_cross_entropy_with_integer_labels, optimizer=os)
learner

Learner:
+-----------------+-----------------+-----------------+-----------------+
|           Model |     DataLoaders |          LossFn |       Optimizer |
| 139986614063552 | 139987146199584 | 139990720158448 | 139986614065904 |
+-----------------+-----------------+-----------------+-----------------+

Learner itself, is not a store, but holds different stores for different aspects of the training.

We have a [`ModelStore`](https://fredguth.github.io/reax/core.html#modelstore), an `OptimizerStore`... it is only missing the most important thing we want to __observe__... the training loop itself. We need a [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore).

But for that... let's first examine what we need.  Let's take a look in the __training loop__:

#### The anatomy of a training loop

```python
# pseudo-code

def fit(epochs: int)->None:
    '''Train the model for a number of epochs.'''
    # before fit
    for epoch in range(epochs):
        # is_training
        one_epoch(dls.train) # train for one epoch
        # is_validating
        one_epoch(dls.valid) # validate for one epoch
        # should halt epochs?
    # after fit

def one_epoch(dl)->None:
    '''Train or validate for one epoch.'''
    # before epoch
    for batch_n, batch in enumerate(dl): 
        one_batch(batch_n, batch)
        # should halt batches?
    # after epoch

def one_batch(batch_n: int, batch: Batch)->None:
    '''Train or validate for one batch.'''
    # before batch
    predict(...) # preds
    evaluate(...)# loss
    update model(...) if is_training
    # after batch
````

Our [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore) shall tell us where we are in the training loop and some information relevant at this point.

>  I am `training`, `epoch` 5, `iteration` 345, after `evaluate` with certain `current loss`.


Another aspect is that it seems it should be a [`Readable`](https://fredguth.github.io/reax/stores.html#readable) store, afterall, we don't want any callback being able to change information like:
`in which batch of which epoch am I?`

Exceptionally, we want to tell the [`TrainingStore`](https://fredguth.github.io/reax/core.html#trainingstore) to halt.

Let's start with:

In [5]:
#| echo: false
#| output: asis
show_doc(TrainingState)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L124){target="_blank" style="float:right; font-size:smaller"}

### TrainingState

>      TrainingState (epochs:int, epoch:int, step:int, iter:int,
>                     batch:Optional[reax.data.Batch], last:Dict=None,
>                     is_running:bool=False, is_training:bool=False,
>                     is_validating:bool=False, should_halt:bool=False)

In [34]:
t = TrainingState(epochs=0, epoch=0, step=0, iter=0, batch=None)
t

TrainingState:
-------------  -----
epochs             0
epoch              0
step               0
iter               0
batch
last
is_running     False
is_training    False
is_validating  False
should_halt    False
-------------  -----

### Training Store

In [6]:
#| echo: false
#| output: asis
show_doc(TrainingStore)

---

[source](https://github.com/fredguth/reax/blob/main/reax/core.py#L149){target="_blank" style="float:right; font-size:smaller"}

### TrainingStore

>      TrainingStore (initial_value:T, start:Notifier)

A store that keeps tracking of the training loop state

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| initial_value | T | initial value of the store |
| start | Notifier | function called when the first subscriber is added |
| **Returns** | **None** |  |

#### TrainingStore representation

In [37]:
# a = [("A", "B", "C"), (1,2,3)]
# b = [("D", "E", None), (4,5,None)]
a = [("A", 1), ["B", 2], ["C", 3]]
b = [["D", 4], ["E", 5]]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+----+------+----+------+
|    |   H1 |    |   H2 |
| A  |    1 | D  |    4 |
+----+------+----+------+
| B  |    2 | E  |    5 |
+----+------+----+------+
| C  |    3 |    |      |
+----+------+----+------+


In [38]:
a = [["A", 1], ["B", 2], ["C", 3]]
b = []
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+----+------+
|    |   H1 |
| A  |    1 |
+----+------+
| B  |    2 |
+----+------+
| C  |    3 |
+----+------+


In [39]:
a = [('epoch', 0), ('step', 0), ('batch_n', 0), ('batch', None), ('metrics', None), ('last_event', None), ('is_training', False), ('should_halt', False)]
b = [('0:', lambda:None)]
c = list(zip(*a))
d = list(zip(*b))
table = list(itertools.zip_longest(*c,*d))
print(tabulate(table, headers=['','H1','', "H2"], tablefmt='grid'))

+-------------+-------+----+---------------------------------------+
|             |    H1 |    | H2                                    |
| epoch       |     0 | 0: | <function <lambda> at 0x7f514c0ed820> |
+-------------+-------+----+---------------------------------------+
| step        |     0 |    |                                       |
+-------------+-------+----+---------------------------------------+
| batch_n     |     0 |    |                                       |
+-------------+-------+----+---------------------------------------+
| batch       |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| metrics     |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| last_event  |       |    |                                       |
+-------------+-------+----+---------------------------------------+
| is_training | False |    |      

In [41]:
ts = TrainingStore(t, lambda x:None)
u4 = ts.subscribe(lambda x: print(f"callback 4:\n {x}"))

callback 4:
 -------------  -----
epochs             0
epoch              0
step               0
iter               0
batch
last
is_running     False
is_training    False
is_validating  False
should_halt    False
-------------  -----


In [42]:
print(ts)

+---------------+---------+----+---------------------------------------+
|               |   State |    | Calbacks                              |
| epochs        |       0 | 0: | <function <lambda> at 0x7f512c3d5af0> |
+---------------+---------+----+---------------------------------------+
| epoch         |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| step          |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| iter          |       0 |    |                                       |
+---------------+---------+----+---------------------------------------+
| batch         |         |    |                                       |
+---------------+---------+----+---------------------------------------+
| last          |         |    |                                       |
+---------------+---------+----+-------------------

In [43]:
unsubs = []
for i in range(12):
    u = ts.subscribe(lambda x: print(f"callback: {i}"))
    unsubs.append(u)
ts

callback: 0
callback: 1
callback: 2
callback: 3
callback: 4
callback: 5
callback: 6
callback: 7
callback: 8
callback: 9
callback: 10
callback: 11


TrainingStore:
+---------------+---------+-----+---------------------------------------+
|               |   State |     | Calbacks                              |
| epochs        |       0 | 0:  | <function <lambda> at 0x7f512c3f6430> |
+---------------+---------+-----+---------------------------------------+
| epoch         |       0 | 1:  | <function <lambda> at 0x7f512c3d5e50> |
+---------------+---------+-----+---------------------------------------+
| step          |       0 | 2:  | <function <lambda> at 0x7f512c3d5670> |
+---------------+---------+-----+---------------------------------------+
| iter          |       0 | 3:  | <function <lambda> at 0x7f512c3f6670> |
+---------------+---------+-----+---------------------------------------+
| batch         |         | 4:  | <function <lambda> at 0x7f512c3f60d0> |
+---------------+---------+-----+---------------------------------------+
| last          |         | 5:  | <function <lambda> at 0x7f512c3d5af0> |
+---------------+------

In [44]:
for u in unsubs: u()

In [45]:
a = NamedTuple("a", [("n", int)])
a.n = 1

In [54]:
a.x = 3
print(a)

<class '__main__.a'>


In [67]:
class Bunch:
    __init__ = lambda self, **kw: setattr(self, '__dict__', kw)
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.__dict__!r})"

In [68]:
a = Bunch(n=2,x=4)

In [71]:
a.y=10
a.iter = 20

In [72]:
a

Bunch({'n': 2, 'x': 4, 'y': 10, 'iter': 20})

In [113]:
a.__dict__|{'y':12}

{'n': 2, 'x': 4, 'y': 12, 'iter': 20, 'z': 30}

In [73]:
setattr(a, 'z', 30)
a

Bunch({'n': 2, 'x': 4, 'y': 10, 'iter': 20, 'z': 30})

In [163]:
class with_interceptor:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        print('intercepting....')
        print(f)
        # print(f'args: {args}')
        # print(f'kwargs: {kwargs}')
        return f
    def __setattr__(self, k,v) -> None:
        print('i.__setattr__:', (k,v))
        super().__setattr__(k,v)

class DummyCls:

    @with_interceptor
    def dummy_fn(self, x):
        x = 1
        print('dummy')

a = DummyCls()
a.dummy_fn()

i.__setattr__: ('nm', <function DummyCls.dummy_fn at 0x7f50e4388160>)


TypeError: __call__() missing 1 required positional argument: 'f'

In [164]:
def make_pretty(func):
    def inner(*args, **kwargs):
        print("I got decorated")
        func(*args, **kwargs)
    return inner


@make_pretty
def ordinary(x):
    print("I am ordinary")


ordinary(1)

I got decorated
I am ordinary


In [182]:
class make_pretty:
    def __init__(self, f):self.f = f

    def __call__(self, *args, **kwargs):
        self.args = list(args)
        self.kwargs = dict(kwargs)
        print("I got decorated")
        self.f(*args, **kwargs)

    def __setattr__(self, k,v) -> None:
        print('intercepting:', (k,v))
        super().__setattr__(k,v)


#     def inner(*args, **kwargs):
#         print("I got decorated")
#         func(*args, **kwargs)
#     return inner


@make_pretty
def ordinary(x):
    x = 3
    print("I am ordinary ", x)

intercepting: ('f', <function ordinary at 0x7f50e43c0310>)


In [183]:
ordinary(1)

intercepting: ('args', [1])
intercepting: ('kwargs', {})
I got decorated
I am ordinary  3


In [155]:
class dummyInterceptor:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(*args, **kwargs):
            f(*args, **kwargs)
        return _f

    def __setattr__(self, k,v) -> None:
        print('intercepting:', (k,v))
        super().__setattr__(k,v)

# def __call__(self, f):
#         def _f(o, *args, **kwargs):
#             try:
#                 o.callback(f'before_{self.nm}')
#                 f(o, *args, **kwargs)
#                 o.callback(f'after_{self.nm}')
#             except globals()[f'Cancel{self.nm.title()}Exception']: pass
#             finally: o.callback(f'cleanup_{self.nm}')
#         return _f


@dummyInterceptor
def dummy(x):
    x = 1
    def g():
        y = 2

intercepting: ('nm', <function dummy at 0x7f50e4412c10>)


In [156]:
dummy()

TypeError: __call__() missing 1 required positional argument: 'f'

In [75]:
import inspect

In [117]:
f.__code__.co_varnames

('x', 'y', 'z', 'a', 'g', 'c', 'd')

In [121]:
f.__code__.co_varnames

('x', 'y', 'z', 'a', 'g', 'c', 'd')

In [125]:
learner.__dict__

{'__stored_args__': {'model': ModelStore:
  +-----------------------------------------+---------+--------------------------------------------+
  | Params                                  | State   | Callbacks                                  |
  | mlp/~/linear_0:                         | {}      | - 0: <function <lambda> at 0x7f512c6abf70> |
  |   b: all_zeros                          |         |                                            |
  |   w: x∈[-0.071, 0.071] μ=-0.000 σ=0.032 |         |                                            |
  | mlp/~/linear_1:                         |         |                                            |
  |   b: all_zeros                          |         |                                            |
  |   w: x∈[-0.280, 0.278] μ=-0.008 σ=0.123 |         |                                            |
  +-----------------------------------------+---------+--------------------------------------------+,
  'dls': <reax.data.DataLoaders at 0x7f514c1f522

'__main__'

In [111]:
@fc.patch
def __getattr__(self:Learner, name):
        if name in ('epochs','epoch','iter','one_batch'): 
            return partial(self.callback, name)
        raise AttributeError(name)

AttributeError: 'Learner' object has no attribute 'getattrs'

In [None]:
class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

In [None]:
def callback(self, method_nm): 
    run_cbs(self.cbs, method_nm, self)