In [None]:
import numpy as np
import pandas as pd
from sklearn import metrics
from typing import Tuple, Dict


In [None]:
DATASET_LABEL = 'dataset'
ORIENTATION_LABEL = 'orientation'
SAME_LABEL = 'same'
SAMPLE_IDX_LABEL = 'sample_idx'

In [None]:
paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/369a729bcbe64d83ac49b37577db79b8/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
#     '260_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/ddc0ee34ee8c4b94b3b16b5f7afa6fe5/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/260_faces_fixed/13a8cd83f47f4c58a714dfe035e6cd56/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/03aaf9b8374847beaa18c550f4213d4f/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/10f88d6edd764503b4322b127f8c87ae/artifacts/dists.csv'
}


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/369a729bcbe64d83ac49b37577db79b8/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/10f88d6edd764503b4322b127f8c87ae/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
#     '260_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/ddc0ee34ee8c4b94b3b16b5f7afa6fe5/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/260_faces_fixed/13a8cd83f47f4c58a714dfe035e6cd56/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/03aaf9b8374847beaa18c550f4213d4f/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
}


In [None]:
cls_diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/369a729bcbe64d83ac49b37577db79b8/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/260_faces_fixed/13a8cd83f47f4c58a714dfe035e6cd56/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/10f88d6edd764503b4322b127f8c87ae/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/03aaf9b8374847beaa18c550f4213d4f/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv'
}


In [None]:
def formalize_pair_type(pairs_list: pd.DataFrame) -> pd.DataFrame:
    """
    given a pairs dist list, with a tupe column like "dataset_name_orientation_same_idx"
    breaks the type to columns:
        ['dataset'] = dataset_name
        ['orientation'] = orientation
        ['same'] = True/False
        ['sample_idx'] = idx
    """
    pairs_list[DATASET_LABEL] = pairs_list['type'].str.split('_').apply(lambda row: '_'.join(row[:-3]))
    pairs_list[ORIENTATION_LABEL] = pairs_list['type'].str.split('_').apply(lambda row: row[-3])
    pairs_list[SAME_LABEL] = pairs_list['type'].str.split('_').apply(lambda row: row[-2] != 'same')
    pairs_list[SAMPLE_IDX_LABEL] = pairs_list['type'].str.split('_').apply(lambda row: row[-1]) # set the batch sample idx
    return pairs_list

In [None]:
def get_aucs(df: pd.DataFrame) -> Tuple[pd.Series, pd.DataFrame]:
    """
    Given the distance lists, calculate the AUCS
    returns (series of the AUCs, mean and stds of aucs in different condition)
    """
    def roc_auc(df: pd.DataFrame):
        label = df[SAME_LABEL]
        score = df['fc7']
        return metrics.roc_auc_score(label, score)
    aucs = df.groupby([DATASET_LABEL, ORIENTATION_LABEL, SAMPLE_IDX_LABEL]).apply(lambda df: roc_auc(df))
    means = aucs.groupby([DATASET_LABEL, ORIENTATION_LABEL]).mean()
    stds = aucs.groupby([DATASET_LABEL, ORIENTATION_LABEL]).std()
    summary = pd.DataFrame({'means': means, 'std': stds})
    return aucs, summary

In [None]:
dfs = {}
total_summary = []
models_aucs_series = {}

for model in paths:
    df = pd.read_csv(paths[model])
    df = formalize_pair_type(df)
    dfs[model] = df
    curr_aucs, curr_summary = get_aucs(df)
    models_aucs_series[model] = curr_aucs
    curr_summary['model'] = model
    total_summary.append(curr_summary)

models_aucs = pd.DataFrame(models_aucs_series)
total_summary = pd.concat(total_summary)

In [None]:
models_aucs_series

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


