In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import pickle
from itertools import product

import numpy as np
import pandas as pd
import scipy.stats as stats

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns

from visualization.src.utils import COLOR_PALETTES, save_figs, COLORS
from visualization.src.visualize import plot_reg, plot_reg_bivariate, plot_confidence_intervals

In [None]:
repo_dir = Path('../..')

if str(repo_dir) not in sys.path:
    sys.path.append(str(repo_dir))

In [None]:
results_path= repo_dir / 'results' / 'benchmark_scores.csv'
results_private_path= repo_dir / 'results' / 'benchmark_scores_brainscore.csv'


df = pd.read_csv(results_path)
df_private = pd.read_csv(results_private_path)

In [None]:
columns = [
    'model_id',
    'arch',
    'dataset',
    'n_params',
    'macs',
    'flops',
    'n_samples',
    'samples_per_class',
    'n_epochs',
    'n_samples_seen',
    'seed',
    'arch_family',
    'acc',
    'is_pretrained',
    'is_random',
    'is_ssl',
    'ssl_method',
    'is_ablation',
    'is_open_clip',
    'is_adv',
    'total_flops',
]

df_simple = df[columns].copy()
df_simple

In [None]:
model_ids = df.model_id.unique()



column_map = {
    'model_name': 'model_id',
}
if 'model_name' in df_private.columns:
    df_private = df_private.rename(columns=column_map)
df_private = df_private[df_private.model_id.isin(model_ids)]

df_private

In [None]:
df_analyze = pd.merge(df_private, df_simple, on='model_id', how='inner').drop_duplicates().reset_index(drop=True)
df_analyze

In [None]:
# df_analyze.columns.tolist()
color_1 = COLOR_PALETTES['regions'][-1]

In [None]:

variable_map = {
    'n_params': 'Number of Parameters (N)',
    'n_samples_seen': 'Number of Samples Seen (D)',
    'total_flops': 'Total FLOPs (C)',
}

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(12, 6), dpi=300)
ax = axes

data = df_analyze
score_src = 'behavior_vision'
# score_src = 'neural_vision'
# score_src = 'average_vision'


variable = 'n_params'
# variable = 'n_samples_seen'
# variable = 'total_flops'



x = np.log10(data[variable])
y = data[score_src]
reg = stats.linregress(x, y)





text = f"$R^2$: {reg.rvalue**2:.2f}\np-value ={reg.pvalue:.2e}"
sns.regplot(x=x, y=y, ax=ax, color=color_1, label=text)
ax.set_xlabel(variable_map[variable], fontsize=12, fontweight='bold')
ax.set_ylabel(score_src, fontsize=12, fontweight='bold')
ax.grid(True)

ax.legend()

In [None]:
fig_multiplier = 2
fig_multiplier = 1.5
figsize = (12, 6)
figsize = (figsize[0]*fig_multiplier, figsize[1]*fig_multiplier)
fig, axes = plt.subplots(3, 3, figsize=figsize, dpi=300)
ax = axes

data = df_analyze

# data = data[data.samples_per_class == 0]



variable_list = [
    'n_params',
    'n_samples_seen',
    'total_flops',
]

score_src_list = [
    'behavior_vision',
    'neural_vision',
    'average_vision',
]



for idx, (score_src, variable)  in enumerate(product(score_src_list, variable_list)):
    ax = axes.flatten()[idx]
    
    x = np.log10(data[variable])
    y = data[score_src]
    
    reg = stats.linregress(x, y)
    text = f"$R^2$: {reg.rvalue**2:.2f}\np-value ={reg.pvalue:.2e}"
    
    x = data[variable]
    sns.regplot(x=x, y=y, ax=ax, color=color_1, label=text, logx=True)
    
    if idx % 3 == 0:
        ax.set_ylabel(score_src, fontsize=12, fontweight='bold')
    else:
        ax.set_ylabel('')
    
    if idx >= 6:
        ax.set_xlabel(variable_map[variable], fontsize=12, fontweight='bold')
    else:
        ax.set_xlabel('')
        
    ax.set_xscale('log')


    ax.grid(True)
    ax.legend(loc='lower right')
    
plt.tight_layout()

In [None]:
columns = df_analyze.columns
columns = ['model_id'] + list(columns[columns.to_series().apply(lambda x: 'Coggan' in x)])
data = df_analyze[columns]
data

