In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio
import seaborn as sns

from constants import similarity_metrics

In [None]:
num_clusters = 7
# similarity_metric = similarity_metrics[-2]
similarity_metric = similarity_metrics[2]

suffix = '_w_wd'

In [None]:
storing_path = Path(
    f'/home/space/diverse_priors/results/plots/single_models{suffix}/{similarity_metric}/num_clusters_{num_clusters}/cluster_qr')
SAVE = True
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
base_path_perf = Path(
    f'/home/lciernik/projects/divers-priors/diverse_priors/benchmark/scripts/test_results/max_performance_per_model{suffix}')

all_series = {}
for path in base_path_perf.glob('max_performance_per_model_*.json'):
    ds = path.stem.split('max_performance_per_model_')[1]
    with open(path, 'r') as f:
        res_dict = json.load(f)
    all_series[ds] = pd.Series(res_dict)

In [None]:
df = pd.DataFrame(all_series)
df.head()

In [None]:
path_clustering = Path(
    f'/home/space/diverse_priors/clustering/imagenet-subset-10k/{similarity_metric}/num_clusters_{num_clusters}/cluster_qr/cluster_labels.csv')
clustering = pd.read_csv(path_clustering)
clustering.set_index('model_id', inplace=True)
clustering.index.name = 'models'
clustering.drop(columns=['Unnamed: 0'], inplace=True)

In [None]:
df['cluster'] = clustering.astype('category')

In [None]:
df = df.reset_index(names='models')

In [None]:
col_oders = [
    'models', 'cluster', 'wds_imagenet1k',
    # 'wds_imagenet-a'     , 'wds_imagenet-r'  , 'wds_imagenet_sketch', 'wds_imagenetv2',
    'cifar100-coarse', 'entity13', 'entity30', 'living17',
    'nonliving26', 'wds_cars', 'wds_country211', 'wds_fer2013',
    'wds_fgvc_aircraft', 'wds_gtsrb', 'wds_stl10', 'wds_voc2007',
    'wds_vtab_caltech101', 'wds_vtab_cifar10', 'wds_vtab_cifar100', 'wds_vtab_diabetic_retinopathy',
    'wds_vtab_dmlab', 'wds_vtab_dtd', 'wds_vtab_eurosat', 'wds_vtab_flowers',
    'wds_vtab_pcam', 'wds_vtab_pets', 'wds_vtab_resisc45', 'wds_vtab_svhn'
]
df = df[col_oders].copy()
df = df[~df['cluster'].isna()].reset_index(drop=True).copy()

In [None]:
df.sort_values('wds_imagenet1k', ascending=True).head(5)

In [None]:
r_values = {}
for col in col_oders[3:]:
    subset = df[['wds_imagenet1k', col]].copy()
    subset = subset[~subset[col].isna()]
    r = np.corrcoef(subset['wds_imagenet1k'], subset[col])[0, 1]
    r_values[col] = r

In [None]:
df_melted = pd.melt(df,
                    id_vars=['models', 'cluster', 'wds_imagenet1k'],
                    var_name='Dataset',
                    value_name='Top-1 Acc of dataset')

In [None]:
x_col = "Top-1 Acc of dataset"
y_col = "wds_imagenet1k"
split_col = "Dataset"
hue_col = 'cluster'
g = sns.relplot(data=df_melted, y=y_col, x=x_col, hue=hue_col, col=split_col, col_wrap=4, height=3, aspect=1,
                facet_kws={'sharex': False, 'sharey': True})

g.set_axis_labels(x_col, "ImageNet1k Top-1 Val Acc")
g.set_titles("Dataset: {col_name}", fontsize=16)


def annotate(data, **kws):
    r = r_values[data[split_col].unique()[0]]
    ax = plt.gca()
    # ax.text(.05, .95, f'r = {r:.2f}', transform=ax.transAxes, 
    ax.text(.7, .1, f'r = {r:.2f}', transform=ax.transAxes,
            fontsize=12, verticalalignment='top')


g.map_dataframe(annotate);

if SAVE:
    plt.savefig(storing_path / f'scatter_in1k_vs_all_ds.pdf', bbox_inches='tight')
    plt.savefig(storing_path / f'scatter_in1k_vs_all_ds.png', bbox_inches='tight')

In [None]:
pio.renderers.default = 'iframe'

x_col = "Top-1 Acc of dataset"
y_col = "wds_imagenet1k"
split_col = "Dataset"
hue_col = 'cluster'

# Create the Plotly express scatter plot
fig = px.scatter(
    df_melted,
    x=x_col,
    y=y_col,
    color=hue_col,
    facet_col=split_col,
    facet_col_wrap=4,
    hover_data=['models'],
    width=1200,
    height=1750
)
fig.update_xaxes(matches=None, showticklabels=True)
fig.update_yaxes(matches=None, showticklabels=True)

for i, annotation in enumerate(fig.layout.annotations):
    ds = annotation.text.split('=')[1]
    r = r_values[ds]
    subset = df_melted[df_melted[split_col] == ds]
    x_ax_min, x_ax_max = subset[x_col].min(), subset[x_col].max()
    y_ax_min, y_ax_max = subset[y_col].min(), subset[y_col].max()
    x = x_ax_min + 0.1 * (x_ax_max - x_ax_min)
    y = y_ax_min + 0.9 * (y_ax_max - y_ax_min)

    xref = f"x{i + 1}"
    yref = f"y{i + 1}"

    fig.add_annotation(
        x=x,
        y=y,
        xref=xref,
        yref=yref,
        text=f'r = {r:.3f}',
        showarrow=False,
        xanchor='center',
        yanchor='bottom',
    )

# Save the figure as an HTML file
if SAVE:
    fig_html = storing_path / "scatter_in1k_vs_all_ds.html"
    pio.write_html(fig, file=fig_html, auto_open=True)
fig.show()

In [None]:
raise ValueError()

In [None]:
pio.renderers.default = 'iframe'

# Create the Plotly express scatter plot
fig = px.scatter(tmp, x="Top-1 Acc", y="wds_imagenet1k", facet_col="OOD Dataset", facet_col_wrap=4,
                 hover_data=['models'])

for i, ood_dataset in enumerate(tmp['OOD Dataset'].unique()):
    subset = tmp[tmp['OOD Dataset'] == ood_dataset]
    corr = np.corrcoef(subset['Top-1 Acc'], subset['wds_imagenet1k'])[0, 1]

    # Determine the domain for the current facet
    col_num = i + 1
    xref = f"x{col_num}" if col_num > 1 else "x"
    yref = f"y{col_num}" if col_num > 1 else "y"

    fig.add_annotation(
        x=0.2,
        y=0.9,
        xref=f"{xref} domain",
        yref=f"{yref} domain",
        xanchor='center',
        yanchor='bottom',
        text=f'r = {corr:.2f}',
        showarrow=False,
        font=dict(size=12)
    )

# Save the figure as an HTML file
fig_html = storing_path / "scatter_in1k_vs_ood.html"
pio.write_html(fig, file=fig_html, auto_open=True)

fig.show()
