In [35]:
import os
import sys
import warnings
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.cluster import KMeans
from utils import data_process

warnings.filterwarnings("ignore")

from analysis_utils import *

sys.append("../")

from get_mlp_mappings import ComputeMLPContributions

from cluster_utils.cluster_utils import get_gabor_label_map

In [None]:
dataset_names = ["cityscapes", "vipseg"]
vidnames = {"cityscapes": ["0005"], "vipseg": ["26_cblDl5vCZnw"]}

cfg_dict = {}
dataloader_dict = {}
weights_dict = {}
ffn_models_dict = {}
categories_dicts = {}

for dataset_name in dataset_names:
    weights_dict[dataset_name] = {}
    cfg_dict[dataset_name] = {}
    ffn_models_dict[dataset_name] = {}
    categories_dicts[dataset_name] = {}

    for vidname in vidnames[dataset_name]:
        weights_dict[dataset_name][
            vidname
        ] = "path/to/checkpoint"

        cfg, categories_dict = load_cfg(
            weights_dict[dataset_name][vidname], dataset_name, vidname
        )
        cfg_dict[dataset_name][vidname] = cfg
        categories_dicts[dataset_name][vidname] = categories_dict

        ffn_models_dict[dataset_name][vidname] = load_model(cfg)

for dataset_name in dataset_names:
    dataloader_dict[dataset_name] = {}
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = get_loader(
            cfg_dict[dataset_name][vidname], dataset_name
        )
        dataloader_dict[dataset_name][vidname] = single_image_dataloader

In [40]:
def get_instance_info(inference_results, object_categories, categories):

    inst_id_to_cat_and_inst_suffix = {}
    object_to_instances_map = {}
    obj_to_obj_name_idx = {}
    instance_names = []
    object_to_instances_map = defaultdict(list)
    instance_to_ann_id_map = {}

    for idx, object_cat in enumerate(object_categories):
        obj_to_obj_name_idx[object_cat] = idx

    # Get annos for current frame
    frame_annos = inference_results["annotations"]
    for ann in frame_annos:
        category_name = [
            cat["name"] for cat in categories if cat["id"] == ann["category_id"]
        ][0]

        num_instances_of_obj = len(object_to_instances_map[category_name])
        if ann["inst_id"] not in list(inst_id_to_cat_and_inst_suffix.keys()):
            inst_id_to_cat_and_inst_suffix[ann["inst_id"]] = {
                "category": category_name,
                "inst_suffix": num_instances_of_obj,
                "instance_name": category_name + "_" + str(num_instances_of_obj),
            }

        # Retrieve the stored instance name
        instance_name = inst_id_to_cat_and_inst_suffix[ann["inst_id"]]["instance_name"]
        instance_to_ann_id_map[instance_name] = ann["id"]

        if instance_name not in instance_names:
            object_to_instances_map[category_name].append(instance_name)
            instance_names.append(instance_name)

    def custom_sort_key(item):
        parts = item.split("_")
        return ("_".join(parts[:-1]), int(parts[-1]))

    instance_names = [item for item in sorted(instance_names, key=custom_sort_key)]
    instance_names.append(instance_names.pop(instance_names.index("other_0")))

    return (
        inst_id_to_cat_and_inst_suffix,
        instance_to_ann_id_map,
        instance_names,
        object_to_instances_map,
        obj_to_obj_name_idx,
        instance_names,
    )

