In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

mpl.rcParams['font.sans-serif'] = "Arial"
mpl.rcParams['font.family'] = "Arial"

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy
import netdissect
from netdissect.easydict import EasyDict
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show
from experiment import dissect_experiment as experiment

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model='vgg16', dataset='places', seg='netpqc', layer='conv5_3', quantile=0.01)
resdir = 'results/%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000))
def resfile(f):
    return os.path.join(resdir, f)

In [None]:
model = experiment.load_model(args)
layername = experiment.instrumented_layername(args)
model.retain_layer(layername)
dataset = experiment.setting.load_dataset('imagenet', crop_size=224)
readable_class = {}
import csv
with open('datasets/imagenet/labels.txt') as f:
    for row in csv.reader(f):
        k, v = row[:2]
        readable_class[k] = v
dataset.orig_classes = dataset.classes
dataset.classes = [readable_class[c] for c in dataset.orig_classes]

In [None]:
# Grab pre-relu data
model.model.features.relu5_3.inplace = False

In [None]:
bus_classes = [c for c, n in enumerate(dataset.classes) if n in ['minibus', 'school bus', 'trolleybus']]
bus_classes

In [None]:
airplane_classes = [c for c, n in enumerate(dataset.classes) if n in ['airliner', 'warplane']]

In [None]:
from collections import defaultdict

# bus_examples = defaultdict(list)
bus_examples = []
nonbus_examples = []
clr_choice = {654: 'r', 779: 'g', 874: 'm'}
bus_colors = []
for i in pbar(range(len(dataset.images))):
    d, c = dataset.images[i]
    if c not in bus_classes:
        nonbus_examples.append(i)
    else:
        bus_examples.append(i)
        bus_colors.append(clr_choice[c])
len(bus_examples), len(nonbus_examples)

In [None]:
from collections import defaultdict

airplane_examples = []
nonairplane_examples = []
airplane_clr_choice = {404: 'r', 405: 'g', 895: 'b'}
airplane_colors = []
for i in pbar(range(len(dataset.images))):
    d, c = dataset.images[i]
    if c not in airplane_classes:
        nonairplane_examples.append(i)
    else:
        airplane_examples.append(i)
        airplane_colors.append(airplane_clr_choice[c])
len(airplane_examples), len(nonairplane_examples)

In [None]:
unit_number = 19 # or 190
from netdissect import tally
def max_unit(imbatch, cls):
    model(imbatch.cuda())
    unit_max = model.retained_layer(layername)[:,unit_number].view(len(imbatch), -1).max(1)[0]
    return unit_max

maxvals_all = tally.tally_cat(max_unit, dataset, batch_size=100, num_workers=30, pin_memory=True,
                             cachefile=resfile('maxvals_all_imagenet.npz'))

In [None]:
result = ([[maxvals_all[i].item(), renormalize.as_image(dataset[i][0], source=dataset)]
          for indices in [bus_examples[:10], nonbus_examples[:10]]
          for i in indices])
show(result)

In [None]:
result = ([[maxvals_all[i].item(), renormalize.as_image(dataset[i][0], source=dataset)]
          for indices in [airplane_examples[:10], nonairplane_examples[:10]]
          for i in indices])
show(result)

In [None]:
bus_vals = maxvals_all[bus_examples].numpy()
nonbus_vals = maxvals_all[nonbus_examples].numpy()
nbvals = nonbus_vals[:len(bus_vals)]


import random

f, (a1, a2, a3) = plt.subplots(nrows=3, ncols=1, figsize=(15, 10), facecolor='white')
# plt.figure(figsize=(10, 5))
a3.set_title('Possible error: testing unit 192 (zero-indexing issue) (3 bus classes in color)')
a3.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
# plt.scatter(nbvals, [random.random() for _ in nbvals], alpha=0.9, s=10, c='black')
a3.scatter(bus_vals, [random.random() for _ in bus_vals], alpha=0.9, s=10, c=bus_colors)
a3.get_yaxis().set_ticks([])
a3.set_xlabel('unit 190 activation value, maximum over featuremap (post-ReLU)')

