In [27]:
import json
import os
import sys
from collections import defaultdict

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from utils import data_process, helper

sys.append('../')

from get_mlp_mappings import ComputeMLPContributions
from image_vps_datasets import (single_image_cityscape_vps_dataset,
                                single_image_vipseg_dataset)
from model_all_analysis import ffn, lightning_model

In [28]:
def convert_tensor_annotations_to_numpy(tensor_annotations):
    annotations = []
    
    for tensor_anno in tensor_annotations:
        annotation = {}

        annotation['id'] = tensor_anno['id'].item()
        annotation['inst_id'] = tensor_anno['inst_id'].item()
        annotation['image_id'] = tensor_anno['image_id'][0] #.item()
        annotation['category_id'] = tensor_anno['category_id'].item()
        annotation['area'] = tensor_anno['area'].item()
        annotation['iscrowd'] = tensor_anno['iscrowd'].item()
        annotation['isthing'] = tensor_anno['isthing'].item()

        # Convert 'bbox' back to regular format
        bbox = [bbox_tensor.item() for bbox_tensor in tensor_anno['bbox']]
        annotation['bbox'] = bbox

        annotation['binary_mask'] = tensor_anno['binary_mask'].numpy()
        
        annotations.append(annotation)

    return annotations

def add_other_annotation(annotations):

    # Create a mask that will indicate whether a location contains at least one instance
    object_region_mask = None
    for ann in annotations:
        binary_mask = ann['binary_mask'].squeeze()
        if object_region_mask is None:
            # If the object_region_mask is None, initialize it to current binary_mask otherwise aggregate it
            object_region_mask = binary_mask.copy()
        else:
            object_region_mask += binary_mask

    # Binarize
    object_region_mask = object_region_mask != 0
    
    # Create an annotation denoting "other" for regions that have no objects
    annotations.append({
        "id": -1,
        "inst_id": -1,
        "bbox": compute_bbox(object_region_mask),
        "area": object_region_mask.sum(),
        "binary_mask": object_region_mask,
        'iscrowd': 0,
        'isthing': 0,
        'category_id': -1,
        'image_id': annotations[0]['image_id']
    })
    return annotations

def plot_image_with_instances(image, annotations, categories_dict, title=None):
    plt.rcParams["figure.figsize"] = 15, 10
    fig, ax = plt.subplots()

    # Plot the image
    ax.imshow(image)

    for anno in annotations:
        # Skip plotting "other" regions (regions without objects)
        if anno["category_id"] == -1:
            continue
        # Draw bbox
        x, y, w, h = anno["bbox"]

        cat_color = np.array(categories_dict[int(anno["category_id"])]['color']) / 255
        rectangle = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor=cat_color, facecolor='none')
        ax.add_patch(rectangle)

        if 'binary_mask' in anno.keys():
            binary_mask = anno["binary_mask"].squeeze(0)
        else:
            raise ValueError("No binary mask found in annotation")
        # Create a mask where the binary mask is not zero
        mask = binary_mask != 0

        # Create rgba mask
        cmap = ListedColormap(cat_color)
        colored_mask = cmap(binary_mask.astype(float) / 1.0)

        # Create a mask where the binary mask is not zero
        mask = binary_mask != 0

        # Set the alpha channel to 0 for regions where the binary mask is zero
        colored_mask[:, :, 3] = mask.astype(float)
        
        # Display the colored mask over the image
        ax.imshow(colored_mask, alpha=0.5)

    if title is not None:
        plt.title(title)
    plt.show()
    
def compute_bbox(binary_mask):
    (rows, cols) = np.where(binary_mask > 0)
    x_min, x_max, y_min, y_max = min(cols), max(cols), min(rows), max(rows)
    # Create the bbox in COCO format [x, y, width, height]
    width = x_max - x_min + 1
    height = y_max - y_min + 1
    bbox = [x_min, y_min, width, height]
    return bbox

