In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import pickle
from itertools import product

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='whitegrid')
sns.set_theme(style='ticks')

In [None]:
repo_dir = Path('../..')

In [None]:
if str(repo_dir) not in sys.path:
    sys.path.append(str(repo_dir))
    
from analysis.curve_fitting.src.fitting_functions import LOSS_FUNCTIONS
from analysis.curve_fitting.src.utils import apply_filters, load_yaml, convert_loss_parameters, convert_loss_parameters_batch

from visualization.src.utils import COLOR_PALETTES, set_ticks, COLORS, save_figs
from visualization.src.visualize import plot_reg, plot_reg_bivariate, plot_confidence_intervals



In [None]:
args = {
    'results_csv': repo_dir / 'results' / 'benchmark_scores.csv',
}

In [None]:
results_csv = args['results_csv']

df_results = pd.read_csv(results_csv)

## Load Experiment Configuration

In [None]:
config = {
    'data_filters': {
        'set_filters': {
            'region': [
                'V1',
                'V2',
                'V4',
                'IT',
                'Behavioral'
                ],
            'dataset': [
                'imagenet',
                'ecoset'
                ],
            },
            
    'boolean_filters': {
        'equals_false': [
            'is_pretrained',
            'is_random',
            'is_ssl',
            'is_ablation',
            'is_adv',
            ],
        },
    
    # 'group_by': {
    #     'avg_score': {
    #         'keys': [
    #             'model_id',
    #             'arch',
    #             'dataset',
    #             'flops',
    #             'n_params',
    #             'n_samples',
    #             'n_samples_seen',
    #             'total_flops',
    #             'arch_family',
    #             'samples_per_class',
    #             'adv_method',

    #         ],
    #         'reduce': {'score': 'mean'}}},

    'combine_arch_families': True,
    }
}



## Apply Data Filters

In [None]:
df = apply_filters(df_results, config.get('data_filters', {}))


## Visualize

#### Plotting settings

In [None]:
linewidth = 3.0
alpha_scatter = 1.0
alpha_ci = 0.2
alpha_fit = 1.0
fig_multiplier = 0.7
fig_multiplier = 1
figsize = (32, 32)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])

color_palaette = COLOR_PALETTES['regions']
color = color_palaette[-1]
# color = "#023e8a"

#### Grid Plot

In [None]:
regions =  df['region'].unique()

In [None]:
data_plot = df.copy()
data_plot = data_plot.rename(columns={'dataset':'Dataset', 'arch_family':'Architecture'})

# Create a figure and a set of subplots
n_benchmarks = len(regions)
fig, axes = plt.subplots(
    nrows=n_benchmarks, 
    ncols=n_benchmarks, 
    figsize=figsize, 
    dpi=300
)
# for region_id, (region, benchmark) in enumerate(REGION2BENCHMARKS.items()):
for idx, (region1, region2) in enumerate(product(regions, regions)):
    ax = axes[idx // n_benchmarks, idx % n_benchmarks]
    
    i = idx // n_benchmarks
    j = idx % n_benchmarks
    if j > i:
        ax.remove()
        continue
    elif region1 == region2:
        data_region = data_plot[data_plot.region == region1].copy()
        # data_region.sort_values(['model_sort'], inplace=True, ascending=True)
        sns.scatterplot(data=data_region, x='acc', y='score', ax=ax, markers=True, style='Architecture', color=color)

        # Set the labels
        ax.set_title(f'{region1} x Accuracy', fontsize=20, fontweight='bold')
        ax.set_xlabel(f'Validation Accuracy', fontsize=16, fontweight='bold')
        ax.set_ylabel(f'Alignment Score for {region1}', fontsize=16, fontweight='bold')
    else:
        # Plot the data
        data_region1 = data_plot[data_plot.region == region1].copy()
        data_region2 = data_plot[data_plot.region == region2].copy()
        # data_region1.rename(columns={'score':region1}, inplace=True)
        # data_region1[region2] = .rename(columns={'score':region1}, inplace=True)
        data_regions = pd.merge(
            data_region1, 
            data_region2, 
            on=['model_id', 'arch', 'n_samples', 'seed', 'acc', 'Architecture', 'Dataset'], 
            suffixes=('_1', '_2'),
            how='inner'
        )
        data_regions.rename(columns={'score_1':region1, 'score_2':region2}, inplace=True)
        # data_regions.sort_values(['model_sort'], inplace=True, ascending=True)
        sns.scatterplot(data=data_regions, x=region2, y=region1, ax=ax, markers=True, style='Architecture', color=color)

        # Set the labels
        ax.set_title(f'{region1} x {region2}', fontsize=20, fontweight='bold')
        ax.set_xlabel(f'Alignment Score for {region2}', fontsize=16, fontweight='bold')
        ax.set_ylabel(f'Alignment Score for {region1}', fontsize=16, fontweight='bold')
        
    # for i, j in product(range(len(REGION2BENCHMARKS)), range(len(REGION2BENCHMARKS))):
    #     if j > i:
    #         axes[i, j].remove()

    # Set the ticks
    ax.grid(which='minor', alpha=0.2)
    ax.grid(which='major', alpha=0.8)
    ax.grid(True)
    
    # Legend
    ax.legend().remove()

    # Remove the top and right spines
    ax.spines[['right', 'top']].set_visible(False)
    
    
figures_dir = '../figures'
fig_name = 'fig14_cartesian'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)

#### Pairplot

In [None]:




data_plot = df.copy()
data_plot = data_plot.rename(columns={'dataset':'Dataset', 'arch_family':'Architecture'})

# Create a figure and a set of subplots
# n_benchmarks = len(regions)
# fig, axes = plt.subplots(
#     nrows=n_benchmarks, 
#     ncols=n_benchmarks, 
#     figsize=figsize, 
#     dpi=300
# )
# for region_id, (region, benchmark) in enumerate(REGION2BENCHMARKS.items()):
data_stacked = []
for  model_id, d in data_plot.groupby('model_id'):
    regions_scores = d[['region', 'score']].set_index('region').T.reset_index(drop=True)
    d = d.iloc[0].to_frame().T.copy()
    d = d.drop(columns=['region', 'score', 'benchmark_id', 'benchmark_name']).reset_index(drop=True)
    d = pd.concat([d, regions_scores], axis='columns')
    d["model_id"] = model_id
    data_stacked.append(d)
data_stacked = pd.concat(data_stacked, axis='index')
    
# grid = sns.PairGrid(data_stacked, x_vars=all_regions, y_vars=all_regions, hue='Dataset', palette=DATASET_COLORS)
grid = sns.PairGrid(data_stacked, x_vars=regions, y_vars=regions)
grid.map_upper(sns.histplot)
grid.map_lower(sns.kdeplot, fill=True)
grid.map_diag(sns.histplot, kde=True)

    
plt.tight_layout()


figures_dir = '../figures'
fig_name = 'fig15_pair'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)


In [None]:
df

In [None]:
# df.model_id.unique()