# Visualization Notebook
Scroll down to the bottom of the notebook to access the visualization code. You will need to provide a log
directory to a trained protonet.

In [None]:
from typing import Tuple
from typing import Union

import os.path as osp

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib import cm

import torch

from torchvision.transforms import functional as F
from skimage.transform import resize

import numpy as np
from numpy import unravel_index

try:
    import pixpnet
except ImportError:
    import sys

    sys.path.append('..')

    import pixpnet
finally:
    from pixpnet.data import get_datasets
    from pixpnet.data import get_metadata
    from pixpnet.protonets.lit_model import ProtoLitModel
    from pixpnet.protonets.utils import load_config_and_best_model
    from pixpnet.symbolic.models import compute_rf_data
    from pixpnet.utils import parse_config_file
    from pixpnet.utils_torch import slices_to_bboxes
    from pixpnet.utils_torch import take_rf_from_bbox
    from pixpnet.utils_torch import take_rf

sns.set(
    style='whitegrid',
    font_scale=2.5,
)
from matplotlib import rc

rc('font', **{
    'family': 'serif',
    'sans-serif': ['Times']
})
rc('text', usetex=True)


def projected_prototypical_samples(metadata, model, ds_train, rf_layer):
    corresponding_sample_idxs = model.model.corresponding_sample_idxs
    min_fmap_idxs = model.model.min_fmap_idxs

    num_prototypes = model.model.num_prototypes
    assert len(corresponding_sample_idxs) == num_prototypes
    assert len(min_fmap_idxs) == num_prototypes

    projected_samples = []
    projected_patches = []
    projected_patch_slices = []

    for j in range(num_prototypes):
        idx = corresponding_sample_idxs[j]
        if idx == -1:
            print(f'prototype {j} has no corresponding sample!')
            projected_samples.append(None)
            projected_patches.append(None)
        else:
            idx_, x, y = ds_train[idx]
            assert idx_ == idx, (idx_, idx)  # BIST

            (fmap_h_start, fmap_h_end,
             fmap_w_start, fmap_w_end) = min_fmap_idxs[j]
            rf_feat = take_rf(rf_layer, fmap_h_start, fmap_h_end,
                              fmap_w_start, fmap_w_end)

            patch = rf_feat.take_from(
                x[None], all_channels=True).squeeze(axis=0)
            patch = rescale_image(patch, metadata.input_size)
            patch = patch.numpy().transpose(1, 2, 0)
            projected_patches.append(patch)

            sample = rescale_image(x, metadata.input_size)
            sample = sample.numpy().transpose(1, 2, 0)
            projected_samples.append(sample)
            projected_patch_slices.append(rf_feat.as_slices(
                all_channels=True))

    return projected_samples, projected_patches, projected_patch_slices


def rescale_image(img: Union[torch.Tensor, np.ndarray, Tuple], meta_size: int,
                  max_size: int = 256):
    if meta_size < max_size:
        return img
    scaler = max_size / meta_size
    if isinstance(img, torch.Tensor):
        c, h, w = img.size()
        return F.resize(img, [round(h * scaler), round(w * scaler)])
    elif isinstance(img, np.ndarray):
        h, w, c = img.shape
        return resize(img, [round(h * scaler), round(w * scaler)])
    elif isinstance(img, tuple):  # bbox
        (x, y), width, height = img
        return (x * scaler, y * scaler), width * scaler, height * scaler
    else:
        raise NotImplementedError(type(img))


def get_last_layer_coef(readout_type, last_layer, y_i, proto_idx, proto_class):
    if last_layer is None:
        return None
    if readout_type == 'linear':
        coef = last_layer.weight[y_i, proto_idx]
    elif readout_type == 'sparse':
        if proto_class != y_i:
            coef = 0
        else:
            proto_idx_group = proto_idx % last_layer.groups
            coef = last_layer.weight[y_i, proto_idx_group]
    elif readout_type == 'proto':
        coef = 1 if y_i == proto_class else 0
    else:
        raise NotImplementedError(readout_type)
    return coef


def compute_contributions_sample_proto2patch(readout_type, last_layer, y_i,
                                             proto_sims_i):
    if readout_type == 'linear':
        contributions = last_layer.weight[y_i] * proto_sims_i
    elif readout_type == 'sparse':
        proto_sims_grouped = proto_sims_i.reshape(
            last_layer.groups, last_layer.in_features_per_group)
        contributions = torch.zeros_like(proto_sims_i)
        contributions[y_i, :] = (
                last_layer.weight[y_i] * proto_sims_grouped[y_i, :])
        contributions = contributions.flatten()
    elif readout_type == 'proto':
        proto_sims_grouped = proto_sims_i.reshape(
            last_layer.groups, last_layer.in_features_per_group)
        contributions = torch.zeros_like(proto_sims_grouped)
        contributions[y_i, :] = proto_sims_grouped[y_i, :]
        contributions = contributions.flatten()
    else:
        raise NotImplementedError(readout_type)
    return contributions