In [29]:
def load_cfg(model_ckpt_dir, dataset_name, vidname):
    
    if dataset_name == "cityscapes":
        # Add cityscapes VPS paths
        # vidname = "0005"
        exp_config_path = os.path.join(model_ckpt_dir, 'exp_config.yaml')
        
        cfg = OmegaConf.load(exp_config_path)
        
        cfg.data.cityscapes_vps_root = "../data/cityscapes_vps"
        cfg.data.split = "val"
        cfg.data.panoptic_video_mask_dir = os.path.join(cfg.data.cityscapes_vps_root, cfg.data.split, "panoptic_video")
        cfg.data.panoptic_inst_mask_dir = os.path.join(cfg.data.cityscapes_vps_root, cfg.data.split, "panoptic_inst")
        
        cfg.data.vidname = vidname
        # We will work with the first annotated frame in the given video
        cfg.data.frame_num_in_video = 0
        
        cfg.data.data_path = os.path.join(cfg.data.cityscapes_vps_root, cfg.data.split, "img_all")
        cfg.data.anno_path = '../data/cityscapes_vps/panoptic_gt_val_city_vps.json'
        
        with open(cfg.data.anno_path, 'r') as f:
            panoptic_gt_val_city_vps = json.load(f)
                    
        panoptic_categories = panoptic_gt_val_city_vps['categories']
        # panoptic_images = panoptic_gt_val_city_vps['images']
        # panoptic_annotations = panoptic_gt_val_city_vps['annotations']    
        
        categories = panoptic_categories
        categories.append(
            {'id': -1, 'name': 'other', 'supercategory': '', 'color':None}
        )
        categories_dict = {el['id']: el for el in categories}

    elif dataset_name == "vipseg":
        exp_config_path = os.path.join(model_ckpt_dir, 'exp_config.yaml')
        
        
        cfg = OmegaConf.load(exp_config_path)
        
        cfg.data.VIPSeg_720P_root = '../data/VIPSeg-Dataset/VIPSeg/VIPSeg_720P'
        cfg.data.panomasks_dir = os.path.join(cfg.data.VIPSeg_720P_root, "panomasks")
        cfg.data.panomasksRGB_dir = os.path.join(cfg.data.VIPSeg_720P_root, "panomasksRGB")
        
        cfg.data.vidname = vidname
        # We will work with the first annotated frame in the given video
        cfg.data.frame_num_in_video = 0
        
        cfg.data.data_path = data_path = os.path.join(cfg.data.VIPSeg_720P_root, "images")
        cfg.data.anno_path = '../data/VIPSeg-Dataset/VIPSeg/VIPSeg_720P/panoptic_gt_VIPSeg.json'
        
        # Crop for VIPSeg to match NeRV
        cfg.data.crop=[640,1280]
        
        with open(cfg.data.anno_path, 'r') as f:
            panoptic_gt_VIPSeg = json.load(f)
                    
        panoptic_categories = panoptic_gt_VIPSeg['categories']
        # panoptic_videos = panoptic_gt_VIPSeg['videos']
        # panoptic_annotations = panoptic_gt_VIPSeg['annotations']    
        
        categories = panoptic_categories
        categories.append(
            {'id': -1, 'name': 'other', 'supercategory': '', 'color':None}
        )
        categories_dict = {el['id']: el for el in categories}
        
    return cfg, categories_dict

# object_categories = [v['name'] for k, v in categories_dict.items()]


def load_model(cfg):
    save_dir = cfg.logging.checkpoint.logdir
    ckpt_path = helper.find_ckpt(save_dir)
    print(f'Loading checkpoint from {ckpt_path}')

    checkpoint = torch.load(ckpt_path)

    # Load checkpoint into this wrapper model cause that is what is stored in disk :)
    model = lightning_model(cfg, ffn(cfg))
    model.load_state_dict(checkpoint['state_dict'])
    ffn_model = model.model
    
    return ffn_model.cuda()

def get_loader(cfg,dataset_name,val=False):
    # use the dataloader which returns image along with annotations
    if dataset_name == "cityscapes":
        img_dataset = single_image_cityscape_vps_dataset(cfg)
    else:
        img_dataset = single_image_vipseg_dataset(cfg)
    #create torch dataset for one image.
    loader = DataLoader(img_dataset, batch_size=1, shuffle = False ,num_workers=0)
    return loader

In [30]:
# Multiple videos
dataset_names = ['cityscapes', 'vipseg']
vidnames = {
    'cityscapes': ['0005', '0175'],
    'vipseg': ['12_n-ytHkMceew', '26_cblDl5vCZnw']
}

vid_data_folder_name = {
    "cityscapes": "Cityscapes_VPS_models",
    "vipseg": "VIPSeg_models"
}


cfg_dict = {}
dataloader_dict = {}
weights_dict = {}
ffn_models_dict = {}
categories_dicts = {}

