In [None]:
import pandas as pd
import numpy as np
from glob import glob
import plotly.express as px
from typing import Tuple, Dict, List
from os import path
from sklearn import metrics
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import sklearn
from tqdm.notebook import tqdm
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import maximum_bipartite_matching

import torchvision.transforms as T
from PIL import Image
from glob import glob
from os import path, makedirs

In [None]:
finches = ['American Goldfinch', 'Black Rosy', 'Brambling', 'Cassins', 'Common Redpoll', 'Evening Grosbeak', 'House Finch', 'Pine Grosbeak', 'Pine Siskin', 'Whitewinged Crossbill']
warblers = ['AmericanRedstart', 'Baybreasted', 'BayBreasted', 'Black Burnian', 'Black Throated Green Warbler', 'Blue Winged', 'Canada', 'Chestnut Sided', 'Golden Winged', 'Nashville', 'Wilsons']
replacements = [(finch, f'Finches: {finch}') for finch in finches] + [(warbler, f'Warblers: {warbler}') for warbler in warblers]

In [None]:
def plot_rdm(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    
    for (loc, pth) in tqdm(rdm_paths):
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
                
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        
        fig.add_trace(go.Heatmap(z=rdm, showscale=False, x=rdm.index, y=rdm.index),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/rdms.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_rdm(rdm_locations)

In [None]:
def plot_prc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1
        
        triu = np.triu_indices(len(imgs), 1)
        dists = rdm.to_numpy()[triu].flatten()
        label = issame[triu].flatten()
        total_same = np.sum(label)
        precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=0)
        auc = sklearn.metrics.auc(recall, precision)
        
        fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/prc.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_prc(rdm_locations)

In [None]:
def plot_fw_prc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:8]
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1
        
        triu = np.triu_indices(len(imgs), 1)
        dists = rdm.to_numpy()[triu].flatten()
        label = issame[triu].flatten()
        total_same = np.sum(label)
        precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=0)
        auc = sklearn.metrics.auc(recall, precision)
        
        fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/fw_prc.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_fw_prc(rdm_locations)

In [None]:
def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)


def create_positive_pairs(cls2imgs: Dict[str, List[int]]) -> pd.DataFrame:
    """
    Create a list of pairs of images, for the SAME cls pairs
    """
    pos = []
    for cls in cls2imgs:
        cls_imgs = cls2imgs[cls]
        num_imgs = len(cls_imgs) // 2
        pos.append(np.random.choice(cls_imgs, [num_imgs, 2], replace=False))
    return pos


def get_cls_idx(imgs: List[str]) -> Dict[str, List[int]]:
    class_idx = {}
    for i, img1 in enumerate(imgs):
        if img1 not in class_idx:
            class_idx[img1] = []
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    class_idx[img1].append(j)
    return class_idx


def get_cls_same(class_idx: Dict[str, List[int]]) -> List[np.ndarray]:
    positive_pairs = []
    for cls in class_idx:
        positive_pairs.append(create_positive_pairs(class_idx[cls]))
    return positive_pairs


def create_negative_pairs(cls2imgs: Dict[str, List[int]]) -> np.ndarray:
    """
    Create a list of pairs of images for the DIFF cls pairs
    """
    total_items = 0
    idx = []
    cols = []

    cls_idx = {}
    cls_cols = {}

    # First shuffle the images within in each class
    for cls in cls2imgs:
        np.random.shuffle(cls2imgs[cls])
        total_items += len(cls2imgs[cls])
        curr_idx = []
        curr_cols = []

        for i, img in enumerate(cls2imgs[cls]):
            if i % 2 == 0:
                # curr_idx.append(path.join(cls, img))
                curr_idx.append(img)
            else:
                # curr_cols.append(path.join(cls, img))
                curr_cols.append(img)
        
        cls_idx[cls] = curr_idx
        cls_cols[cls] = curr_cols

        idx = idx + curr_idx
        cols = cols + curr_cols

    connections = pd.DataFrame(np.ones([len(idx), len(cols)]), index=idx, columns=cols)
    start_idx = 0
    start_col = 0
    # make sure not to set a connection between an image and itself
    for cls in cls2imgs:
        height = len(cls_idx[cls])
        width = len(cls_cols[cls])
        end_idx = start_idx + height
        end_col = start_col + width

        # Zeros the within cls connections
        connections.iloc[start_idx:end_idx, start_col:end_col] = 0
        start_idx = end_idx
        start_col = end_col
    
    # Shuffle the rows
    connections = connections.sample(frac=1.0)
    # Calculate the pairs (max flow)
    graph = csr_matrix(connections)
    matches = maximum_bipartite_matching(graph, perm_type='column')
