In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
pip install ninja 2>> install.log
git clone https://github.com/SIDN-IAP/global-model-repr.git tutorial_code 2>> install.log

In [None]:
try: # set up path
    import google.colab, sys, torch
    sys.path.append('/content/tutorial_code')
    if not torch.cuda.is_available():
        print("Change runtime type to include a GPU.")
except:
    pass

# Import packages

Import torch, netdissect, matplotlib, and set up some things.

In [None]:
import torch, os, matplotlib.pyplot as plt
from netdissect import nethook, imgviz, show, segmenter, renormalize, upsample, tally, pbar
from netdissect import setting

torch.backends.cudnn.benchmark = True
torch.set_grad_enabled(False) # not training anything!

# Load some data

**ds** is a dataset of pictures of places.

It is the validation set from the Places365 dataset.

In [None]:
ds = setting.load_dataset('places', 'val')
iv = imgviz.ImageVisualizer(224, source=ds, percent_level=0.99)
show(iv.image(ds[0][0]))

# Load a pretrained classifier model

**model** is a pretrained VGG classifier that classifies scenes.

In [None]:
model = setting.load_vgg16()
model = nethook.InstrumentedModel(model)
model.cuda()
renorm = renormalize.renormalizer(source=ds, target='zc')
ivsmall = imgviz.ImageVisualizer((56, 56), source=ds, percent_level=0.99)


## Warmup: look at the model

### Exercise 1.
* How many layers does the VGG network have?  `print(model)` will show them.
* What is the fully qualified name of the last convolutional layer?  Look at `model.layernames()`.

In [None]:
# print(model) etc.

In the short example below:
* **indexes** is a list of dataset indexes to retrieve.  `i` indicates a dataset index, and `j` is an index into the indexes array.
* **batch** is a `12 x 3 x 224 x 224` tensor that stacks up twelve RGB 224x224 images from the dataset.
* When we run `model(batch.cuda())`, it scores every image for every class, making a `12 x 365` tensor of scores.
* Then `.max(dim=1)` finds the maximum of 365 scores for each image; it returns a (scores, indexes) tuple.
* **preds** is a tensor of 12 highest scoring class indexes (each one a number up to 365) predicted by the model.
* `iv.image(batch[j])` turns the jth `3 x 224 x 224` tensor into a PIL image for display.
* `ds.classes[ds[i][1]]` shows the human ground-truth label for the `i`th image in the dataset.

So the loop shows a set of twelve images, each with the dataset label and the model prediction. Scene classification is difficult and sometimes ambiguous; nevertheless the model does reasonably well.

### Exercise 2.

(Optional.)  Explore the data set, and the model's predictions on the data.

* Change the **indexes** array to contain a few `soccer_field` and `baseball_field` images within the data set, that is, a set of indexes `i` for which `ds[i][1]` matches the class number for either of those classes.  A tip: `ds.classes.index('soccer_field')` is 310 and the index of `baseball_field` is 42.
* Can you find a baseball field image that is incorrectly classified as a soccer field?

Scrutinize the images, and consider how the model might be making its predictions.  What it might be looking for within the images to tell the difference between baseball and soccer?

In [None]:
target_class = ds.classes.index('soccer_field')
print(target_class)
indexes = range(100, 112)
batch = torch.stack([ds[i][0] for i in indexes])
_, preds = model(batch.cuda()).max(1)
show([[
    iv.image(batch[j]),
    'label: ' + ds.classes[ds[i][1]],
    'pred: ' + ds.classes[preds[j]],
    i,
] for j, i in enumerate(indexes)])

For reference, below is the typical way we evaluate a classifier: check its accuracy on the dataset.  While this gives us a global view of the model (e.g. 53% accuracy), it doesn't show us what the model does internally at all.

In [None]:
if False:
    correct = 0
    tested = 0
    for imagebatch, labelbatch in pbar(torch.utils.data.DataLoader(ds, batch_size=100)):
        modelpreds = model(imagebatch.cuda()).max(1)[1]
        # print(modelpreds.cpu(), labelbatch)
        correct += (modelpreds.cpu() == labelbatch).sum() # fixme
        tested += len(labelbatch)
    print('%d correct out of %d' % (correct, tested))


## Examine raw unit activations.

This bit of code shows the output of individual filters in a layer directly.

