In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy, math
import netdissect
from netdissect.easydict import EasyDict
from experiment import intervention_experiment
from experiment import dissect_experiment
from experiment.intervention_experiment import sharedfile
from experiment.dissect_experiment import make_upfn
from experiment import setting
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show

torch.set_grad_enabled(False)

# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model='vgg16', dataset='places', seg='netpqc', quantile=0.01, layer='conv5_3')
resdir = 'results/%s-%s-%s-%s-%d' % (args.model, args.dataset, args.seg, args.layer, args.quantile*1000)
def resfile(f):
    return os.path.join(resdir, f)

In [None]:
model = nethook.InstrumentedModel(setting.load_classifier(args.model).cuda())
layername = 'features.conv5_3'
model.retain_layer(layername)
dataset = setting.load_dataset('places')
upfn = make_upfn(args, dataset, model, layername)
sample_size = len(dataset)
percent_level = 1 - args.quantile

In [None]:
# Classifier labels
from urllib.request import urlopen
from netdissect import renormalize

# synset_url = 'http://gandissect.csail.mit.edu/models/categories_places365.txt'
# classlabels = [r.split(' ')[0][3:] for r in urlopen(synset_url).read().decode('utf-8').split('\n')]
classlabels = dataset.classes
segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')
renorm = renormalize.renormalizer(dataset)