In [None]:
def plot_ROC(dists: Dict[str, pd.DataFrame], fn: str, show_tpr_variance: bool = True) -> None:
    colors = {
        'upright': {
            'fill': 'rgba(52, 152, 219, 0.2)',
            'line': 'rgba(52, 152, 219, 0.5)',
            'line_main': 'rgba(41, 128, 185, 1.0)',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        },
        'inverted': {
            'fill': 'rgba(255, 0, 50, 0.2)',
            'line': 'rgba(255, 0, 50, 0.5)',
            'line_main': 'rgba(255, 0, 50, 1.0)',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        }
    }
    
    colors = {
        'upright': {
            'fill': 'darkgreen',
            'line': 'darkgreen',
            'line_main': 'darkgreen',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        },
        'inverted': {
            'fill': 'seagreen',
            'line': 'seagreen',
            'line_main': 'seagreen',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        }
    }
    colors = {
        'inverted': {
            'fill': 'rgba(52, 152, 219, 0.2)',
            'line': 'rgba(52, 152, 219, 0.5)',
            'line_main': 'rgba(41, 128, 185, 1.0)',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        },
        'upright': {
            'fill': 'rgba(255, 0, 50, 0.2)',
            'line': 'rgba(255, 0, 50, 0.5)',
            'line_main': 'rgba(255, 0, 50, 1.0)',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        }
    }
    colors = {
        'upright': {
            'fill': 'darkred',
            'line': 'darkred',
            'line_main': 'darkred',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        },
        'inverted': {
            'fill': 'darkblue',
            'line': 'darkblue',
            'line_main': 'darkblue',
            'grid': 'rgba(189, 195, 199, 0.5)',
            'annot': 'rgba(149, 165, 166, 0.5)',
            'highlight': 'rgba(192, 57, 43, 1.0)'
        }
    }
    if len(dists) == 4:
        datasets = ['inanimate_objects', 'species', 'faces', 'sociable_weavers']
        show_legend = True
    if len(dists) == 3:
        datasets = ['inanimate_objects', 'species', 'faces']
        show_legend = True
    if len(dists) == 2:
        datasets = ['inanimate_objects', 'faces']
        show_legend = False
    n_rows = len(dists)
    n_cols = len(datasets)
    fig = make_subplots(rows=n_rows, cols=n_cols, column_titles=datasets,
                       row_titles=list(dists.keys()), horizontal_spacing=0.07)
#     fig.print_grid()
    for k, model in enumerate(dists):
        for j, dataset in enumerate(datasets):
            for orientation in ['upright', 'inverted']:
                roc_fprs = []
                roc_tprs = []
                fpr_mean = np.linspace(0, 1, 100)
                interp_tprs = []
                for i in range(1,30):
                    curr_df = dists[model]
                    filtered = curr_df[curr_df[DATASET_LABEL] == dataset]
                    filtered = filtered[filtered[ORIENTATION_LABEL] == orientation]
                    filtered = filtered[filtered[SAMPLE_IDX_LABEL] == str(i)]
#                     filtered = curr_df[(curr_df[DATASET_LABEL] == dataset) & (curr_df[ORIENTATION_LABEL] == orientation) & (curr_df[SAMPLE_IDX_LABEL] == i)]
                    fpr, tpr, thresholds = metrics.roc_curve(filtered[SAME_LABEL], filtered['fc7'], pos_label=1, drop_intermediate=False)
                    roc_fprs.append(fpr)
                    roc_tprs.append(tpr)
                    interp_tpr = np.interp(fpr_mean, fpr, tpr)
                    interp_tpr[0] = 0.0
                    interp_tprs.append(interp_tpr)
                tpr_mean = np.mean(interp_tprs, axis=0)
                tpr_std = 2*np.std(interp_tprs, axis=0)
                tpr_upper = np.clip(tpr_mean+tpr_std, 0, 1)
                tpr_lower = tpr_mean-tpr_std
                if show_tpr_variance:
                    fig.add_trace(
                        go.Scatter(
                            x = fpr_mean, y = tpr_upper, line = dict(color=colors[orientation]['line'], width=1), hoverinfo = "skip",
                            showlegend = False, name = 'upper'), row=1+k, col=1+j)
                    fig.add_trace(
                        go.Scatter(
                            x = fpr_mean, y = tpr_lower, fill = 'tonexty', fillcolor = colors[orientation]['fill'],
                            line = dict(color=colors[orientation]['line'], width=1), hoverinfo = "skip", showlegend = False, name = 'lower'), 
                        row=k+1, col=j+1)
                fig.add_trace(
                    go.Scatter(
                        x = fpr_mean, y = tpr_mean, line = dict(color=colors[orientation]['line_main'], width=2), hoverinfo = "skip",
                        showlegend = show_legend, name=f'{model} trained, tested on {orientation} {dataset}'), row=k+1, col=j+1)
            # Commented out dashed line for 'guess' ROC
            # fig.add_shape(
            #     type ='line', line =dict(dash='dash'), x0=0, x1=1, y0=0, y1=1, row=k+1, col=j+1)
    fig.update_layout(
        template = 'plotly_white', width = 200*n_cols + show_legend*400, height=200*n_rows)
    # fig.update_layout( legend = {"xanchor": "right", "x": 3})
    fig.update_yaxes(range = [0, 1], gridcolor = colors[orientation]['grid'], scaleanchor = "x", scaleratio = 1,linecolor='black')
    fig.update_xaxes(range = [0, 1], gridcolor = colors[orientation]['grid'], constrain = 'domain',linecolor = 'black')
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Expertise/experiment2/new_design/{fn}.html')
    print(f'/home/ssd_storage/experiments/Expertise/experiment2/new_design/{fn}.html')

