<a href="https://colab.research.google.com/github/harvard-visionlab/psy1410/blob/master/psy1410_week04_alexnet_object_disentanglement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Week04 | Workshop on Object Representation

Today we're going to play with AlexNet! 

First, we're going to explore the extent to which AlexNet features at different stages "disentangle" object categories for a subset of imagenet categories.

Second, you'll either: (a) perform the same analyses on a different dataset, or (b) perform the same analyses with a different vision model. The vision model can have a different architecture ("resnet50"), or a different task ("object detection" instead of imagenet classification).

## imports

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
%%html
<style>
    .jp-Stdin-input {
        width: 90%;
    }
    
    div.p-Widget.jp-RenderedHTMLCommon.jp-RenderedMarkdown.jp-MarkdownOutput p {
      font-size: 16px
    }
    a
</style>

In [None]:
from torchvision import models, datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.pyplot as plt
import numpy as np 
import seaborn as sns 
from scipy.stats import pearsonr 

def show_kernels(model):
  if hasattr(model, 'features'):
    kernels = model.features[0].weight.detach().clone().cpu()
    kernels = kernels - kernels.min()
    kernels = kernels / kernels.max()
    img = make_grid(kernels, nrow=16)
    plt.imshow(img.permute(1, 2, 0))  

def show_images(dataset, num_categories=10, num_per_category=5):
  targets = np.array(dataset.targets)
  show_targets = np.unique(targets)[0:num_categories]
  imgs = []
  for target in show_targets:
    indexes = np.where(targets == target)[0][0:num_per_category]
    for idx in indexes:
      img,label = dataset[idx]
      imgs.append((np.array(img), label))

  fig = plt.figure(figsize=(5*num_categories, 5*num_per_category))
  grid = ImageGrid(fig, 111,  # similar to subplot(111)
                  nrows_ncols=(num_categories, num_per_category),  # creates 2x2 grid of axes
                  axes_pad=0.5,  # pad between axes in inch.
                  )

  for index, (ax, (img,label)) in enumerate(zip(grid, imgs)):
      # Iterating over the grid returns the Axes.
      ax.imshow(img, cmap='gray', vmin=0, vmax=1)
      ax.set_title(f'label={label}', fontsize=16)
  for ax in grid: ax.axis('off')
      
  plt.show()    

In [None]:
'''
Utilities for visualizing feature entanglement/disentanglement.
'''
import math 
import seaborn as sns 
import numpy as np 
import pandas as pd 
from scipy.stats import pearsonr 
from fastprogress.fastprogress import master_bar, progress_bar 
from IPython.core.debugger import set_trace 
import pandas as pd
import seaborn as sns
from scipy.spatial.distance import squareform
from sklearn import manifold
from scipy.spatial.distance import pdist
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
from typing import Optional, Union, List, Dict, Callable

def mds_plot(layer_name, pairwise_similarity, labels, ax=None):
  if ax is None:
    ax = plt.gca()

  ax.set_title(f'{layer_name}')

  RDM = 1 - pairwise_similarity[layer_name]
  n_components = 2
  random_seed = None
  is_metric = True 
  max_iter = 3000
  convergence_tolerance = 1e-9

  seed = None if random_seed is None else np.random.RandomState(seed=random_seed)
  # for whatever reason, MDS only allows dissimilarity='precomputed' or dissimilarity='euclidean'
  # dissimilarities = squareform(pdist(data_df.values, distance))
  # Note: n_jobs>1 hangs indefinitely
  mds = manifold.MDS(n_components=n_components, metric=is_metric, max_iter=max_iter, eps=convergence_tolerance, random_state=seed, dissimilarity='precomputed')
  out = mds.fit_transform(RDM)
  num_labels = len(labels.unique())
  colorize = dict(c=labels, cmap=plt.cm.get_cmap('rainbow', num_labels))
  ax.scatter(out[:, 0], out[:, 1], **colorize)
  ax.axis('square');
  ax.set_xlim([-1,1]);
  ax.set_ylim([-1,1]);

def show_mds_plots(pairwise_similarity, labels, ncols=3):
  layer_names = pairwise_similarity.keys()
  num_layers = len(layer_names)
  nrows = math.ceil(num_layers/ncols)
  fig, axes = plt.subplots(nrows, ncols, sharex=True, figsize=(ncols*4,nrows*4))
  fig.suptitle('MDS plots show the "position" of each image in each feature space')
  
  for ax,layer_name in progress_bar(zip(axes, layer_names), total=num_layers):
    mds_plot(layer_name, pairwise_similarity, labels, ax=ax)


