In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join

from sklearn import metrics
from sklearn.metrics import precision_recall_curve, accuracy_score, average_precision_score


from plot_results import get_dbl_metrics


def sigmoid(x):
    sig = 1 / (1 + np.exp(-(12*x)+6))
    return sig

def get_all_dbl_metrics(test, scores, labels, colors):
    fig = plt.figure(figsize=(7,3),dpi=100)

    plt.subplot(1,2,1)
    #ROC 
    for i, score in enumerate(scores):
        fpr, tpr, thresholds = metrics.roc_curve(test, score, pos_label=1)
        rauc = metrics.auc(fpr, tpr)
        plt.plot(fpr, tpr,lw=2, label = labels[i], c = colors[i])
    plt.plot([0,1],[0,1],'--',color ='black', lw=1)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title('ROC curve')

    plt.subplot(1,2,2)
    # precision recall curve
    for i, score in enumerate(scores):
        
        precision, recall, thresholds = precision_recall_curve(test, score, pos_label=1)
        prauc = metrics.auc(recall, precision)
        
        
        plt.plot(recall, precision, lw=2, label = labels[i], c = colors[i])
    random=len(test[test==1]) / len(test)
    plt.plot([0, 1], [random, random], linestyle='--', c='black', lw=1)
    plt.xlabel("recall")
    plt.ylabel("precision")
    plt.title("PR curve")
    plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    
    plt.tight_layout()
    
    ap = average_precision_score(test, score)
    
    return rauc, prauc, ap

def plot_ROC(cols=7):
    cmap='Paired'
    num_data = scores.shape[2]
    rows = np.ceiling(num_data/cols)
    
    fig,ax = plt.figure(rows, cols, figsize=(8.267717,10.8622), dpi=300)

    for d in scores.shape[2]:
        
        scores_d = scores[:,:,d]
        true_d   = true[:,:,d]
        
        for i, score in enumerate(scores_d):
            #calculate ROC
            fpr, tpr, thresholds = metrics.roc_curve(true_d, score, pos_label=1)
            rauc = metrics.auc(fpr, tpr)
            #plot it
            ax[d].plot(fpr, tpr,lw=2, label = labels[i])
        ax[d].plot([0,1],[0,1],'--',color ='black', lw=1)
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        plt.title('ROC curve')
    
    plt.tight_layout()
    
    return rauc

In [None]:
def get_dbl_metrics(test, score):
    fig = plt.figure(figsize=(6,3),dpi=100)

    plt.subplot(1,2,1)
    #ROC 
    fpr, tpr, thresholds = metrics.roc_curve(test, score, pos_label=1)
    rauc = metrics.auc(fpr, tpr)
    plt.plot(fpr, tpr,lw=2, label='ROC curve (area = %0.2f)' % rauc)
    plt.plot([0,1],[0,1],'--',color ='black', lw=1)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title('ROC (area = %0.2f)' % rauc)

    plt.subplot(1,2,2)
    # precision recall curve
    precision, recall, thresholds = precision_recall_curve(test, score, pos_label=1)
    prauc = metrics.auc(recall, precision)
    plt.plot(recall, precision, lw=2)
    random=len(test[test==1]) / len(test)
    plt.plot([0, 1], [random, random], linestyle='--', label='random', c='black', lw=1)
    plt.xlabel("recall")
    plt.ylabel("precision")
    plt.title("PR curve (area = %0.2f)" % prauc)

    plt.tight_layout()
    
    ap = average_precision_score(test, score)
    
    return rauc, prauc, ap

In [None]:
save_path = '../results_manuscript/ROC_PR_Curves/'

path = '../results_benchmark/'
files = [f.split('_') for f in listdir(path) if (isfile(join(path, f)) and (f[-3:]=='csv'))]
#files = np.sort(files)
files = np.array(files)
files


In [None]:
files[:,1]

In [None]:
methods = ['Scrublet', 'bcds', 'cxds', 'hybrid', 'solo', 'DoubletFinder', 'scDblFinder', 'vaeda']
data_names = np.sort(np.unique(files[:,0]))

In [None]:
data_names

In [None]:
f[0]

In [None]:
path + '_'.join(f[0])

