In [None]:
import pandas as pd
import numpy as np
from glob import glob
import plotly.express as px
from typing import Tuple, Dict, List
from os import path
from sklearn import metrics
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sklearn
from tqdm.notebook import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import maximum_bipartite_matching

import torchvision.transforms as T
from PIL import Image
from glob import glob
from os import path, makedirs

In [None]:
def plot_rdm(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    
    for (loc, pth) in tqdm(rdm_paths):
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
                
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        
        fig.add_trace(go.Heatmap(z=rdm, showscale=False, x=rdm.index, y=rdm.index),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
        
    
    return fig
    

In [None]:
# for layer in ['input', 'conv1', 'conv2', 'conv3','conv4','conv5', 'fc6', 'fc7', 'fc8']:
for layer in ['fc7']:
    rdm_locations = [
        ((1,1), f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_30_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ((1,2), f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_260_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ((1,3), f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_500_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ((1,4), f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_1000_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
    ]

    fig = bar_plot(rdm_locations)
#     fig.show()
#     fig.write_html(f'/home/ssd_storage/experiments/win_lose/rdms_{layer}.html')

In [None]:
rdm1 = pd.read_csv(f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_1000_faces/vgg16/results/eden_stim/eden_stim_fc7.csv', index_col='Unnamed: 0')

In [None]:
negatives = [('tamar', 'marina'), ('dom', 'mat'), ('omer','tal'), ('taz','tess')]
# for cls1, cls2 in negatives:
#     cls1_idx = rdm1.index.str.startswith(cls1)
#     marina = rdm1.index.str.startswith('marina')
marina = rdm1.index.str.startswith('marina')
marina = rdm1.index[marina].tolist()[:-2]
marina = rdm1.filter(marina, axis=0).index
rdm1.loc[marina] = 0
rdm1.loc[marina]


In [None]:
def bar_plot(rdm_paths: List[Tuple[Tuple[int, int], str]], calc: str = 'roc') -> None:
    assert (calc == 'prc') or (calc == 'roc')
    rdm = pd.read_csv(f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_1000_faces/vgg16/results/eden_stim/eden_stim_fc7.csv', index_col='Unnamed: 0')
    data = []
    
    aucs = []
    names = []

    imgs = sorted(list(rdm.index))

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]
    pos, neg = random_balanced_set(rdm)

    for (name, pth) in rdm_paths:

        rdm = pd.read_csv(pth, index_col='Unnamed: 0')

        # Filter stimuli domain
        imgs = sorted(list(rdm.index))

        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]

        dists = rdm.to_numpy()
        pos_dists = rdm.lookup(pos[0], pos[1])
        neg_dists = rdm.lookup(neg[0], neg[1])
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = np.concatenate((pos_dists, neg_dists))

        if calc == 'prc':
            dists = 1 - dists
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
            aucs.append(auc)
            names.append(name)
        elif calc == 'roc':
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
            auc = sklearn.metrics.auc(fpr, tpr)
            aucs.append(auc)
            names.append(name)

    data.append(go.Bar(name = name, x=names, y=aucs))
    
    fig = go.Figure(data=data)
    fig.update_yaxes(range=(0,1))
    fig.update_layout(template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/eden_stimuli/balaned_{calc}_auc.html')

In [None]:
# for layer in ['input', 'conv1', 'conv2', 'conv3','conv4','conv5', 'fc6', 'fc7', 'fc8']:
for layer in ['fc7']:   
    rdm_locations = [
        ('30', f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_30_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ('260', f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_260_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ('500', f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_500_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
        ('1000', f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_1000_faces/vgg16/results/eden_stim/eden_stim_{layer}.csv'),
    ]

    fig = bar_plot(rdm_locations)
#     fig.show()
#     fig.write_html(f'/home/ssd_storage/experiments/win_lose/rdms_{layer}.html')

In [None]:
rdm = pd.read_csv(f'/home/ssd_storage/experiments/eden_stimuli/eden_stimuli_1000_faces/vgg16/results/eden_stim/eden_stim_fc7.csv', index_col='Unnamed: 0')
imgs = ['dom_01.png',
 'dom_02.png',
 'dom_03.png',
 'dom_04.png',
 'dom_05.png',
 'dom_06.png',
 'dom_07.png',
 'dom_08.png',
 'dom_09.png',
 'dom_10.png',
 'matt_01.png',
 'matt_02.png',
 'matt_03.png',
 'matt_04.png',
 'matt_05.png',
 'matt_06.png',
 'matt_07.png',
 'matt_08.png',
 'matt_09.png',
 'matt_10.png',
 'marina_01.png',
 'marina_02.png',
 'marina_03.png',
 'marina_04.png',
 'marina_05.png',
 'marina_06.png',
 'marina_07.png',
 'marina_08.png',
 'marina_09.png',
 'marina_10.png',
 'tamar_01.png',
 'tamar_02.png',
 'tamar_03.png',
 'tamar_04.png',
 'tamar_05.png',
 'tamar_06.png',
 'tamar_07.png',
 'tamar_08.png',
 'tamar_09.png',
 'tamar_10.png',
 'omer_01.png',
 'omer_02.png',
 'omer_03.png',
 'omer_04.png',
 'omer_05.png',
 'omer_06.png',
 'omer_07.png',
 'omer_08.png',
 'omer_09.png',
 'omer_10.png',
 'tal_01.png',
 'tal_02.png',
 'tal_03.png',
 'tal_04.png',
 'tal_05.png',
 'tal_06.png',
 'tal_07.png',
 'tal_08.png',
 'tal_09.png',
 'tal_10.png',
 'taz_01.png',
 'taz_02.png',
 'taz_03.png',
 'taz_04.png',
 'taz_05.png',
 'taz_06.png',
 'taz_07.png',
 'taz_08.png',
 'taz_09.png',
 'taz_10.png',
 'tess_01.png',
 'tess_02.png',
 'tess_03.png',
 'tess_04.png',
 'tess_05.png',
 'tess_06.png',
 'tess_07.png',
 'tess_08.png',
 'tess_09.png',
 'tess_10.png']

# rdm = rdm[imgs]
rdm = rdm.loc[imgs, imgs]
px.imshow(rdm).write_html(f'/home/ssd_storage/experiments/eden_stimuli/1000_faces_net_rdm.html')


In [None]:
def multi_layer_bar_plot(rdm_paths: List[str], calc: str = 'roc') -> None:
    assert (calc == 'prc') or (calc == 'roc')
    fig = go.Figure()
    layers = ['input', 'conv1', 'conv2', 'conv3','conv4','conv5', 'fc6', 'fc7', 'fc8']
#     layers = [f'Residual attention block {i}' for i in range(1,13)] + ['ViT output']
    data = []

    traces = {}
    for _, layer in enumerate(layers):
        aucs = []
        names = []
        _, pth = rdm_paths[0]
        rdm = pd.read_csv(pth.format(layer), index_col='Unnamed: 0')
        filtr = rdm.index.str.startswith(f'Lose_{mode}') | rdm.index.str.startswith(f'Win_{mode}')
        rdm = rdm.loc[filtr,filtr]

        imgs = sorted(list(rdm.index))

        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        pos, neg = random_balanced_set(rdm)

        for (name, pth) in rdm_paths:

            rdm = pd.read_csv(pth.format(layer), index_col='Unnamed: 0')

            # Filter stimuli domain
            filtr = rdm.index.str.startswith(f'Lose_{mode}') | rdm.index.str.startswith(f'Win_{mode}')
            rdm = rdm.loc[filtr,filtr]

            imgs = sorted(list(rdm.index))

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)
            for i in range(len(imgs)):
                if imgs[i].startswith('Lose'):
                    imgs[i] = 'Lose'
                else:
                    imgs[i] = 'Win'

#                 dists = rdm.to_numpy()
            pos_dists = rdm.lookup(pos[0], pos[1])
            neg_dists = rdm.lookup(neg[0], neg[1])
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = np.concatenate((pos_dists, neg_dists))

            if calc == 'prc':
                dists = 1 - dists
                precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
                auc = sklearn.metrics.auc(recall, precision)
                aucs.append(auc)
                names.append(name)
            elif calc == 'roc':
                fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
                auc = sklearn.metrics.auc(fpr, tpr)
                aucs.append(auc)
                names.append(name)
            if name not in traces:
                traces[name] = []
            traces[name].append(auc)
        data.append(go.Bar(name=mode, x=names, y=aucs))
    for name in traces:
        fig.add_trace(go.Scatter(name=name, x=layers, y=traces[name]))
            
    fig.update_yaxes(range=(0,1))
    fig.update_layout(height=1000, template='plotly_white')
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/clip_multi_layer_balanced_{calc}_auc.html')

In [None]:
rdm_locations = [
    ('Inanimates 30', '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Inanimates 260', '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Inanimates 500', '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('ImageNet 1000', '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose/win_lose_{0}.csv'),
    
    ('Bird Species 30', '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Bird Species 260', '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose/win_lose_{0}.csv'),
    
    ('Faces 30', '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Faces 260', '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Faces 500', '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose/win_lose_{0}.csv'),
    ('Faces 1000', '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose/win_lose_{0}.csv'),
    
    ('Sociable Weavers 30', '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose/win_lose_{0}.csv'),
]

multi_layer_bar_plot(rdm_locations)

In [None]:
idx = random_balanced_set(rdm1)[1]
print(idx)


In [None]:
def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)


def create_positive_pairs(cls2imgs: Dict[str, List[str]]) -> pd.DataFrame:
    """
    Create a list of pairs of images, for the SAME cls pairs
    """
    pos = []
    for cls in cls2imgs:
        cls_imgs = cls2imgs[cls]
        num_imgs = len(cls_imgs) // 2
        pos.append(np.random.choice(cls_imgs, [num_imgs, 2], replace=False))
    return pos


def get_cls_idx(imgs: List[str]) -> Dict[str, List[str]]:
    class_idx = {}
    for i, img1 in enumerate(imgs):
        cls = img1.split('_')[0]
        if img1 not in class_idx:
            class_idx[cls] = []
            for j, img2 in enumerate(imgs):
                if cls == img2.split('_')[0]:
                    class_idx[cls].append(img2)
    return class_idx


def create_class_pairs_negatives(cls2imgs: Dict[str, List[str]]) -> np.ndarray:
    """
    Create a list of pairs of images for the DIFF cls pairs
    """
    
    negatives = [('tamar', 'marina'), ('dom', 'mat'), ('omer','tal'), ('taz','tess')]
    
    total_items = 0
    idx = []
    cols = []

    cls_idx = {}
    cls_cols = {}
    
    # First shuffle the images within in each class
    for cls in cls2imgs:
        np.random.shuffle(cls2imgs[cls])
        total_items += len(cls2imgs[cls])
        curr_idx = []
        curr_cols = []

        for i, img in enumerate(cls2imgs[cls]):
            if i % 2 == 0:
                # curr_idx.append(path.join(cls, img))
                curr_idx.append(img)
            else:
                # curr_cols.append(path.join(cls, img))
                curr_cols.append(img)
        
        cls_idx[cls] = curr_idx
        cls_cols[cls] = curr_cols

        idx = idx + curr_idx
        cols = cols + curr_cols

    connections = pd.DataFrame(np.zeros([len(idx), len(cols)]), index=idx, columns=cols)
    # Set only the possible negatives as 1
    for cls1, cls2 in negatives:
        cls1_idx = connections.index.str.startswith(cls1)
        cls2_idx = connections.index.str.startswith(cls2)
        
        cls1_cols = connections.columns.str.startswith(cls1)
        cls2_cols = connections.columns.str.startswith(cls2)

        connections.loc[cls1_idx, cls2_cols] = 1
        connections.loc[cls2_idx, cls1_cols] = 1
    
    # Shuffle the rows
    connections = connections.sample(frac=1.0)
    # Calculate the pairs (max flow)
    graph = csr_matrix(connections)
    matches = maximum_bipartite_matching(graph, perm_type='column')
#     print(matches)

    # Get the actual names of files

    diff_pairs = []
    for i, j in enumerate(matches):
        if j >= 0:
            diff_pairs.append([connections.index[i], connections.columns[j]])
    
    diff_pairs = np.array(diff_pairs)
    return diff_pairs

def random_balanced_set(rdm: pd.DataFrame) -> Tuple:
    
    imgs = list(rdm.index)

    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_class_pairs_negatives(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    print(neg)
    print(pos)
    return pos.T.tolist(), neg.T.tolist()

In [None]:
def plot_balanced_curves(rdm_paths: List[Tuple[Tuple[int, int], str]], calc: str, mode: str) -> None:
    assert (calc == 'prc') or (calc == 'roc')
    assert (mode == 'family') or (mode == 'species') or (mode == 'individual')
    
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_mode(rdm = pd.read_csv(pth, index_col='Unnamed: 0'), mode = mode)
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            if mode == 'familiy':
                imgs[i] = imgs[i][:8]
            elif mode == 'species':
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
            elif mode == 'individual':
                imgs[i] = imgs[i].split('_rand_')[0]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = np.concatenate((pos_dists, neg_dists))
        
        if calc == 'prc':
            dists = 1 - dists
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
            fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        elif calc == 'roc':
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
            auc = sklearn.metrics.auc(fpr, tpr)
            fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])        
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/balaned_{mode}_{calc}_curves.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose/win_lose_fc7.csv'),
]

plot_balanced_curves(rdm_locations, 'roc', 'family')
plot_balanced_curves(rdm_locations, 'roc', 'species')

In [None]:
def plot_bar_auc(rdm_paths: Dict[str, List[Tuple[int,str]]], calc: str, mode: str) -> None:
    assert (calc == 'prc') or (calc == 'roc')
    assert (mode == 'family') or (mode == 'species') or (mode == 'individual')
    
    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_mode(rdm = pd.read_csv(pth, index_col='Unnamed: 0'), mode = mode)
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                if mode == 'familiy':
                    imgs[i] = imgs[i][:8]
                elif mode == 'species':
                    imgs[i] = imgs[i][:-4]
                    imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
                elif mode == 'individual':
                    imgs[i] = imgs[i].split('_rand_')[0]
            
            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = np.concatenate((pos_dists, neg_dists))

            if calc == 'prc':
                dists = 1 - dists
                precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
                auc = sklearn.metrics.auc(recall, precision)
            elif calc == 'roc':
                fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
                auc = sklearn.metrics.auc(fpr, tpr)
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/balanced_{mode}_au{calc}.html')

In [None]:
rdms = {
    'Objects': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Faces': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Bird species': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Sociable Weavers': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv')],
}

# plot_roc_cls(rdm_locations)
# plot_auroc(rdms)
plot_bar_auc(rdms, 'roc', 'family')
plot_bar_auc(rdms, 'roc', 'species')
plot_bar_auc(rdms, 'roc', 'individual')

In [None]:
def plot_roc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=1, cols=4, 
                        column_titles=[30,260,500,1000],
                        x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:8]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = np.concatenate((pos_dists, neg_dists))
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/fw_roc.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose/win_lose_fc7.csv'),
]

plot_fw_roc(rdm_locations)

In [None]:
def plot_roc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        
        dists = np.concatenate((pos_dists, neg_dists))
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/roc.html')

In [None]:
def plot_prc_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], mode: str='family', balanced = True, calc ='prc') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        if balanced:
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
        else:
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1

            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            labels = issame[triu].flatten()
        
        if calc == 'prc':
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
            fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        else:
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(fpr, tpr)
            fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])        
        
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
    fig.update_xaxes(range=(0,1))
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/{calc}_(balanced={balanced})_{mode}.html')

