In [1]:
import os
import sys

import matplotlib.pyplot as plt
import torch
from torch.utils.data import Subset


from analysis_utils import *

sys.path.append("../")

from get_mappings import ComputeContributions
from model_all_analysis import HNeRV
from vps_datasets import CityscapesVPSVideoDataSet, VIPSegVideoDataSet

In [None]:
dataset_names = ["cityscapes", "vipseg"]
vidnames = {
    "cityscapes": ["0005", "0175"],
    "vipseg": ["12_n-ytHkMceew", "26_cblDl5vCZnw"],
}

args_dict = {}
dataloader_dict = {}
weights_dict = {}
models_dict = {}
categories_dicts = {}

for dataset_name in dataset_names:
    weights_dict[dataset_name] = {}
    args_dict[dataset_name] = {}
    dataloader_dict[dataset_name] = {}
    models_dict[dataset_name] = {}
    categories_dicts[dataset_name] = {}

    for vidname in vidnames[dataset_name]:
        weights_dict[dataset_name][vidname] = "path/to/checkpoint/"

        args = load_model_args()

        args.weight = os.path.join(
            weights_dict[dataset_name][vidname], f"model_best.pth"
        )
        args.crop_list = "-1" if dataset_name == "cityscapes" else "640_1280"

        model = HNeRV(args)
        model = load_model_checkpoint(model, args)
        models_dict[dataset_name][vidname] = model

        args, categories_dicts[dataset_name][vidname] = load_dataset_specific_args(
            args, dataset_name, vidname
        )

        args_dict[dataset_name][vidname] = args

In [7]:
for dataset_name in dataset_names:
    for vidname in vidnames[dataset_name]:

        args = args_dict[dataset_name][vidname]

        if dataset_name == "cityscapes":
            full_dataset = CityscapesVPSVideoDataSet(args)
        else:
            full_dataset = VIPSegVideoDataSet(args)

        sampler = (
            torch.utils.data.distributed.DistributedSampler(full_dataset)
            if args.distributed
            else None
        )

        args.final_size = full_dataset.final_size
        args.full_data_length = len(full_dataset)
        split_num_list = [int(x) for x in args.data_split.split("_")]
        train_ind_list, args.val_ind_list = data_split(
            list(range(args.full_data_length)), split_num_list, args.shuffle_data, 0
        )

        train_dataset = Subset(full_dataset, train_ind_list)
        train_sampler = (
            torch.utils.data.distributed.DistributedSampler(train_dataset)
            if args.distributed
            else None
        )

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batchSize,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True,
            worker_init_fn=worker_init_fn,
        )

        dataloader_dict[dataset_name][vidname] = train_dataloader

In [8]:
def compute_inference_results(
    dataset_name, train_dataloader, model, args, first_frame_only=True
):

    if dataset_name == "vipseg":
        # Sample few frames
        num_indices = len(train_dataloader) * args.b
        num_samples = 6
        sampled_img_indices = [
            i * (num_indices - 1) // (num_samples - 1) for i in range(num_samples)
        ]

    inference_results = {}
    with torch.no_grad():
        for batch in train_dataloader:
            img_data, norm_idx, img_idx = (
                batch["img"].to("cuda"),
                batch["norm_idx"].to("cuda"),
                batch["idx"].to("cuda"),
            )

            if dataset_name == "vipseg" and (img_idx not in sampled_img_indices):
                continue

            images = batch["img"].cuda()
            _, _, _, decoder_results, img_out = model(norm_idx)

            # Save all input-output information related to annotated images
            inference_results[img_idx.item()] = {
                "decoder_results": decoder_results,
                "img_out": img_out,
                "img_gt": images[0],
            }

            if first_frame_only:
                break
    return inference_results

# Pixels Per Neuron

