In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

Set up experiment directory and settings

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy
import netdissect
from netdissect.easydict import EasyDict
from netdissect import experiment
from netdissect.experiment import resfile
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show

# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model='vgg16', dataset='places', seg='netpqc', layer=None)
resdir = 'results/%s-%s-%s' % (args.model, args.dataset, args.seg)
experiment.set_result_dir(resdir)

load classifier model and dataset

In [None]:
model = experiment.load_model(args)
layername = experiment.instrumented_layername(args)
model.retain_layer(layername)
dataset = experiment.load_dataset(args)
upfn = experiment.make_upfn(args, dataset, model, layername)
sample_size = len(dataset)

print('Inspecting layer %s of model %s on %s' % (layername, args.model, args.dataset))

In [None]:
# Classifier labels
from urllib.request import urlopen
from netdissect import renormalize

percent_level=0.995
classlabels = dataset.classes
renorm = renormalize.renormalizer(dataset, mode='zc')
pbar.descnext('rq')
def compute_samples(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
rq = tally.tally_quantile(compute_samples, dataset,
                          sample_size=sample_size,
                          r=8192,
                          num_workers=100,
                          pin_memory=True,
                          cachefile=resfile('rq.npz'))
from netdissect import imgviz
iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq, level=rq.quantiles(percent_level))


In [None]:
model.retain_layer(layername)
image_index = 0
out = model(dataset[image_index][0][None,...].cuda())

print('gt', classlabels[dataset[image_index][1]])
print('pred', classlabels[out.max(1)[1][0]])
display(renormalize.as_image(dataset[image_index][0], source=dataset))
model.retained_layer(layername).shape

In [None]:
target_class = 'bedroom'
# target_class_id = 
for class_id, classlabel in enumerate(classlabels): 
    if classlabel == target_class:
        target_class_id = class_id
        print(target_class_id, classlabel)
        break


In [None]:
num_images = 1000
image_index = 0
good_indices = []
filtering_source_class = False
while (True):
    gt_label = classlabels[dataset.images[image_index][1]]
    if filtering_source_class and gt_label != 'ski_resort':
        image_index += 1
        continue
    out = model(dataset[image_index][0][None,...].cuda())
    pred_label = classlabels[out.max(1)[1][0]]
    if gt_label != target_class and pred_label != target_class:
        good_indices.append(image_index)
    else:
        print('image {:d} gt {:s} pred {:s}'.format(image_index, gt_label, pred_label))
    
    image_index += 1
    if len(good_indices) == num_images: 
        print('get {:d} images from {:d} candidates'.format(num_images, image_index))
        break
    

In [None]:
import json, urllib.request
import json, urllib.request
unit_names = json.load(urllib.request.urlopen('http://dissect.csail.mit.edu/results/vgg16-places-netpqc-conv5_3-10/report.json'))
data = json.load(urllib.request.urlopen('http://dissect.csail.mit.edu/results/vgg16-places-netpqc-conv5_3-10/ttv_unit_ablation.json'))


units = data[target_class]

unit_ids = []
unit_acc = []
for unit in units: 
    unit_id = unit['unit']
    acc = unit['val_acc']
    unit_ids.append(unit_id)
    unit_acc.append(acc)
    label = unit_names['units'][unit_id]['label']
    print(unit_id, acc, label, acts_mean_average[unit_id])


In [None]:
dataset.images

In [None]:
print(unit_ids)

In [None]:
print(acts_mean)

In [None]:
for good_index in good_indices[:10]:
    result_path = os.path.join(results_dir, 'image_{:d}_target_{:s}.pkl'.format(good_index, target_class))
    data = pickle_load(result_path)
    image_id = data['image_id']
    target_id = data['target_id']
    ori_image = data['ori']
    adv_image = data['adv']
    image_ori = dataset[image_id][0]
    out = model(image_ori[None,...].cuda())
    acts_ori = model.retained_layer(layername).cpu()

    adv = renormalize.as_tensor(adv_image, source='pt', mode='imagenet')[None,...]
    pred_adv = model(adv.cuda())

    acts_adv = model.retained_layer(layername).cpu()
    acts_mean = (acts_adv - acts_ori).mean(dim=(2, 3)).numpy()[0]

    diff_image = adv_image - ori_image
    diff_image = diff_image/abs(diff_image).max()+0.5
    img = renormalize.as_image(diff_image, source='pt')
#     all_image = np.concatenate([ori_image, adv_image, diff_image], axis=1)
    display(renormalize.as_image(ori_image, source='pt'))
    display(renormalize.as_image(adv_image, source='pt'))
    display(renormalize.as_image(diff_image, source='pt'))
#     for u_idx, u in enumerate(unit_ids):
#         print(u_idx, u)
    display(show.blocks(
        [[[
           'unit {:03d} vac drop {:.3f} diff_mean {:.3f}'.format(u, unit_acc[u_idx], acts_mean_average[u]),
           '{:s} {:3.3f} diff = {:.3f}'.format(unit_names['units'][u]['label'], unit_names['units'][u]['iou'], acts_mean[u]),
           [iv.masked_image(image_ori, acts_ori, (0, u))],
           [iv.heatmap(acts_ori, (0, u), mode='nearest')],
           [iv.masked_image(adv, acts_adv, (0, u))],
           [iv.heatmap(acts_adv, (0, u), mode='nearest')],
          ]
          for u_idx, u in enumerate(unit_ids)]
        ],
    ))

In [None]:
json.dump(good_indices, open('ski_resort_to_bedroom.json', 'w'))

In [None]:
import foolbox
import torch
import torchvision.models as models
import numpy as np
# import cv2
print(foolbox.__version__)
from foolbox.criteria import TargetClass, TargetClassProbability
import numpy as np
import os
import pickle
from netdissect import pbar