In [None]:
for dataset_name in dataset_names:
    weights_dict[dataset_name] = {}
    cfg_dict[dataset_name] = {}
    ffn_models_dict[dataset_name] = {}
    categories_dicts[dataset_name] = {}

    for vidname in vidnames[dataset_name]:
        vid_data_folder = vid_data_folder_name[dataset_name]
        weights_dict[dataset_name][vidname] = f'output/{vid_data_folder}/{vidname}/{vidname}_framenum_0_128_256'
        
        cfg, categories_dict = load_cfg(weights_dict[dataset_name][vidname], dataset_name, vidname)
        cfg_dict[dataset_name][vidname] = cfg
        categories_dicts[dataset_name][vidname] = categories_dict
        
        
        ffn_models_dict[dataset_name][vidname] = load_model(cfg)
        
for dataset_name in dataset_names:
    dataloader_dict[dataset_name] = {}
    
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = get_loader(cfg_dict[dataset_name][vidname], dataset_name)
        
        dataloader_dict[dataset_name][vidname] = single_image_dataloader

In [32]:
def compute_inference_results(single_image_dataloader, ffn_model, cfg, categories_dict):
    with torch.no_grad():
        batch = next(iter(single_image_dataloader))

    data = batch['data'].cuda()
    N,C,H,W = data.shape
    annotations = convert_tensor_annotations_to_numpy(batch['annotations'])
    annotations = add_other_annotation(annotations)

    features = batch['features'].squeeze().cuda()
    features_shape = batch['features_shape'].squeeze().tolist()
    reshape = True

    proc = data_process.DataProcessor(cfg.data, device='cpu')
    x = batch['data']
    coords = proc.get_coordinates(data_shape=features_shape,patch_shape=cfg.data.patch_shape,\
                                    split=cfg.data.coord_split,normalize_range=cfg.data.coord_normalize_range)
    coords = coords.to(x).cuda()

    # Create a dictionary to store the intermediate decoder_results from each seeded model, over time.
    inference_results = {}
    kwargs = {}
    with torch.no_grad():
        out = ffn_model(coords, img=data)
        pred = out['predicted']
        intermediate_results = out["intermediate_results"]
        
        if reshape:
            # This reshapes the prediction into an image
            pred = proc.process_outputs(
                pred,input_img_shape=batch['data_shape'].squeeze().tolist(),\
                features_shape=features_shape,\
                patch_shape=cfg.data.patch_shape)

    inference_results = {
        "data": batch["data"],
        "pred": pred,
        "annotations": annotations,
        "img_hw": (H,W),
        "intermediate_results": intermediate_results
    }
    
    categories_in_frame = {}
    for ann in annotations:
        if ann["category_id"] not in categories_in_frame:
            categories_in_frame[ann["category_id"]] = categories_dict[ann["category_id"]]

    categories_in_frame[-1] = categories_dict[-1]
    object_categories = [v['name'] for k, v in categories_in_frame.items()]
    categories_in_frame = [v for k, v in categories_in_frame.items()]
    
    return inference_results, categories_in_frame, object_categories

In [34]:
def get_instance_info(inference_results, object_categories, categories):
    
    # Create a map from unique inst_id to a suffix that denotes an instance number in current video. Also stores object category.
    inst_id_to_cat_and_inst_suffix = {}
    
    object_to_instances_map = {}
    obj_to_obj_name_idx = {}
    
    instance_names = []
    object_to_instances_map = defaultdict(list)
    
    for idx, object_cat in enumerate(object_categories):
        obj_to_obj_name_idx[object_cat] = idx
    
    instance_to_ann_id_map = {}

    # Get annos for current frame
    frame_annos = inference_results["annotations"]
    for ann in frame_annos:
        category_name = [cat["name"] for cat in categories if cat["id"] == ann["category_id"]][0]
        
        # Get the current number of instances of this category            
        num_instances_of_obj = len(object_to_instances_map[category_name])
        
        if ann["inst_id"] not in list(inst_id_to_cat_and_inst_suffix.keys()):
            # Create a dictionary for the instance
            inst_id_to_cat_and_inst_suffix[ann["inst_id"]] = {
                "category": category_name,
                "inst_suffix": num_instances_of_obj, #0
                "instance_name": category_name + '_' + str(num_instances_of_obj)
            }

        # Retrieve the stored instance name
        instance_name = inst_id_to_cat_and_inst_suffix[ann["inst_id"]]["instance_name"]

        instance_to_ann_id_map[instance_name] = ann['id']

        if instance_name not in instance_names:
            object_to_instances_map[category_name].append(instance_name)
            instance_names.append(instance_name)

    def custom_sort_key(item):
        parts = item.split('_')
        return ("_".join(parts[:-1]), int(parts[-1]))
        
    # Sort the instance names
    instance_names = [item for item in sorted(instance_names, key=custom_sort_key)]
    
    # Find "other_0" instance in this list and move it to the back
    instance_names.append(instance_names.pop(instance_names.index("other_0")))
    
    return inst_id_to_cat_and_inst_suffix, instance_to_ann_id_map, instance_names, object_to_instances_map, obj_to_obj_name_idx, instance_names



