In [None]:
#| default_exp activations

In [None]:
#| export
from __future__ import annotations
import random, math, torch, numpy as np, matplotlib.pyplot as plt
import fastcore.all as fc
from functools import partial

from fastai_course.datasets import *
from fastai_course.learner import *
from fastai_course.conv import *

In [None]:
import torch.nn.functional as F,matplotlib as mpl
from pathlib import Path
from operator import attrgetter,itemgetter
from contextlib import contextmanager

from torch import tensor,nn,optim
import torchvision.transforms.functional as TF
from datasets import load_dataset

from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
mpl.rcParams['figure.constrained_layout.use'] = True

import logging
logging.disable(logging.WARNING)

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
#|export
def set_seed(seed, deterministic=False):
    torch.use_deterministic_algorithms(deterministic)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [None]:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)
bs = 1024

@inplace
def transformi(b): b[x] = [TF.to_tensor(o) for o in b[x]]

tds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train

In [None]:
next(iter(dt))[0].shape

In [None]:
def cnn_layers():
    return [
        conv(1, 8, ks=5),
        conv(8, 16),
        conv(16, 32),
        conv(32, 64),
        conv(64, 10, act=False),
        nn.Flatten()
    ]

In [None]:
from torcheval.metrics import MulticlassAccuracy

In [None]:
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [TrainerCB(), DeviceCB(), metrics, ProgressCB(plot=True)]

In [None]:
model = nn.Sequential(*cnn_layers())
learn = Learner(model, dls, loss_func=F.cross_entropy, lr=0.6, cbs=cbs)
learn.fit(1)

In [None]:
class SequentialModel(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        self.act_means = [[] for _ in layers]
        self.act_stds = [[] for _ in layers]
    
    def forward(self, x):
        for i, l in enumerate(self.layers):
            x = l(x)
            self.act_means[i].append(to_cpu(x).mean())
            self.act_stds[i].append(to_cpu(x).std())
            # print(f"x.shape is {x.shape}, mean is {x.mean()}, std is {x.std()}")
        return x
    
    def __iter__(self):
        return iter(self.layers)

In [None]:
set_seed(1)

model = SequentialModel(*cnn_layers())
input, output = next(iter(dt))

In [None]:
model.act_means

In [None]:
def fit(model, epochs=1, xtra_cbs=None):
    learn = Learner(model, dls, loss_func=F.cross_entropy, lr=0.6, cbs=cbs+fc.L(xtra_cbs))
    learn.fit(epochs)
    return learn

In [None]:
fit(model, epochs=1)

In [None]:
len(model.act_means[0])

In [None]:
for l in model.act_means: plt.plot(l)
plt.legend(range(5));

In [None]:
for l in model.act_stds: plt.plot(l)
plt.legend(range(5));

In [None]:
model = nn.Sequential(*cnn_layers())

In [None]:
act_means = [[] for _ in model]
act_std = [[] for _ in model]
act_means

In [None]:
def append_stats(i, mod, inp, out):
    act_means[i].append(to_cpu(out).mean())
    act_std[i].append(to_cpu(out).std())

In [None]:
for i,m in enumerate(model):
    m.register_forward_hook(partial(append_stats, i))

In [None]:
fit(model)

In [None]:
len(act_means[0])

### Hook class

In [None]:
#|export
class Hook:
    def __init__(self, m, f):
        self.hook = m.register_forward_hook(partial(f, self))
    
    def remove(self): self.hook.remove()
    def __del__(self): self.remove()

In [None]:
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([], [])
    acts = to_cpu(outp)
    hook.stats[0].append(acts.mean())
    hook.stats[1].append(acts.std())

In [None]:
model = nn.Sequential(*cnn_layers())

In [None]:
hooks = [Hook(l, append_stats) for l in model[:5].children()]
hooks

In [None]:
fit(model)

In [None]:
for h in hooks:
    # print(len(h.stats[0]))
    plt.plot(h.stats[0])
plt.legend(range(5))

In [None]:
for h in hooks:
    plt.plot(h.stats[1])
plt.legend(range(5))

In [None]:
for h in hooks:
    h.remove()

In [None]:
class DummyList(list):
    def __delitem__(self, i):
        print(f"Deleting {i}")
        super().__delitem__(i)

In [None]:
dml = DummyList([1, 2, 3])
dml

In [None]:
del(dml[2])

In [None]:
#| export
class Hooks(list):
    def __init__(self, ms, f):
        super().__init__([Hook(m, f) for m in ms])
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()
    def __del__(self): self.remove()
    def remove(self):
        for h in self: h.remove()
    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)

