In [143]:
import os 
import pandas as pd
from tqdm import tqdm
from collections import Counter, defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef

path_benchmark = "/media/concha-eloko/Linux/PPT_clean/benchmark"


***
### Matrice results: 

In [52]:
tropigat_pred_m_df = pd.read_csv(f"{path_benchmark}/TropiGAT.matrices.tsv", sep = "\t", names = ["protein", "prediction"])
tropiseq_pred_m_df = pd.read_csv(f"{path_benchmark}/TropiSEQ.matrices.tsv", sep = "\t", names = ["protein", "prediction"])
SH_pred_m_df = pd.read_csv(f"{path_benchmark}/SpikeHunter_predictions.matrices.tsv", sep = "\t", names = ["protein", "prediction"])

# Pre-formating the names: 
tropigat_pred_m_df["protein_id"] = tropigat_pred_m_df["protein"].apply(lambda x : x.split("_Dpo")[0])

tropiseq_pred_m_df["protein_id_1"] = tropiseq_pred_m_df["protein"].apply(lambda x : x.split("_A")[0] if x.count("_A")>0 else "__".join(x.split(",")[0].split(" ")) if x.count(",")>0 else x)
tropiseq_pred_m_df["protein_id"] = tropiseq_pred_m_df["protein_id_1"].apply(lambda x : x.replace(" cds", "__cds"))

SH_pred_m_df["protein_id_1"] = SH_pred_m_df["protein"].apply(lambda x : x.split("_A")[0] if x.count("_A")>0 else "__".join(x.split(",")[0].split(" ")) if x.count(",")>0 else x)
SH_pred_m_df["protein_id"] = SH_pred_m_df["protein_id_1"].apply(lambda x : x.replace(" cds", "__cds"))

merged_df = tropigat_pred_m_df.merge(tropiseq_pred_m_df, on='protein_id', how='inner').merge(SH_pred_m_df, on='protein_id', how='inner')
merged_df["phage"] = merged_df["protein_id"].apply(lambda x : x.split("__")[0] if x.count("__")>0 else x.split("_")[0])


In [53]:
missing_prot = [prot for prot in list(tropigat_pred_m_df["protein_id"]) if prot not in list(SH_pred_m_df["protein_id"])]
missing_prot_prime = [prot for prot in list(SH_pred_m_df["protein_id"]) if prot not in list(tropigat_pred_m_df["protein_id"])]

In [54]:
missing_prot

['BLCJPOBP_00052']

In [55]:
missing_prot_prime

['K33PH14C2__cds_55', 'K62PH164C2__cds_25']

In [56]:
len(tropigat_pred_m_df), len(tropiseq_pred_m_df),len(SH_pred_m_df), len(merged_df)

(258, 260, 260, 260)

In [103]:
phages_to_eliminate = ["K80PH1317b", "K80PH1317a"] + ["K2064PH2","K2069PH1","OBHDAGOG","A1e","A1f","A3a","EONHMLJF"]
merged_df_clean_01 = merged_df[["phage", "protein_id", "prediction_x", "prediction_y", "prediction"]]
merged_df_clean_01.columns = ["phage", "protein", "TropiGAT_prediction", "TropiSEQ_prediction", "SpikeHunter_prediction"]
merged_df_clean_01 = merged_df_clean_01[~merged_df_clean_01["phage"].isin(phages_to_eliminate)] 
merged_df_clean_01