In [41]:
def get_instance_contribs(
    layer_1_output_contrib,
    layer_2_output_contrib,
    layer_3_output_contrib,
    annotations,
    instance_to_ann_id_map,
    instance_names,
):
    total_img_area = layer_1_output_contrib.size(-2) * layer_1_output_contrib.size(-1)
    num_instances = len(instance_names)
    instance_areas = torch.zeros(num_instances)

    # Maps for kernel to object contributions
    num_layer_1_neurons = layer_1_output_contrib.shape[0]
    num_layer_2_neurons = layer_2_output_contrib.shape[0]
    num_layer_3_neurons = layer_3_output_contrib.shape[0]
    layer_1_to_instance_contribs = torch.zeros((num_layer_1_neurons, num_instances))
    layer_2_to_instance_contribs = torch.zeros((num_layer_2_neurons, num_instances))
    layer_3_to_instance_contribs = torch.zeros((num_layer_3_neurons, num_instances))

    layer_1_instance_contrib_ratio_to_total = torch.zeros(
        (num_layer_1_neurons, num_instances)
    )
    layer_2_instance_contrib_ratio_to_total = torch.zeros(
        (num_layer_2_neurons, num_instances)
    )
    layer_3_instance_contrib_ratio_to_total = torch.zeros(
        (num_layer_3_neurons, num_instances)
    )

    # Store the total neuron-wise contributions to output image
    total_layer_1_output_contrib = torch.sum(
        torch.abs(layer_1_output_contrib), dim=(1, 2)
    )
    total_layer_2_output_contrib = torch.sum(
        torch.abs(layer_2_output_contrib), dim=(1, 2)
    )
    total_layer_3_output_contrib = torch.sum(
        torch.abs(layer_3_output_contrib), dim=(1, 2)
    )

    for instance in instance_to_ann_id_map:
        ann_id = instance_to_ann_id_map[instance]
        ann = [ann for ann in annotations if ann["id"] == ann_id][0]
        area = ann["area"]
        bimask = ann["bimask"].squeeze()

        curr_instance_layer_1_contribs = torch.abs(layer_1_output_contrib[:, bimask])
        curr_instance_layer_2_contribs = torch.abs(layer_2_output_contrib[:, bimask])
        curr_instance_layer_3_contribs = torch.abs(layer_3_output_contrib[:, bimask])

        # Get aggregated total contribution for each kernel to the instance
        total_layer_1_inst_contrib = torch.sum(curr_instance_layer_1_contribs, dim=-1)
        total_layer_2_inst_contrib = torch.sum(curr_instance_layer_2_contribs, dim=-1)
        total_layer_3_inst_contrib = torch.sum(curr_instance_layer_3_contribs, dim=-1)
        avg_layer_1_contrib = total_layer_1_inst_contrib / area
        avg_layer_2_contrib = total_layer_2_inst_contrib / area
        avg_layer_3_contrib = total_layer_3_inst_contrib / area

        # Store the average contribution from each layer neurons to current instance
        inst_idx = instance_names.index(instance)
        layer_1_to_instance_contribs[:, inst_idx] = avg_layer_1_contrib.flatten()
        layer_2_to_instance_contribs[:, inst_idx] = avg_layer_2_contrib.flatten()
        layer_3_to_instance_contribs[:, inst_idx] = avg_layer_3_contrib.flatten()

        # Compute deltas: ( actual contrib - expected contrib ) / expected contrib
        layer_1_expected_instance_contrib = total_layer_1_output_contrib * (
            area / total_img_area
        )
        layer_1_instance_contrib_ratio_to_total[:, inst_idx] = (
            torch.abs(total_layer_1_inst_contrib - layer_1_expected_instance_contrib)
            / layer_1_expected_instance_contrib
        )
        layer_2_expected_instance_contrib = total_layer_2_output_contrib * (
            area / total_img_area
        )
        layer_2_instance_contrib_ratio_to_total[:, inst_idx] = (
            torch.abs(total_layer_2_inst_contrib - layer_2_expected_instance_contrib)
            / layer_2_expected_instance_contrib
        )
        layer_3_expected_instance_contrib = total_layer_3_output_contrib * (
            area / total_img_area
        )
        layer_3_instance_contrib_ratio_to_total[:, inst_idx] = (
            torch.abs(total_layer_3_inst_contrib - layer_3_expected_instance_contrib)
            / layer_3_expected_instance_contrib
        )

    return (
        layer_1_to_instance_contribs,
        layer_2_to_instance_contribs,
        layer_3_to_instance_contribs,
        layer_1_instance_contrib_ratio_to_total,
        layer_2_instance_contrib_ratio_to_total,
        layer_3_instance_contrib_ratio_to_total,
        instance_areas,
    )