In [37]:
def plot_pixels_per_neuron(inference_results, model, args):
    img_idx = 0
    img_out = inference_results[img_idx]["img_out"]
    decoder_results = inference_results[img_idx]["decoder_results"]

    # Get model contributions
    compute_contrib_obj = ComputeContributions(
        model, args, decoder_results, img_out.detach().clone()[0]
    )

    head_layer_output_contrib = compute_contrib_obj.compute_head_mappings()
    nerv_blk_3_output_contrib, _ = (
        compute_contrib_obj.compute_last_nerv_block_mappings()
    )
    # Flatten by kernels by pixels
    head_layer_output_contrib_abs = (
        torch.abs(head_layer_output_contrib).flatten(0, 1).flatten(1, 2)
    )
    nerv_blk_3_output_contrib_abs = (
        torch.abs(nerv_blk_3_output_contrib).flatten(0, 1).flatten(1, 2)
    )
    # Sum contribs across kernels for each pixel
    head_contribs_fraction = (
        head_layer_output_contrib_abs / head_layer_output_contrib_abs.sum(dim=0)
    )
    blk_3_contribs_fraction = (
        nerv_blk_3_output_contrib_abs / nerv_blk_3_output_contrib_abs.sum(dim=0)
    )

    fig, axs = plt.subplots(1, 1, figsize=(20, 10))

    num_pixels_with_meaningful_contrib_dict = {}

    head_num_pixels_with_meaningful_contrib = (
        head_contribs_fraction > (1 / head_layer_output_contrib_abs.size(0))
    ).sum(dim=1)
    blk_3_num_pixels_with_meaningful_contrib = (
        blk_3_contribs_fraction > (1 / nerv_blk_3_output_contrib_abs.size(0))
    ).sum(dim=1)
    head_num_pixels_with_meaningful_contrib, _ = torch.sort(
        head_num_pixels_with_meaningful_contrib
    )
    blk_3_num_pixels_with_meaningful_contrib, _ = torch.sort(
        blk_3_num_pixels_with_meaningful_contrib
    )

    axs.plot(
        head_num_pixels_with_meaningful_contrib,
        color="b",
        label=f"Threshold - 1/num_kernels",
    )
    axs.plot(
        blk_3_num_pixels_with_meaningful_contrib,
        color="g",
        label=f"Threshold - 1/num_kernels",
    )
    axs.set_xlabel("Neuron index")
    axs.set_ylabel("Number of pixels")
    axs.set_title(f"Head layer")

    num_pixels_with_meaningful_contrib_dict = {
        "head": head_num_pixels_with_meaningful_contrib,
        "blk_3": blk_3_num_pixels_with_meaningful_contrib,
    }

    legend, handles = axs.get_legend_handles_labels()
    fig.legend(legend, handles)
    fig.suptitle(
        f"Frame {img_idx} - Number of pixels contributed to by neuron (above Threshold)"
    )
    plt.show()

    return num_pixels_with_meaningful_contrib_dict

# Neurons Per Pixel

In [11]:
def compute_contrib_thresh_using_auc(abs_contrib_map, target_area=0.05):
    total_sum = abs_contrib_map.sum()
    cutoff_contrib_sum = total_sum * (1 - target_area)

    sorted_contributions = abs_contrib_map.flatten()
    sorted_indices = torch.argsort(sorted_contributions, descending=True)
    cum_sum = torch.cumsum(sorted_contributions[sorted_indices], dim=0)

    idx = torch.nonzero(cum_sum >= cutoff_contrib_sum, as_tuple=False)[0, 0].item()
    chosen_thresh = sorted_contributions[sorted_indices][idx]

    return chosen_thresh