def mkdir(path):
    """create a single empty directory if it didn't exist
    Parameters:
        path (str) -- a single directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)

def pickle_load(file_name):
    data = None
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
    return data


def pickle_save(file_name, data):
    with open(file_name, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

# test_stop = 1
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))  #0.475, 0.441, 0.408
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

fmodel = foolbox.models.PyTorchModel(model, bounds=(0, 1), num_classes=365, preprocessing=(mean, std))
results_dir = 'results/adv/vgg16/{:s}_images'.format(target_class)
print(results_dir)
mkdir(results_dir)
print('target id {:d}, class {:s}'.format(target_class_id, target_class))
for good_index in pbar(good_indices):
    result_path = os.path.join(results_dir, 'image_{:d}_target_{:s}.pkl'.format(good_index, target_class))
    if os.path.isfile(result_path):
        continue
    image = dataset[good_index][0]
    image = renormalize.as_tensor(image, source=dataset, mode='pt').numpy()
    pred = np.argmax(fmodel.forward_one(image))
#     print('predicted class', pred, classlabels[pred])
    attack = foolbox.attacks.CarliniWagnerL2Attack(fmodel, criterion=TargetClass(target_class_id))
    adversarial = attack(image, pred)
    adv_label = np.argmax(fmodel.forward_one(adversarial))
#     print('adversarial class', adv_label, classlabels[adv_label])
    ori_image = torch.from_numpy(image).float()
    adv_image =  torch.from_numpy(adversarial).float()
    pickle_save(result_path, {'image_id': good_index, 'target_id': target_class_id, 'ori': ori_image, 'adv': adv_image})
    if good_index % 50 == 0:
        print('process {:d}/{:d}'.format(good_index, len(good_indices)))
        print('predicted class', pred, classlabels[pred])
        print('adversarial class', adv_label, classlabels[adv_label])
#     if good_index +1 >= test_stop:
#         break

visualize activations for single layer of single image

In [None]:
import pickle
import numpy as np
def pickle_load(file_name):
    data = None
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
    return data

# loading results 
results_dir = 'results/adv/vgg16/{:s}_images'.format(target_class)
os.makedirs(results_dir, exist_ok=True)
# test_stop = 70
# target_class = 'ski_resort'
acts_mean_abs_all = []
for good_index in good_indices:
    result_path = os.path.join(results_dir, 'image_{:d}_target_{:s}.pkl'.format(good_index, target_class))
    data = pickle_load(result_path)
    image_id = data['image_id']
    target_id = data['target_id']
    ori_image = data['ori']
    adv_image = data['adv']
    pred_ori = model(dataset[image_id][0][None,...].cuda())

    image_ori = dataset[image_id][0]
#     out = model(image_ori[None,...].cuda())
    acts_ori = model.retained_layer(layername).cpu()

    adv = renormalize.as_tensor(adv_image, source='pt', mode='imagenet')[None,...]
    pred_adv = model(adv.cuda())

    acts_adv = model.retained_layer(layername).cpu()
    
    acts_mean_abs = (acts_adv - acts_ori).abs().mean(dim=(2, 3)).numpy()[0]
    acts_mean_abs_all.append(acts_mean_abs[..., np.newaxis])
#     if good_index >= test_stop:
#         break
acts_mean_abs_all = np.concatenate(acts_mean_abs_all, axis=1)
acts_mean_average = np.mean(acts_mean_abs_all, axis=1)
sort_ids =  np.argsort(acts_mean_average)[::-1][:10] 
print(acts_mean_average[sort_ids])
print(sort_ids)
   

In [None]:
def Diff(li1, li2): 
    return (list(set(li1) - set(li2))) 

print(acts_mean_average[unit_ids])
print(unit_ids)
print(np.mean(acts_mean_average))
print(np.mean(acts_mean_average[unit_ids]))
# print(np.mean(acts_mean_average[unit_ids]))
remain_ids = Diff(range(512), unit_ids)
print(np.mean(acts_mean_average[remain_ids]))


In [None]:
_ = model(dataset[image_index][0][None,...].cuda())
image_ori = dataset[image_index][0]

acts_ori = model.retained_layer(layername).cpu()

display(show.blocks(
    [[['unit {0:03d}'.format(u), '{:s} {:f}'.format(data['images'][u]['label'], data['images'][u]['iou']),
       [iv.masked_image(image_ori, acts_ori, (0, u))],
       [iv.heatmap(acts_ori, (0, u), mode='nearest')]]
      for u in range(min(acts.shape[1], 32)) if data['images'][u]['iou'] > iou_threshold]
    ],
))




## Perform adversarial attacks


In [None]:
import foolbox
import torch
import torchvision.models as models
import numpy as np
import cv2
print(foolbox.__version__)

## Adversarial attack a pre-trained PyTorch model

In [None]:
%%time
from foolbox.criteria import TargetClass, TargetClassProbability

image_index=0

label = dataset[image_index][1]
image = dataset[image_index][0]
image = renormalize.as_tensor(image, source=dataset, mode='pt').numpy()
mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))  #0.475, 0.441, 0.408
std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))

fmodel = foolbox.models.PyTorchModel(model, bounds=(0, 1), num_classes=365, preprocessing=(mean, std))
print(image.max(), image.min(), image.shape, image.dtype)
print('label', label, classlabels[label])

pred = np.argmax(fmodel.forward_one(image))
print('predicted class', pred, classlabels[pred])
# attack = foolbox.attacks.DeepFoolAttack(fmodel)#, criterion=TargetClass(100))  #LBFGSAttack FGSM
attack = foolbox.attacks.CarliniWagnerL2Attack(fmodel, criterion=TargetClass(100))  #LBFGSAttack FGSM
adversarial = attack(image, pred)
adv_label = np.argmax(fmodel.forward_one(adversarial))
print('adversarial class', adv_label, classlabels[adv_label])

In [None]:
# visualize results 
from PIL import Image

ori_image = torch.from_numpy(image).float()
adv_image =  torch.from_numpy(adversarial).float()
diff_image = adv_image - ori_image
diff_image = diff_image/abs(diff_image).max()*0.5+0.5
img = renormalize.as_image(diff_image, source='pt')
display(renormalize.as_image(ori_image, source='pt'))
display(renormalize.as_image(adv_image, source='pt'))
display(renormalize.as_image(diff_image, source='pt'))



In [None]:
from netdissect import imgviz

adv = renormalize.as_tensor(adv_image, source='pt', mode='imagenet')[None,...]
_ = model(adv.cuda())
acts_adv = model.retained_layer(layername).cpu()
iou_threshold = 0.025
display_units = 32
sort_method = 'diff_p'
num_units = acts_adv.shape[1]
if sort_method is 'id':
    sort_ids = range(min(num_units, display_units))

if sort_method is 'iou':
    ious = [data['images'][i]['iou'] for i in range(num_units)]
    sort_ids = np.argsort(ious)[::-1][:display_units]
if sort_method is 'diff_n':
#     acts_mean_abs = (acts_adv - acts_ori).abs().mean(dim=(2, 3)).numpy()[0]
    acts_mean = (acts_adv - acts_ori).mean(dim=(2, 3)).numpy()[0]
    sort_ids = np.argsort(acts_mean)[:display_units]
if sort_method is 'diff_p':
    acts_mean = (acts_adv - acts_ori).mean(dim=(2, 3)).numpy()[0]
    sort_ids = np.argsort(acts_mean)[::-1][:display_units] 
#     print(ious[sort_ids[0]])

print(sort_ids)


In [None]:

display(show.blocks(
    [[[
       'unit {0:03d}'.format(u),
       '{:s} {:3.3f} diff = {:3.3f}'.format(data['images'][u]['label'],
                          data['images'][u]['iou'], acts_mean[u]),
       [iv.masked_image(image_ori, acts_ori, (0, u))],
       [iv.heatmap(acts_ori, (0, u), mode='nearest')],
       [iv.masked_image(adv, acts_adv, (0, u))],
       [iv.heatmap(acts_adv, (0, u), mode='nearest')],
      ]
      for u in sort_ids if data['images'][u]['iou'] > iou_threshold]
    ],
))

In [None]:

display(show.blocks(
    [[[
       'unit {0:03d}'.format(u),
       '{:s} {:3.3f} diff = {:3.3f}'.format(data['images'][u]['label'],
                          data['images'][u]['iou'], acts_mean[u]),
       [iv.masked_image(image_ori, acts_ori, (0, u))],
       [iv.heatmap(acts_ori, (0, u), mode='nearest')],
       [iv.masked_image(adv, acts_adv, (0, u))],
       [iv.heatmap(acts_adv, (0, u), mode='nearest')],
      ]
      for u in sort_ids if data['images'][u]['iou'] > iou_threshold]
    ],
))

## Collect quantile statistics

First, unconditional quantiles over the activations.  We will upsample them to 56x56 to match with segmentations later.


In [None]:
pbar.descnext('rq')
def compute_samples(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])
rq = tally.tally_quantile(compute_samples, dataset,
                          sample_size=sample_size,
                          r=8192,
                          num_workers=100,
                          pin_memory=True,
                          cachefile=resfile('rq.npz'))

In [None]:
def compute_conditional_samples(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    hacts = upfn(acts)
    return tally.conditional_samples(hacts, seg)


condq = tally.tally_conditional_quantile(compute_conditional_samples,
        dataset,
        batch_size=1, num_workers=30, pin_memory=True,
        sample_size=sample_size, cachefile=resfile('condq.npz'))

## Visualize Units

Collect topk stats first.

In [None]:
pbar.descnext('topk')
def compute_image_max(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    acts = acts.view(acts.shape[0], acts.shape[1], -1)
    acts = acts.max(2)[0]
    return acts
topk = tally.tally_topk(compute_image_max, dataset, sample_size=sample_size,
        batch_size=50, num_workers=30, pin_memory=True,
        cachefile=resfile('topk.npz'))

Then we just need to run through and visualize the images.

In [None]:
pbar.descnext('unit_images')

iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq,
        level=rq.quantiles(percent_level))
def compute_acts(image_batch):
    image_batch = image_batch.cuda()
    _ = model(image_batch)
    acts_batch = model.retained_layer(layername)
    return acts_batch
unit_images = iv.masked_images_for_topk(
        compute_acts, dataset, topk, k=10, num_workers=30, pin_memory=True,
        cachefile=resfile('top10images.npz'))

In [None]:
for u in [10, 20, 30, 40]:
    print('unit %d' % u)
    display(unit_images[u])

## Label Units

Collect 99.5 quantile stats.

In [None]:
# Use the segmodel for segmentations.  With broden, we could use ground truth instead.
def compute_conditional_indicator(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_995).float() # indicator
    return tally.conditional_samples(iacts, seg)
pbar.descnext('condi995')
condi995 = tally.tally_conditional_mean(compute_conditional_indicator,
        dataset, sample_size=sample_size,
        num_workers=3, pin_memory=True,
        cachefile=resfile('condi995.npz'))

In [None]:
iou_995 = tally.iou_from_conditional_indicator_mean(condi995)
unit_label_995 = [
        (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*iou_995.max(0))]
label_list = [label for concept, label, labelcat, iou in unit_label_995 if iou > 0.04]
display(IPython.display.SVG(experiment.graph_conceptlist(label_list)))
len(label_list)

In [None]:
from netdissect import experiment
labelcat_list = [labelcat for concept, label, labelcat, iou in unit_label_995 if iou > 0.04]
display(IPython.display.SVG(experiment.graph_conceptcatlist(labelcat_list)))


In [None]:
unit_label_adaptive

Show a few units with their labels

In [None]:
for u in [10, 20, 30, 40]:
    print('unit %d, label %s, iou %.3f' % (u, unit_label_995[u][1], unit_label_995[u][3]))
    display(unit_images[u])

Investigate secondary labels

In [None]:
if False:
    seg_cor = experiment.load_concept_correlation(args, segmodel, seglabels)

In [None]:
sorted_unit_label_995 = sorted([(unit, concept, label, iou)
    for unit, (concept, label, labelcat, iou) in enumerate(unit_label_995)
    ], key=lambda x: -x[-1])

if False:
    count = 0
    double_count = 0
    multilabels = {}
    for unit, concept, label, iou in sorted_unit_label_995:
        if iou < 0.02:
            continue
        labels = [(label, iou)]
        for c2 in iou_995[:, unit].sort(0, descending=True)[1]:
            if c2 == concept or seg_cor[c2, concept] > 0:
                continue
            if iou_995[c2, unit] < 0.02:
                break
            labels.append((seglabels[c2], iou_995[c2, unit]))
            break
        multilabels[unit] = labels
        count += 1
        double_count += 1 if len(labels) > 1 else 0
        print('unit %d: %s' % (unit, ', '.join(['%s: iou %.3f' % r for r in labels])))
        if len(labels) > 1 and label == 'bed':
            display(unit_images[unit])


In [None]:
print('%d doubles out of %d (%.2f)' % (double_count, count, float(double_count) / count))


# Adaptive labeling of units

using conditional quantiles and IQR

In [None]:
if False:
    cutoff_candidates = 1 - torch.logspace(-3, math.log10(0.15), 50)
    unit_quantile_zero = rq.normalize(torch.zeros(256))
    unit_quantile_mask = cutoff_candidates[None,:] <= unit_quantile_zero[:,None]
    iqr_candidates = tally.iqr_from_conditional_quantile(condq, cutoff=cutoff_candidates)
    iou_candidates = tally.iou_from_conditional_quantile(condq, cutoff=cutoff_candidates)

    # Ignore records for which unit is zeroed
    iqr_candidates[unit_quantile_mask[:,None,:].expand(iqr_candidates.shape)] = 0
    best_adaptive_iqr, best_iqr_choice = iqr_candidates.max(2)

    # Obtain the iou at the max-iqr threshold
    iou_at_best_iqr = iou_candidates.gather(2, best_iqr_choice[...,None])[...,0]

    # Ignore records for which the max-iqr is achieved at 50-50 (typically "painted", or "building")
    masked_iou_at_best_iqr = iou_at_best_iqr.clone()
    # masked_iou_at_best_iqr[best_iqr_choice == len(cutoff_candidates) - 1] = 0.0
    masked_iou_at_best_iqr

    best_adaptive_iou, best_adaptive_match = masked_iou_at_best_iqr.max(1)
    for u in range(256):
        print(u, best_adaptive_match[u].item(),
              seglabels[best_adaptive_match[u]],
              best_adaptive_iou[u].item(),
              1 - cutoff_candidates[best_iqr_choice[u, best_adaptive_match[u]]].item())
        print(unit_label_995[u])
        display(unit_images[u])

In [None]:
# Get the best entity based on iqr
cutoff_candidates = 1 - torch.logspace(-3, math.log10(0.5), 50)
unit_quantile_zero = rq.normalize(torch.zeros(256))
unit_quantile_mask = cutoff_candidates[None,:] <= unit_quantile_zero[:,None]
iqr_candidates = tally.iqr_from_conditional_quantile(condq, cutoff=cutoff_candidates)
iou_candidates = tally.iou_from_conditional_quantile(condq, cutoff=cutoff_candidates)
best_adaptive_iqr, best_iqr_choice = iqr_candidates.max(2)

# This is needed for good results:
unmasked_iqr_candidates = iqr_candidates.clone()
# large_concepts = (best_iqr_choice == len(cutoff_candidates) - 1)
# iqr_candidates[large_concepts[:,:,None].expand(iqr_candidates.shape)] = 0
# Also ignore thresholds past zero relu threshold
iqr_candidates[unit_quantile_mask[:,None,:].expand(iqr_candidates.shape)] = 0
best_adaptive_iqr, best_iqr_choice = iqr_candidates.max(2)

# Get rid of cases where the max iqr is at the lowest threshold
if True:
    max_at_low_quantile_mask = (unit_quantile_mask[:,None,:]
        .expand(best_iqr_choice.shape + (unit_quantile_mask.shape[1],))
        .gather(2, (best_iqr_choice[:,:,None] + 1).clamp(0, 49)))[...,0]
    iqr_candidates[max_at_low_quantile_mask[...,None].expand(iqr_candidates.shape)] = 0
best_adaptive_iqr, best_iqr_choice = iqr_candidates.max(2)


iqr_at_best_threshold = iqr_candidates.gather(2, best_iqr_choice[...,None])[...,0]
iou_at_best_iqr = iou_candidates.gather(2, best_iqr_choice[...,None])[...,0]

best_adaptive_iqr, best_adaptive_match = iqr_at_best_threshold.max(1)
best_adaptive_iou = iou_at_best_iqr.gather(1, best_adaptive_match[...,None])[...,0]
unit_label_adaptive = []
for u in range(256):
    unit_label_adaptive.append((
        best_adaptive_match[u].item(),
        seglabels[best_adaptive_match[u]],
        segcatlabels[best_adaptive_match[u]],
        best_adaptive_iou[u].item()
    ))
    if unit_label_995[u][1] == seglabels[best_adaptive_match[u]]:
        continue
    print('adaptive', u, best_adaptive_match[u].item(),
          seglabels[best_adaptive_match[u]],
          best_adaptive_iou[u].item(),
          1 - cutoff_candidates[best_iqr_choice[u, best_adaptive_match[u]]].item())
    print('fixed 99.5%', unit_label_995[u])
    display(unit_images[u])

In [None]:
best_iqr_choice.shape

In [None]:
unit_quantile_mask.shape

In [None]:
max_at_low_quantile_mask = (unit_quantile_mask[:,None,:]
    .expand(best_iqr_choice.shape + (unit_quantile_mask.shape[1],))
    .gather(2, (best_iqr_choice[:,:,None] + 1).clamp(0, 49)))[...,0]
max_at_low_quantile_mask.shape

In [None]:
best_iqr_choice.shape

In [None]:
plt.plot(cutoff_candidates.numpy(), iqr_candidates[65,4].numpy(), linewidth=2)
plt.xlabel('quantile threshold for unit')
plt.ylabel('information quality ratio')


In [None]:
plt.plot(cutoff_candidates.numpy(), iqr_candidates[3, 2].numpy(), linewidth=2)
plt.xlabel('quantile threshold for unit')
plt.ylabel('information quality ratio')
plt.title('segments vs unit at various thresholds')

In [None]:
plt.plot(cutoff_candidates.numpy(), iqr_candidates[3, 12].numpy(), linewidth=2)
plt.xlabel('quantile threshold for unit 254')
plt.ylabel('information quality ratio')
plt.title('Signboard segments vs unit 254 at various thresholds')

In [None]:
from netdissect import experiment
labelcat_list = [labelcat for concept, label, labelcat, iou in unit_label_adaptive]#  if iou > 0.02]
display(IPython.display.SVG(experiment.graph_conceptcatlist(labelcat_list)))

In [None]:
best_iqr_choice[22,1], cutoff_candidates[best_iqr_choice[22,1]]

In [None]:
unit_quantile_zero[22]

In [None]:
unit_quantile_zero = rq.normalize(torch.zeros(256))
cutoff_candidates[None,:] > unit_quantile_zero[:,None]


# Intervention experiment

Part 2.

# Linear Disciminant Analysis

LDA of concepts -> single class.  This will give us a baseline.

In [None]:
focus_class = 'church-outdoor'
focus_classnum = classlabels.index(focus_class)
rcov_in_class = experiment.concept_covariance(
    args, segmodel, seglabels, sample_size=5000,
    filter_class=lambda x: x == focus_classnum,
    cachefile=experiment.sharedfile('lda-%s/%s-rcov.npz' % (args.seg, focus_class)))
rcov_out_of_class = experiment.concept_covariance(
    args, segmodel, seglabels, sample_size=5000,
    filter_class=(lambda x: x != focus_classnum),
    cachefile=experiment.sharedfile('lda-%s/%s-negate-rcov.npz' % (args.seg, focus_class)))


In [None]:
import copy

def rcov_scaled_to_unit_mean(rcov):
    rcov = copy.copy(rcov)
    scale = 1 / rcov._mean
    scale[rcov._mean == 0] = 0
    rcov._mean *= scale
    rcov.cmom2 *= scale[:, None]
    rcov.cmom2 *= scale[None, :]
    return rcov

def rcov_scaled_to_unit_std(rcov):
    rcov = copy.copy(rcov)
    std = rcov.covariance().diag().sqrt()
    scale = std.reciprocal()
    scale[std == 0] = 0
    rcov._mean *= scale
    rcov.cmom2 *= scale[:, None]
    rcov.cmom2 *= scale[None, :]
    return rcov

# rcov_in_class = rcov_scaled_to_unit_std(rcov_in_class)
# rcov_out_of_class = rcov_scaled_to_unit_std(rcov_out_of_class)

In [None]:
from netdissect import lda
reload(lda)
trans = lda.lda_transform_from_covariances([rcov_in_class, rcov_out_of_class], shrinkage=0.1,
                                           prior=[1/365.0, 364/365.0])
for c in trans[:,0].sort(0)[1][-20:].flip(0):
    print(seglabels[c], c.item(), trans[c, 0].item(), rcov_in_class.mean()[c].item())

## Salience of concept to class

Here we compute the mutual information between each visual concept and each scene class.
(We binarize the visual concept by thresholding at some number of pixels in the image; then we compute mutual information between this binary variable and each scene category.  For each scene-concept pair we choose the threshold that maximizes mutual innformation.)

Listed below are the top 3 visual concepts with highest mutual information to each class.

In [None]:
#salience = experiment.load_salience_matrix(args, segmodel, classlabels, seglabels)
salience = experiment.load_class_concept_correlation(args, segmodel, classlabels, seglabels)
for cls in [100, 200, 300]: # range(len(classlabels)):
    print(classlabels[cls])
    for mi, concept in list(zip(*salience[cls].sort(0)))[:-5-1:-1]:
         print(mi.item(), concept.item(), seglabels[concept])

Here we print the salience information the other way: for each visual concept, we list the top scene categories with highest mutual information to that visual concept.

In [None]:
for concept in [5, 10, 15, 20, 25]:
    print(seglabels[concept])
    for mi, cls in list(zip(*salience[:,concept].sort(0)))[:-5-1:-1]:
        print(mi.item(), cls.item(), classlabels[cls])


## Per-class topk

visualization over subsets

In [None]:
pbar.descnext('topk in each class')
def compute_image_max_per_class(batch, class_batch, index_batch, *args):
    classes = class_batch.bincount().nonzero()
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    acts = acts.view(acts.shape[0], acts.shape[1], -1)
    acts = acts.max(2)[0].cpu()
    for cls in classes:
        mask = (class_batch == cls)
        yield (cls.item(), acts[mask], index_batch[mask])

topk_perclass = tally.tally_conditional_topk(compute_image_max_per_class, dataset,
    sample_size=sample_size,
    batch_size=50, num_workers=30, pin_memory=True,
    cachefile=resfile('topk_perclass.npz'))

# Visualization of class accuracy drop

Plotting per-class accuracy drop versus salience (mutual information) ordering

In [None]:
# Which classes are most salient to each concept?

def align_yaxis(ax1, ax2):
    """Align zeros of the two axes, zooming them out by same ratio"""
    axes = (ax1, ax2)
    extrema = [ax.get_ylim() for ax in axes]
    tops = [extr[1] / (extr[1] - extr[0]) for extr in extrema]
    # Ensure that plots (intervals) are ordered bottom to top:
    if tops[0] > tops[1]:
        axes, extrema, tops = [list(reversed(l)) for l in (axes, extrema, tops)]

    # How much would the plot overflow if we kept current zoom levels?
    tot_span = tops[1] + 1 - tops[0]

    b_new_t = extrema[0][0] + tot_span * (extrema[0][1] - extrema[0][0])
    t_new_b = extrema[1][1] - tot_span * (extrema[1][1] - extrema[1][0])
    axes[0].set_ylim(extrema[0][0], b_new_t)
    axes[1].set_ylim(t_new_b, extrema[1][1])

def plot_intervention_classes(concept, acc_diff, count=None, title=None):
    most_salient_classes = salience.sort(0)[1].flip(0)
    # concept = seglabels.index('bed')
    dpi = 100
    # f, (a1, a0) = plt.subplots(2, 1, gridspec_kw = {'height_ratios':[1, 2]}, dpi=dpi)
    f, a1 = plt.subplots(1, 1, dpi=dpi, figsize=(10, 5))

    x = []
    labels = []
    mutual_info = []
    accuracy_diff = []
    for cls in most_salient_classes[:,concept][:count]:
        # pbar.print(classlabels[cls], salience[cls, concept])
        mutual_info.append(salience[cls, concept])
        labels.append(classlabels[cls])
        x.append(len(x))
        accuracy_diff.append(acc_diff[cls])
    a1.bar(x, mutual_info)
    a1.set_ylabel('Concept-class mutual information')
    a1.set_xlabel('Classes ordered by correlation with concept "%s"' % seglabels[concept])
    if len(x) < 60:
        a1.set_xticks(x)
        a1.set_xticklabels([label.replace('_', ' ') for label in labels], rotation='vertical')
    
    a2 = a1.twinx()
    a2.plot(x, [-a for a in accuracy_diff], linewidth=2, color='orange')
    a2.spines["right"].set_visible(True)
    a2.set_ylabel('Class accuracy drop')
    if title is None:
        title = 'Effect of zeroing detector for %s' % (seglabels[concept])
    a2.set_title(title)
    align_yaxis(a2, a1)
    plt.show()



# Ablation of single concept detectors.

Question: when we ablate a single unit, how does it affect accuracy of each output class?

In [None]:
pbar.descnext('baseline_acc')
baseline_accuracy = experiment.test_perclass_accuracy(model, dataset,
        cachefile=resfile('acc_baseline.npy'))
pbar.print('baseline acc', baseline_accuracy.mean().item())

Above: recall which concepts are present and seem to have one or more units specific to that concept.

Below, pull out units to probe corresponding to the top 12 concepts.

In [None]:
ablation_size = 5

for label, group in experiment.get_top_label_unit_groups(unit_label_995,
        size=ablation_size, num=5, min_iou=0.02):
    concept = group[0][1]
    pbar.descnext('test %s' % label)
    ablation_accuracy = experiment.test_perclass_accuracy(model, dataset,
        layername=layername,
        ablated_units=[unit for unit, iou, concept in group],
        cachefile=resfile('acc_ablate_%d_%s.npy' % (ablation_size, label)))
    pbar.print('ablate %s units of %s(%d) acc %.3f %s' %
            (len(group), label, concept, ablation_accuracy.mean().item(),
                args.model) )
    pbar.print(', '.join(['unit %d: iou %.3f' % (unit, iou)
        for unit, concept, iou in group]))
    for unit, _, _ in group:
        print('unit %d: %s' % (unit, ', '.join(['%s: iou %.3f' % r for r in multilabels[unit]])))

    unit = group[0][0]
    display(unit_images[unit])
    # Which classes are most damaged?
    acc_diff = ablation_accuracy - baseline_accuracy
    for cls in acc_diff.sort(0)[1][:5]:
        pbar.print('%s(%d) (mi %.4f): acc %.2f -> %.2f' % (
            classlabels[cls], cls,
            salience[cls, concept],
            baseline_accuracy[cls], ablation_accuracy[cls]))
        # display(iv.masked_image_for_conditional_topk(compute_acts, dataset, topk_perclass, cls.item(), unit))
    plot_intervention_classes(seglabels.index(label), acc_diff,
            title='Effect of zeroing 5 units (%s %s detectors)' % (args.model, label))
    plot_intervention_classes(seglabels.index(label), acc_diff,
            count=50,
            title='Effect of zeroing 5 units (%s %s detectors)' % (args.model, label))
    pbar.print()


In [None]:
multilabels.keys()

# Ablation of single units.

Question: when we ablate a single unit, how does it affect accuracy of each output class?

In [None]:
top_iou_units = sorted([(unit, label, iou)
        for unit, (concept, label, labelcat, iou) in enumerate(unit_label_995)],
        key=lambda x: -x[-1])[:300]
for unit, label, iou in [r for r in top_iou_units if r[0] == 48]:
    pbar.descnext('test unit %d' % unit)
    ablation_accuracy = experiment.test_perclass_accuracy(model, dataset,
        layername=layername,
        ablated_units=[unit],
        cachefile=resfile('acc_ablate_unit_%d.npy' % (unit)))
    pbar.print('ablate unit %d (%s iou %.3f) acc %.3f %s' %
            (unit, label, iou, ablation_accuracy.mean().item(),
                args.model) )
    display(unit_images[unit])
    # Which classes are most damaged?
    acc_diff = ablation_accuracy - baseline_accuracy
    for cls in acc_diff.sort(0)[1][:10]:
        pbar.print('%s(%d) (mi %.4f): acc %.2f -> %.2f' % (
            classlabels[cls], cls,
            salience[cls, concept],
            baseline_accuracy[cls], ablation_accuracy[cls]))
        display(iv.masked_image_for_conditional_topk(compute_acts, dataset, topk_perclass, cls.item(), unit))
    plot_intervention_classes(seglabels.index(label), acc_diff,
            title='Effect of zeroing unit %d (%s %s detector, iou %.3f)' % (unit, args.model, label, iou))
    plot_intervention_classes(seglabels.index(label), acc_diff,
            count=50,
            title='Effect of zeroing unit %d (%s %s detector, iou %.3f)' % (unit, args.model, label, iou))
    pbar.print()

## Load all single-unit ablation perclass accuracy matrix

In [None]:
single_unit_ablation_acc = torch.zeros(num_units, len(classlabels))

for unit in range(num_units):
    single_unit_ablation_acc[unit] = experiment.test_perclass_accuracy(model, dataset,
        layername=layername,
        ablated_units=[unit],
        cachefile=resfile('acc_ablate_unit_%d.npy' % (unit)))

In [None]:
ablation_delta = single_unit_ablation_acc - baseline_accuracy
ablation_delta.max(0)[0].mean(),  ablation_delta.min(0)[0].mean()

# Focus on single discriminative class

In [None]:
focus_class = 'mosque-outdoor'
clsnum = dataset.classes.index(focus_class)
clsnum

Recall on church images is 41%.

In [None]:
baseline_accuracy[clsnum]

In [None]:
discrimination = experiment.load_lda_vector(focus_class, args, segmodel, classlabels, seglabels, shrinkage=0.1)

Here are the top concepts that are most salient to churches, just by mutual information.

In [None]:
for concept in discrimination.sort(0)[1].flip(0)[:20]:
    print(seglabels[concept], discrimination[concept].item())

In [None]:
for unit in ablation_delta[:,clsnum].sort(0)[1][:10]:
    concept, label, labelcat, iou = unit_label_995[unit]
    damage = ablation_delta[unit, clsnum]
    print('unit %d (%s, iou %.3f) causes damage %.3f' % (unit, label, iou, damage))
    display(unit_images[unit])
    display(iv.masked_image_for_conditional_topk(compute_acts, dataset, topk_perclass, clsnum, unit.item()))


In [None]:
unit_label_995

# Visualize units


In [None]:
def plot_twin(triples, count=None, title=None, dpi=100, barlabel=None, linelabel=None,
              label_ticks=True, figsize=(10, 5)):
    ordering = [i for t, i in sorted((t, i) for i, t in enumerate(triples))[::-1]]
    x = []
    labels = []
    bars = []
    lines = []
    for i in ordering[:count]:
        x.append(len(x))
        bars.append(triples[i][0])
        lines.append(triples[i][1])
        labels.append(triples[i][2])
    f, a1 = plt.subplots(1, 1, dpi=dpi, figsize=figsize)
    f.patch.set_facecolor('white')
    a1.bar(x, bars)
    if barlabel is not None:
        a1.set_ylabel(barlabel)
        a1.set_xlabel('Ordered by %s' % barlabel)
    if label_ticks:
        a1.set_xticks(x)
        a1.set_xticklabels([label.replace('_', ' ') for label in labels], rotation='vertical')
    a2 = a1.twinx()
    a2.plot(x, lines, linewidth=2, color='orange')
    a2.spines["right"].set_visible(True)
    if linelabel is not None:
        a2.set_ylabel(linelabel)
    if title:
        a2.set_title(title)
    align_yaxis(a2, a1)
    plt.show()

In [None]:
focus_class = dataset.classes[28]
clsnum = dataset.classes.index(focus_class)
discrimination = experiment.load_lda_vector(focus_class, args, segmodel, classlabels, seglabels, shrinkage=0.1)

clsnum

def plot_intervention_units(clsnum, ablation_delta, unit_label_995, discrimination, count=None,
                            figsize=(20, 5), label_ticks=True):
    triples = []
    for unit in range(256):
        # Bardata is salience of most-salient unit concept
        # Linedata is damage done by the unit
        # Method 1.
        # matching_concepts = [iou_995[:, unit].max(0)[1]]
        matching_concepts = (iou_995[:, unit] > 0.02).nonzero()[:,0]
        if len(matching_concepts) == 0:
            matching_concepts = [iou_995[:, unit].max(0)[1]]
        #    relevance, relevant_concept = (0.0, 0)
        # else:
        relevance, relevant_concept = max([(discrimination[c, 0], c) for c in matching_concepts])
        relevance = relevance.item()
        bar = relevance
        label = '%s (%d)' % (seglabels[relevant_concept], unit)
        line = -ablation_delta[unit, clsnum].item()
        triples.append((bar, line, label))
    plot_twin(triples, count=count, figsize=figsize, label_ticks=label_ticks,
              title="Can we predict how much a unit will damage %s classification accuracy?" % focus_class,
              barlabel="salience of unit concept to %s (bars)" % focus_class,
              linelabel="damage to accuracy of %s (line)" % focus_class)

plot_intervention_units(clsnum, ablation_delta, unit_label_995, discrimination, count=256, figsize=(50, 5))


In [None]:
unit = 11
matching_concepts = (iou_995[:, unit] > 0.02).nonzero()[:,0]
relevance, relevant_concept = max([(discrimination[c, 0], c) for c in matching_concepts])
seglabels[relevant_concept], relevance

In [None]:
for clsnum in range(0, len(classlabels), 10):
    focus_class = dataset.classes[clsnum]
    print(focus_class)
    # clsnum = dataset.classes.index(focus_class)
    discrimination = experiment.load_lda_vector(focus_class, args, segmodel, classlabels, seglabels, shrinkage=0.1)
    plot_intervention_units(clsnum, ablation_delta, unit_label_995, discrimination, count=256, label_ticks=False)
    plot_intervention_units(clsnum, ablation_delta, unit_label_995, discrimination, count=80, label_ticks=True)


In [None]:
# Make a table of concepts that are most disciminative
discriminate_matrix = torch.zeros(len(classlabels), len(seglabels))
for clsnum in range(len(classlabels)):
    focus_class = dataset.classes[clsnum]
    d = experiment.load_lda_vector(focus_class, args, segmodel, classlabels, seglabels, shrinkage=0.1)
    discriminate_matrix[clsnum] = d[:,0]

In [None]:
unit_damage_matrix = torch.zeros(len(classlabels), 256)
for unit in range(256):
    ablation_accuracy = experiment.test_perclass_accuracy(model, dataset,
        layername=layername,
        ablated_units=[unit],
        cachefile=resfile('acc_ablate_unit_%d.npy' % (unit)))
    acc_diff = ablation_accuracy - baseline_accuracy
    unit_damage_matrix[:,unit] = acc_diff

Finding:

On average, zeroing a unit that detects the most discriminative concept for a class (for the 153 classes for which there is a unit for the most discriminative concept) damages accuracy of classification of class by an average of 4.1%, whereas zeroing other units damages accuracy of that class only by an average of 0.05%.

In [None]:
# Idea: for each unit, salience maybe shold be given by the most salient high-iou concept detected by the unit.
relevant_units = []
all_other_units = []
counted_classes = 0

for clsnum in range(len(classlabels)):
    segnum = discriminate_matrix[clsnum].max(0)[1]
    # print('Most relevant to %s is %s' % (classlabels[clsnum], seglabels[segnum]))
    units = [unit
             for unit, (s, _, _, iou) in enumerate(unit_label_995)
             if s == segnum
             if iou > 0.03
            ]
    if not len(units):
        continue
    counted_classes += 1
    other_units = [u for u in range(256) if u not in units]
    # print(', '.join(str(r) for r in units))
    relevant_units.extend(unit_damage_matrix[clsnum, units].numpy().tolist())
    all_other_units.extend(unit_damage_matrix[clsnum, other_units].numpy().tolist())

print('Counted %d classes' % counted_classes)
print('Of the %d most relevant units, average damage is %.3g' %
      (len(relevant_units), torch.tensor(relevant_units).mean().item()))
print('Of the %d most other units, average damage is %.3g' %
      (len(all_other_units), torch.tensor(all_other_units).mean().item()))

Second experiment: compare ablation of the most-relevant concept detectors, where the most-relevant concept is counted among only those concepts that exist.

In [None]:
relevant_units = []
all_other_units = []

iou_floor = 0.03

for clsnum in range(len(classlabels)):
    # segnum = discriminate_matrix[clsnum].max(0)[1]
    # even if there is not a unit for the most discriminate feature, find unit
    # for the most discriminate feature for which there is a unit.
    units = sorted([(-discriminate_matrix[clsnum, s], s, unit)
             for unit, (s, _, _, iou) in enumerate(unit_label_995)
             if iou > iou_floor
             ])
    segnum = units[0][1]
    units = [unit
             for unit, (s, _, _, iou) in enumerate(unit_label_995)
             if s == segnum
             and iou > iou_floor
            ]
    other_units = [u for u in range(256) if u not in units]
    # print('Most relevant to %s is %s (%d units)' % (classlabels[clsnum], seglabels[segnum], len(units)))
    # print(', '.join(str(r) for r in units))
    relevant_units.extend(unit_damage_matrix[clsnum, units].numpy().tolist())
    all_other_units.extend(unit_damage_matrix[clsnum, other_units].numpy().tolist())
print('Of the %d most relevant units, average damage is %.3g' %
      (len(relevant_units), torch.tensor(relevant_units).mean().item()))
print('Of the %d most other units, average damage is %.3g' %
      (len(all_other_units), torch.tensor(all_other_units).mean().item()))

print('Ratio %.3f' % (
    torch.tensor(relevant_units).mean().item() / torch.tensor(all_other_units).mean().item()))

Third idea: for each class, order units according to the salience of the detected concept, and; and then average the impacts.

In [None]:
all_sorted_damage = torch.zeros(len(classlabels), 256)
all_unit_concepts = set(u[1] for u in unit_label_995)
for clsnum in range(len(classlabels)):
    unit_sorter = sorted([(-discriminate_matrix[clsnum, s], s, unit)
             for unit, (s, _, _, iou) in enumerate(unit_label_995)
             ])
    unit_order = [u[-1] for u in unit_sorter]
    unit_concept = [u[1] for u in unit_sorter]
    sorted_damage = unit_damage_matrix[clsnum, unit_order]
    # since units with the same concept could be listed in any order, average their contributionns
    for s in all_unit_concepts:
        sorted_damage[unit_concept == s] = sorted_damage[unit_concept == s].mean()
    all_sorted_damage[clsnum] = sorted_damage

f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.bar(range(51), (-all_sorted_damage.mean(0).numpy()[:50] * 100).tolist() +
       [-all_sorted_damage.mean(0).numpy()[50:].mean() * 100])

a1.set_title('Effect of removing an object detector unit on classification accuracy of a scene class')
a1.set_ylabel('Damage to classification accuracy of scene class\nwhen a single unit is zeroed, percent')
a1.set_xlabel('Units ordered by (LDA-determined) salience of detected object to the affected scene class')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])

Adaptive case

In [None]:
all_sorted_damage = torch.zeros(len(classlabels), 256)
all_unit_concepts = set(u[1] for u in unit_label_adaptive)
for clsnum in range(len(classlabels)):
    unit_sorter = sorted([(-discriminate_matrix[clsnum, s], s, unit)
             for unit, (s, _, _, iou) in enumerate(unit_label_adaptive)
             ])
    unit_order = [u[-1] for u in unit_sorter]
    unit_concept = [u[1] for u in unit_sorter]
    sorted_damage = unit_damage_matrix[clsnum, unit_order]
    # since units with the same concept could be listed in any order, average their contributionns
    for s in all_unit_concepts:
        sorted_damage[unit_concept == s] = sorted_damage[unit_concept == s].mean()
    all_sorted_damage[clsnum] = sorted_damage

f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.bar(range(51), (-all_sorted_damage.mean(0).numpy()[:50] * 100).tolist() +
       [-all_sorted_damage.mean(0).numpy()[50:].mean() * 100])

a1.set_title('Effect of removing an object detector unit on classification accuracy of a scene class')
a1.set_ylabel('Damage to classification accuracy of scene class\nwhen a single unit is zeroed, percent')
a1.set_xlabel('Units ordered by (LDA-determined) salience of detected object to the affected scene class')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])

Fourth idea: scatterplot.  Salience rank of a unit on the x axis, and classification accuracy damage on the y axis.

In [None]:
all_sorted_damage = torch.zeros(len(classlabels), 256)
all_unit_concepts = set(u[1] for u in unit_label_995)
yvals = []
xvals = []
for clsnum in range(len(classlabels)):
    unit_sorter = sorted([(-discriminate_matrix[clsnum, s], s, unit)
             for unit, (s, _, _, iou) in enumerate(unit_label_995)
             ])
    unit_order = [u[-1] for u in unit_sorter]
    unit_concept = [u[1] for u in unit_sorter]
    sorted_damage = unit_damage_matrix[clsnum, unit_order]
    rank_order = torch.arange(len(sorted_damage), dtype=torch.float)
    # since units with the same concept could be listed in any order, average their contributionns
    for s in all_unit_concepts:
        rank_order[unit_concept == s] = rank_order[unit_concept == s].mean()
    xvals.extend(rank_order.numpy().tolist())
    yvals.extend(sorted_damage.numpy().tolist())
    all_sorted_damage[clsnum] = sorted_damage

In [None]:
import random
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(30, 5))
a1.scatter([x + random.random() for x in xvals],
           [y + random.random() * 0.01 for y in yvals], s=0.5, alpha=0.2)
a1.set_xlabel('Units ordered by salience of detected object to the affected scene class using LDA')


In [None]:
[(u, iou) for u, (s, label, labelcat, iou) in enumerate(unit_label_995) if label == 'bed']

Fifth idea, similar to the "Effect" graph, but here the x axis is purely determined by LDA and has nothing to do with the network being teested.

In [None]:
total_sorted_damage = torch.zeros(len(classlabels), len(seglabels))
count_sorted_damage = torch.zeros(len(classlabels), len(seglabels))

for clsnum in range(len(classlabels)):
    dscore, drank = (-discriminate_matrix[clsnum]).sort(0)
    zerorank = (dscore == 0).nonzero()[:, 0].numpy().tolist()
    rankmap = {s.item(): zerorank if r in zerorank else [r] for r, s in enumerate(drank)}
    for u, (s, label, labelcat, iou) in enumerate(unit_label_995):
        r = rankmap[s]
        damage = unit_damage_matrix[clsnum, u]
        total_sorted_damage[clsnum, r] += (damage / len(r))
        count_sorted_damage[clsnum, r] += (1.0 / len(r))
avg_sorted_damage = (total_sorted_damage.sum(0) / count_sorted_damage.sum(0))

In [None]:
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.bar(range(len(avg_sorted_damage)), -avg_sorted_damage.numpy() * 100)
f.show()

f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.bar(range(len(count_sorted_damage.sum(0))), count_sorted_damage.sum(0).numpy())
f.show()

In [None]:
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.bar(range(51), (-avg_sorted_damage.numpy()[:50] * 100).tolist() +
       [-avg_sorted_damage.numpy()[50:].mean() * 100])
a1.set_ylabel('Damage to classification accuracy of scene class\nwhen a single unit is zeroed, percent')
a1.set_xlabel('Which object detector is zeroed, identified by dissection, ordered by LDA salience of the object to the scene')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])

Sixth idea, put average rank by LDA on the y axis, and put rank by intervention impact on the x axis.

In [None]:
lda_rank_for_unit = torch.zeros(len(classlabels), 256)
lda_count_for_unit = torch.zeros(len(classlabels), 256)

for clsnum in range(len(classlabels)):
    damage, damrank = (unit_damage_matrix[clsnum]).sort(0)
    rankmap = {damrank[r].item(): (damage == d).nonzero()[:, 0].numpy().tolist() for r, d in enumerate(damage)}
    
    dscore, drank = (-discriminate_matrix[clsnum]).sort(0)
    # TODO: handle zero rank
    # zerorank = (dscore == 0).nonzero()[:, 0].numpy().mean()
    srankmap = {s.item(): r for r, s in enumerate(drank)}

    for u, (s, label, labelcat, iou) in enumerate(unit_label_995):
        ur = rankmap[u]
        sr = srankmap[s]
        
        damage = unit_damage_matrix[clsnum, u]
        lda_rank_for_unit[clsnum, ur] += (sr / len(ur))
        lda_count_for_unit[clsnum, ur] += (1.0 / len(ur))
        
avg_lda_rank_for_unit = (lda_rank_for_unit.sum(0) / lda_count_for_unit.sum(0))

In [None]:
avg_lda_rank_for_unit

In [None]:
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.plot(range(51), (avg_lda_rank_for_unit.numpy()[:50]).tolist() +
       [avg_lda_rank_for_unit.numpy()[50:].mean()])
a1.set_ylabel('Average rank of object detected by unit,\nordered by salience to class')
a1.set_xlabel('Which unit zeroed, ordered by damage caused to the scene class')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])

Seventh idea, LDA coefficient on y axis

In [None]:
lda_weight_for_unit = torch.zeros(len(classlabels), 256)
lda_count_for_unit = torch.zeros(len(classlabels), 256)

for clsnum in range(len(classlabels)):
    damage, damrank = (unit_damage_matrix[clsnum]).sort(0)
    rankmap = {damrank[r].item(): (damage == d).nonzero()[:, 0].numpy().tolist() for r, d in enumerate(damage)}
    
    ldavec = discriminate_matrix[clsnum]
    ldavec /= ldavec.max()
    dscore, drank = (-ldavec).sort(0)
    dscore = -dscore
    # TODO: handle zero rank
    # zerorank = (dscore == 0).nonzero()[:, 0].numpy().mean()
    srankmap = {s.item(): sc.item() for s, sc in zip(drank, dscore)}

    for u, (s, label, labelcat, iou) in enumerate(unit_label_995):
        ur = rankmap[u]
        sr = srankmap[s]
        
        damage = unit_damage_matrix[clsnum, u]
        lda_weight_for_unit[clsnum, ur] += (sr / len(ur))
        lda_count_for_unit[clsnum, ur] += (1.0 / len(ur))
        
avg_lda_weight_for_unit = (lda_weight_for_unit.sum(0) / lda_count_for_unit.sum(0))

In [None]:
lda_weight_for_unit.mean()

In [None]:
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.plot(range(51), (avg_lda_weight_for_unit.numpy()[:50]).tolist() +
       [avg_lda_weight_for_unit.numpy()[50:].mean()], linewidth=2,
        label="LDA salience of objects detected by units with largest causal effect on class accuracy.")
a1.plot(range(51), 51 * [lda_weight_for_unit.mean().item()], linewidth=2, alpha=0.7,
       label="Mean LDA salience for random units %.2g.  (Object with maximum salience is 1.0.)" %
        lda_weight_for_unit.mean().item())
a1.set_ylabel('Average LDA salience of object detected by zeroed unit')
a1.set_xlabel('Which unit zeroed, ordered by damage caused to the scene class')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])
f.legend()

Repeat the experiment for adaptive case.

In [None]:
lda_weight_for_unit = torch.zeros(len(classlabels), 256)
lda_count_for_unit = torch.zeros(len(classlabels), 256)

for clsnum in range(len(classlabels)):
    damage, damrank = (unit_damage_matrix[clsnum]).sort(0)
    rankmap = {damrank[r].item(): (damage == d).nonzero()[:, 0].numpy().tolist() for r, d in enumerate(damage)}
    
    ldavec = discriminate_matrix[clsnum]
    ldavec /= ldavec.max()
    dscore, drank = (-ldavec).sort(0)
    dscore = -dscore
    # TODO: handle zero rank
    # zerorank = (dscore == 0).nonzero()[:, 0].numpy().mean()
    srankmap = {s.item(): sc.item() for s, sc in zip(drank, dscore)}

    for u, (s, label, labelcat, iou) in enumerate(unit_label_adaptive):
        ur = rankmap[u]
        sr = srankmap[s]
        
        damage = unit_damage_matrix[clsnum, u]
        lda_weight_for_unit[clsnum, ur] += (sr / len(ur))
        lda_count_for_unit[clsnum, ur] += (1.0 / len(ur))
        
avg_lda_weight_for_unit_adaptive = (lda_weight_for_unit.sum(0) / lda_count_for_unit.sum(0))
f, a1 = plt.subplots(1, 1, dpi=200, figsize=(10, 5))
a1.plot(range(51), (avg_lda_weight_for_unit_adaptive.numpy()[:50]).tolist() +
       [avg_lda_weight_for_unit_adaptive.numpy()[50:].mean()], linewidth=2,
        label="LDA salience of objects detected by units with largest causal effect on class accuracy.")
a1.plot(range(51), 51 * [lda_weight_for_unit.mean().item()], linewidth=2, alpha=0.7,
       label="Mean LDA salience for objects detected by random units")
a1.set_ylabel('Average LDA salience of object detected by zeroed unit')
a1.set_xlabel('Which unit zeroed, ordered by damage caused to the scene class')
a1.set_xticks([0, 9, 19, 29, 39, 50])
a1.set_xticklabels(['1', '10', '20', '30', '40',  '>50'])
f.legend()