In [None]:
model = nn.Sequential(*cnn_layers())

In [None]:
with Hooks(model, append_stats) as hooks:
    fit(model)
    fig,axs = plt.subplots(1,2, figsize=(10,4))
    for h in hooks:
        for i in 0,1:
            axs[i].plot(h.stats[i])
    
    plt.legend(range(6))

In [None]:
hooks = Hooks(model, append_stats)
len(hooks), hooks[0].hook

In [None]:
#|export
class HooksCallback(Callback):
    def __init__(self, hookfunc, mod_filter=fc.noop, on_train=True, on_valid=False, mods=None):
        fc.store_attr()
        super().__init__()
    
    def before_fit(self, learn):
        if self.mods: mods = self.mods
        else: mods = fc.filter_ex(learn.model.modules(), self.mod_filter)
        self.hooks = Hooks(mods, partial(self._hookfunc, learn))
    
    def _hookfunc(self, learn, *args, **kwargs):
        if (self.on_train and learn.training) or (self.on_valid and not learn.training):
            self.hookfunc(*args, **kwargs)
    
    def after_fit(self, learn):
        _,axs = plt.subplots(1, 2, figsize=(10,4))
        for h in self.hooks:
            for i in 0,1:
                axs[i].plot(h.stats[i])
        self.hooks.remove()
    
    def __iter__(self): return iter(self.hooks)
    def __len__(self): return len(self.hooks)

In [None]:
fc.risinstance(nn.Conv2d)

In [None]:
fc.filter_ex(model.modules(), fc.risinstance(nn.Conv2d))

In [None]:
hc = HooksCallback(append_stats, mod_filter=fc.risinstance(nn.Conv2d))

In [None]:
model = nn.Sequential(*cnn_layers())
fit(model, xtra_cbs=[hc]);

In [None]:
#|export
def append_stats(hook, mod, inp, outp):
    if not hasattr(hook, 'stats'): hook.stats = ([], [], [])
    acts = to_cpu(outp)
    hook.stats[0].append(acts.mean())
    hook.stats[1].append(acts.std())
    hook.stats[2].append(acts.abs().histc(40, 0, 10))

In [None]:
hc = HooksCallback(append_stats, mod_filter=fc.risinstance(nn.Conv2d))
model = nn.Sequential(*cnn_layers())
fit(model, xtra_cbs=[hc]);

In [None]:
hc.hooks[0].stats[2][0].shape, len(hc.hooks[0].stats[2]), len(hc.hooks[0].stats[0])

In [None]:
len(hc.hooks[0].stats)

In [None]:
len(list(iter(dt)))

In [None]:
len(list(iter(dls.valid)))

In [None]:
torch.stack(hc.hooks[0].stats[2]).t().float().log1p().shape

In [None]:
len(hc.hooks)

In [None]:
#|export
def get_hist(h):
    return torch.stack(h.stats[2]).t().float().log1p()

In [None]:
len(hc)

In [None]:
fig,axes = get_grid(len(hc), figsize=(11,5), nrows=2)
len(axes.flat)

In [None]:
fig,axes = get_grid(len(hc), figsize=(11,5))
for ax,h in zip(axes.flat, hc):
    show_image(get_hist(h), ax, origin='lower')

In [None]:
import nbdev; nbdev.nbdev_export()