In [None]:
def plot_hist_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], balanced: bool, mode: str='family') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        if balanced:
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
        else:
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1
            
            triu = np.tril_indices(len(imgs), 1)
            dists[triu] = np.nan
            pos_dists = dists[issame == 1]
            pos_dists = pos_dists[~np.isnan(pos_dists)]
            print(pos_dists.shape)
            neg_dists = dists[issame == 0]
            neg_dists = neg_dists[~np.isnan(neg_dists)]
#             a = a[~np.isnan(a)]
            print(neg_dists.shape)
        
        
        fig.add_trace(go.Histogram(x=pos_dists, histnorm='probability density',),
                  row=loc[0], col=loc[1])
        fig.add_trace(go.Histogram(x=neg_dists, histnorm='probability density',),
                  row=loc[0], col=loc[1])
    fig.update_layout(barmode='overlay')
    fig.update_traces(opacity=0.75)

    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/hist_(balanced={balanced})_{mode}.html')

In [None]:
def plot_rdm_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], mode: str='family') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        idx = [pos[0] + neg[0], pos[1] + neg[1]]
        mask = np.ones(dists.shape,dtype=bool) #np.ones_like(a,dtype=bool)
        mask[tuple(idx)] = False

        dists[mask] = np.nan
        
        fig.add_trace(go.Heatmap(z=dists, showscale=False),
                  row=loc[0], col=loc[1])

    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/rdm_balanced_{mode}.html')

