In [None]:
%cd ../..
from stimuli_generation import utils as sg_utils
%cd tools/data_analysis

In [None]:
from matplotlib import rcParams

rcParams["font.family"] = "sans-serif"
rcParams["font.sans-serif"] = ["DejaVu Sans"]

# output text as text and not paths
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = "truetype"

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from lucent.optvis.hooks import ModelHook
import torch

from tqdm.notebook import tqdm as tqdm
from torch.cuda.amp import autocast

import pandas as pd
import scipy
import os

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
IMAGENET_VAL_PATH = "/scratch_local/datasets/ImageNet2012/val"

In [None]:
os.environ["TORCH_HOME"] = "~/.torch"
assert os.path.exists(IMAGENET_VAL_PATH)

In [None]:
def compute_sparsity(df_model_name: str, main_df: pd.DataFrame) -> pd.DataFrame:
    """Compute the sparsity of a model's activations.

    Args:
        model_name: The name of the model to compute the sparsity of.
        main_df: The main dataframe to use.

    Returns:
        A copy of the relevant subset of the supplied dataframe with the sparsity
        values added.
    """
    model_name = df_model_name.replace("_hard85", "").replace("_hard95", "")

    val_loader = sg_utils.get_dataloader(IMAGENET_VAL_PATH, model_name, batch_size=128)
    model = sg_utils.load_model(model_name)

    if hasattr(model, "visual"):
        model = model.visual

    model_df = main_df[main_df["model"] == model_name].copy().reset_index(drop=True)
    model_layers = list(model_df.layer.unique())

    pixel_sparsity_values, channel_sparsity_values = count_sparsity(
        model, model_layers, val_loader)

    for k, v in zip(
            ("pixel_sparsity", "channel_sparsity"),
            (pixel_sparsity_values, channel_sparsity_values)
    ):
        model_df[f"mean_validation_{k}"] = model_df.apply(
            lambda row: v[row["layer"]][row["channel"]], axis=1)
        model_df[f"validation_{k}"] = None

    return model_df

def count_sparsity(model, layers, val_loader):
    """Count the sparsity of the activations of a model's layers.

    Args:
        model: The model to count the sparsity of.
        layers: The layers to count the sparsity of.
        val_loader: The validation loader to use.

    Returns:
        A tuple of dictionaries, one for pixelwise sparsity and one for channel
        sparsity. Each dictionary maps layer names to a numpy array of sparsity
        values. Note that these arrays represent per-image values, i.e., they still
        contain a batch dimension.
    """
    pixel_sparsity_values = {}
    channel_sparsity_values = {}

    for k in layers:
        pixel_sparsity_values[k] = []
        channel_sparsity_values[k] = []

    hook_layer_names = [l if not l.startswith("visual_") else l[len("visual_"):] for l in layers]

    with ModelHook(model, layer_names=hook_layer_names) as hook:
        with torch.no_grad():
            for x, _, _ in tqdm(val_loader, leave=False):
                x = x.to(device)
                model(x)
                for ln, hln in zip(layers, hook_layer_names):
                    act = hook(hln)
                    if act.shape[1] < act.shape[-1]:
                        if act.ndim == 4:
                            act = torch.permute(act, (0, 3, 1, 2)).contiguous()
                        else:
                            act = torch.permute(act, (0, 2, 1)).contiguous()
                    if act.ndim == 4:
                        act = act.view(*act.shape[:-2], -1)
                    pixel_sparsity = (act <= 0).float().mean(-1).cpu().numpy()
                    channel_sparsity = torch.all(act <= 0, -1).float().cpu().numpy()
                    pixel_sparsity_values[ln].append(pixel_sparsity)
                    channel_sparsity_values[ln].append(channel_sparsity)

    for sparsity_values in (pixel_sparsity_values, channel_sparsity_values):
        for k in layers:
            sparsity_values[k] = np.mean(np.concatenate(sparsity_values[k], 0), 0).copy()

    return pixel_sparsity_values, channel_sparsity_values