In [None]:
# Use the segmodel for segmentations.  With broden, we could use ground truth instead.
def compute_conditional_indicator(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator
    return tally.conditional_samples(iacts, seg)
pbar.descnext('condi99')
condi99 = tally.tally_conditional_mean(compute_conditional_indicator,
        dataset, sample_size=sample_size,
        num_workers=3, pin_memory=True,
        cachefile=resfile('condi99.npz'))
iou_99 = tally.iou_from_conditional_indicator_mean(condi99)
unit_label_99 = [
        (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*iou_99.max(0))]
label_list = [label for concept, label, labelcat, iou in unit_label_99 if iou > 0.025]
labelcat_list = [labelcat for concept, label, labelcat, iou in unit_label_99 if iou > 0.025]
display(IPython.display.SVG(dissect_experiment.graph_conceptcatlist(labelcat_list)))

In [None]:
baseline_precision, baseline_recall, baseline_accuracy, baseline_ba  = (
    intervention_experiment.test_perclass_pra(
        model, dataset,
        cachefile=sharedfile('pra-%s-%s/pra_baseline.npz'
            % (args.model, args.dataset))))

In [None]:
baseline_recall.mean()

In [None]:
seat_units = [u for u, [_, label, _, _] in enumerate(unit_label_99) if label.startswith('seat')]
seat_units

In [None]:
baseline_accuracy

In [None]:
# Load all single-unit ablation accuracy
num_units = dict(vgg16=512, alexnet=256)[args.model]
single_unit_ablation_acc = torch.zeros(num_units, len(classlabels))
single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
single_unit_ablation_precision = torch.zeros(num_units, len(classlabels))
single_unit_ablation_recall = torch.zeros(num_units, len(classlabels))

for unit in range(num_units):
    [single_unit_ablation_precision[unit], single_unit_ablation_recall[unit], single_unit_ablation_acc[unit], single_unit_ablation_ba[unit]
    ] = intervention_experiment.test_perclass_pra(
                model, dataset,
                layername=layername,
                ablated_units=[unit],
                cachefile=sharedfile('pra-%s-%s/pra_ablate_unit_%d.npz' %
                    (args.model, args.dataset, unit)))

In [None]:
single_unit_ablation_acc[196,70]

In [None]:
for classnum in range(len(classlabels)):
    for unit in single_unit_ablation_ba[:,classnum].sort(0)[1]:
        diff = single_unit_ablation_ba[unit, classnum] - baseline_ba[classnum]
        if diff > -0.01:
            break
        print('%s: unit %d (%s) -> %.3f' % (
            classlabels[classnum], unit, unit_label_99[unit][1],
            diff ))

In [None]:
# Save and reload
numpy.savez(resfile('unit_ablation.npz'),
            single_unit_ablation_ba=single_unit_ablation_ba,
            baseline_ba=baseline_ba)
print(os.path.abspath(resfile('unit_ablation.npz')))
data = numpy.load(resfile('unit_ablation.npz'))
sua = torch.from_numpy(data['single_unit_ablation_ba'])
base = torch.from_numpy(data['baseline_ba'])
for classnum in range(len(classlabels)):
    for unit in sua[:,classnum].sort(0)[1]:
        diff = sua[unit, classnum] - base[classnum]
        if diff > -0.01:
            break
        print('%s: unit %d (%s) -> %.3f' % (
            classlabels[classnum], unit, unit_label_99[unit][1],
            diff ))

In [None]:
train_dataset = setting.load_dataset('places', 'train')

In [None]:
ttv_baseline_precision, ttv_baseline_recall, ttv_baseline_accuracy, ttv_baseline_ba  = (
    intervention_experiment.test_perclass_pra(
        model, train_dataset,
        sample_size=sample_size,
        cachefile=sharedfile('ttv-pra-%s-%s/pra_train_baseline.npz'
            % (args.model, args.dataset))))
    
ttv_single_unit_ablation_ba = torch.zeros(num_units, len(classlabels))
for unit in range(512):
    pbar.descnext('test unit %d' % unit)
    _, _, _, ablation_ba = intervention_experiment.test_perclass_pra(
            model, train_dataset,
            layername=layername,
            ablated_units=[unit],
            sample_size=sample_size,
            cachefile=
                sharedfile('ttv-pra-%s-%s/pra_train_ablate_unit_%d.npz' %
                (args.model, args.dataset, unit)))
    minacc, minclass = (ablation_ba - baseline_ba).min(0)
    ttv_single_unit_ablation_ba[unit] = ablation_ba
        
ttv_ablate_salient = [0.0]
ttv_ablate_nonsalient = [0.0]

classnum = classlabels.index('ski_resort')
for num_salient in range(1, 512):
    unitlist = ttv_single_unit_ablation_ba[:,classnum].sort(0)[1][:num_salient]
    _, _, _, testba = intervention_experiment.test_perclass_pra(model, dataset,
            layername=layername,
            ablated_units=unitlist,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' %
                                 (args.model, args.dataset, classlabels[classnum], len(unitlist))))
    # print([(classlabels[c], d.item()) for d, c in list(zip(*(testba - baseline_ba).sort(0)))[:5]])
    ttv_ablate_salient.append((testba[classnum] - baseline_ba[classnum]).item())

    unitlist = ttv_single_unit_ablation_ba[:,classnum].sort(0)[1][-num_salient:]
    _, _, _, testba2 = intervention_experiment.test_perclass_pra(model, dataset,
            layername=layername,
            ablated_units=unitlist,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz' %
                                 (args.model, args.dataset, classlabels[classnum], len(unitlist))))
    # print([(classlabels[c], d.item()) for d, c in list(zip(*(testba2 - baseline_ba).sort(0)))[:5]])
    ttv_ablate_nonsalient.append((testba2[classnum] - baseline_ba[classnum]).item())
    

In [None]:
import matplotlib.ticker as mtick

plt.style.use('dark_background')

classnum = classlabels.index('ski_resort')
b = baseline_ba[classnum].item()

#fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 1.7), dpi=300)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5.9, 3.1), dpi=300)
#ax.axvline(20, color='gray', linewidth=0.5, linestyle='-')
#ax.axvline(492, color='gray', linewidth=0.5, linestyle='-')
ax.axhline(b, color='gray', linewidth=0.5, linestyle='-')


ax.plot([y+b for y in ttv_ablate_salient], linewidth=1, label='Removing the most important units together',
       c="#4B4CBF")
ax.plot([y+b for y in ttv_ablate_nonsalient] + [0.5], linewidth=1, label='Removing all but the most important units',
       c="#F0883B")
