In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import pickle
from itertools import product
import copy

import numpy as np
import pandas as pd
import scipy.stats as stats

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns

sns.set_theme(style='whitegrid')
sns.set_theme(style='ticks')


In [None]:
repo_dir = Path('../..')

In [None]:
if str(repo_dir) not in sys.path:
    sys.path.append(str(repo_dir))
    
from analysis.curve_fitting.src.fitting_functions import LOSS_FUNCTIONS
from analysis.curve_fitting.src.utils import apply_filters, load_yaml, convert_loss_parameters, convert_loss_parameters_batch

from visualization.src.utils import COLOR_PALETTES, set_ticks, save_figs, COLORS
from visualization.src.visualize import plot_reg, plot_reg_bivariate, plot_confidence_intervals



In [None]:
args = {
    'results_csv': repo_dir / 'results' / 'benchmark_scores.csv',
}

In [None]:
results_csv = args['results_csv']

df_results = pd.read_csv(results_csv)

df_results.arch_family.unique()

## Load Experiment Configuration

In [None]:
config_adv = {
    'data_filters': {
        'set_filters': {
            'region': [
                'V1',
                'V2',
                'V4',
                'IT',
                'Behavioral'
                ],
            'dataset': [
                'imagenet',
                ],
            'adv_method': [
                'ffgsm_eps-1_alpha-125-ep10',
                ],
            },
            
    'boolean_filters': {
        'equals_false': [
            'is_pretrained',
            'is_random',
            'is_ssl',
            'is_ablation'
            ],
        'equals_true': [
            'is_adv',
            ]
        },
    
    'group_by': {
        'avg_score': {
            'keys': [
                'model_id',
                'arch',
                'dataset',
                'flops',
                'n_params',
                'n_samples',
                'n_samples_seen',
                'total_flops',
                'arch_family',
                'samples_per_class',
                'adv_method',

            ],
            'reduce': {'score': 'mean'}}},

    'combine_arch_families': True,
    }
}



In [None]:
analysis_dir = repo_dir / 'analysis'
config_dir = analysis_dir / 'curve_fitting/configs/adv'
config_dir = analysis_dir / 'curve_fitting/configs/model/resnet'
results_dir = analysis_dir / 'curve_fitting/outputs/fitting_results'
results_dir = analysis_dir / 'curve_fitting/outputs/fitting_results_test'

In [None]:
experiment_name = "adv_scaling"
experiment_name = "resnet_avg"
config_nonadv = load_yaml(config_dir / f'{experiment_name}.yaml')

In [None]:
L_fit = config_nonadv['fitting_parameters']['loss_function']
L_viz = config_nonadv['visualization']['loss_function']
x_scaler = float(config_nonadv['fitting_parameters']['X_scaler'])

## Apply Data Filters

In [None]:
df_adv = apply_filters(df_results, config_adv.get('data_filters', {}))
df_adv['total_flops'] = 3 * df_adv['flops'] * df_adv['n_samples'] * 10 + 3 * ( df_adv['flops'] ) * 100 * 1281167 
df_adv['n_samples_seen'] = df_adv['n_samples_seen'] + 100*df_adv['n_samples']
df_adv['is_adv'] = True

In [None]:
config_nonadv['data_filters']['combine_arch_families'] = False

df_nonadv = apply_filters(df_results, config_nonadv.get('data_filters', {}))
df_nonadv['is_adv'] = False
df_nonadv.arch.unique()

In [None]:
# config3 = {
#     'data_filters': {
#         'set_filters': {
#             'region': [
#                 'V1',
#                 'V2',
#                 'V4',
#                 'IT',
#                 'Behavioral'
#                 ],
#             'dataset': [
#                 'imagenet',
#                 ],
#             'arch_family': [
#                 'ResNet',
#                 ],
#             },
#     }
# }

# df_scratch = apply_filters(df_results, config3.get('data_filters', {}))

# df_scratch.adv_method.unique()

In [None]:
# df_scratch = df_scratch[df_scratch.adv_method.isin([
#     'scratch-ffgsm_eps-1_alpha-125_lr-01',
#     'scratch-ffgsm_eps-2_alpha-25_lr-01',
#     'scratch-ffgsm_eps-4_alpha-5_lr-01',
# ])]

# df_scratch['total_flops'] = df_scratch['flops'] * df_scratch['n_samples'] * 100

