In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import pickle
from itertools import product

import numpy as np
import pandas as pd
import scipy.stats as stats

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns

sns.set_theme(style='whitegrid')
sns.set_theme(style='ticks')


In [None]:
repo_dir = Path('../..')

In [None]:
if str(repo_dir) not in sys.path:
    sys.path.append(str(repo_dir))
    
from analysis.curve_fitting.src.fitting_functions import LOSS_FUNCTIONS
from analysis.curve_fitting.src.utils import apply_filters, load_yaml, convert_loss_parameters, convert_loss_parameters_batch

from visualization.src.utils import COLOR_PALETTES, set_ticks, save_figs, COLORS
from visualization.src.visualize import plot_reg, plot_reg_bivariate, plot_confidence_intervals



In [None]:
args = {
    'results_csv': repo_dir / 'results' / 'benchmark_scores.csv',
}

In [None]:
results_csv = args['results_csv']

df_results = pd.read_csv(results_csv)

## Load Experiment Configuration

In [None]:
config2 = {
    'data_filters': {
        'set_filters': {
            'region': [
                'V1',
                'V2',
                'V4',
                'IT',
                'Behavioral'
                ],
            'dataset': [
                'imagenet',
                ],
            'ssl_method': [
                'simclr'
                ]
            },
            
    'boolean_filters': {
        'equals_false': [
            'is_pretrained',
            'is_random',
            'is_adv',
            'is_ablation'
            ],
        'equals_true': [
            'is_ssl',
            ]
        },
    
    # 'group_by': {
    #     'avg_score': {
    #         'keys': [
    #             'model_id',
    #             'arch',
    #             'dataset',
    #             'flops',
    #             'n_params',
    #             'n_samples',
    #             'n_samples_seen',
    #             'total_flops',
    #             'arch_family',
    #             'samples_per_class',

    #         ],
    #         'reduce': {'score': 'mean'}}},

    'combine_arch_families': True,
    }
}



In [None]:
analysis_dir = repo_dir / 'analysis'
config_dir = analysis_dir / 'curve_fitting/configs/ssl/simclr'
results_dir = analysis_dir / 'curve_fitting/outputs/fitting_results'
# results_dir = analysis_dir / 'curve_fitting/outputs/fitting_results_test'

In [None]:
region = [
    'v1',
    'v2',
    'v4',
    'it',
    'behavior',
    'avg',
    'neuro'
]

In [None]:
all_configs = {}

for r in region:
    yaml_config = config_dir / f'simclr_{r}.yaml'
    all_configs[r] = load_yaml(yaml_config)

In [None]:
L_fit_dict = {key: config['fitting_parameters']['loss_function'] for key, config in all_configs.items()}
L_viz_dict = {key: config['visualization']['loss_function'] for key, config in all_configs.items()}
x_scale_dict = {key: float(config['fitting_parameters']['X_scaler']) for key, config in all_configs.items()}

## Apply Data Filters

In [None]:
all_df = {
    name: apply_filters(df_results, config.get('data_filters', {}))
    for name, config in all_configs.items()
}

In [None]:
df_behavior = all_df['behavior']
df_neuro = all_df['neuro']

## Load Fitting Results

In [None]:
optimized_params_dict = {}
opt_params_boot_dict = {}

for exp_name in all_configs.keys():
    with open(results_dir / f'simclr_{exp_name}' / 'results.pkl', 'rb') as f:
        results = pickle.load(f)

    L_fit = L_fit_dict[exp_name]
    L_viz = L_viz_dict[exp_name]
    optimized_params_dict[exp_name] = convert_loss_parameters(results['optimized_parameters'], L_fit, L_viz)

    # Convert bootstrapped parameters
    opt_params_boot = results['optimized_parameters_bootstrapped']
    opt_params_boot_dict[exp_name] = convert_loss_parameters_batch(
        params=opt_params_boot,
        src_loss=L_fit,
        dst_loss=L_viz
    )

