In [27]:
import os
import sys
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from utils import data_process

from analysis_utils import *

sys.append("../")

from get_mlp_mappings import ComputeMLPContributions

In [None]:
# Multiple videos
dataset_names = ["cityscapes", "vipseg"]
vidnames = {
    "cityscapes": ["0005", "0175"],
    "vipseg": ["12_n-ytHkMceew", "26_cblDl5vCZnw"],
}


cfg_dict = {}
dataloader_dict = {}
weights_dict = {}
ffn_models_dict = {}
categories_dicts = {}

for dataset_name in dataset_names:
    weights_dict[dataset_name] = {}
    cfg_dict[dataset_name] = {}
    ffn_models_dict[dataset_name] = {}
    categories_dicts[dataset_name] = {}

    for vidname in vidnames[dataset_name]:
        weights_dict[dataset_name][vidname] = "path/to/checkpoint/"

        cfg, categories_dict = load_cfg(
            weights_dict[dataset_name][vidname], dataset_name, vidname
        )
        cfg_dict[dataset_name][vidname] = cfg
        categories_dicts[dataset_name][vidname] = categories_dict
        ffn_models_dict[dataset_name][vidname] = load_model(cfg)

for dataset_name in dataset_names:
    dataloader_dict[dataset_name] = {}
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = get_loader(
            cfg_dict[dataset_name][vidname], dataset_name
        )
        dataloader_dict[dataset_name][vidname] = single_image_dataloader

In [32]:
def compute_inference_results(single_image_dataloader, ffn_model, cfg, categories_dict):
    with torch.no_grad():
        batch = next(iter(single_image_dataloader))

    data = batch["data"].cuda()
    N, C, H, W = data.shape
    annotations = convert_annotations_to_numpy(batch["annotations"])
    annotations = add_other_annotation(annotations)

    features = batch["features"].squeeze().cuda()
    features_shape = batch["features_shape"].squeeze().tolist()
    reshape = True
    proc = data_process.DataProcessor(cfg.data, device="cpu")
    x = batch["data"]
    coords = proc.get_coordinates(
        data_shape=features_shape,
        patch_shape=cfg.data.patch_shape,
        split=cfg.data.coord_split,
        normalize_range=cfg.data.coord_normalize_range,
    )
    coords = coords.to(x).cuda()

    inference_results = {}
    with torch.no_grad():
        out = ffn_model(coords, img=data)
        pred = out["predicted"]
        intermediate_results = out["intermediate_results"]

        if reshape:
            # This reshapes the prediction into an image
            pred = proc.process_outputs(
                pred,
                input_img_shape=batch["data_shape"].squeeze().tolist(),
                features_shape=features_shape,
                patch_shape=cfg.data.patch_shape,
            )

    inference_results = {
        "data": batch["data"],
        "pred": pred,
        "annotations": annotations,
        "img_hw": (H, W),
        "intermediate_results": intermediate_results,
    }

    categories_in_frame = {}
    for ann in annotations:
        if ann["category_id"] not in categories_in_frame:
            categories_in_frame[ann["category_id"]] = categories_dict[
                ann["category_id"]
            ]

    categories_in_frame[-1] = categories_dict[-1]
    object_categories = [v["name"] for k, v in categories_in_frame.items()]
    categories_in_frame = [v for k, v in categories_in_frame.items()]

    return inference_results, categories_in_frame, object_categories

In [34]:
def get_instance_info(inference_results, object_categories, categories):

    # Create a map from unique inst_id to a suffix that denotes an instance
    # number in current video. Also stores object category.
    inst_id_to_cat_and_inst_suffix = {}
    object_to_instances_map = {}
    obj_to_obj_name_idx = {}
    instance_names = []
    object_to_instances_map = defaultdict(list)

    for idx, object_cat in enumerate(object_categories):
        obj_to_obj_name_idx[object_cat] = idx

    instance_to_ann_id_map = {}

    frame_annos = inference_results["annotations"]
    for ann in frame_annos:
        category_name = [
            cat["name"] for cat in categories if cat["id"] == ann["category_id"]
        ][0]
        num_instances_of_obj = len(object_to_instances_map[category_name])

        if ann["inst_id"] not in list(inst_id_to_cat_and_inst_suffix.keys()):
            inst_id_to_cat_and_inst_suffix[ann["inst_id"]] = {
                "category": category_name,
                "inst_suffix": num_instances_of_obj,
                "instance_name": category_name + "_" + str(num_instances_of_obj),
            }

        instance_name = inst_id_to_cat_and_inst_suffix[ann["inst_id"]]["instance_name"]
        instance_to_ann_id_map[instance_name] = ann["id"]
        if instance_name not in instance_names:
            object_to_instances_map[category_name].append(instance_name)
            instance_names.append(instance_name)

    def custom_sort_key(item):
        parts = item.split("_")
        return ("_".join(parts[:-1]), int(parts[-1]))

    instance_names = [item for item in sorted(instance_names, key=custom_sort_key)]
    instance_names.append(instance_names.pop(instance_names.index("other_0")))

    return (
        inst_id_to_cat_and_inst_suffix,
        instance_to_ann_id_map,
        instance_names,
        object_to_instances_map,
        obj_to_obj_name_idx,
        instance_names,
    )