In [42]:
def get_gridcell_contribs(
    layer_1_output_contrib,
    layer_2_output_contrib,
    layer_3_output_contrib,
    reg_stride_h,
    reg_stride_w,
):
    total_img_area = layer_1_output_contrib.size(-2) * layer_1_output_contrib.size(-1)
    
    unfolded_layer_1_to_gridcell_contribs = (
        torch.abs(layer_1_output_contrib)
        .unfold(1, reg_stride_h, reg_stride_h)
        .unfold(2, reg_stride_w, reg_stride_w)
        .permute(0, 3, 4, 1, 2)
    ) # num_neurons x cell_stride x cell_stride x h/cell_stride*w/cell_stride
    unfolded_layer_2_to_gridcell_contribs = (
        torch.abs(layer_2_output_contrib)
        .unfold(1, reg_stride_h, reg_stride_h)
        .unfold(2, reg_stride_w, reg_stride_w)
        .permute(0, 3, 4, 1, 2)
    )
    unfolded_layer_3_to_gridcell_contribs = (
        torch.abs(layer_3_output_contrib)
        .unfold(1, reg_stride_h, reg_stride_h)
        .unfold(2, reg_stride_w, reg_stride_w)
        .permute(0, 3, 4, 1, 2)
    )
    total_layer_1_output_contrib = torch.sum(
        torch.abs(layer_1_output_contrib), dim=(1, 2)
    )
    total_layer_2_output_contrib = torch.sum(
        torch.abs(layer_2_output_contrib), dim=(1, 2)
    )
    total_layer_3_output_contrib = torch.sum(
        torch.abs(layer_3_output_contrib), dim=(1, 2)
    )

    gridcell_area = unfolded_layer_1_to_gridcell_contribs.size(3) * unfolded_layer_1_to_gridcell_contribs.size(4)

    layer_1_to_gridcell_contribs = torch.abs(unfolded_layer_1_to_gridcell_contribs)
    layer_2_to_gridcell_contribs = torch.abs(unfolded_layer_2_to_gridcell_contribs)
    layer_3_to_gridcell_contribs = torch.abs(unfolded_layer_3_to_gridcell_contribs)

    # Flatten contributions by region before taking variance over pixels in region
    flattened_layer_1_gridcell_contribs = layer_1_to_gridcell_contribs.flatten(3, 4).flatten(1, 2)
    # num_neurons x num_gridcells x h/cell_stride*w/cell_stride
    flattened_layer_2_gridcell_contribs = layer_2_to_gridcell_contribs.flatten(3, 4).flatten(1, 2)
    flattened_layer_3_gridcell_contribs = layer_3_to_gridcell_contribs.flatten(3, 4).flatten(1, 2)

    # Compute deltas: (actual contrib - expected contrib) / expected contrib
    layer_1_expected_region_contrib = total_layer_1_output_contrib[:, None] * (
        gridcell_area / total_img_area
    )
    layer_1_gridcell_contrib_ratio_to_total = (
        torch.sum(flattened_layer_1_gridcell_contribs, dim=-1)
        - layer_1_expected_region_contrib
    ) / layer_1_expected_region_contrib

    layer_2_expected_region_contrib = total_layer_2_output_contrib[:, None] * (
        gridcell_area / total_img_area
    )
    layer_2_gridcell_contrib_ratio_to_total = (
        torch.sum(flattened_layer_2_gridcell_contribs, dim=-1)
        - layer_2_expected_region_contrib
    ) / layer_2_expected_region_contrib

    layer_3_expected_region_contrib = total_layer_3_output_contrib[:, None] * (
        gridcell_area / total_img_area
    )
    layer_3_gridcell_contrib_ratio_to_total = (
        torch.sum(flattened_layer_3_gridcell_contribs, dim=-1)
        - layer_3_expected_region_contrib
    ) / layer_3_expected_region_contrib

    # Aggregate the maps and take per-pixel average
    layer_1_to_gridcell_contribs = (
        layer_1_to_gridcell_contribs.sum(dim=(3, 4)) / gridcell_area
    )
    layer_2_to_gridcell_contribs = (
        layer_2_to_gridcell_contribs.sum(dim=(3, 4)) / gridcell_area
    )
    layer_3_to_gridcell_contribs = (
        layer_3_to_gridcell_contribs.sum(dim=(3, 4)) / gridcell_area
    )
    layer_1_feature_vectors = layer_1_to_gridcell_contribs.view(
        layer_1_to_gridcell_contribs.size(0), -1
    )  # num_neurons x num_gridcells
    layer_2_feature_vectors = layer_2_to_gridcell_contribs.view(
        layer_2_to_gridcell_contribs.size(0), -1
    )  # num_neurons x num_gridcells
    layer_3_feature_vectors = layer_3_to_gridcell_contribs.view(
        layer_3_to_gridcell_contribs.size(0), -1
    )  # num_neurons x num_gridcells

    return (
        layer_1_feature_vectors,
        layer_2_feature_vectors,
        layer_3_feature_vectors,
        layer_1_gridcell_contrib_ratio_to_total,
        layer_2_gridcell_contrib_ratio_to_total,
        layer_3_gridcell_contrib_ratio_to_total,
    )

In [43]:
def compute_kmeans_clusters_in_rgb(image, num_clusters):
    image_reshaped_rgb = image.reshape(-1, 3)
    kmeans = KMeans(n_clusters=num_clusters, n_init=1, random_state=0).fit(
        image_reshaped_rgb
    )
    rgb_cluster_map = kmeans.labels_.reshape(image.shape[0], image.shape[1])
    return rgb_cluster_map

