## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [21]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from vidore_benchmark.evaluation.eval_manager import EvalManager
from vidore_benchmark.utils.constants import OUTPUT_DIR

RESULTS_DIR = OUTPUT_DIR / "token_pooling"
RESULTS_DIR.mkdir(exist_ok=True, parents=True)

sns.set_style("whitegrid")

while "experiments" not in os.listdir():
    os.chdir("..")

## Load data

In [None]:
metrics_paths = list(Path.cwd().glob("./experiments/2024-08-06_impact_of_pool_factor_on_retrieval/outputs/**/metrics/"))
eval_managers = [EvalManager.from_dir(str(path)) for path in metrics_paths]
list_df = [eval_manager.melted for eval_manager in eval_managers]

df = pd.concat(list_df, ignore_index=True)
df

## Data preprocessing

In [23]:
# Extract pool factor and model name from file name
df["pool_factor"] = df["model"].str.extract(r"pool_factor_(\d+)").astype(int)
df["model"] = "vidore/colpali"

# Keep only metric of interest
METRIC_OF_INTEREST = "ndcg_at_5"
df = df[df["metric"] == METRIC_OF_INTEREST].copy()

# Compute relative performance
df["max_score"] = df.sort_values("pool_factor").groupby(["model", "dataset"])["score"].transform("first")
df["relative_performance"] = df["score"] / df["max_score"]
df["relative_performance_percent"] = df["relative_performance"] * 100
df = df.drop(columns=["max_score"])

# Compute relative storage
df["relative_storage"] = 1 / df["pool_factor"]
df["relative_storage_percent"] = df["relative_storage"] * 100

## Sanitize DataFrame

In [24]:
column_mapping = {
    "model": "Model",
    "dataset": "Dataset",
    "pool_factor": "Pool Factor",
    "score": "NDCG@5",
    "relative_performance": "Relative NDCG@5",
    "relative_performance_percent": "Relative NDCG@5 (%)",
    "relative_storage": "Relative Storage",
    "relative_storage_percent": "Relative Storage (%)",
}

df_sanitized = df.rename(columns=column_mapping)

## Plots

### Figure 1

In [None]:
fig, ax = plt.subplots(
    figsize=(9, 5),
)

sns.lineplot(data=df_sanitized, x="Pool Factor", y="NDCG@5", hue="Dataset")
ax.set_title(
    "Impact of pool factor on retrieval performance of ColPali on the ViDoRe benchmark",
    fontsize=14,
    fontweight="bold",
)

fig.tight_layout()
savepath = RESULTS_DIR / f"pool_factor_vs_{METRIC_OF_INTEREST}.png"
fig.savefig(str(savepath), bbox_inches="tight")

### Figure 1bis

In [None]:
df_with_mean = df_sanitized.groupby("Pool Factor")[["Relative NDCG@5 (%)"]].mean().reset_index()
df_with_mean["Model"] = "vidore/colpali"
df_with_mean["Dataset"] = "Average"

df_with_mean

In [None]:
fig, ax = plt.subplots(figsize=(9, 5))

sns.lineplot(data=df_sanitized, x="Pool Factor", y="Relative NDCG@5 (%)", hue="Dataset", linewidth=1.5, alpha=0.5)
sns.lineplot(
    data=df_with_mean, x="Pool Factor", y="Relative NDCG@5 (%)", color="red", linewidth=3.0, alpha=1, label="Average"
)
ax.legend(title="Dataset")
ax.set_title(
    "Impact of pool factor on relative retrieval performance of ColPali\non the ViDoRe benchmark",
    fontsize=14,
    fontweight="bold",
)

fig.tight_layout()
savepath = RESULTS_DIR / f"pool_factor_vs_relative_{METRIC_OF_INTEREST}.png"
fig.savefig(str(savepath), bbox_inches="tight")

### Figure 2

In [None]:
df_sanitized_agg = (
    df_sanitized.groupby(["Model", "Pool Factor"])[["Relative NDCG@5 (%)", "Relative Storage (%)"]].mean().reset_index()
)

# Set relative performance to 100 for pool factor 1
df_sanitized_agg.loc[df_sanitized_agg["Pool Factor"] == 1, "Relative NDCG@5 (%)"] = 100

df_sanitized_agg

In [None]:
fig, ax = plt.subplots(figsize=(6, 5))

sns.scatterplot(data=df_sanitized_agg, x="Relative Storage (%)", y="Relative NDCG@5 (%)", size="Pool Factor")
ax.set_title(
    "Trade-off between relative storage and retrieval performance\nfor ColPali on the ViDoRe benchmark",
    fontsize=14,
    fontweight="bold",
)

# Move the legend outside of the plot
sns.move_legend(ax, loc="center left", bbox_to_anchor=(1, 0.5))

plt.axvline(x=100, color="red", linestyle="--")
plt.axhline(y=100, color="red", linestyle="--")

fig.tight_layout()
savepath = RESULTS_DIR / f"storage_vs_{METRIC_OF_INTEREST}.png"