if True:
    ax.scatter([0, 20, 492], [b, b+ttv_ablate_salient[20], b+ttv_ablate_nonsalient[492]],
           color=['#55B05B', "#4B4CBF", "#F0883B"], zorder=10,s=50)
#    ax.scatter([0, 20, ], [b, b+ttv_ablate_salient[20], ],
#           color=['#55B05B', "#4B4CBF", ], zorder=10,s=50)
else:
    ax.scatter([0, 2, 510], [b, b+ttv_ablate_salient[2], b+ttv_ablate_nonsalient[510]],
           color=['#55B05B', "#4B4CBF", "#F0883B"], zorder=10,s=50)
ax.set_xticks([0, 20, 128, 256, 384, 492, 512])
ax.set_xticklabels([0, 20, 128, 256, 384, '492   ', '    512'])
# ax.set_yticklabels(['40%', '50%', '60%', '70%', '80%', '90%'])
#ax.set_yticks([0.5, 0.65, 0.8])
ax.set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0))
ax.set_xlabel('number of conv5_3 units removed together')
ax.set_ylabel('single-class accuracy')
ax.set_ylim(0.5, 1.0)
ax.legend(loc='center right', bbox_to_anchor=(0.95, 0.3))
plt.savefig("ice-one-class.pdf", bbox_inches='tight')

print(b, b+ttv_ablate_salient[20], b+ttv_ablate_nonsalient[492])


In [None]:
len(train_dataset), len(dataset)

In [None]:
b + ttv_ablate_salient[4]

In [None]:
ttv_best_ba = []
ttv_worst_ba = []
ttv_base_ba = []
ttv_best_ablate = []
ttv_worst_ablate = []

# Change this to 256 to see supplemental results
num_best_units = 20

for classnum in range(len(classlabels)):
    unitlist = ttv_single_unit_ablation_ba[:,classnum].sort(0)[1][:num_best_units]
    _, _, _, testba = intervention_experiment.test_perclass_pra(model, dataset,
            layername=layername,
            ablated_units=unitlist,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_val_ablate_classunits_%s_ba_%d.npz' %
                                 (args.model, args.dataset, classlabels[classnum], len(unitlist))))
    # print([(classlabels[c], d.item()) for d, c in list(zip(*(testba - baseline_ba).sort(0)))[:5]])
    ttv_best_ba.append(testba[classnum].item())
    ttv_base_ba.append(baseline_ba[classnum].item())
    ttv_best_ablate.append((testba[classnum] - baseline_ba[classnum]).item())
    unitlist = ttv_single_unit_ablation_ba[:,classnum].sort(0)[1][num_best_units - 512:]
    _, _, _, testba2 = intervention_experiment.test_perclass_pra(model, dataset,
            layername=layername,
            ablated_units=unitlist,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz' %
                                 (args.model, args.dataset, classlabels[classnum], len(unitlist))))
    # print([(classlabels[c], d.item()) for d, c in list(zip(*(testba2 - baseline_ba).sort(0)))[:5]])
    ttv_worst_ba.append(testba2[classnum].item())
    ttv_worst_ablate.append((testba2[classnum] - baseline_ba[classnum]).item())

In [None]:
import seaborn as sns
import random

fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(6.5, 3), dpi=300,
                              sharex='all')