def min_max_norm(arr, min=None, max=None, inplace=True):
    if not inplace:
        arr = arr.copy()
    # min-max normalize to 0-1
    arr -= arr.min() if min is None else min
    arr /= arr.max() if max is None else (max - min)
    return arr


def overlay_heatmap(img: np.ndarray,
                    heatmap: np.ndarray,
                    heatmap_weight: float = 0.5,
                    cmap: str = 'jet',
                    data_format='channels_first'):
    assert img.ndim == 3, img.shape
    if data_format == 'channels_first':
        img = img.transpose((1, 2, 0))
        if heatmap.ndim == 3:
            heatmap = heatmap.transpose((1, 2, 0))
    if heatmap.ndim == 3:
        assert heatmap.shape[2] == 1, heatmap.shape
        heatmap = heatmap.squeeze(axis=2)
    else:
        assert heatmap.ndim == 2, heatmap.shape
    cmap = cm.get_cmap(cmap)
    overlaid = (
            cmap(heatmap)[:, :, :3] * heatmap_weight +
            img * (1. - heatmap_weight)
    )
    return overlaid

In [None]:
def savefig(logdir: str, prefix: str, fig: Union[plt.Figure, sns.FacetGrid],
            show: bool = False, save: bool = True):
    basenames = []
    logdir_running = logdir
    subdir_count = 0
    while subdir_count < 3:
        logdir_running, basename = osp.split(logdir_running)
        if basename:
            basenames.append(basename)
            subdir_count += 1
        if not logdir_running:
            break
    basenames = '_'.join(reversed(basenames))
    # append basenames string so there are unique filenames when uploading
    #  figures to things where filename conflicts are bad
    save_path = osp.join(logdir, 'results', f'{prefix}_{basenames}.pdf')
    if save:
        print(f'Saving to "{save_path}"')
        fig.savefig(save_path)
    if show is True or (isinstance(show, str) and prefix.startswith(show)):
        plt.show(block=True)
    if isinstance(fig, sns.FacetGrid):
        plt.close(fig.fig)
    else:
        plt.close(fig)

