In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

palette = {"inverted": "plum", "standard": "darkturquoise"}

results_base_dir = "/root/univ-data/results"

results = {
    "imagenet100": {
        "standard":{
            "jaccard": {
                "eps0":f"{results_base_dir}/imagenet100/standard/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/imagenet100/standard/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/imagenet100/standard/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/imagenet100/standard/jac_eps1_10.csv",
                "eps3":f"{results_base_dir}/imagenet100/standard/jac_eps3_10.csv",
            },
        },
        "inverted": {
            "jaccard": {
                "eps0":f"{results_base_dir}/imagenet100/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/imagenet100/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/imagenet100/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/imagenet100/jac_eps1_10.csv",
                "eps3":f"{results_base_dir}/imagenet100/jac_eps3_10.csv",
            },
        }
    },
    "imagenet1k": {
        "inverted": {
            "jaccard": {
                "eps0":f"{results_base_dir}/imagenet1k/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/imagenet1k/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/imagenet1k/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/imagenet1k/jac_eps1_10.csv",
                "eps3":f"{results_base_dir}/imagenet1k/jac_eps3_10.csv",
            },
        },
        "standard": {
            "jaccard": {
                "eps0":f"{results_base_dir}/imagenet1k/standard/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/imagenet1k/standard/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/imagenet1k/standard/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/imagenet1k/standard/jac_eps1_10.csv",
                "eps3":f"{results_base_dir}/imagenet1k/standard/jac_eps3_10.csv",
            },
        }
    },
    "cifar10": {
        "inverted": {
            "jaccard": {
                "eps0":f"{results_base_dir}/cifar10/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/cifar10/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/cifar10/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/cifar10/jac_eps1_10.csv",
            },
        },
        "standard": {
            "jaccard": {
                "eps0":f"{results_base_dir}/cifar10/standard/jac_eps0_10.csv",
                "eps025":f"{results_base_dir}/cifar10/standard/jac_eps025_10.csv",
                "eps05":f"{results_base_dir}/cifar10/standard/jac_eps05_10.csv",
                "eps1":f"{results_base_dir}/cifar10/standard/jac_eps1_10.csv",
            },
        }
    },
}


sns.set_theme("paper", style="darkgrid", font_scale=1.5)


def eps_to_float(eps: str) -> float:
    if eps == "eps0":
        return 0.0
    elif eps == "eps025":
        return 0.25
    elif eps == "eps05":
        return 0.5
    elif eps == "eps1":
        return 1.0
    elif eps == "eps3":
        return 3.0
    else:
        raise ValueError(f"Eps level not recognized: {eps}")

dfs = []
for dataset, dataset_results in results.items():
    for input_type, input_type_results in dataset_results.items():
        for measure, path_dict in input_type_results.items():
            for eps, csv_path in path_dict.items():
                for k in [10, 100, 500]:
                    split = csv_path.split("_")
                    modified_csv_path = "_".join(split[:-1]) + "_" + f"{k}.0.csv"
                    try:
                        df = pd.read_csv(modified_csv_path, index_col=0)
                    except FileNotFoundError as e:
                        print(e)
                        continue

                    # Mask out comparisons between identical models, so we can drop them easier later
                    for index_val in df.index:
                        # df.loc[index_val, index_val] = 1.0
                        df.loc[index_val, index_val] = np.NaN

                    # Give name to index
                    df = df.rename_axis("model1")

                    # Bring data to long format (many rows) and remove self comparisons
                    df = df.reset_index().melt(id_vars="model1", var_name="model2", value_name="score")
                    df = df.dropna(axis=0)

                    df["measure"] = measure
                    df["eps"] = eps_to_float(eps)
                    df["dataset"] = dataset
                    df["input"] = input_type
                    df["k"] = k

                    if measure == "proc":
                        df.loc[:, "score"] = (2 - df.loc[:, "score"]) / 2

                    if measure == "dis":
                        df.loc[:, "score"] = 1 - df.loc[:, "score"]

                    if measure == "jsd":
                        df.loc[:, "score"] = (np.log(2) - df.loc[:, "score"]) / np.log(2)

                    dfs.append(df)

data = pd.concat(dfs, axis=0)