In [35]:
def get_instance_contribs(
    layer_1_output_contrib,
    layer_2_output_contrib,
    layer_3_output_contrib,
    annotations,
    instance_to_ann_id_map,
    instance_names,
):

    # Maps for kernel to object contributions
    num_layer_1_weights = layer_1_output_contrib.shape[0]
    num_layer_2_weights = layer_2_output_contrib.shape[0]
    num_layer_3_weights = layer_3_output_contrib.shape[0]

    num_instances = len(instance_names)
    layer_1_to_instance_contribs = torch.zeros((num_layer_1_weights, num_instances))
    layer_2_to_instance_contribs = torch.zeros((num_layer_2_weights, num_instances))
    layer_3_to_instance_contribs = torch.zeros((num_layer_3_weights, num_instances))

    for instance in instance_to_ann_id_map:
        ann_id = instance_to_ann_id_map[instance]
        ann = [ann for ann in annotations if ann["id"] == ann_id][0]

        area = ann["area"]
        binary_mask = ann["binary_mask"].squeeze()

        curr_instance_layer_1_contribs = torch.abs(
            layer_1_output_contrib[:, binary_mask]
        )
        curr_instance_layer_2_contribs = torch.abs(
            layer_2_output_contrib[:, binary_mask]
        )
        curr_instance_layer_3_contribs = torch.abs(
            layer_3_output_contrib[:, binary_mask]
        )

        # Get aggregated total contribution for each kernel to the instance
        total_layer_1_contrib = torch.sum(curr_instance_layer_1_contribs, dim=-1)
        total_layer_2_contrib = torch.sum(curr_instance_layer_2_contribs, dim=-1)
        total_layer_3_contrib = torch.sum(curr_instance_layer_3_contribs, dim=-1)
        avg_layer_1_contrib = total_layer_1_contrib / area
        avg_layer_2_contrib = total_layer_2_contrib / area
        avg_layer_3_contrib = total_layer_3_contrib / area

        inst_idx = instance_names.index(instance)
        layer_1_to_instance_contribs[:, inst_idx] = avg_layer_1_contrib.flatten()
        layer_2_to_instance_contribs[:, inst_idx] = avg_layer_2_contrib.flatten()
        layer_3_to_instance_contribs[:, inst_idx] = avg_layer_3_contrib.flatten()

    return layer_1_to_instance_contribs, layer_2_to_instance_contribs, layer_3_to_instance_contribs

In [36]:
def get_normalized_contribs(
    layer_1_to_instance_contribs,
    layer_2_to_instance_contribs,
    layer_3_to_instance_contribs,
):

    layer_1_nonzero_rows = (torch.sum(layer_1_to_instance_contribs, dim=1)) != 0
    layer_2_nonzero_rows = (torch.sum(layer_2_to_instance_contribs, dim=1)) != 0
    layer_3_nonzero_rows = (torch.sum(layer_3_to_instance_contribs, dim=1)) != 0

    # Remove the rows (kernels) whose contributions sum to zeros
    layer_1_to_instance_contribs = layer_1_to_instance_contribs[layer_1_nonzero_rows, :]
    layer_2_to_instance_contribs = layer_2_to_instance_contribs[layer_2_nonzero_rows, :]
    layer_3_to_instance_contribs = layer_3_to_instance_contribs[layer_3_nonzero_rows, :]
    layer_1_contribs_normalized_by_instance = (
        layer_1_to_instance_contribs
        / torch.sum(layer_1_to_instance_contribs, dim=1)[:, None]
    )
    layer_2_contribs_normalized_by_instance = (
        layer_2_to_instance_contribs
        / torch.sum(layer_2_to_instance_contribs, dim=1)[:, None]
    )
    layer_3_contribs_normalized_by_instance = (
        layer_3_to_instance_contribs
        / torch.sum(layer_3_to_instance_contribs, dim=1)[:, None]
    )

    return (
        layer_1_contribs_normalized_by_instance,
        layer_2_contribs_normalized_by_instance,
        layer_3_contribs_normalized_by_instance,
    )