bus_vals = maxvals_all[bus_examples].numpy()
nonbus_vals = maxvals_all[nonbus_examples].numpy()

a1.set_title('VGG-16 places unit 191 maxact across the imagenet test set (3 bus classes in color)')
a1.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
# plt.scatter(nbvals, [random.random() for _ in nbvals], alpha=0.9, s=10, c='black')
a1.scatter(bus_vals, [random.random() for _ in bus_vals], alpha=0.9, s=10, c=bus_colors)
a1.get_yaxis().set_ticks([])
a1.set_xlabel('unit 190 activation value, maximum over featuremap (post-ReLU)')

a2.set_title('VGG-16 places unit 191 maxact, 150 random non-bus images (blue) and 150 bus images (red)')
# a3.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
a2.scatter(nbvals, [random.random() for _ in nbvals], alpha=0.9, s=10, c='blue')
a2.scatter(bus_vals, [random.random() for _ in bus_vals], alpha=0.9, s=10, c='red')
a2.get_yaxis().set_ticks([])
a2.set_xlabel('unit 190 activation value, maximum over featuremap (post-ReLU)')

plt.tight_layout()


In [None]:
import seaborn as sns

bus_vals = maxvals_all[bus_examples].numpy()
nonbus_vals = maxvals_all[nonbus_examples].numpy()
nbvals = nonbus_vals[:len(bus_vals)]


fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(8, 3), dpi=300,
                              # facecolor='white',
                              sharex='all')