In [44]:
def get_rgb_cluster_contribs(
    layer_1_output_contrib,
    layer_2_output_contrib,
    layer_3_output_contrib,
    rgb_cluster_map,
):
    total_img_area = layer_1_output_contrib.size(-2) * layer_1_output_contrib.size(-1)
    n_rgb_clusters = len(np.unique(rgb_cluster_map))
    rgb_cluster_areas = torch.zeros(n_rgb_clusters)

    # Maps for kernel to object contributions
    num_layer_1_neurons = layer_1_output_contrib.shape[0]
    num_layer_2_neurons = layer_2_output_contrib.shape[0]
    num_layer_3_neurons = layer_3_output_contrib.shape[0]

    layer_1_to_rgb_cluster_contribs = torch.zeros((num_layer_1_neurons, n_rgb_clusters))
    layer_2_to_rgb_cluster_contribs = torch.zeros((num_layer_2_neurons, n_rgb_clusters))
    layer_3_to_rgb_cluster_contribs = torch.zeros((num_layer_3_neurons, n_rgb_clusters))

    layer_1_rgb_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_1_neurons, n_rgb_clusters)
    )
    layer_2_rgb_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_2_neurons, n_rgb_clusters)
    )
    layer_3_rgb_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_3_neurons, n_rgb_clusters)
    )

    total_layer_1_output_contrib = torch.sum(
        torch.abs(layer_1_output_contrib), dim=(1, 2)
    )
    total_layer_2_output_contrib = torch.sum(
        torch.abs(layer_2_output_contrib), dim=(1, 2)
    )
    total_layer_3_output_contrib = torch.sum(
        torch.abs(layer_3_output_contrib), dim=(1, 2)
    )

    for cluster_id in np.unique(rgb_cluster_map):

        bimask = rgb_cluster_map == cluster_id
        bimask = bimask.squeeze().astype(bool)
        area = bimask.sum()

        curr_rgb_cluster_layer_1_contribs = torch.abs(layer_1_output_contrib[:, bimask])
        curr_rgb_cluster_layer_2_contribs = torch.abs(layer_2_output_contrib[:, bimask])
        curr_rgb_cluster_layer_3_contribs = torch.abs(layer_3_output_contrib[:, bimask])

        total_layer_1_spix_contrib = torch.sum(
            curr_rgb_cluster_layer_1_contribs, dim=-1
        )
        total_layer_2_spix_contrib = torch.sum(
            curr_rgb_cluster_layer_2_contribs, dim=-1
        )
        total_layer_3_spix_contrib = torch.sum(
            curr_rgb_cluster_layer_3_contribs, dim=-1
        )
        avg_layer_1_contrib = total_layer_1_spix_contrib / area
        avg_layer_2_contrib = total_layer_2_spix_contrib / area
        avg_layer_3_contrib = total_layer_3_spix_contrib / area

        # Store the average contribution from neurons of each layer to current rgb cluster
        layer_1_to_rgb_cluster_contribs[:, cluster_id] = avg_layer_1_contrib.flatten()
        layer_2_to_rgb_cluster_contribs[:, cluster_id] = avg_layer_2_contrib.flatten()
        layer_3_to_rgb_cluster_contribs[:, cluster_id] = avg_layer_3_contrib.flatten()

        # Compute deltas: ( actual contrib - expected contrib ) / expected contrib
        layer_1_expected_rgb_cluster_contrib = total_layer_1_output_contrib * (
            area / total_img_area
        )
        layer_1_rgb_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(total_layer_1_spix_contrib - layer_1_expected_rgb_cluster_contrib)
            / layer_1_expected_rgb_cluster_contrib
        )
        layer_2_expected_rgb_cluster_contrib = total_layer_2_output_contrib * (
            area / total_img_area
        )
        layer_2_rgb_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(total_layer_2_spix_contrib - layer_2_expected_rgb_cluster_contrib)
            / layer_2_expected_rgb_cluster_contrib
        )
        layer_3_expected_rgb_cluster_contrib = total_layer_3_output_contrib * (
            area / total_img_area
        )
        layer_3_rgb_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(total_layer_3_spix_contrib - layer_3_expected_rgb_cluster_contrib)
            / layer_3_expected_rgb_cluster_contrib
        )

    return (
        layer_1_to_rgb_cluster_contribs,
        layer_2_to_rgb_cluster_contribs,
        layer_3_to_rgb_cluster_contribs,
        layer_1_rgb_cluster_contrib_ratio_to_total,
        layer_2_rgb_cluster_contrib_ratio_to_total,
        layer_3_rgb_cluster_contrib_ratio_to_total,
        rgb_cluster_areas,
    )

In [46]:
# For each gabor cluster - get average contrib, total contrib and total area
def get_gabor_cluster_contribs(
    layer_1_output_contrib,
    layer_2_output_contrib,
    layer_3_output_contrib,
    gabor_cluster_map,
):
    total_img_area = layer_1_output_contrib.size(-2) * layer_1_output_contrib.size(-1)

    num_layer_1_neurons = layer_1_output_contrib.shape[0]
    num_layer_2_neurons = layer_2_output_contrib.shape[0]
    num_layer_3_neurons = layer_3_output_contrib.shape[0]

    n_gabor_clusters = len(np.unique(gabor_cluster_map))
    layer_1_to_gabor_cluster_contribs = torch.zeros(
        (num_layer_1_neurons, n_gabor_clusters)
    )
    layer_2_to_gabor_cluster_contribs = torch.zeros(
        (num_layer_2_neurons, n_gabor_clusters)
    )
    layer_3_to_gabor_cluster_contribs = torch.zeros(
        (num_layer_3_neurons, n_gabor_clusters)
    )

    gabor_cluster_areas = torch.zeros(n_gabor_clusters)
    layer_1_gabor_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_1_neurons, n_gabor_clusters)
    )
    layer_2_gabor_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_2_neurons, n_gabor_clusters)
    )
    layer_3_gabor_cluster_contrib_ratio_to_total = torch.zeros(
        (num_layer_3_neurons, n_gabor_clusters)
    )

    total_layer_1_output_contrib = torch.sum(
        torch.abs(layer_1_output_contrib), dim=(1, 2)
    )
    total_layer_2_output_contrib = torch.sum(
        torch.abs(layer_2_output_contrib), dim=(1, 2)
    )
    total_layer_3_output_contrib = torch.sum(
        torch.abs(layer_3_output_contrib), dim=(1, 2)
    )

    for cluster_id in np.unique(gabor_cluster_map):

        bimask = gabor_cluster_map == cluster_id
        bimask = bimask.squeeze().astype(bool)
        area = bimask.sum()

        curr_gabor_cluster_layer_1_contribs = torch.abs(
            layer_1_output_contrib[:, bimask]
        )
        curr_gabor_cluster_layer_2_contribs = torch.abs(
            layer_2_output_contrib[:, bimask]
        )
        curr_gabor_cluster_layer_3_contribs = torch.abs(
            layer_3_output_contrib[:, bimask]
        )

        total_layer_1_gabor_clust_contrib = torch.sum(
            curr_gabor_cluster_layer_1_contribs, dim=-1
        )
        total_layer_2_gabor_clust_contrib = torch.sum(
            curr_gabor_cluster_layer_2_contribs, dim=-1
        )
        total_layer_3_gabor_clust_contrib = torch.sum(
            curr_gabor_cluster_layer_3_contribs, dim=-1
        )
        avg_layer_1_contrib = total_layer_1_gabor_clust_contrib / area
        avg_layer_2_contrib = total_layer_2_gabor_clust_contrib / area
        avg_layer_3_contrib = total_layer_3_gabor_clust_contrib / area

        # Store the average contribution from each layer neurons to current gabor cluster
        layer_1_to_gabor_cluster_contribs[:, cluster_id] = avg_layer_1_contrib.flatten()
        layer_2_to_gabor_cluster_contribs[:, cluster_id] = avg_layer_2_contrib.flatten()
        layer_3_to_gabor_cluster_contribs[:, cluster_id] = avg_layer_3_contrib.flatten()

        # Compute deltas: (actual contrib - expected contrib) / expected contrib
        layer_1_expected_gabor_cluster_contrib = total_layer_1_output_contrib * (area / total_img_area)
        layer_1_gabor_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(
                total_layer_1_gabor_clust_contrib
                - layer_1_expected_gabor_cluster_contrib
            )
            / layer_1_expected_gabor_cluster_contrib
        )
        layer_2_expected_gabor_cluster_contrib = total_layer_2_output_contrib * (area / total_img_area)
        layer_2_gabor_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(
                total_layer_2_gabor_clust_contrib
                - layer_2_expected_gabor_cluster_contrib
            )
            / layer_2_expected_gabor_cluster_contrib
        )
        layer_3_expected_gabor_cluster_contrib = total_layer_3_output_contrib * (area / total_img_area)
        layer_3_gabor_cluster_contrib_ratio_to_total[:, cluster_id] = (
            torch.abs(
                total_layer_3_gabor_clust_contrib
                - layer_3_expected_gabor_cluster_contrib
            )
            / layer_3_expected_gabor_cluster_contrib
        )

    return (
        layer_1_to_gabor_cluster_contribs,
        layer_2_to_gabor_cluster_contribs,
        layer_3_to_gabor_cluster_contribs,
        layer_1_gabor_cluster_contrib_ratio_to_total,
        layer_2_gabor_cluster_contrib_ratio_to_total,
        layer_3_gabor_cluster_contrib_ratio_to_total,
        gabor_cluster_areas,
    )