In [None]:
'''
Utilities for instrumenting a torch model.

InstrumentedModel will wrap a pytorch model and allow hooking
arbitrary layers to monitor or modify their output directly.
'''

import torch, numpy, types, copy
from collections import OrderedDict, defaultdict
from IPython.core.debugger import set_trace

class InstrumentedModel(torch.nn.Module):
    '''
    A wrapper for hooking, probing and intervening in pytorch Modules.
    Example usage:

    ```
    model = load_my_model()
    with inst as InstrumentedModel(model):
        inst.retain_layer(layername)
        inst.edit_layer(layername, ablation=0.5, replacement=target_features)
        inst(inputs)
        original_features = inst.retained_layer(layername)
    ```
    '''

    def __init__(self, model):
        super().__init__()
        self.model = model
        self._retained = OrderedDict()
        self._detach_retained = {}
        self._editargs = defaultdict(dict)
        self._editrule = {}
        self._hooked_layer = {}
        self._old_forward = {}
        if isinstance(model, torch.nn.Sequential):
            self._hook_sequential()

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.close()

    def forward(self, *inputs, **kwargs):
        return self.model(*inputs, **kwargs)

    def retain_layer(self, layername, detach=True):
        '''
        Pass a fully-qualified layer name (E.g., module.submodule.conv3)
        to hook that layer and retain its output each time the model is run.
        A pair (layername, aka) can be provided, and the aka will be used
        as the key for the retained value instead of the layername.
        '''
        self.retain_layers([layername], detach=detach)

    def retain_layers(self, layernames, detach=True):
        '''
        Retains a list of a layers at once.
        '''
        self.add_hooks(layernames)
        for layername in layernames:
            aka = layername
            if not isinstance(aka, str):
                layername, aka = layername
            if aka not in self._retained:
                self._retained[aka] = None
                self._detach_retained[aka] = detach

    def stop_retaining_layers(self, layernames):
        '''
        Removes a list of layers from the set retained.
        '''
        self.add_hooks(layernames)
        for layername in layernames:
            aka = layername
            if not isinstance(aka, str):
                layername, aka = layername
            if aka in self._retained:
                del self._retained[aka]
                del self._detach_retained[aka]

    def retained_features(self, clear=False):
        '''
        Returns a dict of all currently retained features.
        '''
        result = OrderedDict(self._retained)
        if clear:
            for k in result:
                self._retained[k] = None
        return result

    def retained_layer(self, aka=None, clear=False):
        '''
        Retrieve retained data that was previously hooked by retain_layer.
        Call this after the model is run.  If clear is set, then the
        retained value will return and also cleared.
        '''
        if aka is None:
            # Default to the first retained layer.
            aka = next(self._retained.keys().__iter__())
        result = self._retained[aka]
        if clear:
            self._retained[aka] = None
        return result

    def edit_layer(self, layername, rule=None, **kwargs):
        '''
        Pass a fully-qualified layer name (E.g., module.submodule.conv3)
        to hook that layer and modify its output each time the model is run.
        The output of the layer will be modified to be a convex combination
        of the replacement and x interpolated according to the ablation, i.e.:
        `output = x * (1 - a) + (r * a)`.
        '''
        if not isinstance(layername, str):
            layername, aka = layername
        else:
            aka = layername

        # The default editing rule is apply_ablation_replacement
        if rule is None:
            rule = apply_ablation_replacement

        self.add_hooks([(layername, aka)])
        self._editargs[aka].update(kwargs)
        self._editrule[aka] = rule

    def remove_edits(self, layername=None):
        '''
        Removes edits at the specified layer, or removes edits at all layers
        if no layer name is specified.
        '''
        if layername is None:
            self._editargs.clear()
            self._editrule.clear()
            return

        if not isinstance(layername, str):
            layername, aka = layername
        else:
            aka = layername
        if aka in self._editargs:
            del self._editargs[aka]
        if aka in self._editrule:
            del self._editrule[aka]

    def add_hooks(self, layernames):
        '''
        Sets up a set of layers to be hooked.

        Usually not called directly: use edit_layer or retain_layer instead.
        '''
        needed = set()
        aka_map = {}
        for name in layernames:
            aka = name
            if not isinstance(aka, str):
                name, aka = name
            if self._hooked_layer.get(aka, None) != name:
                aka_map[name] = aka
                needed.add(name)
        if not needed:
            return
        for name, layer in self.model.named_modules():
            if name in aka_map:
                needed.remove(name)
                aka = aka_map[name]
                self._hook_layer(layer, name, aka)
        for name in needed:
            raise ValueError('Layer %s not found in model' % name)

    def _hook_layer(self, layer, layername, aka):
        '''
        Internal method to replace a forward method with a closure that
        intercepts the call, and tracks the hook so that it can be reverted.
        '''
        if aka in self._hooked_layer:
            raise ValueError('Layer %s already hooked' % aka)
        if layername in self._old_forward:
            raise ValueError('Layer %s already hooked' % layername)
        self._hooked_layer[aka] = layername
        self._old_forward[layername] = (layer, aka,
                layer.__dict__.get('forward', None))
        editor = self
        original_forward = layer.forward
        def new_forward(self, *inputs, **kwargs):
            original_x = original_forward(*inputs, **kwargs)
            x = editor._postprocess_forward(original_x, aka)
            return x
        layer.forward = types.MethodType(new_forward, layer)

    def _unhook_layer(self, aka):
        '''
        Internal method to remove a hook, restoring the original forward method.
        '''
        if aka not in self._hooked_layer:
            return
        layername = self._hooked_layer[aka]
        # Remove any retained data and any edit rules
        if aka in self._retained:
            del self._retained[aka]
            del self._detach_retained[aka]
        self.remove_edits(aka)
        # Restore the unhooked method for the layer
        layer, check, old_forward = self._old_forward[layername]
        assert check == aka
        if old_forward is None:
            if 'forward' in layer.__dict__:
                del layer.__dict__['forward']
        else:
            layer.forward = old_forward
        del self._old_forward[layername]
        del self._hooked_layer[aka]

    def _postprocess_forward(self, x, aka):
        '''
        The internal method called by the hooked layers after they are run.
        '''
        # Retain output before edits, if desired.
        if aka in self._retained:
            if self._detach_retained[aka]:
                self._retained[aka] = x.detach()
            else:
                self._retained[aka] = x
        # Apply any edits requested.
        rule = self._editrule.get(aka, None)
        if rule is not None:
            x = rule(x, self, **(self._editargs[aka]))
        return x

    def _hook_sequential(self):
        '''
        Replaces 'forward' of sequential with a version that takes
        additional keyword arguments: layer allows a single layer to be run;
        first_layer and last_layer allow a subsequence of layers to be run.
        '''
        model = self.model
        self._hooked_layer['.'] = '.'
        self._old_forward['.'] = (model, '.',
                model.__dict__.get('forward', None))
        def new_forward(this, x, layer=None, first_layer=None, last_layer=None):
            assert layer is None or (first_layer is None and last_layer is None)
            first_layer, last_layer = [str(layer) if layer is not None
                    else str(d) if d is not None else None
                    for d in [first_layer, last_layer]]
            including_children = (first_layer is None)
            for name, layer in this._modules.items():
                if name == first_layer:
                    first_layer = None
                    including_children = True
                if including_children:
                    x = layer(x)
                if name == last_layer:
                    last_layer = None
                    including_children = False
            assert first_layer is None, '%s not found' % first_layer
            assert last_layer is None, '%s not found' % last_layer
            return x
        model.forward = types.MethodType(new_forward, model)

    def close(self):
        '''
        Unhooks all hooked layers in the model.
        '''
        for aka in list(self._old_forward.keys()):
            self._unhook_layer(aka)
        assert len(self._old_forward) == 0