In [None]:
# [i for i in df_results[df_results.is_adv].model_id.unique() if 'scratch' in i]

## Load Fitting Results

In [None]:
with open(results_dir / f'model_{experiment_name}' / 'results.pkl', 'rb') as f:
    results = pickle.load(f)


optimized_params = convert_loss_parameters(results['optimized_parameters'], L_fit, L_viz)

# Convert bootstrapped parameters
opt_params_boot = results['optimized_parameters_bootstrapped']
opt_params_boot = convert_loss_parameters_batch(
    params=opt_params_boot,
    src_loss=L_fit,
    dst_loss=L_viz
)

## Visualize

#### Plotting settings

In [None]:
x_extend = 1.1
x_extend = 10
X_str = r'$$\tilde{N}$$'
linewidth = 3.0
alpha_scatter = 1.0
alpha_ci = 0.2
alpha_fit = 1.0
fig_multiplier = 0.7
figsize = (10, 6)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])

color_palette = COLOR_PALETTES['models']
color = "#023e8a"
color_palette = [color_palette[0], color_palette[-1]]



L = LOSS_FUNCTIONS[L_viz]
X = df_nonadv.total_flops.values / x_scaler



#### Plot

In [None]:
df_concat = pd.concat([df_adv, df_nonadv]).reset_index(drop=True)

model_sort = {
    'resnet18': 0,
    'resnet34': 1,
    'resnet50': 2,
    'resnet101': 3,
    'resnet152': 4,
}
df_concat['sort'] = df_concat['arch'].map(model_sort)
df_concat = df_concat.sort_values('sort').reset_index(drop=True)

model_name_map = {
    'resnet18': 'ResNet-18',
    'resnet34': 'ResNet-34',
    'resnet50': 'ResNet-50',
    'resnet101': 'ResNet-101',
    'resnet152': 'ResNet-152',
}
df_concat = df_concat.replace({'arch': model_name_map})
df_concat['Adversarially Finetuned'] = df_concat['is_adv']
df_concat['Model'] = df_concat['arch']




# df_concat = pd.concat([df, df_nonadv, df_scratch]).reset_index(drop=True)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=300)

ax = axes

df_plot = df_concat

# df_plot = df_concat.groupby(['model_id', 'total_flops', 'arch', 'n_samples', 'n_samples_seen', 'n_params', 'arch_family', 'is_adv']).agg({'score':'mean'}).reset_index()
# sns.scatterplot(data=df_plot, x='total_flops', y='score', style='arch', hue='is_adv', ax=ax, s=120, palette=color_palaette)
sns.scatterplot(data=df_plot, x='total_flops', y='score', style='Model', hue='Adversarially Finetuned', ax=ax, s=120, palette=color_palette)


plot_reg(X, optimized_params, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=True, alpha=alpha_fit)
plot_confidence_intervals(X, opt_params_boot, L, ax, color=color, x_extend=x_extend, x_scaler=x_scaler, alpha=alpha_ci, percentile=95.0, invert_y=True)



### Formatting
ax.set_xscale('log')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Alignment')
ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
ax.set_title('Adversarial Finetuning', fontsize=20, fontweight='bold')
ax.grid(False)
ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])


# handles, labels = ax.get_legend_handles_labels()
# handles, labels = handles[-7:], labels[-7:]
# ax.legend(handles, labels, loc='lower right')
# ax.legend().remove()


ax.spines[['right', 'top']].set_visible(False)
    
plt.tight_layout()


figures_dir = '../figures'
fig_name = 'fig7_adv'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)

In [None]:
x_extend = 1.1
X_str = r'$$\tilde{C}$$'
linewidth = 3.0
alpha_scatter = 0.2
alpha_scatter = 1
alpha_ci = 0.2
alpha_fit = 1.0
fig_multiplier = 0.75
figsize = (24, 12)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])

color_palette_models = COLOR_PALETTES['models']
color_palette_regions = COLOR_PALETTES['regions']
color_1, color_2 = color_palette_models[0], color_palette_models[-1]
color_1, color_2 = color_palette_regions[0], color_palette_regions[-1]

In [None]:
regionNames = {
    'v1': 'V1',
    'v2': 'V2',
    'v4': 'V4',
    'it': 'IT',
    'behavior': 'Behavioral',
}