In [None]:
# Dummy loading all models.
# We do this to see if every model can properly be loaded on this device.
for model_name in sg_utils.accuracies.keys():
    sg_utils.load_model(model_name)

In [None]:
main_df = pd.read_pickle("responses_main.pd.pkl")
relevant_columns = ["model", "layer", "channel", "mode"]
main_df = main_df[(~main_df["is_demo"]) & (~main_df["catch_trial"])]
main_df = main_df.drop(
    [c for c in main_df.columns if c not in relevant_columns + ["correct"]], axis=1
).groupby(
    relevant_columns, as_index=False).mean(numeric_only=True).reset_index(drop=True)
main_df["channel"] = main_df["channel"].astype(int)

In [None]:
updated_main_df = pd.concat(
    [compute_sparsity(model_name, main_df) for model_name in tqdm(main_df["model"].unique())],
    axis=0, ignore_index=True).reset_index(drop=True)

updated_main_df.to_pickle("responses_main_with_sparsity.pd.pkl")

In [None]:
updated_main_df = pd.read_pickle("responses_main_with_sparsity.pd.pkl")

In [None]:
import collections
model_name_lut = collections.OrderedDict()
model_name_lut["googlenet"] = "GoogLeNet"
model_name_lut["densenet_201"] = "DenseNet"
model_name_lut["resnet50"] = "ResNet"
model_name_lut["resnet50-l2"] = "Robust ResNet"
model_name_lut["clip-resnet50"] = "Clip ResNet"
model_name_lut["wide_resnet50"] = "WideResNet"
model_name_lut["in1k-vit_b32"] = "ViT"
model_name_lut["clip-vit_b32"] = "Clip ViT"
model_name_lut["convnext_b"] = "ConvNeXT"

In [None]:
!mkdir results

In [None]:
for k, s, t in zip(("mean_validation_channel_sparsity", "mean_validation_pixel_sparsity"),
                   ("channel_sparsity", "pixel_sparsity"),
                   ("Channelwise Sparsity", "Pixelwise Sparsity")):
    n_cols = 3
    n_rows = int(np.ceil(len(model_name_lut) / n_cols))
    fig, ax = plt.subplots(n_rows, n_cols, sharey=True)

    ax = ax.flatten()
    fig.set_size_inches(n_cols*4*0.9, n_rows*3*0.9)
    for i, m in enumerate(model_name_lut.keys()):
        relevant_df = updated_main_df[updated_main_df["model"] == m]
        relevant_df = relevant_df[relevant_df["mode"] == "natural"]
        stats = scipy.stats.spearmanr(relevant_df[k], relevant_df["correct"])
        ax[i].set_title(
            f"""{model_name_lut[m]} ($\\rho = {stats.statistic:.2f}, p = {stats.pvalue:.2f}$)""")
        ax[i].scatter(relevant_df[k], relevant_df["correct"], s=5, clip_on=False, color=f"C{i}")
        ax[i].plotted = True
        ax[i].spines["top"].set_visible(False)
        ax[i].spines["right"].set_visible(False)

        ax[i].spines["left"].set_bounds(0.3, 1)
        ax[i].set_ylim(0.275, 1)

        ax[i].spines["bottom"].set_bounds(0.0, 1)
        ax[i].set_xlim(-0.025, 1)

        ax[i].tick_params(axis='both', which='both', labelsize=11)

    ax[3].set_ylabel("Proportion Correct (Natural)", fontsize=12)

    ax[-2].set_xlabel(t, fontsize=12)
    for a in ax:
        if not hasattr(a, "plotted"):
            a.axis("off")
    plt.tight_layout(rect=[0, 0.0, 1, 0.99])
    plt.savefig(f"results/imi_{s}_natural.pdf", bbox_inches="tight")
    plt.show()