def apply_ablation_replacement(x, imodel, **buffers):
    if buffers is not None:
        # Apply any edits requested.
        a = make_matching_tensor(buffers, 'ablation', x)
        if a is not None:
            x = x * (1 - a)
            v = make_matching_tensor(buffers, 'replacement', x)
            if v is not None:
                x += (v * a)
    return x

def make_matching_tensor(valuedict, name, data):
    '''
    Converts `valuedict[name]` to be a tensor with the same dtype, device,
    and dimension count as `data`, and caches the converted tensor.
    '''
    v = valuedict.get(name, None)
    if v is None:
        return None
    if not isinstance(v, torch.Tensor):
        # Accept non-torch data.
        v = torch.from_numpy(numpy.array(v))
        valuedict[name] = v
    if not v.device == data.device or not v.dtype == data.dtype:
        # Ensure device and type matches.
        assert not v.requires_grad, '%s wrong device or type' % (name)
        v = v.to(device=data.device, dtype=data.dtype)
        valuedict[name] = v
    if len(v.shape) < len(data.shape):
        # Ensure dimensions are unsqueezed as needed.
        assert not v.requires_grad, '%s wrong dimensions' % (name)
        v = v.view((1,) + tuple(v.shape) +
                (1,) * (len(data.shape) - len(v.shape) - 1))
        valuedict[name] = v
    return v

