# Network Dissection (for classifiers)

In this notebook, we will examine internal layer representations for a classifier trained to recognize scene categories.

Setup matplotlib, torch, and numpy for a high-resolution browser.

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import reload
import IPython
mpl.rcParams['lines.linewidth'] = 0.25
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.linewidth'] = 0.25

Set up experiment directory and settings

In [None]:
import torch, argparse, os, shutil, inspect, json, numpy, math
import netdissect
from netdissect.easydict import EasyDict
from netdissect import pbar, nethook, renormalize, parallelfolder, pidfile
from netdissect import upsample, tally, imgviz, imgsave, bargraph, show
from experiment import dissect_experiment as experiment

# choices are alexnet, vgg16, or resnet152.
args = EasyDict(model='vgg16', dataset='places', seg='netpqc', layer='conv5_3', quantile=0.01)
resdir = 'results/new-%s-%s-%s-%s-%s' % (args.model, args.dataset, args.seg, args.layer, int(args.quantile * 1000))
def resfile(f):
    return os.path.join(resdir, f)


load classifier model and dataset

In [None]:
model = experiment.load_model(args, instrumented=False) # Don't wrap in InstrumentedModel
layername = experiment.instrumented_layername(args)
dataset = experiment.load_dataset(args)
sample_size = len(dataset)
percent_level = 1.0 - args.quantile

print('Inspecting layer %s of model %s on %s' % (layername, args.model, args.dataset))

In [None]:
model

Load segmenter, segment labels, classifier labels

In [None]:
# Classifier labels
from urllib.request import urlopen
from netdissect import renormalize

# synset_url = 'http://gandissect.csail.mit.edu/models/categories_places365.txt'
# classlabels = [r.split(' ')[0][3:] for r in urlopen(synset_url).read().decode('utf-8').split('\n')]
classlabels = dataset.classes
segmodel, seglabels, segcatlabels = experiment.setting.load_segmenter(args.seg)
renorm = renormalize.renormalizer(dataset, target='zc')

Test classifier on some images

In [None]:
from netdissect import renormalize

indices = [200, 755, 709, 423, 60, 100, 110, 120]
batch = torch.cat([dataset[i][0][None,...] for i in indices])
truth = [classlabels[dataset[i][1]] for i in indices]
preds = model(batch.cuda()).max(1)[1]
imgs = [renormalize.as_image(t, source=dataset) for t in batch]
prednames = [classlabels[p.item()] for p in preds]
show([[img, 'pred: ' + pred, 'true: ' + gt] for img, pred, gt in zip(imgs, prednames, truth)])


segment single image, and visualize the labels

In [None]:
from netdissect import imgviz

iv = imgviz.ImageVisualizer(120, source=dataset)
seg = segmodel.segment_batch(renorm(batch).cuda(), downsample=4)

show([(iv.image(batch[i]), iv.segmentation(seg[i,0]),
            iv.segment_key(seg[i,0], segmodel))
            for i in range(len(seg))])

## Collect quantile statistics

First, unconditional quantiles over the activations; also keep track of the top 200 activating images for each channel.

In [None]:
pbar.descnext('rq/topk')
from netdissect import dissect
topk, rq, run = dissect.acts_stats(model, dataset, layer=layername,
            k=100,
            batch_size=30,
            num_workers=30,
            cachedir=resdir)

In [None]:
pbar.descnext('unit_images')

iv = imgviz.ImageVisualizer((100, 100), source=dataset, quantiles=rq,
        level=rq.quantiles(percent_level))

unit_images = iv.masked_images_for_topk(
        run, dataset, topk, k=5, num_workers=30, pin_memory=True,
        cachefile=resfile('top5images.npz'))

## Label Units

Label according to new algorithm where iou is scored only among top-k matching images.

In [None]:
reload(dissect)
level = rq.quantiles(percent_level)
    
all_iou = dissect.topk_label_stats_using_segmodel(dataset, segmodel,
                 run, level, topk, k=100, downsample=4)


In [None]:
all_iou[:,0] = 0 # ignore the dummy label
unit_label_99 = [
        (concept.item(), seglabels[concept], segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*all_iou.max(1))]
label_list = [labelcat for concept, label, labelcat, iou in unit_label_99 if iou > 0.04]
display(IPython.display.SVG(experiment.graph_conceptcatlist(label_list)))
len(label_list)

Show a few units with their labels

In [None]:
for u in range(30):
    print('unit %d, label %s, iou %.3f' % (u, unit_label_99[u][1], unit_label_99[u][3]))
    display(unit_images[u])