In [None]:
plot_rdm_balanced(rdm_locations, 'family')
plot_rdm_balanced(rdm_locations, 'species')

In [None]:
plot_hist_balanced(rdm_locations, False, 'family')
plot_hist_balanced(rdm_locations, False,'species')
plot_hist_balanced(rdm_locations, True,'family')
plot_hist_balanced(rdm_locations, True,'species')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose/win_lose_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose/win_lose_fc7.csv'),
]

plot_prc_balanced(rdm_locations, 'family', False, 'roc')
plot_prc_balanced(rdm_locations, 'family', True, 'roc')
plot_prc_balanced(rdm_locations, 'species', False, 'roc')
plot_prc_balanced(rdm_locations, 'species', True, 'roc')

In [None]:
plot_prc_balanced(rdm_locations, 'species', False, 'prc')

In [None]:
NUM_GENERATED = 20

def create_cls_from_img(pth: str, output_path: str):
    img = Image.open(pth)
    tt = T.Compose([T.RandomPerspective(distortion_scale=0.5, p=1.0), T.RandomAffine(degrees=(-40, 40), translate=(0.05, 0.05), scale=(1.0, 1.0))])
    cls = [tt(img) for _ in range(NUM_GENERATED)]
    for i, im in enumerate(cls):
        im.save(output_path + f'_rand_{i}.jpg')

