In [None]:
import os
import sys
import shapiq
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

In [None]:
PATH_INPUT = "../results/mscoco"
MODEL_NAME = "google/siglip2-base-patch32-256"

In [None]:
df_results = pd.DataFrame()
for budget in [2**k for k in list(range(15, 22))]:
    temp = pd.read_csv(os.path.join(PATH_INPUT, MODEL_NAME, str(budget), "fixlip", "0.5", "time.csv"))
    df_results = pd.concat([df_results, temp.assign(budget=budget, method="fixlip")])
    temp = pd.read_csv(os.path.join(PATH_INPUT, MODEL_NAME, str(budget), "banzhaf", "0.5", "time.csv"))
    df_results = pd.concat([df_results, temp.assign(budget=budget, method="banzhaf")])
    # temp = pd.read_csv(os.path.join(PATH_INPUT, MODEL_NAME, str(budget), "shapley", "time.csv"))
    # df_results = pd.concat([df_results, temp.assign(budget=budget, method="shapley")])

In [None]:
# Group and aggregate
grouped = df_results.groupby(['method', 'budget']).agg({
    'time_explanation': ['mean', 'std'],
    'time_game': ['mean', 'std']
})
grouped.columns = ['_'.join(col).strip() for col in grouped.columns.values]
unstacked = grouped.unstack(level=0)
ratio_time_explanation = unstacked['time_explanation_mean']['banzhaf'] / unstacked['time_explanation_mean']['fixlip']
ratio_time_game = unstacked['time_game_mean']['banzhaf'] / unstacked['time_game_mean']['fixlip']

In [None]:
ratio_time_explanation

In [None]:
ratio_time_game

In [None]:
df_plot = pd.melt(df_results, id_vars=['id', 'method', 'budget'], var_name='curve', value_name='time')
df_plot.curve = df_plot.curve.replace({'time_game': 'game', 'time_explanation': 'explanation'})

In [None]:
SMALL_SIZE = 11
MEDIUM_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=9)    # legend fontsize

In [None]:
fig, ax = plt.subplots(figsize=(4,2.4))
sns.lineplot(
    data=df_plot, 
    x='budget', y='time', 
    hue='method', palette=['#4285F4', '#EA4335'],
    hue_order=["banzhaf", "fixlip"],
    style='curve', dashes=[(4, 1), (1, 1)], 
    style_order=["explanation", "game"],
    errorbar="se",
    linewidth=2,
    ax=ax
)
ax.set_yscale('log', base=10)
ax.set_xscale('log', base=2)
handles, labels = ax.get_legend_handles_labels()
del labels[0]; del handles[0]
del labels[2]; del handles[2]
labels[0] = "Model-agnostic"
labels[1] = "Cross-modal"
labels[2] = "Explanation (total)"
labels[3] = "Game (inference)"
ax.legend(handles, labels)
sns.move_legend(ax, "upper left", bbox_to_anchor=(-0.01, 1.03), ncol=2, columnspacing=0.5)
plt.xlabel("Number of sampled masks (budget)")
plt.xlim([2**15, 2**21])
xtick_values = [2**i for i in range(15, 22)]
ax.set_xticks(xtick_values)
ax.set_xticklabels([f"$2^{{{i}}}$" for i in range(15, 22)])
xticks = ax.xaxis.get_majorticklabels()
# xticks[1].set_horizontalalignment('left')
xticks[-1].set_horizontalalignment('right')
plt.ylabel("Time in seconds (A100 GPU)")
plt.ylim([1, 10**4])
yticks = ax.yaxis.get_majorticklabels()
# yticks[1].set_verticalalignment('bottom')
yticks[-2].set_verticalalignment('top')
plt.tight_layout(pad=0.15)
ax.yaxis.set_label_coords(-0.13, 0.37)
plt.savefig(f'time_{MODEL_NAME.replace("/", "-")}.pdf')