In [35]:
# For each instance - get average contrib, total contrib and total area (other useful info too)
def get_instance_contribs(
    layer_1_output_contrib, layer_2_output_contrib, layer_3_output_contrib, annotations, instance_to_ann_id_map, instance_names 
):
    
    # Maps for kernel to object contributions
    num_layer_1_weights = layer_1_output_contrib.shape[0]
    num_layer_2_weights = layer_2_output_contrib.shape[0]
    num_layer_3_weights = layer_3_output_contrib.shape[0]

    num_instances = len(instance_names)
    layer_1_to_instance_contribs = torch.zeros((num_layer_1_weights, num_instances))
    layer_2_to_instance_contribs = torch.zeros((num_layer_2_weights, num_instances))
    layer_3_to_instance_contribs = torch.zeros((num_layer_3_weights, num_instances))

    for instance in instance_to_ann_id_map:
        ann_id = instance_to_ann_id_map[instance]
        ann = [ann for ann in annotations if ann['id'] == ann_id][0]
        
        area = ann['area']
        binary_mask = ann['binary_mask'].squeeze()
        
        # Use binary mask of shape hxw to index into the n_featsxhxw contribution tensor
        # to get the contribs for the current instance
        curr_instance_layer_1_contribs = torch.abs(layer_1_output_contrib[:, binary_mask])
        curr_instance_layer_2_contribs = torch.abs(layer_2_output_contrib[:, binary_mask])
        curr_instance_layer_3_contribs = torch.abs(layer_3_output_contrib[:, binary_mask])
        
        # Get aggregated total contribution for each kernel to the instance
        total_layer_1_contrib = torch.sum(curr_instance_layer_1_contribs, dim=-1)
        total_layer_2_contrib = torch.sum(curr_instance_layer_2_contribs, dim=-1)
        total_layer_3_contrib = torch.sum(curr_instance_layer_3_contribs, dim=-1)
        
        avg_layer_1_contrib = total_layer_1_contrib / area
        avg_layer_2_contrib = total_layer_2_contrib / area
        avg_layer_3_contrib = total_layer_3_contrib / area
            
        # Store the average contribution from each head kernel to current instance
        inst_idx = instance_names.index(instance)
        layer_1_to_instance_contribs[:, inst_idx] = avg_layer_1_contrib.flatten()
    
        # Store the average contribution from each block 3 kernel to current instance
        layer_2_to_instance_contribs[:, inst_idx] = avg_layer_2_contrib.flatten()

        layer_3_to_instance_contribs[:, inst_idx] = avg_layer_3_contrib.flatten()
        

    return layer_1_to_instance_contribs, layer_2_to_instance_contribs, layer_3_to_instance_contribs
    

In [36]:
def get_normalized_contribs(layer_1_to_instance_contribs, layer_2_to_instance_contribs, layer_3_to_instance_contribs):
    # To deal with dead (all 0 contrib) neurons, we need to be careful about normalization
    
    layer_1_nonzero_rows = (torch.sum(layer_1_to_instance_contribs, dim=1)) != 0
    layer_2_nonzero_rows = (torch.sum(layer_2_to_instance_contribs, dim=1)) != 0
    layer_3_nonzero_rows = (torch.sum(layer_3_to_instance_contribs, dim=1)) != 0
    
    # Remove the rows (kernels) whose contributions sum to zeros
    layer_1_to_instance_contribs = layer_1_to_instance_contribs[layer_1_nonzero_rows, :]
    layer_2_to_instance_contribs = layer_2_to_instance_contribs[layer_2_nonzero_rows, :]
    layer_3_to_instance_contribs = layer_3_to_instance_contribs[layer_3_nonzero_rows, :]
    
    layer_1_contribs_normalized_by_instance = layer_1_to_instance_contribs / torch.sum(layer_1_to_instance_contribs, dim=1)[:, None]
    layer_2_contribs_normalized_by_instance = layer_2_to_instance_contribs / torch.sum(layer_2_to_instance_contribs, dim=1)[:, None]
    layer_3_contribs_normalized_by_instance = layer_3_to_instance_contribs / torch.sum(layer_3_to_instance_contribs, dim=1)[:, None]
     
    return layer_1_contribs_normalized_by_instance, layer_2_contribs_normalized_by_instance, layer_3_contribs_normalized_by_instance