def create_classes(dir_pth: str, output_path: str):
    species = glob(path.join(dir_pth, '*'))
    for s in species:
        s_output = path.join(output_path, path.relpath(s, dir_pth))
        imgs = glob(path.join(s, '*'))
        for i, img in enumerate(imgs):
            cls_output = s_output + f'_{i}'
            makedirs(cls_output, exist_ok=True)
            item_output = path.join(cls_output, path.relpath(img, s)[:-4])
            create_cls_from_img(img, item_output)

In [None]:
# create_classes('/home/ssd_storage/datasets/win_lose/rdm_folders', '/home/ssd_storage/datasets/win_lose/rdm_cls_folders')

In [None]:


def plot_roc_cls(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in tqdm(rdm_paths):
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        
        for i in range(len(imgs)):
            imgs[i] = imgs[i].split('_rand_')[0]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = 1 - np.concatenate((pos_dists, neg_dists))
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1

        triu = np.triu_indices(len(imgs), 1)
        dists = 1 - rdm.to_numpy()[triu].flatten()
        labels = issame[triu].flatten()
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/roc_cls.html')

In [None]:
def random_balanced_set_cls_cls(rdm: pd.DataFrame) -> Tuple:
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)

    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        imgs[i] = imgs[i].split('_rand_')[0]
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    print(num_samples)
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