ax2.scatter(ttv_worst_ba, range(len(ttv_worst_ba)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(ttv_best_ba, range(len(ttv_best_ba)), alpha=0.5, s=10, c='#4B4CBF')
ax2.scatter(ttv_base_ba, range(len(ttv_base_ba)), alpha=0.5, s=10, c='#55B05B')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Scene class')


ax1.axvline(numpy.array(ttv_best_ba).mean().item(), color='#B6B6F2', linewidth=1.5, linestyle='--')
ax1.axvline(numpy.array(ttv_worst_ba).mean().item(), color='#F2CFB6', linewidth=1.5, linestyle='--')
ax1.axvline(numpy.array(ttv_base_ba).mean().item(), color='#B6F2BA', linewidth=1.5, linestyle='--')


sns.distplot(ttv_base_ba, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":'#55B05B'},
             label="No units removed, mean class accuracy=%.1f%%" % (100*numpy.array(ttv_base_ba).mean().item()),
            ax=ax1)
if True:
    sns.distplot(ttv_best_ba, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="%d units most damaging to class, mean=%.1f%%" %
             (num_best_units, 100*numpy.array(ttv_best_ba).mean().item()),
            ax=ax1)
sns.distplot(ttv_worst_ba, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="All %d other units removed, mean=%.1f%%" %
             (512 - num_best_units, 100*numpy.array(ttv_worst_ba).mean().item()),
            ax=ax1)
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_ylim([0, 20])
ax2.set_xlabel('Balanced single-class accuracy when sets of units are removed')
ax1.set_xlim(0.48, 1.02)
ax2.set_xticklabels(['40%', '50%', '60%', '70%', '80%', '90%', '100%'])
legend = ax1.legend(loc='upper right', bbox_to_anchor=(1, 1.1))
# legend.get_frame().set_facecolor('none')
legend.get_frame().set_edgecolor('none')
plt.savefig("ice-all-classes.pdf", bbox_inches='tight')

In [None]:
import seaborn as sns
import random

fig, [ax1, ax2] =plt.subplots(nrows=2, ncols=1, figsize=(8, 3), dpi=300,
                              sharex='all')

ax2.scatter(ttv_worst_ablate, range(len(ttv_worst_ablate)), alpha=0.5, s=10, c='#F0883B')
ax2.scatter(ttv_best_ablate, range(len(ttv_best_ablate)), alpha=0.5, s=10, c='#4B4CBF')
ax2.scatter([0] * len(ttv_best_ablate), range(len(ttv_best_ablate)), alpha=0.05, s=10, c='#55B05B')
ax2.get_yaxis().set_ticks([])
ax2.set_ylabel('Scene class')


ax1.axvline(numpy.array(ttv_best_ablate).mean().item(), color='#B6B6F2', linewidth=1.5, linestyle='--')
ax1.axvline(numpy.array(ttv_worst_ablate).mean().item(), color='#F2CFB6', linewidth=1.5, linestyle='--')


sns.distplot(ttv_best_ablate, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#4B4CBF"},
             label="Removed most-important %d units, mean=%.1f%%" %
             (num_best_units, 100*numpy.array(ttv_best_ba).mean().item()),
             ax=ax1)
sns.distplot(ttv_worst_ablate, kde=True, hist=False, kde_kws = {'linewidth': 3, "color":"#F0883B"},
             label="Kept only most-important %d units, mean=%.1f%%" %
             (num_best_units, 100*numpy.array(ttv_worst_ba).mean().item()),
            ax=ax1)
ax1.axvline(0, color='#55B05B', linewidth=3, linestyle='-', label='No units removed')
ax1.set_ylabel('Density')
ax1.get_yaxis().set_ticks([])
ax1.set_ylim([0, 7])
ax2.set_xlabel('Change in balanced single-class accuracy change when sets of units are removed')
ax1.set_xlim(-0.57, 0.3)
ax2.set_xticklabels(['-0.6', '-0.5', '-0.4', '-0.3', '-0.2', '-0.1', 'no change', 0.1, 0.2, 0.3])
#legend = ax1.legend()
legend = ax1.legend(loc='upper left', bbox_to_anchor=(-0.01, 1.25))

legend.get_frame().set_facecolor('none')
legend.get_frame().set_edgecolor('none')

In [None]:
numpy.savez(resfile('ttv_unit_ablation.npz'),
        single_unit_ablation_ba=ttv_single_unit_ablation_ba,
        baseline_ba=ttv_baseline_ba)

results = {}
for classnum in range(len(classlabels)):
    unitlist = []
    for unit in ttv_single_unit_ablation_ba[:,classnum].sort(0)[1]:
        diff = ttv_single_unit_ablation_ba[unit, classnum] - ttv_baseline_ba[classnum]
        if diff > -0.005:
            break
        print('%s: unit %d -> %.3f' % (
            classlabels[classnum], unit, diff ))
        unitlist.append({'unit': unit.item(), 'val_acc': diff.item()})
    results[classlabels[classnum]] = unitlist
with open(resfile('ttv_unit_ablation.json'), 'w') as f:
    json.dump(results, f, indent=1)
    

In [None]:
def calculate_topN_accuracy(img, cls):
    pred = model(img.cuda())
    scores, choices = pred.sort(1)
    correct = (choices.flip(1) == cls.cuda()[:,None].expand(choices.shape)).float()
    cum_correct = correct.cumsum(1)
    return cum_correct

topN_acc = tally.tally_mean(calculate_topN_accuracy, dataset, batch_size=100, pin_memory=True,
                    cachefile=sharedfile('pra-%s-%s/topn_accuracy.npz'
                    % (args.model, args.dataset)))
    

In [None]:
topN_acc.mean()

In [None]:
results = []
for i, (im, c) in enumerate(dataset):
    if i >= 10:
        break
    results.append([[i], [renormalize.as_image(im)]])
show(results)


In [None]:
for num_salient in [4, 20, 492]:
    unitlist = ttv_single_unit_ablation_ba[:,classnum].sort(0)[1][-num_salient:]
    test_pre, test_rec, test_acc, testba2 = intervention_experiment.test_perclass_pra(model, dataset,
            layername=layername,
            ablated_units=unitlist,
            cachefile=sharedfile('ttv-pra-%s-%s/pra_val_ablate_classunits_%s_worstba_%d.npz' %
                                 (args.model, args.dataset, classlabels[classnum], len(unitlist))))
    print(num_salient, test_rec.mean())
print(1/365.0)


In [None]:
print([classlabels[i] for i in test_rec.sort(0)[1].flip(0)[:8]])
print(test_rec.sort(0)[0].flip(0)[:8])

In [None]:
iou_99.shape

In [None]:
ttv_single_unit_ablation_ba.shape

In [None]:
ttv_baseline_ba.shape

In [None]:
important_units = torch.unique((ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).min(0)[1])
len(important_units)

In [None]:
iou_99.max(0)[0][important_units].mean()

In [None]:
iou_99.max(0)[0].mean()

In [None]:
important_cutoff = 4
unit_importance = torch.bincount((ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).sort(0)[1][:important_cutoff].view(-1))
most_important_units = (unit_importance >= 7).nonzero()[:,0]
print(len(most_important_units))
print(most_important_units)




In [None]:
imp_vals = torch.unique(unit_importance)
fig, ax = plt.subplots(figsize=(5,2.5), dpi=300)
tail = 7

xlist = [i for i in imp_vals.numpy() if i < tail] + [tail]
ylist = (
    [iou_99.max(0)[0][unit_importance == i].mean().item()
         for i in imp_vals if i < tail] +
    [iou_99.max(0)[0][unit_importance >= tail].mean().item()])
yerr = (
    [iou_99.max(0)[0][unit_importance == i].std().item()
          / math.sqrt(len(iou_99.max(0)[0][unit_importance == i]))
         for i in imp_vals if i < tail] +
    [iou_99.max(0)[0][unit_importance >= tail].std().item()
        / math.sqrt(len(iou_99.max(0)[0][unit_importance >= tail]))]
)
ax.bar(xlist, ylist, yerr=yerr, color="#4B4CBF",
    error_kw=dict(lw=1, capsize=5, capthick=1)
)
for x, y in zip(xlist, ylist):
    n = (sum(unit_importance==x) if x < tail
         else sum(unit_importance>=tail)).item()
    plt.text(x=x, y=1e-3, s='n=%d' % n, size=7.6, ha='center', va='baseline', color='white')
ax.set_xlabel('number of classes for which unit is important')
ax.set_ylabel('mean IoU$_{u,c}$')
ax.set_xticks(xlist)
ax.set_xticklabels([x if x < tail else '$\\geq %d$' % tail for x in xlist ])

In [None]:
imp_vals = torch.unique(unit_importance)
fig, ax = plt.subplots(figsize=(6,1.7), dpi=300)
tail = 7

xlist = [i for i in imp_vals.numpy() if i < tail] + [tail]
ylist = (
    [iou_99.max(0)[0][unit_importance == i].mean().item()
         for i in imp_vals if i < tail] +
    [iou_99.max(0)[0][unit_importance >= tail].mean().item()])
yerr = (
    [iou_99.max(0)[0][unit_importance == i].std().item()
          / math.sqrt(len(iou_99.max(0)[0][unit_importance == i]))
         for i in imp_vals if i < tail] +
    [iou_99.max(0)[0][unit_importance >= tail].std().item()
        / math.sqrt(len(iou_99.max(0)[0][unit_importance >= tail]))]
)
ax.barh(xlist,ylist, xerr=yerr, color="#4B4CBF",
    error_kw=dict(lw=1, capsize=2, capthick=1)
)
if True:
    for x, y in zip(xlist, ylist):
        n = (sum(unit_importance==x) if x < tail
             else sum(unit_importance>=tail)).item()
        plt.text(y=x, x=1e-3, s='n=%d' % n, size=7.6, ha='left', va='center', color='white')
ax.set_ylabel('classes for which\nunit is top-%d imp' % important_cutoff)
ax.set_ylabel('classes for which\nunit is important')
ax.set_xlabel('mean IoU$_{u,c}$ of units')
ax.set_yticks(xlist)
ax.set_yticklabels([x if x < tail else '$\geq %d$' % tail for x in xlist ])
plt.savefig("ice-vs-iou.pdf", bbox_inches='tight')

In [None]:
((ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).sort(0)[1] == 150).nonzero()

In [None]:
ydata = [
    unit_label_99[u][-1]
    for u in (ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).mean(1).sort(0)[1]
]
xdata = (ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).mean(1).sort(0)[0].numpy()

bsize = 32
ybatch = [numpy.mean(ydata[i:i+bsize]) for i in range(0, 512, bsize)]
yerr = [numpy.std(ydata[i:i+bsize] / numpy.sqrt(bsize)) for i in range(0, 512, bsize)]
xbatch = [-numpy.mean(xdata[i:i+bsize]) for i in range(0, 512, bsize)]
xerr = [numpy.std(xdata[i:i+bsize] / numpy.sqrt(bsize)) for i in range(0, 512, bsize)]
fig, ax = plt.subplots(figsize=(5.8, 5.5), dpi=300)
ax.plot(xbatch, ybatch, marker='o', color="#4B4CBF", lw=2)
ax.errorbar(xbatch, ybatch, yerr=yerr, xerr=xerr, capsize=2, capthick=1, color='black')
ax.set_ylabel('mean iou')
ax.set_xlabel('mean class importance\n(class importance of unit averaged over all classes)')
legend = ax.legend(['32 units grouped by mean class importance', 'error bars show standard error'], loc='upper left')
legend.get_frame().set_facecolor('none')
legend.get_frame().set_edgecolor('none')

In [None]:
xdata

In [None]:
from collections import defaultdict
unit_importance = defaultdict(list)
for cls, units in enumerate(
    (ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).sort(0)[1][:important_cutoff].permute((1, 0))):
    for u in units:
        unit_importance[u.item()].append(classlabels[cls])
unit_importance_records = []
for u in range(512):
    unit_importance_records.append(dict(unit=u, important_to=unit_importance[u]))
with open(resfile('importance.json'), 'w') as f:
    json.dump(dict(importance=unit_importance_records), f)

In [None]:
import networkx as nx
G = nx.Graph()
added = set()
for u in range(512):
    if len(unit_importance_records[u]['important_to']) > 1:
        G.add_node('%d' % u)
        for c in unit_importance_records[u]['important_to']:
            if c not in added:
                G.add_node(c)
            G.add_edge('%d' % u, c)
fig, ax = plt.subplots(1, 1, figsize=(50, 50));
nx.draw_networkx(G, ax=ax)
plt.show()

In [None]:
# print(nx.minimum_cycle_basis(G))
print(nx.find_cycle(G, "soccer_field"))

In [None]:
# Find cycles of length 3
adjacent_cls = defaultdict(set)
for r in unit_importance_records:
    for c in r['important_to']:
        adjacent_cls[c].update(r['important_to'])
set_of_sets = [set(r['important_to']) for r in unit_importance_records]
for u in iou_99.max(0)[0].sort(0)[1].flip(0):
    r = unit_importance_records[u]
    print()
    print('unit %d:' % r['unit'])
    here = r['important_to']
    for i, c in enumerate(here):
        for j in range(i + 1, len(here)):
            d = here[j]
            candidates = adjacent_cls[c].intersection(adjacent_cls[d]).difference(here)
            for e in list(candidates):
                triple = [c, d, e]
                if any(s.issuperset(triple) for s in set_of_sets):
                    candidates.remove(e)
            if len(candidates):
                print(c, d, candidates)


In [None]:
# Measure covariance between (max) units and classes

def compute_maxact_and_pred(batch, *args):
    image_batch = batch.cuda()
    preds = torch.nn.functional.softmax(model(image_batch), dim=1)
    acts = model.retained_layer(layername)
    maxacts = acts.view(acts.shape[0], acts.shape[1], -1).max(2)[0]
    return maxacts, preds
actpredcov = tally.tally_cross_covariance(compute_maxact_and_pred,
        dataset, sample_size=sample_size,
        num_workers=3, pin_memory=True,
        cachefile=resfile('actpredcov.npz'))


In [None]:
actpredcov.correlation()

In [None]:
apc = actpredcov.correlation()
apc.shape

In [None]:
# important unit per class
iupc = (ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).sort(0)[1][:important_cutoff].permute((1, 0))
niupc = (ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:]).sort(0)[1][important_cutoff:].permute((1, 0))
negclass = apc[iupc, torch.arange(365)[:,None]].min(1)[0].min(0)[1]
apc[:,negclass][iupc[negclass]], iupc[negclass]

In [None]:
apc[iupc, torch.arange(365)[:,None]].contiguous().view(-1).sort(0)

In [None]:
imp = -(ttv_single_unit_ablation_ba - ttv_baseline_ba[None,:])
fig, ax = plt.subplots(figsize=(6, 6), dpi=300)
#ax.scatter(imp.view(-1), apc.view(-1))
ax.scatter(
    imp[niupc, torch.arange(365)[:,None]].contiguous().view(-1),
    apc[niupc, torch.arange(365)[:,None]].contiguous().view(-1),
    s=0.1, alpha=0.2, color="#F0883B")
ax.scatter(
    imp[iupc, torch.arange(365)[:,None]].contiguous().view(-1),
    apc[iupc, torch.arange(365)[:,None]].contiguous().view(-1),
    s=1, alpha=0.5, color="#4B4CBF")
ax.set_ylabel('correlation between unit and class')
ax.set_xlabel('importance of unit to class accuracy')
ax.add_patch(
     mpl.patches.Rectangle(
        (-0.025, 0),
        0.05,
        -0.15,
        linewidth=1,edgecolor='r',facecolor='none'
     ) )
ax.text(0.07, 0.1, '%.1f%% of important-unit-class\ncorrelations are positive'
        % (100*(apc[iupc, torch.arange(365)[:,None]] > 0).sum().double() / iupc.numel()))
ax.text(0.03, -0.1, '%.1f%% of all unit-class\ncorrelations are negative'
        % (100*((apc < 0).sum().float() / apc.numel())))
plt.show()


In [None]:
(apc < 0).sum().float() / apc.numel()