def subsequence(sequential, first_layer=None, last_layer=None,
            share_weights=False):
    '''
    Creates a subsequence of a pytorch Sequential model, copying over
    modules together with parameters for the subsequence.  Only
    modules from first_layer to last_layer (inclusive) are included.

    If share_weights is True, then references the original modules
    and their parameters without copying them.  Otherwise, by default,
    makes a separate brand-new copy.
    '''
    included_children = OrderedDict()
    including_children = (first_layer is None)
    for name, layer in sequential._modules.items():
        if name == first_layer:
            first_layer = None
            including_children = True
        if including_children:
            included_children[name] = layer if share_weights else (
                    copy.deepcopy(layer))
        if name == last_layer:
            last_layer = None
            including_children = False
    if first_layer is not None:
        raise ValueError('Layer %s not found' % first_layer)
    if last_layer is not None:
        raise ValueError('Layer %s not found' % last_layer)
    if not len(included_children):
        raise ValueError('Empty subsequence')
    return torch.nn.Sequential(OrderedDict(included_children))



In [None]:
from collections import OrderedDict 
from fastprogress.fastprogress import master_bar, progress_bar
from IPython.core.debugger import set_trace 

def get_features(model, dataloader, layer_names):
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  model.to(device)
  model.eval()

  if not isinstance(model, InstrumentedModel):
    model = InstrumentedModel(model)

  if isinstance(layer_names, str):
    layer_names = [layer_names]

  model.retain_layers(layer_names)

  pairwise_similarity = OrderedDict({})
  features = OrderedDict({})
  labels = []
    
  mb = master_bar(dataloader)
  for batch in mb:
    imgs = batch[0].to(device)
    targets = batch[1].to(device)
    batch_size = imgs.shape[0]
    labels.append(targets)

    if 'pixels' not in features:
      features['pixels'] = []
    features['pixels'].append(imgs.view(batch_size, -1))

    with torch.no_grad():
      model(imgs)

    for layer_name in progress_bar(layer_names, parent=mb):      
      X = model.retained_layer(layer_name)
      #if len(X.shape) == 4:
      #  X = X.sum(dim=1)
      X = X.view(batch_size, -1)
      if layer_name not in features:
        features[layer_name] = []
      features[layer_name].append(X.detach().cpu())

  labels = torch.cat(labels)
  for layer_name in progress_bar(features.keys(), parent=mb):
    features[layer_name] = torch.cat(features[layer_name], dim=0)
    pairwise_similarity[layer_name] = np.corrcoef(features[layer_name])

  model.stop_retaining_layers(layer_names)
  model.close()

  return features, labels, pairwise_similarity

## AlexNet

Let's explore the alexnet architecture.

In [None]:
# pretrained=True allows loads weights for a model trained on ImageNet classification
model = models.alexnet(pretrained=True)
model

In [None]:
show_kernels(model)

## Exercise 1 | Disentangled Object Representations

- [ ] choose & download a dataset
- [ ] create a dataloader
- [ ] get activations from one layer
- [ ] create an MDS plot to visualize whether your object categories are "disentangled" or "entangled"

### download dataset

Skip this step if you've already downloaded this dataset (click the folder icon on the left; if you see "imagenette-320" then you are all set).

You can find a collection of vision datasets here:
https://course.fast.ai/datasets

We're going to download and "unpack" one of them, the "imagenette" subset, which is a subset of 10 image categories from the full ImageNet-1000 category dataset:

https://s3.amazonaws.com/fast-ai-imageclas/imagenette-320.tgz


In [None]:
!wget -c https://s3.amazonaws.com/fast-ai-imageclas/imagenette-320.tgz

