In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import sys
import pickle
from itertools import product

import numpy as np
import pandas as pd
import scipy.stats as stats

from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import seaborn as sns


sns.set_theme(style='whitegrid')
sns.set_theme(style='ticks')



In [None]:
repo_dir = Path('../..')

In [None]:
if str(repo_dir) not in sys.path:
    sys.path.append(str(repo_dir))
    
from analysis.curve_fitting.src.fitting_functions import LOSS_FUNCTIONS
from analysis.curve_fitting.src.utils import apply_filters, load_yaml, convert_loss_parameters, convert_loss_parameters_batch

from visualization.src.utils import COLOR_PALETTES, set_ticks, save_figs, COLORS
from visualization.src.visualize import plot_reg, plot_reg_bivariate, plot_confidence_intervals


In [None]:
args = {
    'results_csv': repo_dir / 'results' / 'benchmark_scores.csv',
    'results_csv_ckpts': repo_dir / 'results' / 'benchmark_scores_ckpts.csv',
}

In [None]:
results_csv = args['results_csv']
df_results = pd.read_csv(results_csv)

results_csv_ckpts = args['results_csv_ckpts']
df_results_ckpts = pd.read_csv(results_csv_ckpts)

## Experiment Configuration

## Apply Data Filters

In [None]:
df_ckpts = df_results_ckpts[
    (df_results_ckpts['is_pretrained'] == False) &
    (df_results_ckpts['is_random'] == False) &
    (df_results_ckpts['is_ablation'] == False)
    # (df_results_ckpts['is_adv'] == False)
]

df_ssl = df_ckpts[df_ckpts['is_ssl'] == True]

In [None]:
# sorted(df_ckpts.model_id.unique())

In [None]:
# df_ssl.model_id.unique()

In [None]:
df = df_ckpts[
    df_ckpts.model_id.isin([
        'simclr_resnet50_imagenet_full_seed-0',
        'simclr_vit_small_imagenet_full_seed-0',
        'dino_resnet50_imagenet_full_seed-0',
        'dino_vit_small_imagenet_full_seed-0',
        'resnet50_imagenet_full',
        'deit_small_imagenet_full_seed-0',  
    ])
].copy()

df.arch = df.arch.map({
    'resnet50': 'ResNet50',
    'resnet18': 'ResNet18',
    'vit_small': 'ViT-S',
    'vit_base': 'ViT-B',
    'deit_small': 'ViT-S',
    'deit_base': 'ViT-B',
})
df.loc[~df.is_ssl, 'ssl_method'] = 'Supervised'

df.ssl_method = df.ssl_method.map({
    'simclr': 'SimCLR',
    'dino': 'DINO',
    'Supervised': 'Supervised',
})
df['Model'] = df.arch
df['Learning Method'] = df.ssl_method


df

In [None]:
df_avg = df.groupby(['model_id', 'arch', 'dataset', 'arch_family', 'ckpt', 'n_samples_seen', 'is_ssl', 'ssl_method', 'Learning Method', 'Model']).agg({'score':'mean'}).reset_index()
df_avg

## Visualize

#### Plotting settings

In [None]:
linewidth = 3.0
fig_multiplier = 0.7
figsize = (10, 6)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])


color_palette = COLOR_PALETTES['models']
color_palette_regions = COLOR_PALETTES['regions']

In [None]:
fig, axes = plt.subplots(1, 1, figsize=figsize, dpi=300)

ax = axes
data_plot = df_avg.copy()



sns.lineplot(data=data_plot, x='ckpt', y='score', markers=True, style='Model', hue='Learning Method', ax=ax, palette=color_palette[::2], linewidth=linewidth)

ax.set_xlabel('Training Epoch', fontsize=16, fontweight='bold')
ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
ax.set_title('Alignment During Training - SSL', fontsize=20, fontweight='bold')

set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
# handles, labels = ax.get_legend_handles_labels()
# labels[0] = "SSL Method"
# labels[4] = "Model"
# # # h, l = handles[0], labels[0]
# # # handles, labels = handles[:5], labels[:5]
# ax.legend(handles, labels, loc='lower right')

# ax.legend(loc="lower right")

ax.set_xscale('log')




ax.spines[['right', 'top']].set_visible(False)
plt.tight_layout()


figures_dir = '../figures'
fig_name = 'fig7_ssl'
formats = ['pdf', 'png', 'svg']
save_figs(figures_dir, fig_name, formats=formats)

In [None]:
figsize = (12, 12)
figsize = (fig_multiplier * figsize[0], fig_multiplier * figsize[1])
fig, axes = plt.subplots(2, 1, figsize=figsize, dpi=300)

for idx, model in enumerate(['ResNet50', 'ViT-S']):
    ax = axes[idx]
    data_plot = df[df.arch == model].copy()


    sns.lineplot(data=data_plot, x='ckpt', y='score', markers=True, style='region', hue='Learning Method', ax=ax, palette=color_palette[::2])

    ax.set_xlabel('Training Epoch', fontsize=16, fontweight='bold')
    ax.set_ylabel('Alignment Score (S)', fontsize=16, fontweight='bold')
    ax.set_title(model, fontsize=20, fontweight='bold')

    set_ticks(ax, xticks_mode='log', yticks_mode=None, yticks=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5])
    ax.legend().remove()



    ax.set_xscale('log')

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, bbox_to_anchor = (1.05, 1))
#                         # borderaxespad  = 0)
plt.suptitle('Alignment During Training - Per Region', fontsize=24, fontweight='bold')
plt.tight_layout()


# figures_dir = '../figures'
# fig_name = 'fig11_ssl_regions'
# formats = ['pdf', 'png', 'svg']
# save_figs(figures_dir, fig_name, formats=formats)