In [38]:
# Normalize and stuff
def compute_all_variables_for_frame(inference_results, ffn_model):
    intermediate_results = inference_results["intermediate_results"]
    (H,W) = inference_results["img_hw"]
    annotations = inference_results["annotations"]

    # Get model contributions
    compute_contrib_obj = ComputeMLPContributions(
        ffn_model, intermediate_results, (H,W)
    )

    layer_1_output_contrib, layer_2_output_contrib, layer_3_output_contrib, _, _, _ = compute_contrib_obj.compute_all_layer_mappings()

    layer_1_to_instance_contribs, layer_2_to_instance_contribs, layer_3_to_instance_contribs \
            = get_instance_contribs(layer_1_output_contrib, layer_2_output_contrib, layer_3_output_contrib, annotations, instance_to_ann_id_map, instance_names)

    # Beware, some of the neurons in MLP are dead (all 0 contribs). These are removed in normalization

    layer_1_contribs_normalized_by_instance, layer_2_contribs_normalized_by_instance, layer_3_contribs_normalized_by_instance \
            = get_normalized_contribs(layer_1_to_instance_contribs, layer_2_to_instance_contribs, layer_3_to_instance_contribs)

    all_variables_for_frame = {
        "instance_names": instance_names,
        "layer_1_contribs_normalized_by_instance": layer_1_contribs_normalized_by_instance,
        "layer_2_contribs_normalized_by_instance": layer_2_contribs_normalized_by_instance,
        "layer_3_contribs_normalized_by_instance": layer_3_contribs_normalized_by_instance
    }

    return all_variables_for_frame


# MLP Contribs to Instances (Scatter Plot)

In [41]:
custom_colors = [
    '#1f77b4', '#2ca02c', '#d62728', '#e377c2', '#7f7f7f',
    '#843c39', '#7b4173', '#5254a3', '#fdc7c7', '#637939',
    '#bcbd22', '#17becf', '#b1c2b3', '#ff9896', '#ff7f0e',
    '#f5c651', '#02a3a3'
]

