In [1]:
import sys
import pandas as pd
import numpy as np
from pathlib import Path
import seaborn as sns
import matplotlib.pyplot as plt

from helper import save_or_show

sys.path.append('..')
from scripts.helper import parse_datasets

In [2]:
SAVE = True
storing_path = Path('/home/space/diverse_priors/results/plots/combined_vs_mean_single')

combined = True
suffix = 'combined_concat' if combined else 'ensemble'

In [3]:
datasets = "../scripts/webdatasets_wo_imagenet.txt"
datasets = parse_datasets(datasets)
datasets = map(lambda ds: ds.replace('/', '_'), datasets)
datasets = sorted(list(datasets))
ncols = 4 
nrows = len(datasets) // 4

In [4]:
def get_data_combined(ds, combined=True):
    sm = pd.read_pickle(f'/home/space/diverse_priors/results/aggregated/{ds}/single_model/results_hp_size_imagenet1k.pkl')
    if combined:
        cm = pd.read_pickle(f'/home/space/diverse_priors/results/aggregated/{ds}/combined_models_concat/results_hp_size_imagenet1k.pkl')
    else:
        cm = pd.read_pickle(f'/home/space/diverse_priors/results/aggregated/{ds}/ensemble/results_hp_size_imagenet1k.pkl')

    # Only look at weight_decay and seed 0 
    sm = sm[(sm['regularization'] == 'weight_decay') & (sm['seed'] == 0)].copy().reset_index(drop=True)
    cm = cm[(cm['regularization'] == 'weight_decay') & (cm['seed'] == 0)].copy().reset_index(drop=True)

    # Only look at unique modelsets:
    cm = cm[~cm['model_ids'].duplicated()].copy().reset_index(drop=True)

    # pp single models
    sm['single_model'] = sm['model_ids'].apply(lambda x: x[0])
    sm = sm.set_index('single_model')

    res = []
    for idx, row in cm.iterrows():
        modelset = row['model_ids']
        mean_modelset_perf = np.mean(list(map(lambda x: sm.loc[x, 'test_lp_acc1'], modelset)))
    
        res.append({
            'modelset': modelset,
            'mean_modelset_perf': mean_modelset_perf,
            'modelset_perf': row['test_lp_acc1'],
            'n_models': len(modelset),
            'dataset': ds,
        })
    res = pd.DataFrame(res)
    res['n_models'] = res['n_models'].astype('category') 
    return res

In [5]:
def plot_one_facet(res, ax, show_xlbl=True, show_ylbl=True):
    sns.scatterplot(
       res,
       x='mean_modelset_perf',
       y='modelset_perf',
       hue = 'n_models', 
        ax = ax
    )
    vmin = res.min(numeric_only=True).min()
    vmax = res.max(numeric_only=True).max()
    x = np.linspace(vmin, vmax, 20)
    ax.plot(x,x, c='grey', ls=':', alpha=0.5, zorder=-1)
    
    bad_perf = res[(res['modelset_perf'] - res['mean_modelset_perf'])<0]
    if len(bad_perf)>0:
        print(bad_perf['dataset'].unique()[0])
        for idx, row in bad_perf.iterrows():
            print(f"modelset={row['modelset']}, mean_single={round(row['mean_modelset_perf'], 3)}, combined={round(row['modelset_perf'],3)}")
        print('\n\n')
    # for idx, row in bad_perf.iterrows():
    #     txt = 'models:\n' + "\n".join(row["modelset"])
    #     ax.text(row['mean_modelset_perf'], row['modelset_perf'], 
    #            txt, 
    #            fontsize=8, verticalalignment='top')
    ax.set_xlabel('Mean acc. single models' if show_xlbl else '');
    ax.set_ylabel('Acc. combined models' if show_ylbl else '');
    
    ax.set_title(ds);

In [6]:
data_per_ds = {}
for ds in datasets:
    data_per_ds[ds] = get_data_combined(ds, combined = True)

In [None]:
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(5*ncols, 5*nrows))
for i, (ds, res_data) in enumerate(data_per_ds.items()):
    ax = axes.flatten()[i]
    plot_one_facet(res_data, ax)
fig.tight_layout()

save_or_show(fig, storing_path / f'{suffix}.pdf', SAVE)

entity13
modelset=('OpenCLIP_ViT-L-14_laion400m_e32', 'beit_large_patch16_224', 'dino-vit-base-p16_gLocal', 'efficientnet_b6', 'mae-vit-huge-p14', 'swin_base_patch4_window7_224', 'vgg19'), mean_single=0.701, combined=0.691
modelset=('OpenCLIP_ViT-L-14_laion400m_e32_gLocal', 'beit_base_patch16_224', 'efficientnet_b7', 'mae-vit-huge-p14'), mean_single=0.699, combined=0.69
modelset=('dino-vit-base-p16_gLocal', 'mae-vit-huge-p14', 'seresnet50', 'vgg16'), mean_single=0.701, combined=0.694
modelset=('beit_base_patch16_224.in22k_ft_in22k', 'efficientnet_b3', 'vgg19'), mean_single=0.664, combined=0.656
modelset=('beit_large_patch16_224.in22k_ft_in22k', 'dino-vit-base-p16_gLocal', 'efficientnet_b6'), mean_single=0.651, combined=0.575
modelset=('OpenCLIP_ViT-B-16_openai', 'mae-vit-huge-p14', 'resnet152'), mean_single=0.715, combined=0.711
modelset=('beit_base_patch16_224.in22k_ft_in22k', 'efficientnet_b7', 'vit_base_patch16_224.augreg_in21k'), mean_single=0.667, combined=0.644
modelset=('OpenCLI