#     print(matches)

    # Get the actual names of files

    diff_pairs = []
    for i, j in enumerate(matches):
        if j >= 0:
            diff_pairs.append([connections.index[i], connections.columns[j]])
    
    diff_pairs = np.array(diff_pairs)
    return diff_pairs

def random_balanced_set_mode(rdm: pd.DataFrame, mode: str) -> Tuple:
    assert (mode == 'family') or (mode == 'species') or (mode == 'individual')
    
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)

    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        if mode == 'family':
            imgs[i] = imgs[i][:8]
        elif mode == 'species':
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        elif mode == 'individual':
            imgs[i] = imgs[i].split('_rand_')[0]
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


def random_balanced_set_family(rdm: pd.DataFrame) -> Tuple:
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)

    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        imgs[i] = imgs[i][:8]
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


def random_balanced_set_species(rdm: pd.DataFrame) -> Tuple:
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)
    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)
    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        imgs[i] = imgs[i][:-4]
        imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


def random_balanced_set_individual(rdm: pd.DataFrame) -> Tuple:
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)
    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)
    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        imgs[i] = imgs[i].split('_rand_')[0]
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    print(num_samples)
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


In [None]:
def plot_balanced_curves(rdm_paths: List[Tuple[Tuple[int, int], str]], calc: str, mode: str) -> None:
    assert (calc == 'prc') or (calc == 'roc')
    assert (mode == 'family') or (mode == 'species') or (mode == 'individual')
    
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_mode(rdm = pd.read_csv(pth, index_col='Unnamed: 0'), mode = mode)
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            if mode == 'familiy':
                imgs[i] = imgs[i][:8]
            elif mode == 'species':
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
            elif mode == 'individual':
                imgs[i] = imgs[i].split('_rand_')[0]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = np.concatenate((pos_dists, neg_dists))
        
        if calc == 'prc':
            dists = 1 - dists
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
            fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        elif calc == 'roc':
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
            auc = sklearn.metrics.auc(fpr, tpr)
            fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])        
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/balaned_{mode}_{calc}_curves.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_balanced_curves(rdm_locations, 'roc', 'family')
plot_balanced_curves(rdm_locations, 'roc', 'species')

In [None]:
def plot_bar_auc(rdm_paths: Dict[str, List[Tuple[int,str]]], calc: str, mode: str) -> None:
    assert (calc == 'prc') or (calc == 'roc')
    assert (mode == 'family') or (mode == 'species') or (mode == 'individual')
    
    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_mode(rdm = pd.read_csv(pth, index_col='Unnamed: 0'), mode = mode)
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                if mode == 'familiy':
                    imgs[i] = imgs[i][:8]
                elif mode == 'species':
                    imgs[i] = imgs[i][:-4]
                    imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
                elif mode == 'individual':
                    imgs[i] = imgs[i].split('_rand_')[0]
            
            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = np.concatenate((pos_dists, neg_dists))

            if calc == 'prc':
                dists = 1 - dists
                precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
                auc = sklearn.metrics.auc(recall, precision)
            elif calc == 'roc':
                fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=0)
                auc = sklearn.metrics.auc(fpr, tpr)
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/balanced_{mode}_au{calc}.html')

In [None]:
rdms = {
    'Objects': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Faces': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Bird species': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Sociable Weavers': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv')],
}

# plot_roc_cls(rdm_locations)
# plot_auroc(rdms)
plot_bar_auc(rdms, 'roc', 'family')
plot_bar_auc(rdms, 'roc', 'species')
plot_bar_auc(rdms, 'roc', 'individual')