def plot_auroc(rdm_paths: Dict[str, List[Tuple[int,str]]]) -> None:
#     fig = make_subplots(rows=4, cols=1, 
#                         row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'],
#                         y_title='Training domain', vertical_spacing = 0.13, horizontal_spacing = 0.13)

    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                imgs[i] = imgs[i].split('_rand_')[0]
                
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1

            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
            
            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            labels = issame[triu].flatten()

            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(fpr, tpr)
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/aucs_(balanced=False)_cls.html')
    
    
def plot_auprc(rdm_paths: Dict[str, List[Tuple[int,str]]]) -> None:
    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                imgs[i] = imgs[i].split('_rand_')[0]
            
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1
                        
            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            label = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
            
            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            label = issame[triu].flatten()
            total_same = np.sum(label)
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
#             fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
#             auc = sklearn.metrics.auc(fpr, tpr)
            
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/auprcs_cls.html')

In [None]:
rdms = {
    'Objects': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Faces': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Bird species': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),],
    
    'Sociable Weavers': [('30 classes', '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv')],
}

# plot_roc_cls(rdm_locations)
# plot_auroc(rdms)
plot_auroc(rdms)

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_inanimates/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_object_net/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_species/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/win_lose/win_lose_260_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/win_lose/win_lose_500_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/win_lose/win_lose_1000_faces/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/win_lose/win_lose_30_soc_weav/vgg16/results/win_lose_cls/win_lose_cls_fc7.csv'),
]

plot_roc_cls(rdm_locations)
# plot_rdm(rdm_locations)

In [None]:
def plot_prc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1
        
        triu = np.triu_indices(len(imgs), 1)
        dists = rdm.to_numpy()[triu].flatten()
        label = issame[triu].flatten()
        total_same = np.sum(label)
        precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=0)
        auc = sklearn.metrics.auc(recall, precision)
        
        fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/win_lose/prc.html')