# Collecting Samples for Activation Atlases with captum.optim

This notebook demonstrates how to collect the activation and corresponding attribution samples required for [Activation Atlases](https://distill.pub/2019/activation-atlas/) for the InceptionV1 model imported from Caffe.

In [None]:
%load_ext autoreload
%autoreload 2

from typing import List, Optional, Tuple, cast

import os
import torch
import torchvision

from tqdm.auto import tqdm

from captum.optim.models import googlenet

import captum.optim as opt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Dataset Download & Setup 

To begin, we'll need to download and setup the image dataset that our model was trained on. You can download ImageNet's ILSVRC2012 dataset from the [ImageNet website](http://www.image-net.org/challenges/LSVRC/2012/) or via BitTorrent from [Academic Torrents](https://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2).

In [None]:
collect_attributions = True  # Set to False for no attributions

# Setup basic transforms
# The model has the normalization step in its internal transform_input
# function, so we don't need to normalize our inputs here.
transform_list = [
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
]
transform_list = torchvision.transforms.Compose(transform_list)

To make it easier to load the ImageNet dataset, we can use [Torchvision](https://pytorch.org/vision/stable/datasets.html#imagenet)'s `torchvision.datasets.ImageNet` instead of the default `ImageFolder`.

In [None]:
# Load the dataset
image_dataset = torchvision.datasets.ImageNet(
    root="path/to/dataset", split="train", transform=transform_list
)

Now we wrap our dataset in a `torch.utils.data.DataLoader` instance, and set the desired batch size.

In [None]:
# Set desired batch size & load dataset with torch.utils.DataLoader
image_loader = torch.utils.data.DataLoader(
    image_dataset,
    batch_size=32,
    shuffle=True,
)

We load our model, then set the desired model target layers and corresponding file names.

In [None]:
# Model to collect samples from, what layers of the model to collect samples from,
# and the desired names to use for the target layers.
sample_model = (
    googlenet(
        pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True
    )
    .eval()
    .to(device)
)
sample_targets = [sample_model.mixed4c_relu]
sample_target_names = ["mixed4c_relu_samples"]

By default the activation samples will not have the right class attributions, so we remedy this by loading a second instance of our model. We then replace all `nn.MaxPool2d` layers in the second model instance with Captum's `MaxPool2dRelaxed` layer. The relaxed max pooling layer lets us estimate the sample class attributions by determining the rate at which increasing the neuron affects the output classes.

In [None]:
# Optionally collect attributions from a copy of the first model that's
# been setup with relaxed pooling layers.
if collect_attributions:
    sample_model_attr = (
        googlenet(
            pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True
        )
        .eval()
        .to(device)
    )
    opt.models.replace_layers(
        sample_model_attr,
        torch.nn.MaxPool2d,
        opt.models.MaxPool2dRelaxed,
        transfer_vars=True,
    )
    sample_attr_targets = [sample_model_attr.mixed4c_relu]
    sample_logit_target = sample_model_attr.fc
else:
    sample_model_attr = None
    sample_attr_targets = None
    sample_logit_target = None

With our dataset loaded and models ready to go, we can now start collecting our samples. To perform the sample collection, we define a function called `capture_activation_samples` to randomly sample an x and y position for every image for all specified target layers.

In [None]:
def attribute_spatial_position(
    target_activ: torch.Tensor,
    logit_activ: torch.Tensor,
    position_mask: torch.Tensor,
) -> torch.Tensor:
    """
    This function employs the double backward trick in order to perform
    forward-mode AD.

    See here for more details:
    https://github.com/renmengye/tensorflow-forward-ad/issues/2

    Based on the Collect Activations Lucid tutorial:
    https://colab.research.google.com/github/tensorflow
    /lucid/blob/master/notebooks/activation-atlas/activation-atlas-collect.ipynb

    Args:

        logit_activ: Captured activations from the FC / logit layer.
            target_activ: Captured activations from the target layer.
        position_mask (torch.Tensor, optional): If using a batch size greater than
            one, a mask is used to zero out all the non-target positions.

    Returns:
        logit_attr (torch.Tensor): A sorted list of class attributions for the target
            spatial positions.
    """

    assert target_activ.dim() == 2 or target_activ.dim() == 4
    assert logit_activ.dim() == 2

    zeros = torch.nn.Parameter(torch.zeros_like(logit_activ))
    target_zeros = target_activ * position_mask

    grad_one = torch.autograd.grad(
        outputs=[logit_activ],
        inputs=[target_activ],
        grad_outputs=[zeros],
        create_graph=True,
    )
    logit_attr = torch.autograd.grad(
        outputs=grad_one,
        inputs=[zeros],
        grad_outputs=[target_zeros],
        create_graph=True,
    )[0]
    return logit_attr


def capture_activation_samples(
    loader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    targets: List[torch.nn.Module],
    target_names: Optional[List[str]] = None,
    sample_dir: str = "",
    num_images: Optional[int] = None,
    samples_per_image: int = 1,
    input_device: torch.device = torch.device("cpu"),
    collect_attributions: bool = False,
    attr_model: Optional[torch.nn.Module] = None,
    attr_targets: Optional[List[torch.nn.Module]] = None,
    logit_target: Optional[torch.nn.Module] = None,
    show_progress: bool = False,
):
    """
    Capture randomly sampled activations & optional attributions for those samples,
    for an image dataset from one or more target layers.

    Samples are saved to files for speed, memory efficient, and to preserve them in
    the event of any crashes.

    Based on the Collect Activations Lucid tutorial:
    https://colab.research.google.com/github/tensorflow
    /lucid/blob/master/notebooks/activation-atlas/activation-atlas-collect.ipynb

    Args:

        loader (torch.utils.data.DataLoader): A torch.utils.data.DataLoader
            instance for an image dataset.
        model (nn.Module): A PyTorch model instance.
        targets (list of nn.Module): A list of layers to collect activation samples
            from.
        target_names (list of str, optional): A list of names to use when saving sample
            tensors as files. Names will automatically be chosen if set to None.
            Default: None
        sample_dir (str): Path to where activation samples should be saved.
            Default: ""
        num_images (int, optional): How many images to collect samples from.
            Default is to collect samples for every image in the dataset.  Set to None
            to collect samples from every image in the dataset.
            Default: None
        samples_per_image (int): How many samples to collect per image.
            Default: 1
        input_device (torch.device, optional): The device to use for model
            inputs.
            Default: torch.device("cpu")
        collect_attributions (bool, optional): Whether or not to collect attributions
            for samples.
            Default: False
        attr_model (nn.Module, optional): A PyTorch model instance to use for
            calculating sample attributions.
            Default: None
        attr_targets (list of nn.Module, optional): A list of attribution model layers
            to collect attributions from. This should be the exact same as the targets
            parameter, except for the attribution model.
            Default: None
        logit_target (nn.Module, optional): The final layer in the attribution model
            that determines the classes. This parameter is only enabled if
            collect_attributions is set to True.
            Default: None
        show_progress (bool, optional): Whether or not to show progress.
            Default: False
    """

    if target_names is None:
        target_names = ["target" + str(i) + "_" for i in range(len(targets))]

    assert len(target_names) == len(targets)
    assert os.path.isdir(sample_dir)

    def random_sample(
        activations: torch.Tensor,
    ) -> Tuple[List[torch.Tensor], List[List[List[int]]]]:
        """
        Randomly sample H & W dimensions of activations with 4 dimensions.
        """
        assert activations.dim() == 4 or activations.dim() == 2

        activation_samples: List = []
        position_list: List = []

        with torch.no_grad():
            for i in range(samples_per_image):
                sample_position_list: List = []
                for b in range(activations.size(0)):
                    if activations.dim() == 4:
                        h, w = activations.shape[2:]
                        y = torch.randint(low=1, high=h - 1, size=[1])
                        x = torch.randint(low=1, high=w - 1, size=[1])
                        activ = activations[b, :, y, x]
                        sample_position_list.append((b, y, x))
                    elif activations.dim() == 2:
                        activ = activations[b].unsqueeze(1)
                        sample_position_list.append(b)
                    activation_samples.append(activ)
                position_list.append(sample_position_list)
        return activation_samples, position_list

    def attribute_samples(
        activations: torch.Tensor,
        logit_activ: torch.Tensor,
        position_list: List[List[List[int]]],
    ) -> List[torch.Tensor]:
        """
        Collect attributions for target sample positions.
        """
        assert activations.dim() == 4 or activations.dim() == 2

        sample_attributions: List = []
        with torch.set_grad_enabled(True):
            zeros_mask = torch.zeros_like(activations)
            for sample_pos_list in position_list:
                for c in sample_pos_list:
                    if activations.dim() == 4:
                        zeros_mask[c[0], :, c[1], c[2]] = 1
                    elif activations.dim() == 2:
                        zeros_mask[c] = 1
                attr = attribute_spatial_position(
                    activations, logit_activ, position_mask=zeros_mask
                ).detach()
                sample_attributions.append(attr)
        return sample_attributions

    if collect_attributions:
        logit_target == list(model.children())[len(list(model.children())) - 1 :][
            0
        ] if logit_target is None else logit_target
        attr_targets = cast(List[torch.nn.Module], attr_targets)
        attr_targets += [cast(torch.nn.Module, logit_target)]

    if show_progress:
        total = (
            len(loader.dataset) if num_images is None else num_images  # type: ignore
        )
        pbar = tqdm(total=total, unit=" images")

    image_count, batch_count = 0, 0
    with torch.no_grad():
        for inputs, _ in loader:
            inputs = inputs.to(input_device)
            image_count += inputs.size(0)
            batch_count += 1

            target_activ_dict = opt.models.collect_activations(model, targets, inputs)
            if collect_attributions:
                with torch.set_grad_enabled(True):
                    target_activ_attr_dict = opt.models.collect_activations(
                        attr_model, attr_targets, inputs
                    )
                    logit_activ = target_activ_attr_dict[logit_target]
                    del target_activ_attr_dict[logit_target]

            sample_coords = []
            for t, n in zip(target_activ_dict, target_names):
                sample_tensors, p_list = random_sample(target_activ_dict[t])
                torch.save(
                    sample_tensors,
                    os.path.join(
                        sample_dir, n + "_activations_" + str(batch_count) + ".pt"
                    ),
                )
                sample_coords.append(p_list)

            if collect_attributions:
                for t, n, s_coords in zip(
                    target_activ_attr_dict, target_names, sample_coords
                ):
                    sample_attrs = attribute_samples(
                        target_activ_attr_dict[t], logit_activ, s_coords
                    )
                    torch.save(
                        sample_attrs,
                        os.path.join(
                            sample_dir,
                            n + "_attributions_" + str(batch_count) + ".pt",
                        ),
                    )

            if show_progress:
                pbar.update(inputs.size(0))

            if num_images is not None:
                if image_count > num_images:
                    break

    if show_progress:
        pbar.close()

We now collect our activation samples and attribution, as we iterate through our image dataset. Note that this step can be rather time consuming depending on the image dataset being used.

In [None]:
# Directory to save sample files to
sample_dir = "inceptionv1_samples"
try:
    os.mkdir(sample_dir)
except:
    pass

# Collect samples & optionally attributions as well
capture_activation_samples(
    loader=image_loader,
    model=sample_model,
    targets=sample_targets,
    target_names=sample_target_names,
    attr_model=sample_model_attr,
    attr_targets=sample_attr_targets,
    input_device=device,
    sample_dir=sample_dir,
    show_progress=True,
    collect_attributions=collect_attributions,
    logit_target=sample_logit_target,
)

Now that we've collected our samples, we need to combine them into a single tensor. Below we use the `consolidate_samples` function to load each list of tensor samples, and then concatinate them into a single tensor.

In [None]:
def consolidate_samples(
    sample_dir: str,
    sample_basename: str = "",
    dim: int = 1,
    num_files: Optional[int] = None,
    show_progress: bool = False,
) -> torch.Tensor:
    """
    Combine samples collected from capture_activation_samples into a single tensor
    with a shape of [n_target_classes, n_samples].

    Args:

        sample_dir (str): The directory where activation samples where saved.
        sample_basename (str, optional): If samples from different layers are present
            in sample_dir, then you can use samples from only a specific layer by
            specifying the basename that samples of the same layer share.
            Default: ""
        dim (int, optional): The dimension to concatinate the samples together on.
            Default: 1
        num_files (int, optional): The number of sample files that you wish to
            concatinate together, if you do not wish to concatinate all of them.
            Default: None
        show_progress (bool, optional): Whether or not to show progress.
            Default: False

    Returns:
        sample_tensor (torch.Tensor): A tensor containing all the specified sample
            tensors with a shape of [n_target_classes, n_samples].
    """

    assert os.path.isdir(sample_dir)

    tensor_samples = [
        os.path.join(sample_dir, name)
        for name in os.listdir(sample_dir)
        if sample_basename.lower() in name.lower()
        and os.path.isfile(os.path.join(sample_dir, name))
    ]
    assert len(tensor_samples) > 0

    if show_progress:
        total = len(tensor_samples) if num_files is None else num_files  # type: ignore
        pbar = tqdm(total=total, unit=" sample batches collected")

    samples: List[torch.Tensor] = []
    for file in tensor_samples:
        sample_batch = torch.load(file)
        for s in sample_batch:
            samples += [s.cpu()]
        if show_progress:
            pbar.update(1)

    if show_progress:
        pbar.close()
    return torch.cat(samples, dim)

In [None]:
# Combine our newly collected samples into single tensors.
# We load the sample tensors from sample_dir and then
# concatenate them.

for name in sample_target_names:
    print("Combining " + name + " samples:")
    activation_samples = consolidate_samples(
        sample_dir=sample_dir,
        sample_basename=name + "_activations",
        dim=1,
        show_progress=True,
    )
    if collect_attributions:
        sample_attributions = consolidate_samples(
            sample_dir=sample_dir,
            sample_basename=name + "_attributions",
            dim=0,
            show_progress=True,
        )

    # Save the results
    torch.save(activation_samples, name + "activation_samples.pt")
    if collect_attributions:
        torch.save(sample_attributions, name + "attribution_samples.pt")

Now that we have successfully collected the required sample activations & attributions, we can move onto the main Activation Atlas and Class Activation Atlas tutorials!