In [None]:
optimized_params_neuro = optimized_params_dict['neuro']
opt_params_boot_neuro = opt_params_boot_dict['neuro']

optimized_params_behavior = optimized_params_dict['behavior']
opt_params_boot_behavior = opt_params_boot_dict['behavior']


L = LOSS_FUNCTIONS[L_viz_dict['neuro']]
x_scaler = x_scale_dict['neuro']
X = df_neuro.total_flops.values / x_scaler

## Visualize

#### Plotting settings

In [None]:
x_extend = 1.1
x_extend = 10
X_str = r'$$\tilde{C}$$'
linewidth = 3.0
alpha_scatter = 1.0
alpha_ci = 0.2
alpha_fit = 1.0
fig_multiplier = 0.7
figsize = (12, 6)
figsize = (10, 6)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])


color = COLORS['cyan_dark']
color_palette = COLOR_PALETTES['samples']

color_palaette = COLOR_PALETTES['regions']
color_1, color_2 = color_palaette[0], color_palaette[-1]

In [None]:
fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=300)
ax = axes

### Neuro
df_region = df_neuro
color = color_1
sns.scatterplot(data=df_region, x='total_flops', y='score', ax=ax, color=color, alpha=alpha_scatter)
plot_reg(X, optimized_params_neuro, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=False, alpha=alpha_fit)
plot_confidence_intervals(X, opt_params_boot_neuro, L, ax, color=color, x_scaler=x_scaler, x_extend=x_extend, percentile=95.0, invert_y=True, alpha=alpha_ci)

### Behavioral
df_region = df_behavior
color = color_2
sns.scatterplot(data=df_region, x='total_flops', y='score', ax=ax, color=color, alpha=alpha_scatter)
plot_reg(X, optimized_params_behavior, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, invert_y=True, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=False, alpha=alpha_fit)
plot_confidence_intervals(X, opt_params_boot_behavior, L, ax, color=color, x_scaler=x_scaler, x_extend=x_extend, percentile=95.0, invert_y=True, alpha=alpha_ci)


### Formatting
ax.set_xscale('log')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Alignment')
ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
ax.set_title('SimCLR Training', fontsize=20, fontweight='bold')
ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.1, 0.2, 0.3, 0.4, 0.5])
# ax.set_ylim(0, 0.6)
# ax.grid(False)

### Legend
handles, labels = ax.get_legend_handles_labels()
labels = [
    'Neural  Alignment\n' + labels[0],
    'Behavioral  Alignment\n'  + labels[1]
]
l1 = ax.legend([handles[0]], [labels[0]], fontsize=12, loc='upper left')
# l2 = ax.legend([handles[1]], [labels[1]], fontsize=12, loc='lower right')
l2 = ax.legend([handles[1]], [labels[1]], fontsize=12, bbox_to_anchor=(0.6, 0.0), loc='lower right')
ax.add_artist(l1)
# ax.legend(handles, labels)

ax.spines[['right', 'top']].set_visible(False)
plt.tight_layout()


figures_dir = '../figures'
fig_name = 'fig7_simclr'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)

#### Plot

In [None]:


fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=300)

ax = axes

### 
# df_plot = df.groupby(['model_id', 'total_flops', 'arch', 'n_samples', 'n_samples_seen', 'n_params', 'arch_family']).agg({'score':'mean'}).reset_index()
df_plot = all_df['avg']
L = LOSS_FUNCTIONS[L_viz_dict['avg']]
x_scaler = x_scale_dict['avg']
X = df_plot.total_flops.values / x_scaler
optimized_params = optimized_params_dict['avg']
opt_params_boot = opt_params_boot_dict['avg']


sns.scatterplot(data=df_plot, x='total_flops', y='score', style='arch_family', size='n_params', hue='n_samples_seen', ax=ax, hue_norm=LogNorm(), size_norm=LogNorm(), palette=color_palette)