In [None]:
config_nonadv2 = copy.deepcopy(config_nonadv)
config_nonadv2['data_filters']['group_by'] = {}
df_nonadv2 = apply_filters(df_results, config_nonadv2.get('data_filters', {}))


config_adv2 = copy.deepcopy(config_adv)
config_adv2['data_filters']['group_by'] = {}
df_adv2 = apply_filters(df_results, config_adv2.get('data_filters', {}))

df_adv2['total_flops'] = 3 * df_adv2['flops'] * df_adv2['n_samples'] * 10 + 3 * ( df_adv2['flops'] ) * 100 * 1281167
df_adv2['n_samples_seen'] = df_adv2['n_samples_seen'] + 100*df_adv2['n_samples']

df_adv2['is_adv'] = True
df_nonadv2['is_adv'] = False
df_concat2 = pd.concat([df_adv2, df_nonadv2]).reset_index(drop=True)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=figsize, dpi=300)
for idx, reg in enumerate(regionNames.keys()):
    ax = axes.flatten()[idx]

    ### Group 11
    exp_name = f'{reg}_group1'
    
    df_region = df_concat2[df_concat2.region == regionNames[reg]].copy()
    
    # df_region = all_df[exp_name]
    # optimized_params_neuro = optimized_params_dict[exp_name]
    # opt_params_boot_neuro = opt_params_boot_dict[exp_name]
    # L = LOSS_FUNCTIONS[L_viz_dict[exp_name]]
    # x_scaler = x_scale_dict[exp_name]
    # X = df_region.total_flops.values / x_scaler
    
    color = color_1
    sns.scatterplot(data=df_region, x='total_flops', y='score', ax=ax, hue='is_adv', alpha=alpha_scatter, palette=color_palette)
    # plot_reg(X, optimized_params_neuro, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=False, alpha=alpha_fit)
    # plot_confidence_intervals(X, opt_params_boot_neuro, L, ax, color=color, x_scaler=x_scaler, alpha=alpha_ci, percentile=95.0, invert_y=True)


    ### Formatting
    ax.set_xscale('log')
    ax.set_xlabel('FLOPs')
    ax.set_ylabel('Alignment')
    ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
    ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
    ax.set_title(regionNames[reg], fontsize=20, fontweight='bold')
    # ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])


    ax.spines[['right', 'top']].set_visible(False)
    
    
ax = axes.flatten()[-1]



df_region = df_concat2.groupby(['model_id', 'total_flops', 'arch', 'n_samples', 'is_adv']).agg({'score':'mean'}).copy()

# df_region = all_df[exp_name]
# optimized_params_neuro = optimized_params_dict[exp_name]
# opt_params_boot_neuro = opt_params_boot_dict[exp_name]
# L = LOSS_FUNCTIONS[L_viz_dict[exp_name]]
# x_scaler = x_scale_dict[exp_name]
# X = df_region.total_flops.values / x_scaler

color = color_1
sns.scatterplot(data=df_region, x='total_flops', y='score', hue='is_adv', ax=ax, alpha=alpha_scatter, palette=color_palette)
# plot_reg(X, optimized_params_neuro, L, ax, color=color, x_extend=x_extend, linestyle='-', X_str=X_str, x_scaler=x_scaler, show_x_scaler=False, linewidth=linewidth, legend=False, alpha=alpha_fit)
# plot_confidence_intervals(X, opt_params_boot_neuro, L, ax, color=color, x_scaler=x_scaler, alpha=alpha_ci, percentile=95.0, invert_y=True)


### Formatting
ax.set_xscale('log')
ax.set_xlabel('FLOPs')
ax.set_ylabel('Alignment')
ax.set_xlabel('Total Flops (C)', fontsize=16, fontweight='bold')
ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
ax.set_title('Average', fontsize=20, fontweight='bold')
# ax = set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])

### Legend
# handles, labels = ax.get_legend_handles_labels()
# labels = [
#     'Strong Prior  ' + labels[0],
#     'Weak Prior '  + labels[1]
# ]
# ax.legend(handles, labels, fontsize=12)

ax.spines[['right', 'top']].set_visible(False)

    
plt.tight_layout()



# figures_dir = '../figures'
# fig_name = 'fig5_regions_compare'
# formats = ['pdf', 'png', 'svg']
# save_figs(figures_dir, fig_name, formats=formats)