In [47]:
def compute_inference_results(
    single_image_dataloader,
    ffn_model,
    cfg,
    categories_dict,
    num_rgb_clusters,
    num_gabor_clusters,
):
    with torch.no_grad():
        batch = next(iter(single_image_dataloader))

    data = batch["data"].cuda()
    N, C, H, W = data.shape
    annotations = convert_annotations_to_numpy(batch["annotations"])
    annotations = add_other_annotation(annotations)

    features = batch["features"].squeeze().cuda()
    features_shape = batch["features_shape"].squeeze().tolist()
    reshape = True

    proc = data_process.DataProcessor(cfg.data, device="cpu")
    x = batch["data"]
    coords = proc.get_coordinates(
        data_shape=features_shape,
        patch_shape=cfg.data.patch_shape,
        split=cfg.data.coord_split,
        normalize_range=cfg.data.coord_normalize_range,
    )
    coords = coords.to(x).cuda()

    inference_results = {}
    with torch.no_grad():
        out = ffn_model(coords, img=data)
        pred = out["predicted"]
        intermediate_results = out["intermediate_results"]

        if reshape:
            # This reshapes the prediction into an image
            pred = proc.process_outputs(
                pred,
                input_img_shape=batch["data_shape"].squeeze().tolist(),
                features_shape=features_shape,
                patch_shape=cfg.data.patch_shape,
            )

    image_numpy = data[0].permute(1, 2, 0).cpu().numpy()
    rgb_cluster_map = compute_kmeans_clusters_in_rgb(image_numpy, num_rgb_clusters)

    # Compute Gabor clusters map
    image_pil_format = (
        (data[0].clamp(0, 1) * 255).permute(1, 2, 0).cpu().numpy().astype(np.uint8)
    )
    gabor_cluster_map = get_gabor_label_map(image_pil_format, num_gabor_clusters)

    inference_results = {
        "data": batch["data"],
        "pred": pred,
        "annotations": annotations,
        "img_hw": (H, W),
        "intermediate_results": intermediate_results,
        "rgb_cluster_map": rgb_cluster_map,
        "gabor_cluster_map": gabor_cluster_map,
    }

    categories_in_frame = {}
    for ann in annotations:
        if ann["category_id"] not in categories_in_frame:
            categories_in_frame[ann["category_id"]] = categories_dict[
                ann["category_id"]
            ]

    categories_in_frame[-1] = categories_dict[-1]
    object_categories = [v["name"] for k, v in categories_in_frame.items()]
    categories_in_frame = [v for k, v in categories_in_frame.items()]

    return inference_results, categories_in_frame, object_categories