In [38]:
def compute_all_variables_for_frame(
    inference_results, instance_to_ann_id_map, instance_names, ffn_model
):
    intermediate_results = inference_results["intermediate_results"]
    (H, W) = inference_results["img_hw"]
    annotations = inference_results["annotations"]

    # Get model contributions
    compute_contrib_obj = ComputeMLPContributions(
        ffn_model, intermediate_results, (H, W)
    )

    layer_1_output_contrib, layer_2_output_contrib, layer_3_output_contrib, _, _, _ = (
        compute_contrib_obj.compute_all_layer_mappings()
    )
    (
        layer_1_to_instance_contribs,
        layer_2_to_instance_contribs,
        layer_3_to_instance_contribs,
    ) = get_instance_contribs(
        layer_1_output_contrib,
        layer_2_output_contrib,
        layer_3_output_contrib,
        annotations,
        instance_to_ann_id_map,
        instance_names,
    )

    # Some of the neurons in MLP are dead (all 0 contribs). These are removed in normalization
    (
        layer_1_contribs_normalized_by_instance,
        layer_2_contribs_normalized_by_instance,
        layer_3_contribs_normalized_by_instance,
    ) = get_normalized_contribs(
        layer_1_to_instance_contribs,
        layer_2_to_instance_contribs,
        layer_3_to_instance_contribs,
    )

    all_variables_for_frame = {
        "instance_names": instance_names,
        "layer_1_contribs_normalized_by_instance": layer_1_contribs_normalized_by_instance,
        "layer_2_contribs_normalized_by_instance": layer_2_contribs_normalized_by_instance,
        "layer_3_contribs_normalized_by_instance": layer_3_contribs_normalized_by_instance,
    }

    return all_variables_for_frame

## MLP Contributions to Instances

In [45]:
custom_colors = [
    "#1f77b4",
    "#2ca02c",
    "#d62728",
    "#e377c2",
    "#7f7f7f",
    "#843c39",
    "#7b4173",
    "#5254a3",
    "#fdc7c7",
    "#637939",
    "#bcbd22",
    "#17becf",
    "#b1c2b3",
    "#ff9896",
    "#ff7f0e",
    "#f5c651",
    "#02a3a3",
]

