In [None]:
# import the necessary packages
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras import backend as K
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import random
import os
from tqdm import tqdm

In [None]:
from sklearn.model_selection import StratifiedKFold, RepeatedKFold
from typing import List, Dict
from scipy import interp
from itertools import cycle
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

In [None]:
colors = ['red', 'blue', 'green', 'pink', 'gray', 'brown', 'purple', 'darkorange', 'cyan']

In [None]:
phrog_metadata = pd.read_csv('../PHROG_index_downloaded_01232022.csv')

In [None]:
phrog_known = phrog_metadata[~phrog_metadata['Category'].isna()]
phrog_known = phrog_known[~phrog_known['Category'].isin(['unknown function'])]

cs = set(phrog_known['Category'])

## dict for family:label -> {fl}
fl = {}
for c in cs:
    ps = phrog_known[phrog_known['Category'] == c]['#phrog']
    for p in ps:
        fl[p] = c

In [None]:
## fit a label binarizer to the classes, need to have this done before splits to the categories are the same in each split
lb = LabelBinarizer()
lb.fit(list(cs))

In [None]:
report = pd.read_csv('protbert_bfd_embeddings_phrog/5CV_report.csv', index_col=0)
rocs = pd.read_csv('protbert_bfd_embeddings_phrog/5CV_rocs.csv', index_col=0)
prcs = pd.read_csv('protbert_bfd_embeddings_phrog/5CV_prcs.csv', index_col=0)
n_splits = 5

## figure 2A
plt.figure(figsize=(8,6))
lw = 1

mean_fpr = np.linspace(0, 1, 100)

for i, color in zip(range(len(lb.classes_)), colors):
    a_tpr = []
    for j in range(n_splits):
        df = rocs[rocs["class"] == lb.classes_[i]]
        df = df[df["fold"] == j]
        a_tpr.append(np.interp(mean_fpr, df["fpr"], df["tpr"]))
        a_tpr[-1][0] = 0.0

    mean_tpr = np.mean(a_tpr, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(report.loc[lb.classes_[i]]['auroc'])
    plt.plot(mean_fpr, mean_tpr, color=color,
             label='AUC={1:0.2f}, SD={2:0.2f}' ''.format(lb.classes_[i].split(' ')[0], mean_auc, std_auc),
             lw=1)

    std_tpr = np.std(a_tpr, axis=0)
    tpr_upper = np.minimum(mean_tpr + std_tpr, 1)
    tpr_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tpr_lower, tpr_upper, color=color, alpha=.1)


#plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate', fontsize=14)
plt.ylabel('True Positive Rate', fontsize=14)

plt.legend(loc="lower right")
plt.savefig('performance_5CFV_AUROC.png', dpi=300)
plt.show()




### figure 2B
plt.figure(figsize=(8,6))
mean_recall = np.linspace(0, 1, 100)

for i, color in zip(range(len(lb.classes_)), colors):
    a_prec = []
    for j in range(n_splits):
        df = prcs[prcs["class"] == lb.classes_[i]]
        df = df[df["fold"] == j]
        prec_fold = df['precision']
        recall_fold = df['recall']
        prec_fold = prec_fold[::-1]
        recall_fold = recall_fold[::-1]
        a_prec.append(np.interp(mean_recall, recall_fold, prec_fold))

    mean_prec = np.mean(a_prec, axis=0)
    std_prec = np.std(a_prec, axis=0)
    prec_upper = np.minimum(mean_prec + std_prec, 1)
    prec_lower = np.maximum(mean_prec - std_prec, 0)
    plt.fill_between(mean_recall, prec_lower, prec_upper, color=color, alpha=.1)


    mean_auc = np.mean(report.loc[lb.classes_[i]]['auprc'])
    std_auc = np.std(report.loc[lb.classes_[i]]['auprc'])
    plt.plot(mean_recall, mean_prec, color=color,
             label='AUC={1:0.2f}, SD={2:0.2f}' ''.format(lb.classes_[i].split(' ')[0], mean_auc, std_auc),
             lw=1)


#plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=14)
plt.ylabel('Precision', fontsize=14)

plt.legend(loc="lower left")
plt.savefig('performance_5CFV_AUPRC.png', dpi=300)
plt.show()