In [48]:
def compute_all_variables_for_image(
    inference_results,
    ffn_model,
    instance_to_ann_id_map,
    cell_stride_h,
    cell_stride_w,
    instance_names,
):
    intermediate_results = inference_results["intermediate_results"]
    (H, W) = inference_results["img_hw"]
    annotations = inference_results["annotations"]

    all_variables_for_image = {}

    intermediate_results = inference_results["intermediate_results"]

    rgb_cluster_map = inference_results["rgb_cluster_map"]
    gabor_cluster_map = inference_results["gabor_cluster_map"]

    # Get model contributions
    compute_contrib_obj = ComputeMLPContributions(
        ffn_model, intermediate_results, (H, W)
    )

    layer_1_output_contrib, layer_2_output_contrib, layer_3_output_contrib, _, _, _ = (
        compute_contrib_obj.compute_all_layer_mappings()
    )

    # Get contributions clustered by each entity type and normalize them.
    # Additionally, obtain variances within instance.
    (
        _,
        _,
        _,
        layer_1_instance_contrib_ratio_to_total,
        layer_2_instance_contrib_ratio_to_total,
        layer_3_instance_contrib_ratio_to_total,
        instance_areas,
    ) = get_instance_contribs(
        layer_1_output_contrib,
        layer_2_output_contrib,
        layer_3_output_contrib,
        annotations,
        instance_to_ann_id_map,
        instance_names,
    )
    (
        _,
        _,
        _,
        layer_1_gridcell_contrib_ratio_to_total,
        layer_2_gridcell_contrib_ratio_to_total,
        layer_3_gridcell_contrib_ratio_to_total,
    ) = get_gridcell_contribs(
        layer_1_output_contrib,
        layer_2_output_contrib,
        layer_3_output_contrib,
        cell_stride_h,
        cell_stride_w,
    )
    (
        _,
        _,
        _,
        layer_1_rgb_cluster_contrib_ratio_to_total,
        layer_2_rgb_cluster_contrib_ratio_to_total,
        layer_3_rgb_cluster_contrib_ratio_to_total,
        rgb_cluster_areas,
    ) = get_rgb_cluster_contribs(
        layer_1_output_contrib,
        layer_2_output_contrib,
        layer_3_output_contrib,
        rgb_cluster_map,
    )
    (
        _,
        _,
        _,
        layer_1_gabor_cluster_contrib_ratio_to_total,
        layer_2_gabor_cluster_contrib_ratio_to_total,
        layer_3_gabor_cluster_contrib_ratio_to_total,
        gabor_cluster_areas,
    ) = get_gabor_cluster_contribs(
        layer_1_output_contrib,
        layer_2_output_contrib,
        layer_3_output_contrib,
        gabor_cluster_map,
    )

    # Some of the neurons in MLP are dead (all 0 contributions). These are removed in normalization
    all_variables_for_image = {
        "layer_1_output_contrib": torch.abs(layer_1_output_contrib),
        "layer_2_output_contrib": torch.abs(layer_2_output_contrib),
        "layer_3_output_contrib": torch.abs(layer_3_output_contrib),
        # areas
        "instance_areas": instance_areas,
        "rgb_cluster_areas": rgb_cluster_areas,
        "gabor_cluster_areas": gabor_cluster_areas,
        # per-patch contribution ratios
        "layer_3_instance_contrib_ratio_to_total": layer_3_instance_contrib_ratio_to_total,
        "layer_2_instance_contrib_ratio_to_total": layer_2_instance_contrib_ratio_to_total,
        "layer_1_instance_contrib_ratio_to_total": layer_1_instance_contrib_ratio_to_total,
        "layer_3_gridcell_contrib_ratio_to_total": layer_3_gridcell_contrib_ratio_to_total,
        "layer_2_gridcell_contrib_ratio_to_total": layer_2_gridcell_contrib_ratio_to_total,
        "layer_1_gridcell_contrib_ratio_to_total": layer_1_gridcell_contrib_ratio_to_total,
        "layer_3_rgb_cluster_contrib_ratio_to_total": layer_3_rgb_cluster_contrib_ratio_to_total,
        "layer_2_rgb_cluster_contrib_ratio_to_total": layer_2_rgb_cluster_contrib_ratio_to_total,
        "layer_1_rgb_cluster_contrib_ratio_to_total": layer_1_rgb_cluster_contrib_ratio_to_total,
        "layer_3_gabor_cluster_contrib_ratio_to_total": layer_3_gabor_cluster_contrib_ratio_to_total,
        "layer_2_gabor_cluster_contrib_ratio_to_total": layer_2_gabor_cluster_contrib_ratio_to_total,
        "layer_1_gabor_cluster_contrib_ratio_to_total": layer_1_gabor_cluster_contrib_ratio_to_total,
        "num_instances_in_frame": len(instance_areas),
    }

    return all_variables_for_image