It shows each filter in two ways.  First, it overlays a region of high activation on the image; and on the right, it shows a heatmap of filter activations.

### Exercise 3: look at individual activations.

* Change the layername, and compare the activation patterns in early convolutional layers, like conv2_1, with later ones, like conv5_3.
* Change j to select an image with people in it, and look at all 512 filters of conv5_3.
* Do any filters seem to be sensitive to particular body parts in this image?  Which ones?

In [None]:
layername = 'features.conv5_3'
model.retain_layer(layername)
model(batch.cuda())
acts = model.retained_layer(layername).cpu()
show([
    [
        [ivsmall.masked_image(batch[imagenum], acts[imagenum], unitnum)],
        [ivsmall.heatmap(acts[imagenum], unitnum, mode='nearest')],
        'unit %d' % unitnum
    ]
    for unitnum in range(acts.shape[1])
    for imagenum in [6]
])

In [None]:
upfn = upsample.upsampler(
    target_shape=(56, 56),
    data_shape=(7, 7),
)

def flatten_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

rq = tally.tally_quantile(
    flatten_activations,
    dataset=ds,
    sample_size=1000,
    batch_size=100)
    #cachefile='results/rq_cache.npz')

### Exercise 4: look at the ranges of activations

The loop above collects statistics of each filter over a sample of 1000 images.
What are typical values of the filters?  How often do they fire?

* Plot median (0.5 quantile) values of each filter in conv5_3.
* Compare the 0.5, 0.8, 0.9, and 0.99 quantiles for each filter.
* Do different units activate in different ranges from one another?


In [None]:
plt.plot(rq.quantiles(0.9))

### Exercise 5: examine images that maximize each unit

The loop below identifies the images, out of a sample of 1000, that cause each filter to activate strongest.  The current code tallies up images that maximize the mean activation of the filter over the image.
* (Optional) Change the code to find images that maximize the maximum acvitation across the image instead.

In [None]:
sample_size = 1000

def max_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).max(2)[0]

def mean_activations(batch, *args):
    image_batch = batch.cuda()
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    return acts.view(acts.shape[:2] + (-1,)).mean(2)

topk = tally.tally_topk(
    mean_activations,
    dataset=ds,
    sample_size=sample_size,
    batch_size=100,
    cachefile='results/cache_mean_topk.npz'
)

top_indexes = topk.result()[1]

Below is a loop that runs the model for each of the top-activating images for a particular unit (12), and then shows where that unit activates within the images.

* Change the unit number to examine the behavior of different units.

In [None]:
show.blocks([
    ['unit %d' % u,
     'img %d' % i,
     'pred: %s' % ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()],
     [iv.masked_image(
        ds[i][0],
        model.retained_layer(layername)[0],
        u)]
    ]
    for u in [12]
    for i in top_indexes[u, :20]
])

The following code automates the above process for all the units, collecting a visualization of top images for each unit in the network.

* Examine `unit_images[u]` for various units `u`.

In [None]:
def compute_activations(image_batch, label_batch):
    image_batch = image_batch.cuda()
    _ = model(image_batch)
    acts_batch = model.retained_layer(layername)
    return acts_batch

unit_images = iv.masked_images_for_topk(
    compute_activations,
    ds,
    topk,
    k=5,
    num_workers=10,
    pin_memory=True,
    cachefile='results/cache_top10images.npz')

## Loading a segmentation model.

To systematically identify units that match semantic concepts better or worse, we can find units that align well with the predictions of a semantic segmentation network.

The code below runs and displays segmentations on a batch of images.

**seg** is a tensor that assigns a set of semantic segmentation labels to every pixel of an image.  `seg[i, 0]` shows the 0th label for each pixel of the `i`th image.

In [None]:
segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')

  torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [None]:
seg = segmodel.segment_batch(renorm(batch).cuda(), downsample=4)
show([(iv.image(batch[i]), iv.segmentation(seg[i, 0]),
            iv.segment_key(seg[i,0], segmodel))
            for i in range(len(seg))])

The code below finds the intersections between every unit's 99th percentile activation, and every segmentation class identified by the semgenter.  It can take a few minutes to run, so you can reduce the sample size if you do not want to wait.

In [None]:
level_at_99 = rq.quantiles(0.99).cuda()[None,:,None,None]