In [12]:
def plot_neurons_per_pixel_heatmap(vidname, inference_results, model, args):
    img_idx = 0
    gt_img = inference_results[img_idx]["img_gt"]
    img_out = inference_results[img_idx]["img_out"]
    decoder_results = inference_results[img_idx]["decoder_results"]

    # Get model contributions
    compute_contrib_obj = ComputeContributions(
        model, args, decoder_results, img_out.detach().clone()[0]
    )

    # The following can be extended to inner blocks of NeRV as well
    head_layer_output_contrib = compute_contrib_obj.compute_head_mappings()
    nerv_blk_3_output_contrib, _ = (
        compute_contrib_obj.compute_last_nerv_block_mappings()
    )
    head_layer_output_contrib_abs = torch.abs(head_layer_output_contrib).flatten(0, 1)
    nerv_blk_3_output_contrib_abs = torch.abs(nerv_blk_3_output_contrib).flatten(0, 1)

    target_areas = [0.1, 0.5]

    num_kernels_with_meaningful_contrib = {}

    for target_area in target_areas:
        fig, axs = plt.subplots(1, 3, figsize=(20, 10))

        # Compute head and block 3 kernel threshold contributions
        head_thresh = compute_contrib_thresh_using_auc(
            abs_contrib_map=head_layer_output_contrib_abs, target_area=target_area
        )
        blk_3_thresh = compute_contrib_thresh_using_auc(
            abs_contrib_map=nerv_blk_3_output_contrib_abs, target_area=target_area
        )

        # Find total number of neurons over threshold and divide by total number
        head_num_kernels_with_meaningful_contrib = (
            head_layer_output_contrib_abs > head_thresh
        ).sum(dim=0) / head_layer_output_contrib_abs.size(0)
        blk_3_num_kernels_with_meaningful_contrib = (
            nerv_blk_3_output_contrib_abs > blk_3_thresh
        ).sum(dim=0) / nerv_blk_3_output_contrib_abs.size(0)

        axs[0].imshow(torch.clamp(gt_img, 0, 1).permute(1, 2, 0).cpu().numpy())
        axs[0].set_title("Ground Truth Image")
        axs[1].set_title(f"Head Layer Heatmap - thresh={head_thresh:.4f}")
        axs[2].set_title(f"Block 3 Heatmap - thresh={blk_3_thresh:.4f}")
        for ax in axs:
            ax.axis("off")

        num_kernels_with_meaningful_contrib[target_area] = {
            "head": head_num_kernels_with_meaningful_contrib,
            "blk_3": blk_3_num_kernels_with_meaningful_contrib,
        }

        fig.suptitle("Neurons per Pixel Heatmap", y=0.7)
        plt.show()

    return num_kernels_with_meaningful_contrib

Analyze each video and save the raw values for downstream visualization

In [None]:
per_vid_num_pixels_with_meaningful_contrib = {}
per_vid_num_kernels_with_meaningful_contrib = {}

for dataset_name in dataset_names:
    for vidname in vidnames[dataset_name]:
        train_dataloader = dataloader_dict[dataset_name][vidname]
        model = models_dict[dataset_name][vidname]
        categories_dict = categories_dicts[dataset_name][vidname]
        args = args_dict[dataset_name][vidname]

        inference_results = compute_inference_results(
            dataset_name, vidname, train_dataloader, model, args
        )
        per_vid_num_pixels_with_meaningful_contrib[vidname] = plot_pixels_per_neuron(
            vidname, inference_results, model, args
        )
        per_vid_num_kernels_with_meaningful_contrib[vidname] = (
            plot_neurons_per_pixel_heatmap(vidname, inference_results, model, args)
        )

In [42]:
import pickle

save_dir = "../analysis_data/NeRV/representation_is_distributed"
os.makedirs(save_dir, exist_ok=True)

with open(
    os.path.join(save_dir, f"per_vid_num_pixels_with_meaningful_contrib.pkl"), "wb"
) as f:
    pickle.dump(per_vid_num_pixels_with_meaningful_contrib, f)
with open(
    os.path.join(save_dir, f"per_vid_num_kernels_with_meaningful_contrib.pkl"), "wb"
) as f:
    pickle.dump(per_vid_num_kernels_with_meaningful_contrib, f)