# Exclude tiny_vit_5m results on ImageNet100 and 50, because we mistakingly used checkpoints for IN1k
data = data.loc[~(((data.model1 == "tiny_vit_5m") | (data.model2 == "tiny_vit5m")) & (data.dataset.isin(["imagenet100", "imagenet50"])))]

# Exclude broken densenet161 on IN1k
data = data.loc[~(((data.model1 == "densenet161") | (data.model2 == "densenet161")) & (data.dataset.isin(["imagenet1k"])))]


data.head()

In [None]:
# plotdata = data.loc[]
sns.catplot(data, x="k", hue="eps", col="dataset", y="score", row="input", kind="box")

In [None]:
from matplotlib.colors import Normalize, LinearSegmentedColormap


datasets = ["cifar10", "imagenet1k", "imagenet100"]

for input_type in ["inverted", "standard"]:
    plotdata = data.loc[(data.input==input_type) & data.dataset.isin(datasets)].copy()
    plotdata["dataset"] = plotdata["dataset"].map({"imagenet1k": "ImageNet1k", "cifar10": "CIFAR 10", "imagenet100": "ImageNet100"})

    ########## Fixing Coloring ###############
    # Normalize the eps values
    norm = Normalize(vmin=data['eps'].min(), vmax=data['eps'].max())
    # Choose a colormap and truncate it (e.g., exclude the darkest 10% and lightest 10%)
    if input_type == "standard":
        cmap = sns.light_palette("darkturquoise", as_cmap=True)
        truncated_cmap = LinearSegmentedColormap.from_list(
            'truncated_viridis', cmap(np.linspace(0.0, 1.0, 256))  # Exclude the extremes
        )
    else:
        cmap = sns.light_palette("deeppink", as_cmap=True)
        truncated_cmap = LinearSegmentedColormap.from_list(
            'truncated_viridis', cmap(np.linspace(0.0, 1.0, 256))  # Exclude the extremes
        )
    # cmap = plt.get_cmap('rocket_r')  # You can use any colormap
    # truncated_cmap = LinearSegmentedColormap.from_list(
    #     'truncated_viridis', cmap(np.linspace(0.0, 0.75, 256))  # Exclude the extremes
    # )
    # Create a color palette based on the normalized eps values
    palette = [truncated_cmap(norm(eps)) for eps in sorted(data['eps'].unique())]
    hue_order = list(sorted(data['eps'].unique()))
    ###########################################

    n_panels = 3
    fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))

    fontsize=24
    # fontweight="normal"
    fontweight="bold"

    ax_idx = 0

    subplotdata = plotdata[plotdata.dataset == "ImageNet1k"]
    ax = axes[ax_idx]
    sns.boxplot(data=subplotdata, x="k", y="score", hue="eps", ax=ax, legend=False, palette=palette, hue_order=hue_order)
    ax.set_title("ImageNet1k")
    if input_type == "inverted":
        ax.set_ylabel("Mechanistic Jaccard Similarity")
    else:
        ax.set_ylabel("Regular Jaccard Similarity")

    ax.set_xlabel("Neighborhood size k")
    ax_idx += 1

    subplotdata = plotdata[plotdata.dataset == "ImageNet100"]
    ax = axes[ax_idx]
    sns.boxplot(data=subplotdata, x="k", y="score", hue="eps", ax=ax, legend=False, palette=palette, hue_order=hue_order)
    ax.set_title("ImageNet100")
    # ax.set_ylabel("Jaccard Similarity")
    ax.set_ylabel("")
    ax.set_xlabel("Neighborhood size k")
    ax_idx += 1

    subplotdata = plotdata[plotdata.dataset == "CIFAR 10"]
    ax = axes[ax_idx]
    sns.boxplot(data=subplotdata, x="k", y="score", hue="eps", ax=ax, legend=False, palette=palette, hue_order=hue_order)
    ax.set_title("CIFAR-10")
    # ax.set_ylabel("Jaccard Similarity")
    ax.set_ylabel("")
    ax.set_xlabel("Neighborhood size k")
    ax_idx += 1

    # fig.tight_layout()
    fig.savefig(f"../figs/jaccard_ks_{input_type}.pdf", bbox_inches="tight")