In [None]:
for data_name in data_names:
    fs = files[files[:,0]==data_name]
    
    ano_path  = '../data/mtx_files/' + data_name + '_anno.csv'

    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    for method in methods:
        f = fs[fs[:,1]==method]
        
        results = pd.read_csv(path + '_'.join(f[0]))
        preds = results.doublet_scores  
        
        if (method == methods[0]):
            all_preds = preds
        else:
            all_preds = np.vstack([all_preds, preds])
    
    colors = ['#6BBBDB', '#6BBBDB', '#8CCAE3', '#9CD2E7', '#8AD08D', '#A7DCA9', '#BEE5BF', '#d81e5b']
    get_all_dbl_metrics(true, all_preds, methods, colors)
    plt.show()
        
    

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join

from sklearn import metrics
from sklearn.metrics import precision_recall_curve, accuracy_score, average_precision_score

def sigmoid(x):
    sig = 1 / (1 + np.exp(-(12*x)+6))
    return sig

In [None]:
save_path = '../results_manuscript/ROC_PR_Curves/'

path = '../results_benchmark/'
files = [f.split('_') for f in listdir(path) if (isfile(join(path, f)) and (f[-3:]=='csv'))]
#files = np.sort(files)
files = np.array(files)
files

In [None]:
def plot_ROC(scores, true, dataset_names, method_names, cols=8):
    #width=210mm=8.267717in, height=275.9mm=10.8622in
    cmap='Set3'
    
    num_data = len(scores)
    rows = int(np.ceil(num_data/cols))
    
    width = 8.267717
    width_p_plot = width / cols
    height = rows*width_p_plot + rows*0.3
        
    #fig,ax = plt.subplots(rows, cols, figsize=(width,height), dpi=300)
    fig,ax = plt.subplots(figsize=(width,height),sharex=True, sharey=True,dpi=300)
    
    for d in range(len(scores)):
        
        scores_d = scores[d]
        true_d   = true[d]
    
        plt.subplot(rows,cols,d+1)
        for i, score in enumerate(scores_d):
            #calculate ROC
            fpr, tpr, thresholds = metrics.roc_curve(true_d, score, pos_label=1)
            rauc = metrics.auc(fpr, tpr)
            #plot it
            plt.plot(fpr, tpr,lw=1, label = method_names[i])
        plt.plot([0,1],[0,1],'--',color ='black', lw=0.5)
        plt.title(dataset_names[d], fontsize=7)
    
    plt.suptitle('ROC curves')
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
        
    plt.tight_layout()
    
    return 

def plot_PRC(scores, true, dataset_names, method_names, cols=8):
    #width=210mm=8.267717in, height=275.9mm=10.8622in
    cmap='Set3'
    
    num_data = len(scores)
    rows = int(np.ceil(num_data/cols))
    
    width = 8.267717
    width_p_plot = width / cols
    height = rows*width_p_plot + rows*0.3
        
    #fig,ax = plt.subplots(rows, cols, figsize=(width,height), dpi=300)
    fig,ax = plt.subplots(figsize=(width,height),sharex=True, sharey=True,dpi=300)
    
    for d in range(len(scores)):
        
        scores_d = scores[d]
        true_d   = true[d]
    
        plt.subplot(rows,cols,d+1)
        for i, score in enumerate(scores_d):
            precision, recall, thresholds = precision_recall_curve(true_d, score, pos_label=1)
            prauc = metrics.auc(recall, precision)
            plt.plot(recall, precision, lw=1, label = method_names[i])
        random=len(true_d[true_d==1]) / len(true_d)
        plt.plot([0, 1], [random, random], linestyle='--', c='black', lw=0.5)
        plt.title(dataset_names[d], fontsize=7)
    
    plt.suptitle('PRC curves')
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
    plt.xlabel("recall")
    plt.ylabel("precision")
        
    plt.tight_layout()
    
    return 

def make_legend(method_names):
    fig,ax = plt.subplots(figsize=(2,4), sharex=True, sharey=True,dpi=300)
    
    for method in method_names:
        plt.plot([0,0,1],[0,1,1],label=method)
        #plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.legend()
    
def make_legend_scatter(method_names):
    fig,ax = plt.subplots(figsize=(2,4), sharex=True, sharey=True,dpi=300)
    
    for method in method_names:
        plt.scatter([0,0,1],[0,1,1],label=method)
        #plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    plt.legend()

In [None]:
scores = []
true_list = []
dataset_names=[]

