## Analysis of relationship between local contrast and IMI score

In [None]:
%cd ../..
from stimuli_generation import utils as sg_utils
%cd tools/data_analysis

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from lucent.optvis.hooks import ModelHook
import torch

from tqdm.notebook import tqdm as tqdm
from torch.cuda.amp import autocast
import torch.nn.functional as F

import pandas as pd
import scipy
from scipy import stats
import os

In [None]:
from matplotlib import rcParams

rcParams["font.family"] = "sans-serif"
rcParams["font.sans-serif"] = ["DejaVu Sans"]

# output text as text and not paths
rcParams["svg.fonttype"] = "none"
rcParams["pdf.fonttype"] = "truetype"

colors = {
    "synthetic": [71 / 255, 120 / 255, 158 / 255], 
    "natural": [255 / 255, 172 / 255, 116 / 255],

    "natural easy": [255 / 255, 172 / 255, 116 / 255],
    "natural medium": [197 / 255, 135 / 255, 96 / 255],
    "natural hard": [150 / 255, 105 / 255, 75 / 255],
    "natural very hard": [109 / 255, 75 / 255, 52 / 255],

    "synthetic easy": [71 / 255, 120 / 255, 158 / 255], 
    "synthetic medium": [52 / 255, 81 / 255, 105 / 255],
    "synthetic very hard": [39 / 255, 61 / 255, 79 / 255], 

    "natural low": [255 / 255, 172 / 255, 116 / 255],
    "natural high": [146 / 255, 100 / 255, 71 / 255],

    "synthetic low": [71 / 255, 120 / 255, 158 / 255], 
    "synthetic high": [39 / 255, 61 / 255, 79 / 255], 

    "c0": [245 / 255, 181 / 255, 121 / 255],
    "c1": [244 / 255, 170 / 255, 113 / 255],
    "c2": [203 / 255, 147 / 255, 94 / 255],
    "c3": [196 / 255, 134 / 255, 91 / 255],
    "c4": [150 / 255, 108 / 255, 68 / 255],
    "c5": [142 / 255, 98 / 255, 67 / 255],
    "c6": [116 / 255, 72 / 255, 44 / 255],
    "c7": [103 / 255, 76 / 255, 54 / 255],
    "c8":  [88 / 255, 73 / 255, 59 / 255],
}

In [None]:
IMAGENET_VAL_PATH = os.path.join(sg_utils.IMAGENET_PATH, "val")
assert os.path.exists(IMAGENET_VAL_PATH)
output_dir = "/mnt/qb/work/bethge/tklein16/int_comp_out/contrast/"
os.makedirs(output_dir, exist_ok=True)

In [None]:
@torch.no_grad()
def compute_batch_contrast_torch(acts):
    """
    Computes the average local contrast for a batch of activations on the GPU.
    See last cell for an explanation by example.

    :param acts: tensor containing activations for a batch and layer, shape (bs x n_units x h x w)
    :returns: tensor of average local contrasts of shape (bs x n_units)
    """

    assert acts.dim() == 4, "Expected a 4d tensor!"

    # pad each image with a border of 0s
    p2d = (1, 1, 1, 1)  # pad last dim by (1, 1) and 2nd to last by (1,1)
    acts = F.pad(acts, p2d, "constant", 0)

    # make tensor to hold results
    res = torch.zeros_like(acts)

    # roll image left, right, up, down
    for ax in [-1, -2]:
        for shift in [1, -1]:
            rolled = torch.roll(acts, shifts=shift, dims=ax)
            res += torch.abs(rolled - acts)

    # cut off the padding
    res = res[:, :, 1:-1, 1:-1]

    # calculate the averages
    res = torch.mean(res, dim=(-2, -1))

    return res