Unnamed: 0,phage,protein,TropiGAT_prediction,TropiSEQ_prediction,SpikeHunter_prediction
0,K15PH90,K15PH90__cds_54,KL123:0.997 ; KL27:0.988 ; KL112:0.985 ; KL14:...,No_hits,K106
1,K7PH164C4,K7PH164C4__cds_20,KL39:1.0 ; KL123:0.999 ; KL22:0.996 ; KL114:0....,No_hits,No_hits
2,K32PH164C1,K32PH164C1__cds_20,KL36:1.0 ; KL39:0.998 ; KL25:0.992 ; KL116:0.9...,No_hits,No_hits
3,K18PH07C1,K18PH07C1__cds_245,KL3:0.998 ; KL63:0.984 ; KL43:0.94 ; KL145:0.9...,KL63:0.594,K63
4,K13PH07C1L,K13PH07C1L__cds_11,KL23:0.992 ; KL10:0.99 ; KL30:0.975 ; KL60:0.9...,KL13:0.527,K2
...,...,...,...,...,...
255,PFOEGONH,PFOEGONH_00078,KL3:1.0 ; KL81:0.985 ; KL14:0.957 ; KL52:0.857...,KL3:0.674 ; KL35:0.736,K3
256,NBNDMPCG,NBNDMPCG_00163,KL13:0.958 ; KL2:0.954 ; KL57:0.94 ; KL60:0.92...,KL2:0.684,K2
257,NJHLHPIG,NJHLHPIG_00061,KL46:0.99 ; KL128:0.975 ; KL18:0.971 ; KL52:0....,KL18:0.522 ; KL46:0.937,K46
258,HIIECEMK,HIIECEMK_00054,KL60:0.995 ; KL18:0.991 ; KL23:0.943 ; KL14:0....,KL13:0.527,K2


> Add the target data: 

In [85]:
# Fold data
path_project = "/media/concha-eloko/Linux/PPT_clean"
df_folds = pd.read_csv(f"{path_project}/in_vitro/dpos_folds.all_matrices.tsv", header = 0, sep = "\t")