fig.savefig(str(savepath), bbox_inches="tight")

## Paper version

In [30]:
ds_clean_mapping = {
    "vidore/shiftproject_test": "Shift",
    "vidore/infovqa_test_subsampled": "InfoVQA",
    "vidore/syntheticDocQA_energy_test": "Energy",
    "vidore/tabfquad_test_subsampled": "TabFQuad",
    "vidore/docvqa_test_subsampled": "DocVQA",
    "vidore/arxivqa_test_subsampled": "ArxivQA",
}


df_sanitized_clean = df_sanitized.copy()
df_sanitized_clean["Dataset"] = df_sanitized_clean["Dataset"].map(ds_clean_mapping)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

sns.lineplot(
    data=df_sanitized_clean,
    x="Pool Factor",
    y="Relative NDCG@5 (%)",
    hue="Dataset",
    linewidth=1,
    alpha=0.5,
    marker="o",
    linestyle="-",
    ax=ax,
)

sns.lineplot(
    data=df_with_mean,
    x="Pool Factor",
    y="Relative NDCG@5 (%)",
    color="black",
    linewidth=2.0,
    alpha=1,
    linestyle="--",
    label="Average",
    ax=ax,
)

ax.legend(title="Dataset")

ax.set_xlabel("Pool Factor", fontsize=14)
ax.set_ylabel("Relative NDCG@5 (%)", fontsize=14)

savepath = RESULTS_DIR / "paper_version" / "token_pooling_ndcg.png"
fig.savefig(str(savepath), bbox_inches="tight")

savepath = RESULTS_DIR / "paper_version" / "token_pooling_ndcg.pdf"
fig.savefig(str(savepath), bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(figsize=(3, 3))

sns.scatterplot(
    data=df_sanitized_agg,
    x="Relative Storage (%)",
    y="Relative NDCG@5 (%)",
    size="Pool Factor",
    ax=ax,
    legend="full",
)

# Move the legend outside of the plot
sns.move_legend(ax, loc="center left", bbox_to_anchor=(0.5, 0.4))

# Add vertical and horizontal lines at 100%
ax.axvline(x=100, color="lightgray", linestyle="--", zorder=0)
ax.axhline(y=100, color="lightgray", linestyle="--", zorder=0)

ax.set_xlabel("Relative storage (%)", fontsize=14)
ax.set_ylabel("Relative nDCG@5 (%)", fontsize=14)

# Adjust layout to prevent overlap
fig.tight_layout()

savepath = RESULTS_DIR / "paper_version" / "token_pooling_relative_storage.png"
fig.savefig(str(savepath), bbox_inches="tight")

savepath = RESULTS_DIR / "paper_version" / "token_pooling_relative_storage.pdf"
fig.savefig(str(savepath), bbox_inches="tight")

In [None]:
# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3), gridspec_kw={"width_ratios": [2, 1]})

# -----------------------------
# First Plot: Line Plots on ax1
# -----------------------------

sns.lineplot(
    data=df_sanitized_clean,
    x="Pool Factor",
    y="Relative NDCG@5 (%)",
    hue="Dataset",
    linewidth=1,
    alpha=0.5,
    marker="o",
    linestyle="-",
    ax=ax1,
)

sns.lineplot(
    data=df_with_mean,
    x="Pool Factor",
    y="Relative NDCG@5 (%)",
    color="black",
    linewidth=2.0,
    alpha=1,
    linestyle="--",
    label="Average",
    ax=ax1,
)

ax1.legend(title="Dataset")

ax1.set_xlabel("Pool Factor", fontsize=14)
ax1.set_ylabel("Relative NDCG@5 (%)", fontsize=14)

# -------------------------------
# Second Plot: Scatter Plot on ax2
# -------------------------------

sns.scatterplot(
    data=df_sanitized_agg,
    x="Relative Storage (%)",
    y="Relative NDCG@5 (%)",
    size="Pool Factor",
    ax=ax2,
    legend="full",
)

# Move the legend outside of the plot
sns.move_legend(ax2, loc="center left", bbox_to_anchor=(0.6, 0.4))

# Add vertical and horizontal lines at 100%
ax2.axvline(x=100, color="lightgray", linestyle="--", zorder=0)
ax2.axhline(y=100, color="lightgray", linestyle="--", zorder=0)

ax2.set_xlabel("Relative storage (%)", fontsize=14)
ax2.set_ylabel("Relative nDCG@5 (%)", fontsize=14)

# Adjust layout to prevent overlap
fig.tight_layout()

# Create the directory if it doesn't exist
(RESULTS_DIR / "paper_version").mkdir(exist_ok=True, parents=True)

# Save the figure
savepath = RESULTS_DIR / "paper_version" / "token_pooling.png"
fig.savefig(str(savepath), bbox_inches="tight")

savepath = RESULTS_DIR / "paper_version" / "token_pooling.pdf"
fig.savefig(str(savepath), bbox_inches="tight")