In [58]:
def compute_variance_of_deltas(all_variables_for_image, num_rgb_clusters, cell_stride_h, cell_stride_w):

    num_instances_in_frame = all_variables_for_image["num_instances_in_frame"]

    fig, axs = plt.subplots(1, 3, figsize=(12, 8), tight_layout=True)

    layer_3_instance_variances = torch.var(
        all_variables_for_image["layer_3_instance_contrib_ratio_to_total"], dim=-1
    )
    layer_3_gridcell_variances = torch.var(
        all_variables_for_image["layer_3_gridcell_contrib_ratio_to_total"], dim=-1
    )
    layer_3_rgb_cluster_variances = torch.var(
        all_variables_for_image["layer_3_rgb_cluster_contrib_ratio_to_total"], dim=-1
    )
    layer_3_gabor_cluster_variances = torch.var(
        all_variables_for_image["layer_3_gabor_cluster_contrib_ratio_to_total"], dim=-1
    )

    layer_2_instance_variances = torch.var(
        all_variables_for_image["layer_2_instance_contrib_ratio_to_total"], dim=-1
    )
    layer_2_gridcell_variances = torch.var(
        all_variables_for_image["layer_2_gridcell_contrib_ratio_to_total"], dim=-1
    )
    layer_2_rgb_cluster_variances = torch.var(
        all_variables_for_image["layer_2_rgb_cluster_contrib_ratio_to_total"], dim=-1
    )
    layer_2_gabor_cluster_variances = torch.var(
        all_variables_for_image["layer_2_gabor_cluster_contrib_ratio_to_total"], dim=-1
    )

    layer_1_instance_variances = torch.var(
        all_variables_for_image["layer_1_instance_contrib_ratio_to_total"], dim=-1
    )
    layer_1_gridcell_variances = torch.var(
        all_variables_for_image["layer_1_gridcell_contrib_ratio_to_total"], dim=-1
    )
    layer_1_rgb_cluster_variances = torch.var(
        all_variables_for_image["layer_1_rgb_cluster_contrib_ratio_to_total"], dim=-1
    )
    layer_1_gabor_cluster_variances = torch.var(
        all_variables_for_image["layer_1_gabor_cluster_contrib_ratio_to_total"], dim=-1
    )

    sorted_variance_layer_3_instance_contrib_ratio, layer_3_instance_sorted_indices = torch.sort(layer_3_instance_variances)
    sorted_variance_layer_3_gridcell_contrib_ratio, layer_3_gridcell_sorted_indices = torch.sort(layer_3_gridcell_variances)
    sorted_variance_layer_3_rgb_cluster_contrib_ratio, layer_3_rgb_cluster_sorted_indices = torch.sort(layer_3_rgb_cluster_variances)
    sorted_variance_layer_3_gabor_cluster_contrib_ratio, layer_3_gabor_cluster_sorted_indices = torch.sort(layer_3_gabor_cluster_variances)
    
    sorted_variance_layer_2_instance_contrib_ratio, layer_2_instance_sorted_indices = torch.sort(layer_2_instance_variances)
    sorted_variance_layer_2_gridcell_contrib_ratio, layer_2_gridcell_sorted_indices = torch.sort(layer_2_gridcell_variances)
    sorted_variance_layer_2_rgb_cluster_contrib_ratio, layer_2_rgb_cluster_sorted_indices = torch.sort(layer_2_rgb_cluster_variances)
    sorted_variance_layer_2_gabor_cluster_contrib_ratio, layer_2_gabor_cluster_sorted_indices = torch.sort(layer_2_gabor_cluster_variances)

    sorted_variance_layer_1_instance_contrib_ratio, layer_1_instance_sorted_indices = torch.sort(layer_1_instance_variances)
    sorted_variance_layer_1_gridcell_contrib_ratio, layer_1_gridcell_sorted_indices = torch.sort(layer_1_gridcell_variances)
    sorted_variance_layer_1_rgb_cluster_contrib_ratio, layer_1_rgb_cluster_sorted_indices = torch.sort(layer_1_rgb_cluster_variances)
    sorted_variance_layer_1_gabor_cluster_contrib_ratio, layer_1_gabor_cluster_sorted_indices = torch.sort(layer_1_gabor_cluster_variances)

    labels = ["Instances variance", "Grid cells variance", "RGB Cluster variance", "Gabor Cluster variance"]
    colors = ["r", "g", "b", "m"]

    # Plot layer 3
    for idx, var in enumerate(
        [
            sorted_variance_layer_3_instance_contrib_ratio,
            sorted_variance_layer_3_gridcell_contrib_ratio,
            sorted_variance_layer_3_rgb_cluster_contrib_ratio,
            sorted_variance_layer_3_gabor_cluster_contrib_ratio,
        ]
    ):
        axs[0].plot(var, label=labels[idx], c=colors[idx])
    axs[0].set_title(f"Layer 3")

    # Plot layer 2
    for idx, var in enumerate(
        [
            sorted_variance_layer_2_instance_contrib_ratio,
            sorted_variance_layer_2_gridcell_contrib_ratio,
            sorted_variance_layer_2_rgb_cluster_contrib_ratio,
            sorted_variance_layer_2_gabor_cluster_contrib_ratio,
        ]
    ):
        axs[1].plot(var, label=labels[idx], c=colors[idx])
    axs[1].set_title(f"Layer 2")

    # Plot layer 1
    for idx, var in enumerate(
        [
            sorted_variance_layer_1_instance_contrib_ratio,
            sorted_variance_layer_1_gridcell_contrib_ratio,
            sorted_variance_layer_1_rgb_cluster_contrib_ratio,
            sorted_variance_layer_1_gabor_cluster_contrib_ratio,
        ]
    ):
        axs[2].plot(var, label=labels[idx], c=colors[idx])
    axs[2].set_title(f"Layer 1")

    fig.suptitle(
        f"num_inst={num_instances_in_frame}, num_rgb_clust={num_rgb_clusters}, num_cells={cell_stride_h*cell_stride_w}",
        fontweight="bold",
    )

    # Every subplot has the same legend, let us pick one
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper center", ncol=6, bbox_to_anchor=(0.5, 1.05))

    sorted_deltas_dict = {
        "layer_1": {
            "instances_deltas": sorted_variance_layer_1_instance_contrib_ratio,
            "gridcells_deltas": sorted_variance_layer_1_gridcell_contrib_ratio,
            "rgb_clusters_deltas": sorted_variance_layer_1_rgb_cluster_contrib_ratio,
            "gabor_clusters_deltas": sorted_variance_layer_1_gabor_cluster_contrib_ratio,
        },
        "layer_2": {
            "instances_deltas": sorted_variance_layer_2_instance_contrib_ratio,
            "gridcells_deltas": sorted_variance_layer_2_gridcell_contrib_ratio,
            "rgb_clusters_deltas": sorted_variance_layer_2_rgb_cluster_contrib_ratio,
            "gabor_clusters_deltas": sorted_variance_layer_2_gabor_cluster_contrib_ratio,
        },
        "layer_3": {
            "instances_deltas": sorted_variance_layer_3_instance_contrib_ratio,
            "gridcells_deltas": sorted_variance_layer_3_gridcell_contrib_ratio,
            "rgb_clusters_deltas": sorted_variance_layer_3_rgb_cluster_contrib_ratio,
            "gabor_clusters_deltas": sorted_variance_layer_3_gabor_cluster_contrib_ratio,
        },
        "sorted_indices": {
            "layer_1": {
                "instances": layer_1_instance_sorted_indices,
                "gridcells": layer_1_gridcell_sorted_indices,
                "rgb_clusters": layer_1_rgb_cluster_sorted_indices,
                "gabor_clusters": layer_1_gabor_cluster_sorted_indices,
            },
            "layer_2": {
                "instances": layer_2_instance_sorted_indices,
                "gridcells": layer_2_gridcell_sorted_indices,
                "rgb_clusters": layer_2_rgb_cluster_sorted_indices,
                "gabor_clusters": layer_2_gabor_cluster_sorted_indices,
            },
            "layer_3": {
                "instances": layer_3_instance_sorted_indices,
                "gridcells": layer_3_gridcell_sorted_indices,
                "rgb_clusters": layer_3_rgb_cluster_sorted_indices,
                "gabor_clusters": layer_3_gabor_cluster_sorted_indices,
            },
        },
    }

    return sorted_deltas_dict