In [None]:
def plot_fw_roc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:8]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = np.concatenate((pos_dists, neg_dists))
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/fw_roc.html')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_fw_roc(rdm_locations)

In [None]:
def plot_roc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        
        dists = np.concatenate((pos_dists, neg_dists))
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/roc.html')

In [None]:
def plot_prc_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], mode: str='family', balanced = True, calc ='prc') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        if balanced:
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
        else:
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1

            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            labels = issame[triu].flatten()
        
        if calc == 'prc':
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
            fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        else:
            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(fpr, tpr)
            fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])        
        
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
    fig.update_xaxes(range=(0,1))
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/{calc}_(balanced={balanced})_{mode}.html')

In [None]:
def plot_hist_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], balanced: bool, mode: str='family') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        if balanced:
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
        else:
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1
            
            triu = np.tril_indices(len(imgs), 1)
            dists[triu] = np.nan
            pos_dists = dists[issame == 1]
            pos_dists = pos_dists[~np.isnan(pos_dists)]
            print(pos_dists.shape)
            neg_dists = dists[issame == 0]
            neg_dists = neg_dists[~np.isnan(neg_dists)]
#             a = a[~np.isnan(a)]
            print(neg_dists.shape)
        
        
        fig.add_trace(go.Histogram(x=pos_dists, histnorm='probability density',),
                  row=loc[0], col=loc[1])
        fig.add_trace(go.Histogram(x=neg_dists, histnorm='probability density',),
                  row=loc[0], col=loc[1])
    fig.update_layout(barmode='overlay')
    fig.update_traces(opacity=0.75)

    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/hist_(balanced={balanced})_{mode}.html')

In [None]:
def plot_rdm_balanced(rdm_paths: List[Tuple[Tuple[int, int], str]], mode: str='family') -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    if mode == 'family':
        pos, neg = random_balanced_set(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    else:
        pos, neg = random_balanced_set_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        if mode == 'family':
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:8]
        else:
            for i in range(len(imgs)):
                imgs[i] = imgs[i][:-4]
                imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        dists = rdm.to_numpy()
        idx = [pos[0] + neg[0], pos[1] + neg[1]]
        mask = np.ones(dists.shape,dtype=bool) #np.ones_like(a,dtype=bool)
        mask[tuple(idx)] = False

        dists[mask] = np.nan
        
        fig.add_trace(go.Heatmap(z=dists, showscale=False),
                  row=loc[0], col=loc[1])

    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/rdm_balanced_{mode}.html')

In [None]:
plot_rdm_balanced(rdm_locations, 'family')
plot_rdm_balanced(rdm_locations, 'species')

In [None]:
plot_hist_balanced(rdm_locations, False, 'family')
plot_hist_balanced(rdm_locations, False,'species')
plot_hist_balanced(rdm_locations, True,'family')
plot_hist_balanced(rdm_locations, True,'species')

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug/Sheinbug_fc7.csv'),
]

plot_prc_balanced(rdm_locations, 'family', False, 'roc')
plot_prc_balanced(rdm_locations, 'family', True, 'roc')
plot_prc_balanced(rdm_locations, 'species', False, 'roc')
plot_prc_balanced(rdm_locations, 'species', True, 'roc')

In [None]:
plot_prc_balanced(rdm_locations, 'species', False, 'prc')

In [None]:
NUM_GENERATED = 20

def create_cls_from_img(pth: str, output_path: str):
    img = Image.open(pth)
    tt = T.Compose([T.RandomPerspective(distortion_scale=0.5, p=1.0), T.RandomAffine(degrees=(-40, 40), translate=(0.05, 0.05), scale=(1.0, 1.0))])
    cls = [tt(img) for _ in range(NUM_GENERATED)]
    for i, im in enumerate(cls):
        im.save(output_path + f'_rand_{i}.jpg')