In [None]:
geirhos_beh = [
    'Geirhos2021-error_consistency',
    'Geirhos2021colour-error_consistency',
    'Geirhos2021contrast-error_consistency',
    'Geirhos2021cueconflict-error_consistency',
    'Geirhos2021edge-error_consistency',
    'Geirhos2021eidolonI-error_consistency',
    'Geirhos2021eidolonII-error_consistency',
    'Geirhos2021eidolonIII-error_consistency',
    'Geirhos2021falsecolour-error_consistency',
    'Geirhos2021highpass-error_consistency',
    'Geirhos2021lowpass-error_consistency',
    'Geirhos2021phasescrambling-error_consistency',
    'Geirhos2021powerequalisation-error_consistency',
    'Geirhos2021rotation-error_consistency',
    'Geirhos2021silhouette-error_consistency',
    'Geirhos2021sketch-error_consistency',
    'Geirhos2021stylized-error_consistency',
    'Geirhos2021uniformnoise-error_consistency',
]

In [None]:
fig_multiplier = 2
fig_multiplier = 1.5
figsize = (12, 6)
figsize = (figsize[0]*fig_multiplier, figsize[1]*fig_multiplier)
# fig, axes = plt.subplots(3, 3, figsize=figsize, dpi=300)
fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=300)
ax = axes

data = df_analyze
data['geirhos_avg'] = data[geirhos_beh].mean(axis=1)

# data = data[data.samples_per_class == 0]

x = data['Rajalingham2018-i2n']
y = data['behavior_vision']
y = data['geirhos_avg']
y = data['Ferguson2024']
y = data['Maniquet2024']
# y = data['Hebart2023-match']

# y = data['Malania2007']
# y = data['tong.Coggan2024_behavior-ConditionWiseAccuracySimilarity']
# y = data['Baker2022']
# y = data['BMD2024']


sns.regplot(x=x, y=y, ax=ax, color=color_1, label=text, logx=False)


ax.grid(True)
ax.legend(loc='lower right')
    
plt.tight_layout()

In [None]:
behavior_benchmarks = [
    'Geirhos2021-error_consistency',
    'Geirhos2021colour-error_consistency',
    'Geirhos2021contrast-error_consistency',
    'Geirhos2021cueconflict-error_consistency',
    'Geirhos2021edge-error_consistency',
    'Geirhos2021eidolonI-error_consistency',
    'Geirhos2021eidolonII-error_consistency',
    'Geirhos2021eidolonIII-error_consistency',
    'Geirhos2021falsecolour-error_consistency',
    'Geirhos2021highpass-error_consistency',
    'Geirhos2021lowpass-error_consistency',
    'Geirhos2021phasescrambling-error_consistency',
    'Geirhos2021powerequalisation-error_consistency',
    'Geirhos2021rotation-error_consistency',
    'Geirhos2021silhouette-error_consistency',
    'Geirhos2021sketch-error_consistency',
    'Geirhos2021stylized-error_consistency',
    'Geirhos2021uniformnoise-error_consistency',
    'Scialom2024',
    'Scialom2024_rgb-behavioral_accuracy',
    'Scialom2024_phosphenes-100-behavioral_accuracy',
    'Scialom2024_segments-100-behavioral_accuracy',
    'Scialom2024_phosphenes-all-behavioral_accuracy',
    'Scialom2024_segments-all-behavioral_accuracy',
    'Maniquet2024',
    'Maniquet2024-confusion_similarity',
    'Maniquet2024-tasks_consistency',
    'Ferguson2024',
    'Ferguson2024tilted_line-value_delta',
    'Ferguson2024eighth-value_delta',
    'Ferguson2024circle_line-value_delta',
    'Ferguson2024lle-value_delta',
    'Ferguson2024juncture-value_delta',
    'Ferguson2024convergence-value_delta',
    'Ferguson2024round_f-value_delta',
    'Ferguson2024gray_hard-value_delta',
    'Ferguson2024gray_easy-value_delta',
    'Ferguson2024round_v-value_delta',
    'Ferguson2024color-value_delta',
    'Ferguson2024llh-value_delta',
    'Ferguson2024half-value_delta',
    'Ferguson2024quarter-value_delta',
    'Hebart2023-match',
    'Malania2007',
    'Malania2007.long2',
    'Malania2007.equal16',
    'Malania2007.long16',
    'Malania2007.short2',
    'Malania2007.vernieracuity-threshold',
    'Malania2007.short8',
    'Malania2007.short6',
    'Malania2007.short4',
    'Malania2007.short16',
    'Malania2007.equal2',
    'tong.Coggan2024_behavior-ConditionWiseAccuracySimilarity',
    'Baker2022',
    'Baker2022frankenstein-accuracy_delta',
    'Baker2022inverted-accuracy_delta',
    'Baker2022fragmented-accuracy_delta',
    'BMD2024',
    'BMD2024.texture_1Behavioral-accuracy_distance',
    'BMD2024.dotted_2Behavioral-accuracy_distance',
    'BMD2024.texture_2Behavioral-accuracy_distance',
    'BMD2024.dotted_1Behavioral-accuracy_distance',
]

In [None]:
list(df_analyze.columns)