In [45]:
def plot_kernel_instance_contribs(all_variables_for_frame, obj_to_obj_name_idx, custom_colors):
    num_instances_per_obj = 3
    num_neuron_samples_per_layer = 3

    layer_1_contribs_normalized_by_instance = all_variables_for_frame["layer_1_contribs_normalized_by_instance"]
    layer_2_contribs_normalized_by_instance = all_variables_for_frame["layer_2_contribs_normalized_by_instance"]
    layer_3_contribs_normalized_by_instance = all_variables_for_frame["layer_3_contribs_normalized_by_instance"]
    
    sampled_layer_1_neurons = torch.randperm(layer_1_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]
    sampled_layer_2_neurons = torch.randperm(layer_2_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]
    sampled_layer_3_neurons = torch.randperm(layer_3_contribs_normalized_by_instance.shape[0])[:num_neuron_samples_per_layer]

    layer_to_sampled_neurons = {
        1: sampled_layer_1_neurons,
        2: sampled_layer_2_neurons,
        3: sampled_layer_3_neurons
    }

    object_to_color_map = dict(zip(object_categories, custom_colors[:len(object_categories)]))
        
    markers = [".", "*", "p", "X", "^", "s", "2", "d"]

    instance_names = all_variables_for_frame["instance_names"]

    fig, ax = plt.subplots(figsize=(6,4), tight_layout=True)

    x = np.arange(num_neuron_samples_per_layer * 3)
    # or define offsets in terms of objects
    x_offsets = np.linspace(-0.25, 0.25, len(obj_to_obj_name_idx.keys()))

    # Get contributions of each instance from each sampled kernel
    inst_kernel_contribs = torch.zeros((len(instance_names), num_neuron_samples_per_layer*3))

    # Plot neurons for each layer
    for layer_idx, layer in enumerate([1, 2, 3]):
        
        for neuron_idx, sampled_neuron_idx in enumerate(layer_to_sampled_neurons[layer]):
            
            if layer == 1:
                sampled_layer_contribs = layer_1_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
            elif layer == 2:
                sampled_layer_contribs = layer_2_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
            elif layer == 3:
                sampled_layer_contribs = layer_3_contribs_normalized_by_instance[sampled_neuron_idx, :] # 1 x num_instances
                
            inst_kernel_contribs[:, (layer_idx-1)*3 + neuron_idx] = sampled_layer_contribs.squeeze()
 
    # Plot each instances line, using a different marker for each instance in a category
    leading_obj = '_'.join(instance_names[0].split('_')[0:-1])
    marker_idx = -1
    for inst_num, inst_name in enumerate(instance_names):
        # reset marker index for each new object
        if('_'.join(inst_name.split('_')[0:-1]) == leading_obj):
            marker_idx += 1
        else:
            marker_idx = 0
            leading_obj = '_'.join(inst_name.split('_')[0:-1])
        
        if marker_idx < num_instances_per_obj:
            data_with_nones = np.where(inst_kernel_contribs[inst_num, :] == 0.0, None, inst_kernel_contribs[inst_num, :]
            # Offset instances of each object by a certain amount
            offset_for_obj = x_offsets[list(obj_to_obj_name_idx.keys()).index(leading_obj)]
            label = inst_name.replace(' ', '_')
            label = label.replace('vegitation', 'vegetation')
            ax.plot(x + offset_for_obj, data_with_nones, marker=markers[marker_idx], linestyle='None', label=label, color=object_to_color_map[leading_obj])
            
            
    ax.set_xticks(x)
    ax.set_xticklabels(['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z', ])
    
        
    fig.suptitle(f"MLP Layers - Instance Contributions for Sampled Neurons", y=1.15)
    ax.set_ylabel(f"Contributions")
    ax.set_xlabel(f"Sampled Neurons from Layers 1, 2, 3")

    # Add vertical lines between the 3rd and 4th x-ticks and between the 6th and 7th x-ticks
    ax.axvline(x=2.5, linestyle='--', color='gray', alpha=0.4, linewidth=1)
    ax.axvline(x=5.5, linestyle='--', color='gray', alpha=0.4, linewidth=1)

    # Every subplot has the same legend, I want to pick one and then plot that for the entire figure
    handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=6, bbox_to_anchor=(0.5, 1.10), fontsize='small')


    neuron_to_inst_contrib_dict = {
        "instance_names": instance_names,
        "object_categories": object_categories,
        "obj_to_obj_name_idx": obj_to_obj_name_idx,
        
        "num_instances_per_obj": num_instances_per_obj,
        "num_neuron_samples_per_layer": num_neuron_samples_per_layer,
        "inst_kernel_contribs": inst_kernel_contribs,
    }
    return neuron_to_inst_contrib_dict

## Trying different images

In [None]:
per_vid_neuron_to_inst_contrib_dict = {}

for dataset_name in dataset_names:
    for vidname in vidnames[dataset_name]:
        single_image_dataloader = dataloader_dict[dataset_name][vidname]
        ffn_model = ffn_models_dict[dataset_name][vidname]
        categories_dict = categories_dicts[dataset_name][vidname]
        cfg = cfg_dict[dataset_name][vidname]

        categories = list(categories_dict.values())
        

        # inference_results = compute_inference_results(single_image_dataloader, ffn_model, cfg, categories_dict)

        inference_results, categories_in_frame, object_categories = compute_inference_results(
            single_image_dataloader, ffn_model, cfg, categories_dict
        )
        
        inst_id_to_cat_and_inst_suffix, instance_to_ann_id_map, instance_names, object_to_instances_map, \
            obj_to_obj_name_idx, instance_names = get_instance_info(inference_results, object_categories, categories)
        
        all_variables_for_frame = compute_all_variables_for_frame(inference_results, ffn_model)
        per_vid_neuron_to_inst_contrib_dict[vidname] = plot_kernel_instance_contribs(all_variables_for_frame, obj_to_obj_name_idx, custom_colors)

### Save dictionary

In [46]:
import pickle

save_dir = '../plotting_source_data/MLP/C-INRs_perhaps_care_about_objects'
os.makedirs(save_dir, exist_ok=True)

with open(os.path.join(save_dir, f"per_vid_neuron_to_inst_contrib_dict.pkl"), 'wb') as f:
    pickle.dump(per_vid_neuron_to_inst_contrib_dict, f)