def create_classes(dir_pth: str, output_path: str):
    species = glob(path.join(dir_pth, '*'))
    for s in species:
        s_output = path.join(output_path, path.relpath(s, dir_pth))
        imgs = glob(path.join(s, '*'))
        for i, img in enumerate(imgs):
            cls_output = s_output + f'_{i}'
            makedirs(cls_output, exist_ok=True)
            item_output = path.join(cls_output, path.relpath(img, s)[:-4])
            create_cls_from_img(img, item_output)

In [None]:
# create_classes('/home/ssd_storage/datasets/Sheinbug/rdm_folders', '/home/ssd_storage/datasets/Sheinbug/rdm_cls_folders')

In [None]:


def plot_roc_cls(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    _, pth = rdm_paths[0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for (loc, pth) in tqdm(rdm_paths):
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        
        for i in range(len(imgs)):
            imgs[i] = imgs[i].split('_rand_')[0]
        
        dists = rdm.to_numpy()
        pos_dists = dists[tuple(pos)]
        neg_dists = dists[tuple(neg)]
        labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
        dists = 1 - np.concatenate((pos_dists, neg_dists))
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1

        triu = np.triu_indices(len(imgs), 1)
        dists = 1 - rdm.to_numpy()[triu].flatten()
        labels = issame[triu].flatten()
        
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
        auc = sklearn.metrics.auc(fpr, tpr)
        
        fig.add_trace(go.Scatter(x=fpr, y=tpr, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/roc_cls.html')

In [None]:
def random_balanced_set_cls_cls(rdm: pd.DataFrame) -> Tuple:
    imgs = list(rdm.index)
    for i in range(len(imgs)):
        for s, r in replacements:
            imgs[i] = imgs[i].replace(s, r)

    rdm.index = imgs
    rdm.columns = imgs
    imgs = sorted(imgs)

    rdm = rdm[imgs]
    rdm = rdm.loc[imgs]

    imgs = list(rdm.index)
    for i in range(len(imgs)):
        imgs[i] = imgs[i].split('_rand_')[0]
    
    cls_idx = get_cls_idx(imgs)
    pos = create_positive_pairs(cls_idx)
    pos = np.concatenate(pos)
    neg = create_negative_pairs(cls_idx)
    
    num_samples = min(neg.shape[0], pos.shape[0])
    print(num_samples)
    pos = shuffle_along_axis(pos, 1)
    neg = shuffle_along_axis(neg, 1)
    pos = pos[:num_samples]
    neg = neg[:num_samples]
    
    return pos.T.tolist(), neg.T.tolist()


def plot_auroc(rdm_paths: Dict[str, List[Tuple[int,str]]]) -> None:
#     fig = make_subplots(rows=4, cols=1, 
#                         row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'],
#                         y_title='Training domain', vertical_spacing = 0.13, horizontal_spacing = 0.13)

    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                imgs[i] = imgs[i].split('_rand_')[0]
                
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1

            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            labels = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
            
            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            labels = issame[triu].flatten()

            fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
            auc = sklearn.metrics.auc(fpr, tpr)
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/aucs_(balanced=False)_cls.html')
    
    
def plot_auprc(rdm_paths: Dict[str, List[Tuple[int,str]]]) -> None:
    fig = go.Figure()
    for key0 in rdm_paths:
        break
    _, pth = rdm_paths[key0][0]
    pos, neg = random_balanced_set_cls_cls(rdm = pd.read_csv(pth, index_col='Unnamed: 0'))
    
    for domain in rdm_paths:
        aucs = []
        cls = []
        for (loc, pth) in tqdm(rdm_paths[domain]):
            cls.append(loc)
            rdm = pd.read_csv(pth, index_col='Unnamed: 0')
            imgs = list(rdm.index)
            for i in range(len(imgs)):
                for s, r in replacements:
                    imgs[i] = imgs[i].replace(s, r)

            rdm.index = imgs
            rdm.columns = imgs
            imgs = sorted(imgs)

            rdm = rdm[imgs]
            rdm = rdm.loc[imgs]

            imgs = list(rdm.index)

            for i in range(len(imgs)):
                imgs[i] = imgs[i].split('_rand_')[0]
            
            issame = np.zeros((len(imgs), len(imgs)))
            for i, img1 in enumerate(imgs):
                for j, img2 in enumerate(imgs):
                    if img1 == img2:
                        issame[i,j] = 1
                        
            dists = rdm.to_numpy()
            pos_dists = dists[tuple(pos)]
            neg_dists = dists[tuple(neg)]
            label = [1 for _ in pos[0]] + [0 for _ in neg[0]]
            dists = 1 - np.concatenate((pos_dists, neg_dists))
            
            triu = np.triu_indices(len(imgs), 1)
            dists = 1 - rdm.to_numpy()[triu].flatten()
            label = issame[triu].flatten()
            total_same = np.sum(label)
            precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=1)
            auc = sklearn.metrics.auc(recall, precision)
#             fpr, tpr, thresholds = sklearn.metrics.roc_curve(labels, dists, pos_label=1)
#             auc = sklearn.metrics.auc(fpr, tpr)
            
            aucs.append(auc)
        fig.add_trace(go.Bar(x=cls, y=aucs, text=aucs, textposition='auto', name=domain))

    fig.update_layout(height=500, width=1000,  template='plotly_white')
    fig.update_yaxes(range=(0,1))
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/auprcs_cls.html')

In [None]:
rdms = {
    'Objects': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Faces': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('500 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('1000 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Bird species': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ('260 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),],
    
    'Sociable Weavers': [('30 classes', '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv')],
}

# plot_roc_cls(rdm_locations)
# plot_auroc(rdms)
plot_auroc(rdms)

In [None]:
rdm_locations = [
    ((1,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((1,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((1,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_inanimates/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((1,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_object_net/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    
    ((2,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((2,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_species/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    
    ((3,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((3,2), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_260_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((3,3), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_500_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    ((3,4), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_1000_faces/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
    
    ((4,1), '/home/ssd_storage/experiments/Sheinbug/Sheinbug_30_soc_weav/vgg16/results/Sheinbug_cls/Sheinbug_cls_fc7.csv'),
]

plot_roc_cls(rdm_locations)
# plot_rdm(rdm_locations)

In [None]:
def plot_prc(rdm_paths: List[Tuple[Tuple[int, int], str]]) -> None:
    fig = make_subplots(rows=4, cols=4, 
                        row_titles=['Objects', 'Bird species', 'Faces', 'Sociable Weavers'], column_titles=[30,260,500,1000],
                        y_title='Training domain', x_title='Number of classes', vertical_spacing = 0.13, horizontal_spacing = 0.13)
    aucs = []
    for (loc, pth) in rdm_paths:
        rdm = pd.read_csv(pth, index_col='Unnamed: 0')
        imgs = list(rdm.index)
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            for s, r in replacements:
                imgs[i] = imgs[i].replace(s, r)
        
        rdm.index = imgs
        rdm.columns = imgs
        imgs = sorted(imgs)
        
        rdm = rdm[imgs]
        rdm = rdm.loc[imgs]
        
        imgs = list(rdm.index)
        for i in range(len(imgs)):
            imgs[i] = imgs[i][:-4]
            imgs[i] = ''.join([s for s in imgs[i] if s.isalpha()])
        
        issame = np.zeros((len(imgs), len(imgs)))
        for i, img1 in enumerate(imgs):
            for j, img2 in enumerate(imgs):
                if img1 == img2:
                    issame[i,j] = 1
        
        triu = np.triu_indices(len(imgs), 1)
        dists = rdm.to_numpy()[triu].flatten()
        label = issame[triu].flatten()
        total_same = np.sum(label)
        precision, recall, thresholds = sklearn.metrics.precision_recall_curve(label, dists, pos_label=0)
        auc = sklearn.metrics.auc(recall, precision)
        
        fig.add_trace(go.Scatter(x=recall, y=precision, name=f'{auc}'),
                      row=loc[0], col=loc[1])
        
    fig.update_layout(height=500 * 4, width=500 * 4,  template='plotly_white')
#     fig.update_layout(subplot_titles=("First Subplot","Second Subplot", "Third Subplot"))
        
    fig.show()
    fig.write_html(f'/home/ssd_storage/experiments/Sheinbug/prc.html')