Analyze each image and save the raw values for downstream visualization

In [None]:
per_vid_patch_deltas_var_dict = {}

# Cluster settings
num_rgb_and_gabor_clusters_dict = {"0005": 32, "26_cblDl5vCZnw": 24}
cell_stride_h_dict = {"0005": 4, "26_cblDl5vCZnw": 4}
cell_stride_w_dict = {"0005": 8, "26_cblDl5vCZnw": 6}


for dataset_name in dataset_names:
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = dataloader_dict[dataset_name][vidname]
        ffn_model = ffn_models_dict[dataset_name][vidname]
        categories_dict = categories_dicts[dataset_name][vidname]
        cfg = cfg_dict[dataset_name][vidname]

        categories = list(categories_dict.values())

        num_rgb_clusters = num_rgb_and_gabor_clusters_dict[vidname]
        num_gabor_clusters = num_rgb_and_gabor_clusters_dict[vidname]
        cell_stride_h, cell_stride_w = (
            cell_stride_h_dict[vidname],
            cell_stride_w_dict[vidname],
        )

        inference_results, categories_in_frame, object_categories = (
            compute_inference_results(
                single_image_dataloader,
                ffn_model,
                cfg,
                categories_dict,
                num_rgb_clusters,
                num_gabor_clusters,
            )
        )
        (
            inst_id_to_cat_and_inst_suffix,
            instance_to_ann_id_map,
            instance_names,
            object_to_instances_map,
            obj_to_obj_name_idx,
            instance_names,
        ) = get_instance_info(inference_results, object_categories, categories)

        all_variables_for_image = compute_all_variables_for_image(
            inference_results,
            ffn_model,
            instance_to_ann_id_map,
            cell_stride_h,
            cell_stride_w,
            instance_names,
        )

        sorted_deltas_dict = compute_variance_of_deltas(all_variables_for_image, num_rgb_clusters, cell_stride_h, cell_stride_w)

        per_vid_patch_deltas_var_dict[vidname] = {
            "sorted_deltas_dict": sorted_deltas_dict,
            "cluster_info": {
                "num_instances": all_variables_for_image["num_instances_in_frame"],
                "num_rgb_clusters": num_rgb_clusters,
                "num_gabor_clusters": num_gabor_clusters,
                "cell_stride_h": cell_stride_h,
                "cell_stride_w": cell_stride_w,
            },
        }

In [None]:
import pickle

save_dir = "../analysis_data/MLP/contributions_to_objects"
os.makedirs(save_dir, exist_ok=True)

with open(os.path.join(save_dir, f"per_vid_patch_deltas_var_dict.pkl"), "wb") as f:
    pickle.dump(per_vid_patch_deltas_var_dict, f)