In [None]:
!tar -xf imagenette-320.tgz

### setup dataloader

In [None]:
# for starters let's load our dataset without transforming to a "tensor" 
# it's just easier to visualize them this way
transform = transforms.Compose([
  transforms.Resize(224),
  transforms.CenterCrop(224),
])

In [None]:
# first load the dataset without transforms so we can look at them
dataset = datasets.ImageFolder('./imagenette-320/val', transform=transform)
dataset

In [None]:
show_images(dataset, num_categories=10, num_per_category=5)

In [None]:
# ok, add the transforms needed to feed images to alexnet
transform = transforms.Compose([
  transforms.Resize(224),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

In [None]:
# dataset with transforms
dataset = datasets.ImageFolder('./imagenette-320/val', transform=transform)
dataset

In [None]:
# dataloader for feeding batches to our model
dataloader = DataLoader(dataset, batch_size=250, num_workers=8, pin_memory=True, shuffle=False)

### get activations

In [None]:
# features.0.12
model

In [None]:
features, labels, pairwise_similarity = get_features(model, dataloader, layer_names=['features.12','classifier.6'])

In [None]:
features['pixels'].shape

In [None]:
features['features.12'].shape

In [None]:
features['classifier.6'].shape

In [None]:
import seaborn as sns 
from scipy.stats import pearsonr 

# take first two images, and scatter plot their features
# we'll take the correlation as a measure of similarity
img1_features = features['pixels'][0]
img2_features = features['pixels'][1]
corr = pearsonr(img1_features, img2_features)[0]
print(corr)
sns.scatterplot(x=img1_features, y=img2_features);

In [None]:
# take first two images, and scatter plot their features
# we'll take the correlation as a measure of similarity
img1_features = features['classifier.6'][0]
img2_features = features['classifier.6'][1]
corr = pearsonr(img1_features, img2_features)[0]
print(corr)
sns.scatterplot(x=img1_features, y=img2_features);

In [None]:
similarity_matrix = pairwise_similarity['classifier.6']
mask = np.zeros_like(similarity_matrix)
mask[np.tril_indices_from(mask)] = True
ax = sns.heatmap(similarity_matrix, mask=mask, square=True, cmap="coolwarm");

In [None]:
mds_plot('classifier.6', pairwise_similarity, labels)

In [None]:
show_mds_plots(pairwise_similarity, labels, ncols=3)

***Project Idea #1***
It would be great if these points were clickable, so that we could click and "see" the images that were clustering next to each other, and click on "oddballs" and see if we can guess why they are clusting "incorrectly". Maybe the ones that don't fit in their cluster are just oddballs. To make this happen, you'd have to learn how to make interactive plots in "jupyter notebooks".

# Exercise 2 | Analyze a new dataset, or new model, or both!

***1) try a different dataset***   
You can find other easy-to-work-with datasets here: https://course.fast.ai/datasets

After you see the options there, you can develop a hypothesis. e.g., Are sub-categories more entangled (compare imagenette with imagewoof, or the "flowers" dataset).

***related project idea#2*** 
The machine vision community has generated tons of image and video datasets, and you might generate some testable ideas looking at those. There are many lists/resources, but this might be a good place to start https://www.kaggle.com/datasets

One interesting question related to this week's reading is whether a CNN trained on regular images would struggle with line-drawings (try googling "computer vision line drawings dataset"). 

***2) try a different model***
The torchvision package has different models: https://pytorch.org/vision/0.8/models.html

Many of these models were trained on ImageNet-1000 classification, but others were trained on semantic segmentation or object detection. The easiest thing would be to grab a new object-classification network first, since networks trained for other tasks might have unfamiliar architectures  making it harder to work with them (but you have me here to help!).

In [None]:
# notice that this model has a transform built in! 
# So we'll want to re-create our dataset without Resize, Crop, or Normalize transforms
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model

In [None]:
# model.rpn

In [None]:
transform = transforms.Compose([
  transforms.Resize(320),
  transforms.CenterCrop(320),
  transforms.ToTensor(),
  #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = datasets.ImageFolder('./imagenette-320/val', transform=transform)
dataloader = DataLoader(dataset, batch_size=50, num_workers=8, pin_memory=True, shuffle=False)

In [None]:
features, labels, pairwise_similarity = get_features(model, dataloader, layer_names=['backbone.body.layer4','rpn'])