In [None]:
def plot_ordered(csv_pths: Dict[str, str], fn: str, show_tpr_variance: bool = True) -> None:
    dfs = {}
    models_aucs_series = {}
    
    for model in csv_pths:
        df = pd.read_csv(csv_pths[model])
        df = formalize_pair_type(df)
        dfs[model] = df
        curr_aucs, curr_summary = get_aucs(df)
        print(curr_aucs)
        print(curr_aucs.reset_index().groupby(['dataset', 'orientation'], axis=0).mean())
        models_aucs_series[model] = curr_aucs
        
    models_aucs = pd.DataFrame(models_aucs_series)
    models_aucs.to_csv(f'/home/ssd_storage/experiments/expertise_cls_size_control/faces_260_aucs.csv')
    plot_ROC(dfs, fn, show_tpr_variance)

In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/4a897372f2414e5e9bab33a0f0fb94c4/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/ad27097ceff64b719e8da3c9593e58b9/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/214575675c834788ab7141d87789f314/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
}
plot_ordered(diagonal_paths, 'ROCS_domain_diagonal_datasets_with_variance_fixed_faces_objects_with_variance', True)


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/4a897372f2414e5e9bab33a0f0fb94c4/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/ad27097ceff64b719e8da3c9593e58b9/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/214575675c834788ab7141d87789f314/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
}
plot_ordered(diagonal_paths, 'ROCS_domain_diagonal_datasets_with_variance_fixed_faces_objects_no_variance', False)


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/ad27097ceff64b719e8da3c9593e58b9/artifacts/dists.csv'
}
plot_ordered(diagonal_paths, 'ROCS_260_cls', False)


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/ad27097ceff64b719e8da3c9593e58b9/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/4a897372f2414e5e9bab33a0f0fb94c4/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/214575675c834788ab7141d87789f314/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
}
plot_ordered(diagonal_paths, 'ROCS_cls_diagonal_datasets_with_variance_fixed_faces_objects_no_variance', False)


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/ad27097ceff64b719e8da3c9593e58b9/artifacts/dists.csv',
    '30_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/4a897372f2414e5e9bab33a0f0fb94c4/artifacts/dists.csv',
    '30_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '30_faces': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/214575675c834788ab7141d87789f314/artifacts/dists.csv',
    '30_sociable_weavers': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/c6b3733cbac74bee87ad03213ff13188/artifacts/dists.csv',
}
plot_ordered(diagonal_paths, 'ROCS_cls_diagonal_datasets_with_variance_fixed_faces_objects_with_variance', True)


In [None]:
diagonal_paths = {
    '260_inanimates': '/home/hdd_storage/mlflow/artifact_store/faces_objects_inversion/665cb3935bac4f8f9e462be5cda28745/artifacts/dists.csv',
    '260_species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8675b29e48f84a0b99870f006a74c71d/artifacts/dists.csv',
    '260_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/7ada2c4bc4e447a4a21082e062694ec7/artifacts/dists.csv'
}
plot_ordered(diagonal_paths, '260_csl_faces_300_imgs', False)


In [None]:
#1000
paths = {
    '1000_inanimates': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/b36e60a0041348e893ba71847ea3839a/artifacts/dists.csv',
    '1000_faces': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/8623d0f3e5064f3fa73ff4dad165ca4d/artifacts/dists.csv',
}
plot_ordered(paths, 'ROCS_1000_faces_objects', False)

In [None]:
paths = {
    '30 inanimate objects': '/home/hdd_storage/mlflow/artifact_store/expertise_cls_size_control/b0cd2de5763943d8a1074ac2fd61135a/artifacts/dists.csv',
    '30 species': '/home/hdd_storage/mlflow/artifact_store/inversion_effect/eb293b67f8624e698b06e5fa382ef005/artifacts/dists.csv',
    '30 faces': '/home/hdd_storage/mlflow/artifact_store/expertise_cls_size_control/000213465fe34a84ac1cecdcc6a73518/artifacts/dists.csv',
    '30 sociable Weavers':'/home/hdd_storage/mlflow/artifact_store/expertise_cls_size_control/0c1a24c77e084e7e885101e4f75766fd/artifacts/dists.csv'
    }
plot_ordered(paths, 'soc_weav_cls_size_100', False)
#/home/ssd_storage/experiments/Expertise/experiment2/new_design/soc_weav_cls_size_100.html