/
nb_005a.py
133 lines (107 loc) · 5.52 KB
/
nb_005a.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#################################################
### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
#################################################
# file to edit: dev_nb/005a_interpretation.ipynb
from nb_005 import *
HookFunc = Callable[[Model, Tensors, Tensors], Any]
class Hook():
"Creates a hook"
def __init__(self, m:Model, hook_func:HookFunc, is_forward:bool=True):
self.hook_func,self.stored = hook_func,None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
def hook_fn(self, module:Model, input:Tensors, output:Tensors):
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
if not self.removed:
self.hook.remove()
self.removed=True
class Hooks():
"Creates several hooks"
def __init__(self, ms:Collection[Model], hook_func:HookFunc, is_forward:bool=True):
self.hooks = [Hook(m, hook_func, is_forward) for m in ms]
def __getitem__(self,i:int) -> Hook: return self.hooks[i]
def __len__(self) -> int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def remove(self):
for h in self.hooks: h.remove()
def hook_output (module:Model) -> Hook: return Hook (module, lambda m,i,o: o)
def hook_outputs(modules:Collection[Model]) -> Hooks: return Hooks(modules, lambda m,i,o: o)
class HookCallback(LearnerCallback):
"Callback that registers given hooks"
def __init__(self, learn:Learner, modules:Sequence[Model]=None, do_remove:bool=True):
super().__init__(learn)
self.modules,self.do_remove = modules,do_remove
def on_train_begin(self, **kwargs):
if not self.modules:
self.modules = [m for m in flatten_model(self.learn.model)
if hasattr(m, 'weight')]
self.hooks = Hooks(self.modules, self.hook)
def on_train_end(self, **kwargs):
if self.do_remove: self.remove()
def remove(self): self.hooks.remove
def __del__(self): self.remove()
class ActivationStats(HookCallback):
"Callback that record the activations"
def on_train_begin(self, **kwargs):
super().on_train_begin(**kwargs)
self.stats = []
def hook(self, m:Model, i:Tensors, o:Tensors) -> Tuple[Rank0Tensor,Rank0Tensor]:
return o.mean().item(),o.std().item()
def on_batch_end(self, **kwargs): self.stats.append(self.hooks.stored)
def on_train_end(self, **kwargs): self.stats = tensor(self.stats).permute(2,1,0)
def idx_dict(a): return {v:k for k,v in enumerate(a)}
def calc_loss(y_pred:Tensor, y_true:Tensor, loss_class:type=nn.CrossEntropyLoss):
"Calculate loss between `y_pred` and `y_true` using `loss_class`"
loss_dl = DataLoader(TensorDataset(tensor(y_pred),tensor(y_true)), bs)
with torch.no_grad():
return torch.cat([loss_class(reduction='none')(*b) for b in loss_dl])
class ClassificationInterpretation():
"Interpretation methods for classification models"
def __init__(self, data:DataBunch, y_pred:Tensor, y_true:Tensor,
loss_class:type=nn.CrossEntropyLoss, sigmoid:bool=True):
self.data,self.y_pred,self.y_true,self.loss_class = data,y_pred,y_true,loss_class
self.losses = calc_loss(y_pred, y_true, loss_class=loss_class)
self.probs = preds.sigmoid() if sigmoid else preds
self.pred_class = self.probs.argmax(dim=1)
def top_losses(self, k, largest=True):
"`k` largest(/smallest) losses"
return self.losses.topk(k, largest=largest)
def plot_top_losses(self, k, largest=True, figsize=(12,12)):
"Show images in `top_losses` along with their loss, label, and prediction"
tl = self.top_losses(k,largest)
classes = self.data.classes
rows = math.ceil(math.sqrt(k))
fig,axes = plt.subplots(rows,rows,figsize=figsize)
for i,idx in enumerate(self.top_losses(k, largest=largest)[1]):
t=data.valid_ds[idx]
t[0].show(ax=axes.flat[i], title=
f'{classes[self.pred_class[idx]]}/{classes[t[1]]} / {self.losses[idx]:.2f} / {self.probs[idx][0]:.2f}')
def confusion_matrix(self):
"Confusion matrix as an `np.ndarray`"
x=torch.arange(0,data.c)
cm = ((self.pred_class==x[:,None]) & (self.y_true==x[:,None,None])).sum(2)
return cm.cpu().numpy()
def plot_confusion_matrix(self, normalize:bool=False, title:str='Confusion matrix', cmap:Any="Blues", figsize:tuple=None):
"Plot the confusion matrix"
# This function is copied from the scikit docs
cm = self.confusion_matrix()
plt.figure(figsize=figsize)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, self.data.classes, rotation=45)
plt.yticks(tick_marks, self.data.classes)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')