In [None]:
def explain(logdir, metadata, model, sample_norm, sample, y, rf_layer,
            projected_patches, sample_id, k=11, sort='contributions',
            plot=True, save=True):
    """"""
    class_specific = model.model.class_specific
    prototype_class_identity = model.model.prototype_class_identity
    proto_layer_stride = model.model.prototype_layer_stride
    n_prototypes, _, proto_h, proto_w = model.model.prototype_shape
    if k is None:
        k = n_prototypes
    last_layer = model.model.last_layer
    readout_type = model.model.readout_type

    with torch.no_grad():
        result = model(sample_norm[None].to(model.device),
                       return_features=True)
    logits = result['logits']
    proto_dists = result['distances']
    min_dist_idxs = result['min_dist_idxs']
    proto_max_sims = result['max_similarities']

    preds = torch.argmax(logits, dim=1)

    contributions = compute_contributions_sample_proto2patch(
        readout_type, last_layer, preds[0], proto_max_sims[0])

    if sort == 'contributions':
        sort_idxs = torch.argsort(contributions, descending=True)
    elif sort == 'similarity':
        sort_idxs = torch.argsort(proto_max_sims[0], descending=True)
    elif not sort:
        sort_idxs = torch.arange(len(contributions))
    else:
        raise ValueError(f'sort = {sort}')
    sample_npy = sample.numpy()

    ncols = 5
    scale = 3
    fig, axes = plt.subplots(k, ncols, squeeze=False,
                             figsize=(ncols * scale, k * scale))

    pred_lab = preds[0]
    y_lab = y
    if metadata.label_names is not None:
        pred_lab = metadata.label_names[pred_lab]
        y_lab = metadata.label_names[y]

    # PxHxW
    sample_heat_map_max_all = model.model.pixel_space_map(
        sample_norm.to(model.device), proto_dists, sigma_factor=1.)[0].cpu().numpy()

    for j, axes_j in enumerate(axes):
        j_adjust = sort_idxs[j]

        projected_patch = projected_patches[j_adjust]

        fmap_h_start, fmap_w_start = unravel_index(
            min_dist_idxs[:, j_adjust].item(), proto_dists.shape[2:])
        # retrieve the corresponding feature map patch
        rf_feat = take_rf_from_bbox(rf_layer, fmap_h_start, fmap_w_start,
                                    proto_h, proto_w, proto_layer_stride)
        sample_patch = rf_feat.take_from(
            sample_npy[None], all_channels=True).squeeze(axis=0)

        sample_heat_map_max = sample_heat_map_max_all[j_adjust][None]
        sample_heat_map_max = min_max_norm(sample_heat_map_max)
        sample_heat_map_max_overlaid = overlay_heatmap(
            sample_npy, sample_heat_map_max, heatmap_weight=.4)

        bboxes_sample = slices_to_bboxes(rf_feat.as_slices(all_channels=True))

        if class_specific:
            proto_class = torch.argmax(
                prototype_class_identity[j_adjust]).item()
            if metadata.label_names is not None:
                proto_class = metadata.label_names[proto_class]
            proj_extra = f'\n{proto_class}'
        else:
            proj_extra = ''

        contrib_str = f' = {contributions[j_adjust]:.3g}'
        print(f'Predicted: {pred_lab}  |  Actual: {y_lab}')
        print(proj_extra)
        for ax, (patch, title, bboxes) in zip(axes_j, (
                # The test sample (full w/ patch bbox)
                (rescale_image(sample_npy.transpose(1, 2, 0),
                               metadata.input_size), 'Sample', bboxes_sample),
                # Pushed/projected training sample (patch)
                (projected_patch, f'Prototype', None),
                # The test sample (patch)
                (rescale_image(sample_patch.transpose(1, 2, 0),
                               metadata.input_size),
                 f'Corresponding\nImage Patch', None),
                (rescale_image(sample_heat_map_max_overlaid,
                               metadata.input_size),
                 'Overlaid Heat Map', bboxes_sample),
                (contrib_str, 'Contribution', None),
        )):
            if isinstance(patch, str):
                ax.text(0, 0, patch, va='center', ha='center')
                ax.set_xlim(-3, 3)
                ax.set_ylim(-3, 3)
            elif patch is not None:
                kwargs = {}
                if patch.shape[2] == 1:
                    kwargs['cmap'] = 'jet'
                ax.imshow(patch, **kwargs)
            if bboxes is not None:
                for bbox in bboxes:
                    bbox = rescale_image(bbox, metadata.input_size)
                    rect = patches.Rectangle(*bbox, linewidth=2,
                                             edgecolor='r',
                                             facecolor='none')
                    ax.add_patch(rect)
            if j == 0:
                ax.set_title(title)
            ax.axis('off')
            ax.set_xticks([])
            ax.set_yticks([])

    fig.tight_layout()
    topk = '' if k is None else f'_top{k}'
    sort_str = ('_' + sort) if sort else ''
    # Uncomment to save figure
    # savefig(logdir, f'explanation{sort_str}_{sample_id}{topk}', fig, show=plot,
    #         save=save)

# To Run
Replace `logdir` with the path to your log directory (relative to the notebook). A template format is shown below.

In [None]:
logdir = '/path/to/logs/protonet/dataset/protonet/timestamp'

config, model = load_config_and_best_model(logdir)
model.eval()
if torch.cuda.is_available():
    model.cuda()

config.dataset.val_size = 0
config.debug = False
print('get unnormalized datasets')
_, ds_train, _, ds_test = get_datasets(config, normalize=False)
print('get normalized datasets')
_, _, _, ds_test_norm = get_datasets(config, normalize=True)

metadata = get_metadata(config)

print('compute rf data')
_, rf_data = compute_rf_data(config.model.feature_extractor,
                             metadata.input_size,
                             metadata.input_size,
                             num_classes=1)
assert model.model.features.last_module_name == config.model.feature_layer
rf_layer = rf_data[model.model.features.last_module_name]
rf_hcc_lens = [len(hcc) for hcc in rf_layer.flat]
im_size = metadata.input_size * metadata.input_size
print(f'mean/max/min rf: {100 * np.mean(rf_hcc_lens) / im_size:.2f}% / '
      f'{100 * np.max(rf_hcc_lens) / im_size:.2f}% / '
      f'{100 * np.min(rf_hcc_lens) / im_size:.2f}%')

print('project prototypical samples')
with torch.no_grad():
    (projected_samples, projected_patches,
     projected_patch_slices) = projected_prototypical_samples(
        metadata, model, ds_train, rf_layer)

In [None]:
plot = True
save = True

idxs_to_explain = torch.randperm(len(ds_test))[:10]

for sort in ['contributions']:
    for idx in idxs_to_explain:
        print(f'Explain by {sort} for idx={idx}')
        if hasattr(idx, 'item'):
            idx = idx.item()

        with torch.no_grad():
            explain(
                logdir, metadata, model, ds_test_norm[idx][0],
                ds_test[idx][0], ds_test[idx][1], rf_layer,
                projected_patches, f'{idx}', sort=sort, plot=plot,
                save=save, k=4,
            )