plot_reg(X, optimized_params, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=True, alpha=alpha_fit)
plot_confidence_intervals(X, opt_params_boot, L, ax, color=color, x_extend=x_extend, x_scaler=x_scaler, alpha=alpha_ci, percentile=95.0, invert_y=True)


### Colorbar
sm = plt.cm.ScalarMappable(cmap= color_palette, norm=LogNorm())
sm.set_clim(df_plot['n_samples_seen'].min(), df_plot['n_samples_seen'].max())
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label('Number of Samples Seen')
cbar.set_label('Number of Samples Seen', rotation=270, labelpad=15)


### Formatting
ax.set_xscale('log')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Alignment')
ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
ax.set_title('SimCLR Training', fontsize=20, fontweight='bold')
ax.grid(False)
ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])


handles, labels = ax.get_legend_handles_labels()
# handles, labels = handles[-7:], labels[-7:]
# handles, labels = handles[-2:], labels[-2:]
handles, labels = handles[-1:], labels[-1:]
ax.legend(handles, labels, loc='lower right')
# ax.legend().remove()


ax.spines[['right', 'top']].set_visible(False)
    
plt.tight_layout()


# figures_dir = '../figures'
# fig_name = 'fig7_simclr'
# formats = ['pdf', 'png', 'svg']
# save_figs(figures_dir, fig_name, formats=formats)

In [None]:
x_extend = 1.1
X_str = r'$$\tilde{C}$$'
linewidth = 3.0
alpha_scatter = 0.2
alpha_scatter = 1
alpha_ci = 0.2
alpha_fit = 1.0
fig_multiplier = 0.75
figsize = (24, 12)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])

color_palette_models = COLOR_PALETTES['models']
color_palette_regions = COLOR_PALETTES['regions']
color_1, color_2 = color_palette_regions[0], color_palette_regions[-1]

In [None]:
regionNames = {
    'v1': 'V1',
    'v2': 'V2',
    'v4': 'V4',
    'it': 'IT',
    'behavior': 'Behavioral',
    'avg': 'Average',
    # 'neuro': 'Neural',
}

In [None]:
fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=300)
for idx, reg in enumerate(regionNames.keys()):
    ax = axes.flatten()[idx]

    ### Group 11
    exp_name = f'{reg}'
    
    
    df_region = all_df[exp_name]
    optimized_params_neuro = optimized_params_dict[exp_name]
    opt_params_boot_neuro = opt_params_boot_dict[exp_name]
    L = LOSS_FUNCTIONS[L_viz_dict[exp_name]]
    x_scaler = x_scale_dict[exp_name]
    X = df_region.total_flops.values / x_scaler
    
    color = color_2
    sns.scatterplot(data=df_region, x='total_flops', y='score', ax=ax, color=color, alpha=alpha_scatter)
    plot_reg(X, optimized_params_neuro, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=False, alpha=alpha_fit)
    plot_confidence_intervals(X, opt_params_boot_neuro, L, ax, color=color, x_scaler=x_scaler, alpha=alpha_ci, percentile=95.0, invert_y=True)


    ### Formatting
    ax.set_xscale('log')
    ax.set_ylim(0, 0.5)
    ax.set_xlabel('FLOPs')
    ax.set_ylabel('Alignment')
    ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
    ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
    ax.set_title(regionNames[reg], fontsize=20, fontweight='bold')
    ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])

    ### Legend
    # handles, labels = ax.get_legend_handles_labels()
    # labels = [
    #     'Strong Prior  ' + labels[0],
    #     'Weak Prior '  + labels[1]
    # ]
    # ax.legend(handles, labels, fontsize=12)
    plt.suptitle('SimCLR Training - Regions', fontsize=24, fontweight='bold')
    
    ax.legend(loc='lower right')

    ax.spines[['right', 'top']].set_visible(False)
    
    
ax = axes.flatten()[-1]



plt.tight_layout()



figures_dir = '../figures'
fig_name = 'fig13_simclr_regions'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)