def compute_local_contrast(df_model_name: str, main_df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute the local contrast of a model's activations.

    Args:
        model_name: The name of the model to compute the sparsity of.
        main_df: The main dataframe to use.

    Returns:
        A copy of the relevant subset of the supplied dataframe with the contrast
        values added.
    """
    print(f"Calculating local contrast for model {df_model_name}")
    model_name = df_model_name.replace("_hard85", "").replace("_hard95", "")

    val_loader = sg_utils.get_dataloader(IMAGENET_VAL_PATH, model_name, batch_size=64)
    model = sg_utils.load_model(model_name)

    if hasattr(model, "visual"):
        model = model.visual

    model_df = main_df[main_df["model"] == model_name].copy().reset_index(drop=True)
    model_layers = list(model_df.layer.unique())

    channel_local_contrast = calc_contrast(
        model_name, model, model_layers, val_loader
    )

    model_df["mean_validation_contrast"] = model_df.apply(
        lambda row: channel_local_contrast[row["layer"]][row["channel"]], axis=1)

    return model_df


def calc_contrast(model_name, model, layers, val_loader):
    """
    Calculate the local contrast of the activations of a model's layers.

    Args:
        model_name: The name of the model
        model: The model for which contrast is calculated.
        layers: The layers for which contrast is calculated.
        val_loader: The validation loader to use.

    Returns:
        tuple containing
        - A dictionary that maps layer names to a numpy array of contrast
          values. Note that these arrays represent per-image values, i.e., they still
          contain a batch dimension.
        - A list of layer names containing only the conv-layers
    """
    channel_contrast_values = {
        layer: [] for layer in layers
    }

    hook_layer_names = [l if not l.startswith("visual_") else l[len("visual_"):] for l in layers]

    with ModelHook(model, layer_names=hook_layer_names) as hook:
        with torch.no_grad():
            for x, _, _ in tqdm(val_loader):
                x = x.to(sg_utils.DEVICE)
                model(x)
                for ln, hln in zip(layers, hook_layer_names):
                    act = hook(hln)

                    # some layers need to be permuted from (b,h,w,c) to the canonical (b,c,h,w)
                    if sg_utils.test_permute(model_name, ln):
                        act = torch.permute(act, (0, 3, 1, 2)).contiguous()

                    channel_contrast = compute_batch_contrast_torch(act)
                    channel_contrast_values[ln].append(channel_contrast.numpy(force=True))

    for layer in layers:
        channel_contrast_values[layer] = np.mean(np.concatenate(channel_contrast_values[layer], 0), 0).copy()

    return channel_contrast_values

In [None]:
main_df = pd.read_csv("/mnt/qb/bethge/shared/interpretability_comparison/IMI/responses_main.csv") #pd.read_pickle("responses_main.pd.pkl")
relevant_columns = ["model", "layer", "channel", "mode"]
main_df = main_df[(~main_df["is_demo"]) & (~main_df["catch_trial"])]
main_df = main_df.drop(
    [c for c in main_df.columns if c not in relevant_columns + ["correct"]], axis=1
).groupby(
    relevant_columns, as_index=False).mean(numeric_only=True).reset_index(drop=True)
main_df["channel"] = main_df["channel"].astype(int)

In [None]:
# investigate all models that are not ViTs, and ignore the 'hard' models, as those are redundant
relevant_models = [m for m in main_df["model"].unique() if ('vit' not in m and 'hard' not in m)]
print("Will analyze models", relevant_models)

In [None]:
# calculate contrast for all models, write each one to its own pickle file
for model_name in relevant_models:
    model_contrast_df = compute_local_contrast(model_name, main_df)
    model_contrast_df.to_pickle(
        os.path.join(
            output_dir,
            f"local_contrast_{model_name}.pd.pkl"
        )
    )

In [None]:
# read in the pickle-files, concat them to final df
updated_main_df = pd.concat(
    [pd.read_pickle(
        os.path.join(output_dir, f"local_contrast_{model_name}.pd.pkl")
    ) for model_name in relevant_models],
    axis=0, ignore_index=True
).reset_index(drop=True)

updated_main_df.to_pickle(
    os.path.join(
        output_dir,
        "responses_main_with_local_contrast.pd.pkl"
    )
)

In [None]:
# read in the main-df from pickle
updated_main_df = pd.read_pickle(
    os.path.join(
        output_dir,
        "responses_main_with_local_contrast.pd.pkl"
    )
)
print("Read main-df with models:", updated_main_df.model.unique())

In [None]:
import collections
model_name_lut = collections.OrderedDict()
model_name_lut["googlenet"] = "GoogLeNet"
model_name_lut["densenet_201"] = "DenseNet"
model_name_lut["resnet50"] = "ResNet"
model_name_lut["resnet50-l2"] = "Robust ResNet"
model_name_lut["clip-resnet50"] = "Clip ResNet"
model_name_lut["wide_resnet50"] = "WideResNet"
model_name_lut["convnext_b"] = "ConvNeXT"

In [None]:
!mkdir results

In [None]:

k = "mean_validation_contrast"
s = "local_contrast"

n_cols = 3
n_rows = int(np.ceil(len(model_name_lut) / n_cols))
fig, ax = plt.subplots(n_rows, n_cols, sharey=True)
ax = ax.flatten()
fig.set_size_inches(n_cols*4*0.9, n_rows*3*0.9)
keys = list(model_name_lut.keys())
keys.insert(-1, 'none') # to skip lower left corner and have last plot centered
for i, m in enumerate(keys):
    if m == "none":
        i += 1
        continue
    relevant_df = updated_main_df[updated_main_df["model"] == m]
    relevant_df = relevant_df[relevant_df["mode"] == "natural"]
    stats = scipy.stats.spearmanr(relevant_df[k], relevant_df["correct"])
    ax[i].set_title(
        f"""{model_name_lut[m]} ($\\rho = {stats.correlation:.2f}, p = {stats.pvalue:.2f}$)""")
    ax[i].scatter(relevant_df[k], relevant_df["correct"], s=5, clip_on=False, color=f"C{i}" if i != 7 else "C8")

    ax[i].plotted = True
    ax[i].spines["top"].set_visible(False)
    ax[i].spines["right"].set_visible(False)

    ax[i].spines["left"].set_bounds(0.3, 1)
    ax[i].set_ylim(0.275, 1)

    x_max = np.max(relevant_df[k])
    ax[i].spines["bottom"].set_bounds(0.0, x_max)
    ax[i].set_xlim(-0.025, x_max)

    ax[i].tick_params(axis='both', which='both', labelsize=11)

ax[3].set_ylabel("Proportion Correct (Natural)", fontsize=12)
ax[7].set_xlabel("Local Contrast in Activation Maps", fontsize=12)
for a in ax:
    if not hasattr(a, "plotted"):
        a.axis("off")

# Adding yticklabels for last plot manually
plt.tight_layout(rect=[0, 0.0, 1, 0.99])
ax[-2].set_autoscale_on(False)
ticks = [0.4, 0.6, 0.8, 1.0]
for tick in ticks:
    ax[-2].text(
        x=-2.2,
        y=tick - 0.02, # small offset for some reason
        s=str(tick),
        fontsize=11
    )
plt.savefig(f"results/imi_{s}.pdf", bbox_inches="tight")
plt.show()
