forked from fastai/fastai
-
Notifications
You must be signed in to change notification settings - Fork 1
/
plots.py
86 lines (71 loc) · 3.08 KB
/
plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from .imports import *
from .torch_imports import *
from sklearn.metrics import confusion_matrix
def plots(ims, figsize=(12,6), rows=1, interp=False, titles=None, maintitle=None):
if type(ims[0]) is np.ndarray:
ims = np.array(ims)
if (ims.shape[-1] != 3): ims = ims.transpose((0,2,3,1))
f = plt.figure(figsize=figsize)
if maintitle is not None:
plt.suptitle(maintitle, fontsize=16)
for i in range(len(ims)):
sp = f.add_subplot(rows, len(ims)//rows, i+1)
sp.axis('Off')
if titles is not None: sp.set_title(titles[i], fontsize=16)
plt.imshow(ims[i], interpolation=None if interp else 'none')
def plots_from_files(imspaths, figsize=(10,5), rows=1, titles=None, maintitle=None):
"""Plots images given image files.
Arguments:
im_paths (list): list of paths
figsize (tuple): figure size
rows (int): number of rows
titles (list): list of titles
maintitle (string): main title
"""
f = plt.figure(figsize=figsize)
if maintitle is not None: plt.suptitle(maintitle, fontsize=16)
for i in range(len(imspaths)):
sp = f.add_subplot(rows, len(imspaths)//rows, i+1)
sp.axis('Off')
if titles is not None: sp.set_title(titles[i], fontsize=16)
img = plt.imread(imspaths[i])
plt.imshow(img)
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
(This function is copied from the scikit docs.)
"""
plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def plots_raw(ims, figsize=(12,6), rows=1, titles=None):
f = plt.figure(figsize=figsize)
for i in range(len(ims)):
sp = f.add_subplot(rows, len(ims)//rows, i+1)
sp.axis('Off')
if titles is not None: sp.set_title(titles[i], fontsize=16)
plt.imshow(ims[i])
def load_img_id(ds, idx, path): return np.array(PIL.Image.open(path+ds.fnames[idx]))
def plot_val_with_title(idxs, ds, probs, path):
imgs = [load_img_id(ds,x,path) for x in idxs]
title_probs = [probs[x] for x in idxs]
return plots_raw(imgs, rows=1, titles=title_probs, figsize=(16,8))
def most_by_mask(mask, mult, probs):
idxs = np.where(mask)[0]
return idxs[np.argsort(mult * probs[idxs])[:4]]
def most_by_incorrect(y, is_correct, preds, acts):
mult = -1 if (y==1)==is_correct else 1
return most_by_mask((preds == acts)==is_correct & (acts == y), mult)