-
Notifications
You must be signed in to change notification settings - Fork 7.5k
/
interpret.py
137 lines (117 loc) · 6.39 KB
/
interpret.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/20_interpret.ipynb (unless otherwise specified).
__all__ = ['plot_top_losses', 'Interpretation', 'ClassificationInterpretation', 'SegmentationInterpretation']
# Cell
from .data.all import *
from .optimizer import *
from .learner import *
from .tabular.core import *
import sklearn.metrics as skm
# Cell
@typedispatch
def plot_top_losses(x, y, *args, **kwargs):
raise Exception(f"plot_top_losses is not implemented for {type(x)},{type(y)}")
# Cell
#nbdev_comment _all_ = ["plot_top_losses"]
# Cell
class Interpretation():
"Interpretation base class, can be inherited for task specific Interpretation classes"
def __init__(self, learn, dl, losses, act=None):
store_attr()
def __getitem__(self, idxs):
"Return inputs, preds, targs, decoded outputs, and losses at `idxs`"
if isinstance(idxs, Tensor): idxs = idxs.tolist()
if not is_listy(idxs): idxs = [idxs]
items = getattr(self.dl.items, 'iloc', L(self.dl.items))[idxs]
tmp_dl = self.learn.dls.test_dl(items, with_labels=True, process=not isinstance(self.dl, TabDataLoader))
inps,preds,targs,decoded = self.learn.get_preds(dl=tmp_dl, with_input=True, with_loss=False,
with_decoded=True, act=self.act, reorder=False)
return inps, preds, targs, decoded, self.losses[idxs]
@classmethod
def from_learner(cls, learn, ds_idx=1, dl=None, act=None):
"Construct interpretation object from a learner"
if dl is None: dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)
_,_,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=False,
with_preds=False, with_targs=False, act=act)
return cls(learn, dl, losses, act)
def top_losses(self, k=None, largest=True, items=False):
"`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). Optionally include items."
losses, idx = self.losses.topk(ifnone(k, len(self.losses)), largest=largest)
if items: return losses, idx, getattr(self.dl.items, 'iloc', L(self.dl.items))[idx]
else: return losses, idx
def plot_top_losses(self, k, largest=True, **kwargs):
"Show `k` largest(/smallest) preds and losses. `k` may be int, list, or `range` of desired results."
if is_listy(k) or isinstance(k, range):
losses, idx = (o[k] for o in self.top_losses(None, largest))
else:
losses, idx = self.top_losses(k, largest)
inps, preds, targs, decoded, _ = self[idx]
inps, targs, decoded = tuplify(inps), tuplify(targs), tuplify(decoded)
x, y, its = self.dl._pre_show_batch(inps+targs)
x1, y1, outs = self.dl._pre_show_batch(inps+decoded, max_n=len(idx))
if its is not None:
plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), preds, losses, **kwargs)
#TODO: figure out if this is needed
#its None means that a batch knows how to show itself as a whole, so we pass x, x1
#else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
def show_results(self, idxs, **kwargs):
"Show predictions and targets of `idxs`"
if isinstance(idxs, Tensor): idxs = idxs.tolist()
if not is_listy(idxs): idxs = [idxs]
inps, _, targs, decoded, _ = self[idxs]
b = tuplify(inps)+tuplify(targs)
self.dl.show_results(b, tuplify(decoded), max_n=len(idxs), **kwargs)
# Cell
class ClassificationInterpretation(Interpretation):
"Interpretation methods for classification models."
def __init__(self, learn, dl, losses, act=None):
super().__init__(learn, dl, losses, act)
self.vocab = self.dl.vocab
if is_listy(self.vocab): self.vocab = self.vocab[-1]
def confusion_matrix(self):
"Confusion matrix as an `np.ndarray`."
x = torch.arange(0, len(self.vocab))
_,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True,
with_targs=True, act=self.act)
d,t = flatten_check(decoded, targs)
cm = ((d==x[:,None]) & (t==x[:,None,None])).long().sum(2)
return to_np(cm)
def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap="Blues", norm_dec=2,
plot_txt=True, **kwargs):
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix()
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(len(self.vocab))
plt.xticks(tick_marks, self.vocab, rotation=90)
plt.yticks(tick_marks, self.vocab, rotation=0)
if plot_txt:
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
coeff = f'{cm[i, j]:.{norm_dec}f}' if normalize else f'{cm[i, j]}'
plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white" if cm[i, j] > thresh else "black")
ax = fig.gca()
ax.set_ylim(len(self.vocab)-.5,-.5)
plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.grid(False)
def most_confused(self, min_val=1):
"Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences."
cm = self.confusion_matrix()
np.fill_diagonal(cm, 0)
res = [(self.vocab[i],self.vocab[j],cm[i,j])
for i,j in zip(*np.where(cm>=min_val))]
return sorted(res, key=itemgetter(2), reverse=True)
def print_classification_report(self):
"Print scikit-learn classification report"
_,targs,decoded = self.learn.get_preds(dl=self.dl, with_decoded=True, with_preds=True,
with_targs=True, act=self.act)
d,t = flatten_check(decoded, targs)
print(skm.classification_report(t, d, labels=list(self.vocab.o2i.values()), target_names=[str(v) for v in self.vocab]))
# Cell
class SegmentationInterpretation(Interpretation):
"Interpretation methods for segmentation models."
pass