# Target data
path_finetuning = "/media/concha-eloko/Linux/PPT_clean/in_vitro/fine_tuning"
bea_df = pd.read_csv(f"{path_finetuning}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

ferriol_df = pd.read_csv(f"{path_finetuning}/ferriol_fine_tuning.df", sep = "\t", header = 0)
#ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

towndsend_df = pd.read_csv(f"{path_finetuning}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}

# targets dico : 
dico_hits = {}
for author in dico_matrices :
    matrix = dico_matrices[author]["matrix"]
    for _, row in matrix.iterrows() : 
        for phage in matrix["Phages"].unique() : 
            all_targets = set()
            targets = matrix[matrix["Phages"] == phage]["Target"].values
            for calls in targets : 
                actual_targets = [x.strip() for x in calls.split(",")]
                all_targets.update(actual_targets)
            dico_hits[phage] = all_targets


In [115]:
df_folds["protein_id"].tolist()

['A1a_00002',
 'A1a_00014',
 'A1b_00048',
 'A1b_00036',
 'A1c_00046',
 'A1c_00034',
 'A1d_00013',
 'A1d_00009',
 'A1e_00024',
 'A1e_00012',
 'A1f_00024',
 'A1f_00012',
 'A1g_00045',
 'A1g_00057',
 'A1h_00021',
 'A1h_00009',
 'A1h_00013',
 'A1i_00041',
 'A1i_00037',
 'A1i_00049',
 'A1j_00040',
 'A1j_00049',
 'A1j_00002',
 'A1k_00018',
 'A1k_00014',
 'A1l_00058',
 'A1l_00005',
 'A1m_00045',
 'A1m_00049',
 'A1n_00050',
 'A1o_00045',
 'A1o_00041',
 'A1p_00055',
 'A1q_00023',
 'A1q_00010',
 'A1q_00019',
 'A1r_00013',
 'A1r_00009',
 'A2a_00010',
 'A2a_b_00022',
 'A2a_b_00036',
 'A2a_00049',
 'A2b_00022',
 'A2b_00008',
 'A3a_00002',
 'A3a_00045',
 'A3b_00021',
 'A3b_00016',
 'A3c_00044',
 'A3c_00039',
 'A3c_00045',
 'A3d_00041',
 'A3d_00042',
 'A3d_00036',
 'BLCJPOBP_00041',
 'D7b_00043',
 'D7c_00007',
 'DIMCIIMF_00039',
 'DIMCIIMF_00240',
 'DJLANJJD_00238',
 'EHPPICDA_00095',
 'EKPIEFBL_00113',
 'EKPIEFBL_00177',
 'EONHMLJF_00087',
 'FADJDIKG_00083',
 'FADJDIKG_00019',
 'GCLPFEGH_00240',
 'H

In [117]:
# Initialize new columns :
target_serie = []
fold_serie = []
tropigat_pred_serie, tropiseq_pred_serie, sh_pred_serie = [], [], [] 

# Feed comumns : 
for _, row in merged_df_clean_01.iterrows() :
    if row["phage"] not in dico_hits :
        print(f"weird phage: {row['phage']}")
    else :
        targets = dico_hits[row["phage"]]
        target_serie.append(",".join(targets))
        prot_it = row['protein'].replace("__cds", "_cds")
        try :
            fold = df_folds[df_folds["protein_id"] == prot_it]["Fold"].values[0]
        except IndexError  :
            fold = df_folds[df_folds["protein_id"] == row['protein']]["Fold"].values[0]
        except Exception as e :
            fold = "unknown"
        fold_serie.append(fold)
        # TropiGAT part : 
        tropigat_pred = set([x.split(":")[0].strip() for x in row["TropiGAT_prediction"].split(";")])
        good_calls = tropigat_pred.intersection(targets)
        if len(good_calls) > 0 : 
            tropigat_pred_serie.append(",".join(list(good_calls)))
        else : 
            tropigat_pred_serie.append(0)
        # TropiSEQ part : 
        if row["TropiSEQ_prediction"][0] != "N" : 
            tropiseq_pred = [x.split(":")[0].strip() for x in row["TropiSEQ_prediction"].split(";")]
            top_predictions = set(tropiseq_pred)
            good_calls = top_predictions.intersection(targets)
            if len(good_calls) > 0 : 
                tropiseq_pred_serie.append(",".join(list(good_calls)))
            else :
                tropiseq_pred_serie.append(0)
        else :
            tropiseq_pred_serie.append(0)
        # SH part : 
        if row["SpikeHunter_prediction"][0] != "N" : 
            sh_pred = [x.split(":")[0].strip().replace("K", "KL") for x in row["SpikeHunter_prediction"].split(";")]
            top_predictions = set(sh_pred)
            good_calls = top_predictions.intersection(targets)
            if len(good_calls) > 0 : 
                sh_pred_serie.append(",".join(list(good_calls)))
            else :
                sh_pred_serie.append(0)
        else :
            sh_pred_serie.append(0)
        
    

In [118]:
merged_df_clean_01["TropiGAT_good_calls"] = tropigat_pred_serie
merged_df_clean_01["TropiSEQ_good_calls"] = tropiseq_pred_serie
merged_df_clean_01["SpikeHunter_good_calls"] = sh_pred_serie
merged_df_clean_01["Targets"] = target_serie
merged_df_clean_01["Folds"] = fold_serie

merged_df_clean_01

Unnamed: 0,phage,protein,TropiGAT_prediction,TropiSEQ_prediction,SpikeHunter_prediction,TropiGAT_good_calls,TropiSEQ_good_calls,SpikeHunter_good_calls,Targets,Folds
0,K15PH90,K15PH90__cds_54,KL123:0.997 ; KL27:0.988 ; KL112:0.985 ; KL14:...,No_hits,K106,0,0,0,KL15,6-bladed beta-propeller
1,K7PH164C4,K7PH164C4__cds_20,KL39:1.0 ; KL123:0.999 ; KL22:0.996 ; KL114:0....,No_hits,No_hits,"KL39,KL36,KL70",0,0,"KL30,KL70,KL7,KL69,KL65,KL27,KL36,KL64,KL29,KL...",triple-helix
2,K32PH164C1,K32PH164C1__cds_20,KL36:1.0 ; KL39:0.998 ; KL25:0.992 ; KL116:0.9...,No_hits,No_hits,0,0,0,"KL56,KL32",triple-helix
3,K18PH07C1,K18PH07C1__cds_245,KL3:0.998 ; KL63:0.984 ; KL43:0.94 ; KL145:0.9...,KL63:0.594,K63,0,0,0,KL18,right-handed beta-helix
4,K13PH07C1L,K13PH07C1L__cds_11,KL23:0.992 ; KL10:0.99 ; KL30:0.975 ; KL60:0.9...,KL13:0.527,K2,KL13,KL13,0,KL13,right-handed beta-helix
...,...,...,...,...,...,...,...,...,...,...
255,PFOEGONH,PFOEGONH_00078,KL3:1.0 ; KL81:0.985 ; KL14:0.957 ; KL52:0.857...,KL3:0.674 ; KL35:0.736,K3,KL3,KL3,KL3,KL3,right-handed beta-helix
256,NBNDMPCG,NBNDMPCG_00163,KL13:0.958 ; KL2:0.954 ; KL57:0.94 ; KL60:0.92...,KL2:0.684,K2,KL2,KL2,KL2,"KL35,KL2",right-handed beta-helix
257,NJHLHPIG,NJHLHPIG_00061,KL46:0.99 ; KL128:0.975 ; KL18:0.971 ; KL52:0....,KL18:0.522 ; KL46:0.937,K46,0,0,0,KL35,right-handed beta-helix
258,HIIECEMK,HIIECEMK_00054,KL60:0.995 ; KL18:0.991 ; KL23:0.943 ; KL14:0....,KL13:0.527,K2,KL2,0,KL2,KL2,right-handed beta-helix


In [119]:
merged_df_clean_02 = merged_df_clean_01[["phage", "protein", "Folds", "TropiGAT_prediction", "TropiGAT_good_calls", "TropiSEQ_prediction", "TropiSEQ_good_calls", "SpikeHunter_prediction", "SpikeHunter_good_calls", "Targets"]]
merged_df_clean_02.sort_values(by='phage', ascending=True)
merged_df_clean_02.to_csv(f"{path_benchmark}/Prediction_results.matrices.tsv", sep = "\t", header = True, index = False)


***
### Get metrics:

In [172]:
import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter 

path_benchmark = "/media/concha-eloko/Linux/PPT_clean/benchmark"

results_bench_df = pd.read_csv(f"{path_benchmark}/Prediction_results.matrices.tsv", sep = "\t", index_col = False, header = 0)


> Functions: 

In [191]:
# Useful functions: 

def make_labels(pred_df, predictor = "tropigat" , top_n = 15, to_ignore = []) :
    dico_pred = {"tropigat" : "TropiGAT_prediction" ,
                 "tropiseq" : "TropiSEQ_prediction", 
                 "spikehunter" : "SpikeHunter_prediction"}
    col = dico_pred[predictor]
    labels_tropigat = {}
    all_kl_types = set(kltype.strip() for row in results_bench_df["Targets"] for kltype in row.split(","))
    count_kltypes = {}
    dico_prophage_count = all_kl_types
    for kltype in tqdm(dico_prophage_count) : 
        n = 0
        pred_labels , real_labels = [] , []
        for author in dico_matrices :
            if kltype in dico_matrices[author]["pool"] : 
                matrix = dico_matrices[author]["matrix"]
                for phage in matrix["Phages"].unique() :
                    #if phage not in to_ignore : 
                    if phage not in to_ignore and phage in pred_df["phage"].tolist(): 
                        top_predictions = set()
                        predictions = pred_df[pred_df["phage"] == phage][col].values
                        for calls in predictions : 
                            if predictor == "spikehunter" :
                                predicted_kltypes = [x.split(":")[0].strip().replace("K", "KL") for x in calls.split(";")]
                            else :
                                predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";")]
                            top_predictions.update(predicted_kltypes[0:top_n-1])
                        if kltype in top_predictions : 
                            pred_labels.append(1)
                        else : 
                            pred_labels.append(0)
                        if kltype in dico_hits[phage] : 
                            real_labels.append(1)
                            n += 1 
                        else :
                            real_labels.append(0)
        labels_tropigat[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}
        sorted_dict = dict(sorted(labels_tropigat.items(), key=lambda item: int(item[0].split("KL")[1])))
        count_kltypes[kltype] = n 
    return sorted_dict , labels_tropigat, count_kltypes

def decript_dic(sorted_dict, labels_tropigat, count_kltypes, print_df=True):
    # Initialize lists to collect metrics when print_df is False
    kltypes, counts, f1_scores, accuracies, recalls, precisions, aucs, mccs = ([] for _ in range(8))
    
    if print_df:
        print("KL type", "Count", "F1", "Accuracy", "Recall", "Precision", "AUC", "MCC", sep="\t")
    
    for kltype in sorted_dict:
        if len(labels_tropigat[kltype]["real_labels"]) > 1:
            if Counter(labels_tropigat[kltype]["y_pred"])[1] == 0:
                if print_df:
                    print(kltype, count_kltypes[kltype], "error", sep="\t")
                else:
                    # Append NaN for each metric when Counter condition is met
                    kltypes.append(kltype)
                    counts.append(count_kltypes[kltype])
                    f1_scores.append(np.nan)
                    accuracies.append(np.nan)
                    recalls.append(np.nan)
                    precisions.append(np.nan)
                    aucs.append(np.nan)
                    mccs.append(np.nan)
            else:
                try:
                    # Calculate each metric
                    f1 = f1_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                    precision = precision_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                    recall = recall_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                    mcc = matthews_corrcoef(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                    accuracy = accuracy_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                    auc = roc_auc_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                    
                    if print_df:
                        print(kltype, count_kltypes[kltype], round(f1, 5), round(accuracy, 5), round(recall, 5), 
                              round(precision, 5), round(auc, 5), round(mcc, 5), sep="\t")
                    else:
                        # Append each metric to the corresponding list
                        kltypes.append(kltype)
                        counts.append(count_kltypes[kltype])
                        f1_scores.append(f1)
                        accuracies.append(accuracy)
                        recalls.append(recall)
                        precisions.append(precision)
                        aucs.append(auc)
                        mccs.append(mcc)                        
                except ValueError:
                    if print_df:
                        print(kltype, "error", sep="\t")
                    else:
                        # Append NaN values when a ValueError occurs
                        kltypes.append(kltype)
                        counts.append(count_kltypes[kltype])
                        f1_scores.append(np.nan)
                        accuracies.append(np.nan)
                        recalls.append(np.nan)
                        precisions.append(np.nan)
                        aucs.append(np.nan)
                        mccs.append(np.nan)

    # Return lists when print_df is False
    if not print_df:
        dataframe = pd.DataFrame({
            "KL type": kltypes,
            "Count": counts,
            "F1": f1_scores,
            "Accuracy": accuracies,
            "Recall": recalls,
            "Precision": precisions,
            "AUC": aucs,
            "MCC": mccs
        })
        dataframe = dataframe[dataframe["Count"] > 0]
        return dataframe


In [196]:
# Full dataset: 
sorted_dic_tropigat, labels_tropigat, count_kltypes_tropigat = make_labels(results_bench_df, predictor = "tropigat", top_n = 40)
sorted_dic_tropiseq, labels_tropiseq, count_kltypes_tropiseq = make_labels(results_bench_df, predictor = "tropiseq", top_n = 40)
sorted_dic_tropish, labels_tropish, count_kltypes_sh = make_labels(results_bench_df, predictor = "spikehunter", top_n = 40)

tropigat_full_df = decript_dic(sorted_dic_tropigat, labels_tropigat, count_kltypes_tropigat, print_df=False)
tropiseq_full_df = decript_dic(sorted_dic_tropiseq, labels_tropiseq, count_kltypes_tropiseq, print_df=False)
tropish_full_df = decript_dic(sorted_dic_tropish, labels_tropish, count_kltypes_sh, print_df=False)

tropigat_full_df.to_csv(f"{path_benchmark}/TropiGAT.KL_metrics.full.tsv", sep = "\t", index = False, header = True)
tropiseq_full_df.to_csv(f"{path_benchmark}/TropiSEQ.KL_metrics.full.tsv", sep = "\t", index = False, header = True)
tropish_full_df.to_csv(f"{path_benchmark}/SpikeHunter.KL_metrics.full.tsv", sep = "\t", index = False, header = True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 75.61it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 82.09it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 83.97it/s]


> Comparative metrics, full: 

In [222]:
comparative_full_df = pd.DataFrame({
            "KL type": tropigat_full_df["KL type"],
            "Count": tropigat_full_df["Count"],
            # recall
            "Recall_TropiGAT": tropigat_full_df["Recall"].round(3),
            "Recall_TropiSEQ": tropiseq_full_df["Recall"].round(3),
            "Recall_TropiSH": tropish_full_df["Recall"].round(3),
            # precision: 
            "Precision_TropiGAT": tropigat_full_df["Precision"].round(3),
            "Precision_TropiSEQ": tropiseq_full_df["Precision"].round(3),
            "Precision_TropiSH": tropish_full_df["Precision"].round(3),
            # MCC: 
            "MCC_TropiGAT": tropigat_full_df["MCC"].round(3),
            "MCC_TropiSEQ": tropiseq_full_df["MCC"].round(3),
            "MCC_TropiSH": tropish_full_df["MCC"].round(3),
        })

comparative_full_df = comparative_full_df.fillna("-")
comparative_full_df.to_csv(f"{path_benchmark}/Comparative.Full.tsv", sep = "\t", index = False, header = True)


In [223]:
comparative_full_df

Unnamed: 0,KL type,Count,Recall_TropiGAT,Recall_TropiSEQ,Recall_TropiSH,Precision_TropiGAT,Precision_TropiSEQ,Precision_TropiSH,MCC_TropiGAT,MCC_TropiSEQ,MCC_TropiSH
0,KL1,3,0.0,-,0.0,0.0,-,0.0,-0.082,-,-0.017
1,KL2,16,0.375,0.188,0.25,0.286,1.0,0.667,0.213,0.409,0.362
2,KL3,12,0.333,0.25,0.25,0.571,1.0,1.0,-0.043,0.343,0.343
3,KL4,7,-,-,-,-,-,-,-,-,-
4,KL5,1,0.0,-,-,0.0,-,-,-0.096,-,-
...,...,...,...,...,...,...,...,...,...,...,...
84,KL144,2,-,-,-,-,-,-,-,-,-
85,KL149,2,0.0,-,-,0.0,-,-,-0.061,-,-
86,KL151,1,-,1.0,1.0,-,1.0,1.0,-,1.0,1.0
87,KL157,1,-,-,1.0,-,-,1.0,-,-,1.0


> Comparative metrics, helices:

In [197]:
# Helices:  
helices_results_bench_df = results_bench_df[results_bench_df["Folds"].isin(["triple-helix","right-handed beta-helix"])]

sorted_dic_tropigat_helices, labels_tropigat_helices, count_kltypes_tropigat_helices = make_labels(helices_results_bench_df, predictor = "tropigat", top_n = 40)
sorted_dic_tropiseq_helices, labels_tropiseq_helices, count_kltypes_tropiseq_helices = make_labels(helices_results_bench_df, predictor = "tropiseq", top_n = 40)
sorted_dic_tropish_helices, labels_tropish_helices, count_kltypes_sh_helices = make_labels(helices_results_bench_df, predictor = "spikehunter", top_n = 40)

tropigat_helices_df = decript_dic(sorted_dic_tropigat_helices, labels_tropigat_helices, count_kltypes_tropigat_helices, print_df=False)
tropiseq_helices_df = decript_dic(sorted_dic_tropiseq_helices, labels_tropiseq_helices, count_kltypes_tropiseq_helices, print_df=False)
tropish_helices_df = decript_dic(sorted_dic_tropish_helices, labels_tropish_helices, count_kltypes_sh_helices, print_df=False)

tropigat_helices_df.to_csv(f"{path_benchmark}/TropiGAT.KL_metrics.helices.tsv", sep = "\t", index = False, header = True)
tropiseq_helices_df.to_csv(f"{path_benchmark}/TropiSEQ.KL_metrics.helices.tsv", sep = "\t", index = False, header = True)
tropish_helices_df.to_csv(f"{path_benchmark}/SpikeHunter.KL_metrics.helices.tsv", sep = "\t", index = False, header = True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 87.74it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 96.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 89/89 [00:00<00:00, 95.52it/s]
  _warn_prf(average, modifier, msg_start, len(result))


In [220]:
comparative_helices_df = pd.DataFrame({
            "KL type": tropigat_helices_df["KL type"],
            "Count": tropigat_helices_df["Count"],
            # recall
            "Recall_TropiGAT": tropigat_helices_df["Recall"].round(3),
            "Recall_TropiSEQ": tropiseq_helices_df["Recall"].round(3),
            "Recall_TropiSH": tropish_helices_df["Recall"].round(3),
            # precision: 
            "Precision_TropiGAT": tropigat_helices_df["Precision"].round(3),
            "Precision_TropiSEQ": tropiseq_helices_df["Precision"].round(3),
            "Precision_TropiSH": tropish_helices_df["Precision"].round(3),
            # MCC: 
            "MCC_TropiGAT": tropigat_helices_df["MCC"].round(3),
            "MCC_TropiSEQ": tropiseq_helices_df["MCC"].round(3),
            "MCC_TropiSH": tropish_helices_df["MCC"].round(3),
        })

comparative_helices_df = comparative_helices_df.fillna("-")
comparative_helices_df.to_csv(f"{path_benchmark}/Comparative.Helices.tsv", sep = "\t", index = False, header = True)

In [221]:
comparative_helices_df

Unnamed: 0,KL type,Count,Recall_TropiGAT,Recall_TropiSEQ,Recall_TropiSH,Precision_TropiGAT,Precision_TropiSEQ,Precision_TropiSH,MCC_TropiGAT,MCC_TropiSEQ,MCC_TropiSH
0,KL1,3,0.0,-,0.0,0.0,-,0.0,-0.085,-,-0.018
1,KL2,11,0.545,0.273,0.364,0.286,1.0,0.667,0.305,0.503,0.455
2,KL3,7,0.571,0.429,0.429,0.571,1.0,1.0,0.143,0.522,0.522
3,KL4,2,-,-,-,-,-,-,-,-,-
4,KL5,1,0.0,-,-,0.0,-,-,-0.027,-,-
...,...,...,...,...,...,...,...,...,...,...,...
83,KL140,3,-,-,1.0,-,-,1.0,-,-,1.0
85,KL149,2,0.0,-,-,0.0,-,-,-0.065,-,-
86,KL151,1,-,1.0,1.0,-,1.0,1.0,-,1.0,1.0
87,KL157,1,-,-,1.0,-,-,1.0,-,-,1.0


***
### Experimentally validated depolymerase results: 

In [6]:
tropigat_pred_xp_df = pd.read_csv(f"{path_benchmark}/TropiGAT.exp_val_depolymerase.tsv", sep = "\t", names = ["protein", "prediction"])
tropiseq_pred_xp_df = pd.read_csv(f"{path_benchmark}/TropiSEQ.exp_val_depolymerase.tsv", sep = "\t", names = ["protein", "prediction"])
SH_pred_xp_df = pd.read_csv(f"{path_benchmark}/SpikeHunter_predictions.exp_val_depolymerase.tsv", sep = "\t", names = ["protein", "prediction"])
