### Ablation of finetuning vs. probing


In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys

import matplotlib.pyplot as plt
import seaborn as sns

sys.path.append('..')
sys.path.append('../..')

from constants import BASE_PATH_PROJECT, FOLDER_SUBSTRING
from helper import init_plotting_params, save_or_show

In [2]:
init_plotting_params()

{
  "agg.path.chunksize": 0,
  "axes.labelsize": 13.0,
  "axes.titlesize": 14.0,
  "axes3d.trackballsize": 0.667,
  "boxplot.flierprops.markersize": 6.0,
  "boxplot.meanprops.markersize": 6.0,
  "errorbar.capsize": 0.0,
  "figure.figsize": [
    6.4,
    4.8
  ],
  "figure.labelsize": "large",
  "figure.titlesize": "large",
  "font.cursive": [
    "Apple Chancery",
    "Textile",
    "Zapf Chancery",
    "Sand",
    "Script MT",
    "Felipa",
    "Comic Neue",
    "Comic Sans MS",
    "cursive"
  ],
  "font.family": [
    "sans-serif"
  ],
  "font.fantasy": [
    "Chicago",
    "Charcoal",
    "Impact",
    "Western",
    "xkcd script",
    "fantasy"
  ],
  "font.monospace": [
    "DejaVu Sans Mono",
    "Bitstream Vera Sans Mono",
    "Computer Modern Typewriter",
    "Andale Mono",
    "Nimbus Mono L",
    "Courier New",
    "Courier",
    "Fixed",
    "Terminal",
    "monospace"
  ],
  "font.sans-serif": [
    "DejaVu Sans",
    "Bitstream Vera Sans",
    "Computer Modern Sans Serif

In [None]:
SAVE = "both"

base_storing_path = (
    BASE_PATH_PROJECT / f"results_{FOLDER_SUBSTRING}_rebuttal/plots/finetuning"
)
if SAVE:
    base_storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_runs = pd.read_pickle(
    BASE_PATH_PROJECT
    / f"results_{FOLDER_SUBSTRING}_end2end_finetuning/aggregated/complete_set_of_run.pkl"
)
print(all_runs.shape)

(2044, 87)


Filter experiment runs


In [None]:
all_runs = (
    all_runs[~all_runs["base_model"].str.startswith("mae")]
    .copy()
    .reset_index(drop=True)
)

In [None]:
allowed_datasets = ["wds/fer2013", "wds/gtsrb", "wds/vtab/cifar100", "wds/vtab/eurosat"]
allowed_experiments = [
    "CLS last layer",
    "All tokens last layer (attentive)",
    "CLS last layer (finetuning)",
    "CLS+AP layers from all blocks (attentive)",
]

In [None]:
all_runs = (
    all_runs[
        all_runs["dataset"].isin(allowed_datasets)
        & all_runs["Experiment"].isin(allowed_experiments)
    ]
    .copy()
    .reset_index(drop=True)
)
all_runs = (
    all_runs.drop(
        index=all_runs[
            (all_runs["nr_layers"] == 1) & all_runs["contains_intermediate"]
        ].index
    )
    .copy()
    .reset_index(drop=True)
)
all_runs = (
    all_runs[all_runs["probe_type"].isin(["cae", "linear"])]
    .copy()
    .reset_index(drop=True)
)
all_runs = all_runs[all_runs["model_size"] == "base"].copy().reset_index(drop=True)

In [None]:
all_runs["Experiment"].value_counts().sort_index()

Experiment
All tokens last layer (attentive)            12
CLS last layer                               12
CLS last layer (finetuning)                  12
CLS+AP layers from all blocks (attentive)    12
Name: count, dtype: int64

In [None]:
all_runs.loc[all_runs["Experiment"] == "CLS last layer (finetuning)", "Experiment"] = (
    "Finetuning (CLS last layer)"
)

In [None]:
all_runs["Experiment"] = all_runs["Experiment"].map(
    {
        "AP last layer": "Last layer (AP, linear)",
        "CLS last layer": "Last layer (CLS, linear)",
        "Finetuning (CLS last layer)": "Finetuning",
        "All tokens last layer (attentive)": "Last layer (all tokens, attentive)",
        "CLS+AP last layer (attentive)": "Last layer (CLS+AP, attentive)",
        "CLS+AP layers from all blocks (attentive)": "All layers (CLS+AP, attentive)",
    }
)

In [None]:
hue_order = [
    "Last layer (CLS, linear)",
    "Last layer (all tokens, attentive)",
    "Finetuning",
    "All layers (CLS+AP, attentive)",
]

In [None]:
tab20c = plt.cm.tab20c.colors
palette_list = list(tab20c[:8])
reversed_palette = []
for group_start in [0, 4]:
    group = palette_list[group_start : group_start + 4]
    reversed_group = group[::-1]  # reverse the group
    reversed_palette.extend(reversed_group)

reversed_palette = [tab20c[17]] + reversed_palette

colors = [
    palette_list[1],
    tab20c[17],
    tab20c[18],
    palette_list[5],
]

g = sns.catplot(
    all_runs,
    y="test_lp_bal_acc1",
    x="base_model_fmt",
    col="dataset_fmt",
    hue="Experiment",
    hue_order=hue_order,
    kind="bar",
    sharey=False,
    palette=colors,
    col_wrap=4,
    # col_order=['FER2013', 'GTSRB', 'CIFAR-100', 'EuroSAT'],
    col_order=["GTSRB", "CIFAR-100", "EuroSAT"],
    gap=0,
)
g.set_titles("{col_name}", size=12)
g.set_xlabels("Model", size=11)
g.set_ylabels("Balanced Accuracy", size=11)

for ax in g.axes.flatten():
    for p in ax.patches:
        height = p.get_height()
        x = p.get_x() + p.get_width() / 2
        y = height
        if y == 0:
            continue

        ax.annotate(
            f"{height:.2f}",  # number formatting
            (x, y),
            ha="center",
            va="bottom",
            xytext=(0, 3),
            textcoords="offset points",
            fontsize=9.5,
        )

# sns.move_legend(g, loc="upper center", bbox_to_anchor=(0.475, 0), title='', ncols=4)
sns.move_legend(g, loc="upper center", bbox_to_anchor=(0.4, 0), title="", ncols=4)
g.fig.tight_layout()
save_or_show(
    g.fig, base_storing_path / "performance_overview_wo_fer.pdf", SAVE, show_path=False
)

stored img at.


In [None]:
all_runs["training_time"] = all_runs["training_time"].astype(float) / 60

#### Training time plot


In [None]:
order = all_runs.groupby("Experiment")["training_time"].describe().sort_values("mean")

In [None]:
# Set figure size
plt.figure(figsize=(10, 6))

# Create barplot without redundant hue
g = sns.barplot(
    data=all_runs,
    x="Experiment",
    y="training_time",
    order=order.index.tolist(),
    palette="viridis",  # or 'Set2', 'husl', etc.
)

for container in g.containers:
    g.bar_label(container, fmt="%.1f", padding=15, fontsize=10)

g.set_xlabel("")
g.set_ylabel("Training Time (min)")
g.set_title("Training Time by Experiment")

g.set_xticklabels(
    ["\n(".join(val.get_text().split(" (")) for val in g.get_xticklabels()]
)

g.set_yscale("log")
plt.grid(axis="y", alpha=0.3, linestyle="--")
plt.tight_layout()
save_or_show(plt.gcf(), base_storing_path / "training_times.pdf", SAVE, show_path=False)


Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  g = sns.barplot(
  g.set_xticklabels(["\n(".join(val.get_text().split(" (")) for val in g.get_xticklabels()]);


stored img at.
