In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_005 import *

In [None]:
torch.cuda.set_device(3)

# Model interpretation

## Basic data aug

In [None]:
PATH = Path('data/dogscats')
arch = tvm.resnet34

In [None]:
size,lr = 224,3e-3

data_norm,data_denorm = normalize_funcs(*imagenet_stats)
tfms = get_transforms(do_flip=True, max_rotate=10, max_zoom=1.2, max_lighting=0.3, max_warp=0.15)
data = data_from_imagefolder(PATH, bs=64, ds_tfms=tfms, num_workers=8, tfms=data_norm, size=size)

## Save activations

In [None]:
#export

HookFunc = Callable[[Model, Tensors, Tensors], Any]

class Hook():
    "Creates a hook"
    def __init__(self, m:Model, hook_func:HookFunc, is_forward:bool=True):
        self.hook_func,self.stored = hook_func,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:Model, input:Tensors, output:Tensors):
        input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
        output = (o.detach() for o in output) if is_listy(output) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        if not self.removed:
            self.hook.remove()
            self.removed=True

class Hooks():
    "Creates several hooks"
    def __init__(self, ms:Collection[Model], hook_func:HookFunc, is_forward:bool=True):
        self.hooks = [Hook(m, hook_func, is_forward) for m in ms]
        
    def __getitem__(self,i:int) -> Hook: return self.hooks[i]
    def __len__(self) -> int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]
    
    def remove(self):
        for h in self.hooks: h.remove()

def hook_output (module:Model) -> Hook:  return Hook (module,  lambda m,i,o: o)
def hook_outputs(modules:Collection[Model]) -> Hooks: return Hooks(modules, lambda m,i,o: o)

In [None]:
LearnerCallback

In [None]:
#export
class HookCallback(LearnerCallback):
    "Callback that registers given hooks"
    def __init__(self, learn:Learner, modules:Sequence[Model]=None, do_remove:bool=True):
        super().__init__(learn)
        self.modules,self.do_remove = modules,do_remove

    def on_train_begin(self, **kwargs):
        if not self.modules:
            self.modules = [m for m in flatten_model(self.learn.model)
                            if hasattr(m, 'weight')]
        self.hooks = Hooks(self.modules, self.hook)

    def on_train_end(self, **kwargs):
        if self.do_remove: self.remove()

    def remove(self): self.hooks.remove
    def __del__(self): self.remove()

class ActivationStats(HookCallback):
    "Callback that record the activations"
    def on_train_begin(self, **kwargs):
        super().on_train_begin(**kwargs)
        self.stats = []
        
    def hook(self, m:Model, i:Tensors, o:Tensors) -> Tuple[Rank0Tensor,Rank0Tensor]: 
        return o.mean().item(),o.std().item()
    def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored)
    def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0)

def idx_dict(a): return {v:k for k,v in enumerate(a)}

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2, metrics=accuracy, path=PATH,
                    callback_fns=ActivationStats)

In [None]:
learn.fit_one_cycle(1, lr)

In [None]:
ms = learn.activation_stats.modules
d = idx_dict(ms)
ln = d[learn.model[1][8]]; ln

In [None]:
plt.plot(learn.activation_stats.stats[1][ln].numpy());

In [None]:
learn.save('e1')

## Best/worst

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2, metrics=accuracy)

In [None]:
learn.load('1')

In [None]:
bs=64
classes = data.valid_ds.classes

In [None]:
preds,y = learn.TTA()

In [None]:
def calc_loss(y_pred, y_true, loss_class):
    loss_dl = DataLoader(TensorDataset(tensor(y_pred),tensor(y_true)), bs)
    with torch.no_grad():
        return torch.cat([loss_class(reduction='none')(*b) for b in loss_dl])

In [None]:
class ClassificationInterpretation():
    def __init__(self, data, y_pred, y_true, loss_class, sigmoid=True):
        self.data,self.y_pred,self.y_true,self.loss_class = data,y_pred,y_true,loss_class
        self.losses = calc_loss(y_pred, y_true, loss_class=loss_class)
        self.probs = preds.sigmoid() if sigmoid else preds
        self.pred_class = probs.argmax(dim=1)

    def top_losses(self, k, largest=True): return self.losses.topk(k, largest=largest)

    def plot_top_losses(self, k, largest=True, figsize=(12,12)):
        tl = self.top_losses(k,largest)
        classes = self.data.classes
        rows = math.ceil(math.sqrt(k))
        fig,axes = plt.subplots(rows,rows,figsize=figsize)
        for i,idx in enumerate(worst[1]):
            t=data.valid_ds[idx]
            t[0].show(ax=axes.flat[i], title=
                f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][0]:.2f}')

In [None]:
interp = ClassificationInterpretation(data, preds, y, loss_class=nn.CrossEntropyLoss)

In [None]:
interp.top_losses(9)

In [None]:
interp.plot_top_losses(9)

## Fin