In [None]:
#|default_exp init

### Initialization

Temp version to get things running.  Need to remove much content to a tutorial notebook

In [None]:
#|export
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import sys,gc,traceback
import fastcore.all as fc
from collections.abc import Mapping
from pathlib import Path
from operator import attrgetter,itemgetter
from functools import partial
from copy import copy
from contextlib import contextmanager

import torchvision.transforms.functional as TF,torch.nn.functional as F
from torch import tensor,nn,optim
from torch.utils.data import DataLoader,default_collate
from torch.nn import init
from torcheval.metrics import MulticlassAccuracy
from datasets import load_dataset,load_dataset_builder

from miniai.datasets import *
from miniai.utils import set_seed, def_device
from miniai.callbacks import *
from miniai.learner import *
from miniai.activations import *
from miniai.layers import GeneralRelu
from miniai.model_blocks import conv

In [None]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)

import logging
logging.disable(logging.WARNING)

set_seed(42)

In [None]:
# Set to avoid mps on mac - comment out
if def_device == 'mps': def_device = 'cpu'

In [None]:
xl,yl = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)

@inplace
def transformi(b): b[xl] = [TF.to_tensor(o) for o in b[xl]]

bs = 1024
tds = dsd.with_transform(transformi)

dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]

In [None]:
def get_model():
    return nn.Sequential(conv(1 ,8), conv(8 ,16), conv(16,32), conv(32,64),
                         conv(64,10, act=False), nn.Flatten()).to(def_device)

In [None]:
MomentumLearner(get_model(), dls, F.cross_entropy, cbs=[DeviceCB()]).lr_find(gamma=1.1, start_lr=1.e-2)

In [None]:
metrics = MetricsCB(accuracy=MulticlassAccuracy())
astats = ActivationStatsCB(fc.risinstance(nn.ReLU))
cbs = [DeviceCB(), metrics, ProgressCB(plot=True), astats]
learn = TrainLearner(get_model(), dls, F.cross_entropy, lr=0.2, cbs=cbs, opt=optim.ActivationStatsCB)

In [None]:
learn.fit(1)

In [None]:
astats.color_dim()

In [None]:
astats.plot_stats()

In [None]:
#|export
class BatchTransformCB(Callback):
    def __init__(self, tfm, on_train=True, on_val=True): fc.store_attr()
    
    def before_batch(self, learn):
        if (self.on_train and learn.training) or (self.on_val and not learn.training): 
            learn.batch = self.tfm(learn.batch)

In [None]:
def _norm(x):
    """Function to normalise the input to a neural network from a batch of data.  
    The targets are returned unchanged
    """
    return (x[0]-xmean/xstd, x[1])

In [None]:
#|export
def plot_func(f, start=-5, end=5, steps=100):
    x = torch.linspace(start, end, steps)
    plt.plot(x, f(x))
    plt.grid(visible=True, which='both', ls='--')
    plt.axhline(y=0, color='k', linewidth=1.0)
    plt.axvline(x=0, color='k', linewidth=1.0)

In [None]:
#| export
def init_weights(m, leaky=0.):
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): init.kaiming_normal_(m.weight, a=leaky)

In [None]:
#|export
def _lsuv_stats(hook, mod, inp, outp):
    """Calculate stats for a specific module given the input and output values.  Assigns the mean and std
    as properties of the hook
    """
    acts = to_cpu(outp)
    hook.mean = acts.mean()
    hook.std = acts.std()
    
def lsuv_init(m, m_in, xb):
    """ Setput hook for specific module (one of the activation layer outputs usually).  Run a batch of 
    data trhough the model and adjust the weights of the layer feeding the hooked layer to bring the mean
    and std deviation at the end of thta layer to the target values
    
    args:
        m: layer to apply hook to.  Usually the output of an activation
        m_in: layer prior to the activation
        xb: a batch of data
    """
    h = Hook(m, _lsuv_stats)
    with torch.no_grad():
        while model(xb) is not None and (abs(h.mean)>1e-3 or (abs(h.std-1)>1.e-3)):
            m_in.bias -= h.mean
            m_in.weight.data /= h.std
    h.remove()

### Export 

In [None]:
import nbdev; nbdev.nbdev_export()