for data_name in data_names:
    fs = files[files[:,0]==data_name]
    
    ano_path  = '../data/mtx_files/' + data_name + '_anno.csv'

    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    
    for method in methods:
        f = fs[fs[:,1]==method]
        
        results = pd.read_csv(path + '_'.join(f[0]))
        preds = results.doublet_scores  
        
        if (method == methods[0]):
            all_preds = preds
        else:
            all_preds = np.vstack([all_preds, preds])
    
    scores.append(all_preds)
    true_list.append(true)
    dataset_names.append(data_name)

    

In [None]:
plt.rc('xtick', labelsize=5) 
plt.rc('ytick', labelsize=5) 

plot_ROC(scores, true_list, dataset_names, methods, cols=6)
plt.savefig(save_path + 'ROC.png', dpi=300)
plt.show()
plt.close()

plot_PRC(scores, true_list, dataset_names, methods, cols=6)
plt.savefig(save_path +'PRC.png', dpi=300)
plt.show()
plt.close()

make_legend(methods)
plt.savefig(save_path +'LEGEND.png', dpi=300)
plt.show()
plt.close()

In [None]:

make_legend_scatter(methods)
plt.savefig(save_path +'LEGEND_scatter.png', dpi=300)
plt.show()
plt.close()
    

In [None]:
save_path

In [None]:
for data_name in data_names:
    df = df_pr[df_pr.data_name == data_name]
    
    sns.set(rc={"figure.figsize":(16, 4)})
    sns.set_style("white")

    fig, ax1 = plt.subplots()
    v = sns.violinplot(x='frac', y='score', data=df, inner='quartile', 
                   hue='method', palette="Set2", linewidth=0, hue_order=methods)# ax2.set_aspect("equal")
    plt.setp(v.collections, alpha=.3)

    sns.swarmplot(x='frac', y='score', data=df, hue='method', palette="Set2", dodge=True, size=3,
                          hue_order=methods)
    df_means = df.groupby(['method','frac'])['score'].agg('mean').reset_index().sort_values(['frac'], ascending=False)
    sns.swarmplot(x='frac', y='score', data=df_means, marker='o', hue='method', palette="Set2", s=2, dodge=True,
                          linewidth=1, edgecolor='black', hue_order=methods)

    cell_num = np.unique(df['cell num'])
    ax3 = ax1.twiny()
    ax3.set_xlim([ax1.get_xlim()[0],ax1.get_xlim()[1]])
    ax3.set_xticks(ax1.get_xticks())
    ax3.set_xticklabels(cell_num)
    ax3.tick_params(top=True)
    ax3.set_xlabel('cell number')
    ax3.spines['top'].set_visible(True)

    ax1.set_ylabel('AUPRC')
    plt.title(data_name, fontsize=15)
    ax1.legend([],[], frameon=False)
    
    ax1.grid(axis='y')
    #ax3.grid(axis='x')
    
    plt.savefig(save_path + data_name + '_PR_violin.png', dpi=300)
    plt.show()
    plt.close()

In [None]:
true = true_list

plt.rc('xtick', labelsize=5) 
plt.rc('ytick', labelsize=5) 
#cols = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
#plt.rcParams['axes.prop_cycle'] = cycler('color', cols[:len(method_names)])

dfs=[]

for d in range(len(scores)):

    scores_d = scores[d]
    true_d   = true[d]
    
    for i, score in enumerate(scores_d):
        
        num = int(np.round(len(score) *0.05))
        call = true_d[np.argsort(score)[-num:]]
        
        total = np.sum(true_d)
        frac = np.sum(call) / total        
        m = methods[i]
        data = data_names[d]
        
        dfs.append(pd.DataFrame({'%captured':[frac], 'method':[m], 'data_name':[data]}))
        
        


In [None]:
df = pd.concat(dfs)
df

In [None]:
import seaborn as sns

In [None]:
methods.reverse()

In [None]:
sns.set(rc={"figure.figsize":(16, 4)})
sns.set_style("whitegrid")
ax = sns.swarmplot(x='data_name', y='%captured', data=df, hue='method', palette="Set2", dodge=True, size=5,
              hue_order=methods)
plt.legend(bbox_to_anchor=(1,1), loc="upper left")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, horizontalalignment='right')
plt.tight_layout()      
plt.savefig(save_path +'top_scores.png', dpi=300)