def plot_kernel_instance_contribs(all_variables_for_frame, obj_to_obj_name_idx, object_categories, custom_colors):
    num_instances_per_obj = 3
    num_neuron_samples_per_layer = 3

    layer_1_contribs_normalized_by_instance = all_variables_for_frame["layer_1_contribs_normalized_by_instance"]
    layer_2_contribs_normalized_by_instance = all_variables_for_frame["layer_2_contribs_normalized_by_instance"]
    layer_3_contribs_normalized_by_instance = all_variables_for_frame["layer_3_contribs_normalized_by_instance"]
    sampled_layer_1_neurons = torch.randperm(layer_1_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]
    sampled_layer_2_neurons = torch.randperm(layer_2_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]
    sampled_layer_3_neurons = torch.randperm(layer_3_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]

    layer_to_sampled_neurons = {
        1: sampled_layer_1_neurons,
        2: sampled_layer_2_neurons,
        3: sampled_layer_3_neurons
    }

    object_to_color_map = dict(zip(object_categories, custom_colors[:len(object_categories)]))
    markers = [".", "*", "p", "X", "^", "s", "2", "d"]
    instance_names = all_variables_for_frame["instance_names"]

    fig, ax = plt.subplots(figsize=(6,4), tight_layout=True)
    x = np.arange(num_neuron_samples_per_layer * 3)
    x_offsets = np.linspace(-0.25, 0.25, len(obj_to_obj_name_idx.keys()))

    # Get contributions of each instance from each sampled kernel
    inst_kernel_contribs = torch.zeros((len(instance_names), num_neuron_samples_per_layer*3))

    # Plot neurons for each layer
    for layer_idx, layer in enumerate([1, 2, 3]):
        for neuron_idx, sampled_neuron_idx in enumerate(layer_to_sampled_neurons[layer]):
            if layer == 1:
                sampled_layer_contribs = layer_1_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
            elif layer == 2:
                sampled_layer_contribs = layer_2_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
            elif layer == 3:
                sampled_layer_contribs = layer_3_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
                
            inst_kernel_contribs[:, (layer_idx-1)*3 + neuron_idx] = sampled_layer_contribs.squeeze()
 
    # Plot each instance, using a different marker for each instance in a category
    leading_obj = '_'.join(instance_names[0].split('_')[0:-1])
    marker_idx = -1
    for inst_num, inst_name in enumerate(instance_names):
        # reset marker index for each new object
        if('_'.join(inst_name.split('_')[0:-1]) == leading_obj):
            marker_idx += 1
        else:
            marker_idx = 0
            leading_obj = '_'.join(inst_name.split('_')[0:-1])
        
        if marker_idx < num_instances_per_obj:
            data_with_nones = np.where(inst_kernel_contribs[inst_num, :] == 0.0, None, inst_kernel_contribs[inst_num, :]
            offset_for_obj = x_offsets[list(obj_to_obj_name_idx.keys()).index(leading_obj)]
            label = inst_name.replace(' ', '_')
            label = label.replace('vegitation', 'vegetation')
            ax.plot(x + offset_for_obj, data_with_nones, marker=markers[marker_idx], linestyle='None', label=label, color=object_to_color_map[leading_obj])
            
            
    ax.set_xticks(x)
    ax.set_xticklabels(['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z', ])
    ax.set_ylabel(f"Contributions")
    ax.set_xlabel(f"Sampled Neurons from Layers 1, 2, 3")
    ax.axvline(x=2.5, linestyle='--', color='gray', alpha=0.4, linewidth=1)
    ax.axvline(x=5.5, linestyle='--', color='gray', alpha=0.4, linewidth=1)

    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=6, bbox_to_anchor=(0.5, 1.10), fontsize='small')
    fig.suptitle(f"MLP Layers - Instance Contributions for Sampled Neurons", y=1.15)

    neuron_to_inst_contrib_dict = {
        "instance_names": instance_names,
        "object_categories": object_categories,
        "obj_to_obj_name_idx": obj_to_obj_name_idx,
        "num_instances_per_obj": num_instances_per_obj,
        "num_neuron_samples_per_layer": num_neuron_samples_per_layer,
        "inst_kernel_contribs": inst_kernel_contribs,
    }
    return neuron_to_inst_contrib_dict

Analyze each image and save the raw values for downstream visualization

In [None]:
per_vid_neuron_to_inst_contrib_dict = {}

for dataset_name in dataset_names:
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = dataloader_dict[dataset_name][vidname]
        ffn_model = ffn_models_dict[dataset_name][vidname]
        categories_dict = categories_dicts[dataset_name][vidname]
        cfg = cfg_dict[dataset_name][vidname]

        categories = list(categories_dict.values())

        inference_results, categories_in_frame, object_categories = (
            compute_inference_results(
                single_image_dataloader, ffn_model, cfg, categories_dict
            )
        )
        (
            inst_id_to_cat_and_inst_suffix,
            instance_to_ann_id_map,
            instance_names,
            object_to_instances_map,
            obj_to_obj_name_idx,
            instance_names,
        ) = get_instance_info(inference_results, object_categories, categories)

        all_variables_for_frame = compute_all_variables_for_frame(
            inference_results, instance_to_ann_id_map, instance_names, ffn_model
        )
        per_vid_neuron_to_inst_contrib_dict[vidname] = plot_kernel_instance_contribs(
            all_variables_for_frame, obj_to_obj_name_idx, object_categories, custom_colors
        )

In [46]:
import pickle

save_dir = "../analysis_data/MLP/contributions_to_objects"
os.makedirs(save_dir, exist_ok=True)

with open(
    os.path.join(save_dir, f"per_vid_neuron_to_inst_contrib_dict.pkl"), "wb"
) as f:
    pickle.dump(per_vid_neuron_to_inst_contrib_dict, f)