#ax1.set_title('VGG-16 places unit 19 maxact, 150 random non-bus images (blue) and 150 bus images (red)')
# a3.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
ax2.scatter(nbvals, range(len(nbvals)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(bus_vals, range(len(bus_vals)), alpha=0.5, s=10, c='#4B4CBF')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Jitterplot')


#ax2.set_title('Kernel density plot')
sns.distplot(nonbus_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="images from non-bus imagenet classes",
            ax=ax1)
sns.distplot(bus_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="imagenet school bus, trolley bus, minibus images",
            ax=ax1)
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_xlim(-50, 500)
ax2.set_xlabel('Unit 19 max activation in featuremap')

In [None]:
from netdissect import tally
def max_unit_150(imbatch, cls):
    model(imbatch.cuda())
    unit_max = model.retained_layer(layername)[:,150].view(len(imbatch), -1).max(1)[0]
    return unit_max

maxvals_all_150 = tally.tally_cat(max_unit_150, dataset, batch_size=100, num_workers=30, pin_memory=True,
                                 cachefile=resfile('maxval_val_150.npz'))

In [None]:
import seaborn as sns

airplane_vals = maxvals_all_150[airplane_examples].numpy()
nonairplane_vals = maxvals_all_150[nonairplane_examples].numpy()
navals = nonairplane_vals[:len(airplane_vals)]

fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(8, 3), dpi=300,
                              # facecolor='white',
                              sharex='all')

#ax1.set_title('VGG-16 places unit 19 maxact, 150 random non-bus images (blue) and 150 bus images (red)')
# a3.scatter(nonbus_vals, [random.random() for _ in nonbus_vals], alpha=0.1, s=10, c='gray')
ax2.scatter(navals, range(len(navals)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(airplane_vals, range(len(airplane_vals)), alpha=0.5, s=10, c='#4B4CBF')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Jitterplot')


#ax2.set_title('Kernel density plot')
sns.distplot(nonairplane_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="images from non-airplane imagenet classes",
            ax=ax1)
sns.distplot(airplane_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="imagenet airliner, warplane images",
            ax=ax1)
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_xlim(-50, 250)
ax2.set_xlabel('Unit %d max activation in featuremap' % unit_number)

In [None]:
bus_vals.mean(), nonbus_vals.mean()

In [None]:
dataset.classes

In [None]:
train_dataset = experiment.setting.load_dataset('imagenet', split='train', crop_size=224)
train_dataset.orig_classes = train_dataset.classes
train_dataset.classes = [readable_class[c] for c in train_dataset.orig_classes]

maxvals_train_150 = tally.tally_cat(max_unit_150, train_dataset, batch_size=100, num_workers=30, pin_memory=True,
                                    cachefile=resfile('maxval_train_150.npz'))

In [None]:
model(train_dataset[0][0][None].cuda())
print(train_dataset[0][0][0,0,0])
print(train_dataset.transforms)
print(model.retained_layer(layername)[:,150].view(1, -1).max(1)[0])

renormalize.as_image(train_dataset[0][0])

In [None]:
from collections import defaultdict

airplane_train_examples = []
nonairplane_train_examples = []
for i in pbar(range(len(train_dataset.images))):
    d, c = train_dataset.images[i]
    if c not in airplane_classes:
        nonairplane_train_examples.append(i)
    else:
        airplane_train_examples.append(i)
len(airplane_train_examples), len(nonairplane_train_examples)

In [None]:
import seaborn as sns

sample_size = 1000
all_airplane_train_vals = maxvals_train_150[airplane_train_examples].numpy()
airplane_train_vals = all_airplane_train_vals[:sample_size]
# calculage mean over the entire large sample, but illustrate with sample of
# 1000 points, for making the plot legible.
nonairplane_train_vals = maxvals_train_150[nonairplane_train_examples].numpy()
navals = nonairplane_train_vals[:len(airplane_train_vals)]

fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(8, 3), dpi=300,
                              # facecolor='white',
                              sharex='all')

#ax1.set_title('VGG-16 places unit 19 maxact, 150 random non-bus images (blue) and 150 bus images (red)')
ax2.scatter(navals, range(len(navals)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(airplane_train_vals, range(len(airplane_train_vals)), alpha=0.5, s=10, c='#4B4CBF')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Jitterplot')


#ax2.set_title('Kernel density plot')
sns.distplot(nonairplane_train_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="non-airplane imagenet images, mean=%.1f"
             % numpy.array(nonairplane_train_vals).mean().item(),
            ax=ax1)
sns.distplot(airplane_train_vals, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="imagenet airplane images, mean=%.1f"
             % numpy.array(airplane_train_vals).mean().item(),
            ax=ax1)
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_xlim(-50, 250)
ax1.axvline(numpy.array(all_airplane_train_vals).mean().item(), color="#4B4CBF", linewidth=1.5, linestyle='--')
ax1.axvline(numpy.array(nonairplane_train_vals).mean().item(), color="#F0883B", linewidth=1.5, linestyle='--')
for i in [1,3,5]:
    ax2.scatter([nonairplane_train_vals[i]], [i], s=50, 
           marker='|', edgecolors='none', color="#F0883B" )
    ax2.scatter([airplane_train_vals[i]], [i], s=50, 
           marker='|', edgecolors='none', color="#4B4CBF")
    
plt.savefig("scup.pdf", bbox_inches='tight')

In [None]:
maxvals_train_150[0]

In [None]:
sort_pos = maxvals_train_150[airplane_train_examples].sort(0)[0]
sort_neg = maxvals_train_150[nonairplane_train_examples].sort(0)[0]
unsort_neg = sort_neg[::len(sort_neg) // len(sort_pos)][:len(sort_pos)].flip(0)
thresh = sort_pos[(sort_pos < unsort_neg).sum()].item()
acc = ((sort_pos >= thresh).sum().float() / len(sort_pos) +
       (sort_neg < thresh).sum().float() / len(sort_neg)).item() / 2
thresh, acc

In [None]:
len(train_dataset)
airplane_train_examples[0]

In [None]:
for i in range(10):
    print(airplane_train_examples[i], airplane_train_vals[i])
    display(renormalize.as_image(train_dataset[airplane_train_examples[i]][0], source=dataset))

In [None]:
for i in [1,3,5]:
    print(nonairplane_train_examples[i], nonairplane_train_vals[i])
    display(renormalize.as_image(train_dataset[nonairplane_train_examples[i]][0], source=dataset))