def compute_selected_segments(batch, *args):
    image_batch = batch.cuda()
    seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
    _ = model(image_batch)
    acts = model.retained_layer(layername)
    hacts = upfn(acts)
    iacts = (hacts > level_at_99).float() # indicator where > 0.99 percentile.
    return tally.conditional_samples(iacts, seg)

condi99 = tally.tally_conditional_mean(
    compute_selected_segments,
    dataset=ds,
    sample_size=sample_size,
    cachefile='results/cache_condi99.npz')

iou99 = tally.iou_from_conditional_indicator_mean(condi99)
iou99.shape

The code below sorts the units, showing the units with the best match to a segmentation class first.

In [None]:
iou_unit_label_99 = sorted([(
    unit, concept.item(), seglabels[concept], bestiou.item())
    for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
    key=lambda x: -x[-1])
for unit, concept, label, score in iou_unit_label_99[:20]:
    show(['unit %d; iou %g; label "%s"' % (unit, score, label),
          [unit_images[unit]]])


Which types of patterns are detected across the whole representation?

The following code counts up segmentation classes that are matched by units, and plots the histograms.

In [None]:
iou_threshold = 0.04
unit_label_99 = [
        (concept.item(), seglabels[concept],
            segcatlabels[concept], bestiou.item())
        for (bestiou, concept) in zip(*iou99.max(0))]
labelcat_list = [labelcat
        for concept, label, labelcat, iou in unit_label_99
        if iou > iou_threshold]
import IPython
IPython.display.SVG(setting.graph_conceptcatlist(labelcat_list))

To delete cached results and run things again, you can remove and recreate results directory.

In [None]:
# rm -rv '../results'

In [None]:
# mkdir '../results'

# Network Dissection

Network dissection is a systematic method for finding and measuring single units (convolutional filters) that match meaningful semantic concepts in a vision network.

Our fundamental question is this: how does the network decompose the task of understanding what an image is?  Does it identify any features that are understandable to a human?

Simply running this notebook will provide a simple dissection, but at each step, there are exercises for modifying the notebook to find more interesting results.

## About the netdissect library

The netdissect library contains several useful packages for inspecting internals of a vision network.
Here are packages that we use in this notebook:

 * **nethook** wraps any pytorch model, adding the ability to record or modify any internal computation.
 * **imgviz** provides ImageVisualizer, that collects together several useful image visualization functions.
 * **show** arranges nested arrays of PIL images and strings as nicely formatted HTML for display in a notebook.
 * **segmenter** provides an interface and a pretrained implementation for a semantic segmentation network.
 * **tally** gathers statistics over a dataset, based on your function to compute features for each datum.
 * **renormalize** deals with conversions between the zoo of RGB encoding scales typically seen in vision data.
 * **upsample** provids simple functions for resampling grid data at higher or lower resolutions.
 * **pbar** is a progress bar.

These will be explained a bit more in the exercises below.  Of course you can always run `help(object)` for a bit more information on most things in the library.  For this tutorial we also have a package **settting**, which automatically downloads and creates datasets and pretrained models that we will be looking at.

## About pretrained models and data

Here are some fixed variables that we define up-front for all the objects that we will be inspecting in this tutorial.

* **model** is the network we will look at.  It is a VGG convolutional network, trained to classify images of scenes into one of 365 place categories.  We wrap `model` as a `nethook.InstrumenteModel` so that we can easily retrieve and modiry its internal activations.
* **ds** is a small held-out sample from the Places dataset that was used to train the model; each entry is a pytorch tensor representing an image, and an integer representing the class.  A pytorch dataset can be derefernces like an array, so `ds[35]` is a pair `(x, y)` where `x` is a tensor containing RGB image data for a scene and `y` is an integer for the human-given class label.  Classnames are available as `ds.class[y]`.
* **renorm** is a function that renormalizes RGB data from the staistically-based scaling used in `ds` to a simple `[-1...1]` range scale.
* **segmodel** is a semantic segmentation network trained to recognize a large vocabulary of objects and parts of objects within scenes.  We will use it as a reference, to see if there are any internal filters that approximately match the same concepts.
* **seglabels** are human-readable names for the numerical segmentation classes.
* **iv** is an image visualization object that visualizes 2d data such as images and heatmaps as 224x224 images.
* **ivsmall** is another visualization object, but outputs smaller 56x56 images.
* **resfile** is a function that generates filenames in a results subdirectory that we will use for caching data.