***
### Import modules : 

In [53]:
import os 
import pandas as pd 
from tqdm import tqdm
from collections import Counter
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef


# Measure the performances formally :


> Open the results files : 

In [132]:
path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.matrices.tailored.tsv", header = 0, sep = "\t")
tropigat_results

Unnamed: 0,phage,protein_id,predictions_seqbased,predictions_tropigat
0,A1a,A1a_00002,No hits,KL123:0.996 ; KL7:0.981 ; KL9:0.974 ; KL110:0....
1,A1a,A1a_00014,KL151:0.708,KL74:0.974 ; KL70:0.907 ; KL29:0.823 ; KL110:0...
2,A1b,A1b_00048,KL157:0.57,KL53:0.989 ; KL60:0.989 ; KL34:0.865 ; KL128:0...
3,A1b,A1b_00036,No hits,KL123:0.999 ; KL7:0.998 ; KL110:0.981 ; KL9:0....
4,A1c,A1c_00046,No hits,KL123:0.996 ; KL7:0.992 ; KL110:0.967 ; KL9:0....
...,...,...,...,...
180,S13a,S13a_00036,No hits,KL12:0.979 ; KL70:0.965 ; KL136:0.934 ; KL123:...
181,S13b,S13b_00058,KL63:0.893,KL47:0.992 ; KL64:0.913 ; KL28:0.882 ; KL34:0....
182,S13c,S13c_00055,KL38:0.9,KL12:0.988 ; KL123:0.976 ; KL145:0.972 ; KL136...
183,S13d,S13d_00057,KL14:0.976,KL14:0.99 ; KL21:0.975 ; KL13:0.741 ; KL53:0.5...


***
### Read the matrices :

In [97]:
path_finetuning = "/media/concha-eloko/Linux/PPT_clean/in_vitro/fine_tuning"

bea_df = pd.read_csv(f"{path_finetuning}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

ferriol_df = pd.read_csv(f"{path_finetuning}/ferriol_fine_tuning.df", sep = "\t", header = 0)
ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

towndsend_df = pd.read_csv(f"{path_finetuning}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}



> TropiGATv2 DF : 

In [98]:
DF_info = pd.read_csv(f"{path_project}/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]



In [99]:
# targets dico : 
dico_hits = {}
for author in dico_matrices :
    matrix = dico_matrices[author]["matrix"]
    for _, row in matrix.iterrows() : 
        for phage in matrix["Phages"].unique() : 
            all_targets = set()
            targets = matrix[matrix["Phages"] == phage]["Target"].values
            for calls in targets : 
                actual_targets = [x.strip() for x in calls.split(",")]
                all_targets.update(actual_targets)
            dico_hits[phage] = all_targets

***
### Make the raw results file : 

In [131]:
top_n = 15

path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.classic_1112.tsv", header = 0, sep = "\t")

with open(f"{path_project}/raw_metrics.classic.tsv", "w") as outfile :
    outfile.write(f"Phage\tProtein\tTropiGAT_predictions\tTropiGAT_good_calls\tTropiSeq_predictions\tTropiSeq_good_calls\tTargets\n")
    for _, row in tropigat_results.iterrows() : 
        targets = dico_hits[row["phage"]]
        outfile.write(f"{row['phage']}\t{row['protein_id']}\t")
        # TropiGAT part : 
        # write the pred
        top_n_predictions = ";".join([x for x in row["predictions_tropigat"].split(";")][0:top_n-1])
        outfile.write(top_n_predictions + "\t")
        # check the calls
        tropigat_pred = [x.split(":")[0].strip() for x in row["predictions_tropigat"].split(";")]
        top_KLtypes_pred = set(tropigat_pred[0: top_n-1])
        good_calls = top_KLtypes_pred.intersection(targets)
        if len(good_calls) > 0 : 
            outfile.write(",".join(list(good_calls)) + "\t")
        else : 
            outfile.write("0" + "\t")
        # TropiSeq part : 
        # write the pred
        outfile.write(row["predictions_seqbased"] + "\t")
        if row["predictions_seqbased"] != "No hits" : 
            tropiseq_pred = [x.split(":")[0].strip() for x in row["predictions_seqbased"].split(";")]
            top_predictions = set(tropiseq_pred[0: top_n-1])
            good_calls = top_predictions.intersection(targets)
            if len(good_calls) > 0 : 
                outfile.write(",".join(list(good_calls)) + "\t")
            else :
                outfile.write("0" + "\t")
        else :
            outfile.write("0\t")
        target_clean = ",".join(list(targets))
        outfile.write(target_clean + "\n")

    

In [130]:
raw_df = pd.read_csv(f"{path_project}/raw_metrics.tailored_bit50.tsv" , sep = "\t", header = 0)
raw_df

Unnamed: 0,Phage,Protein,TropiGAT_predictions,TropiGAT_good_calls,TropiSeq_predictions,TropiSeq_good_calls,Targets
0,A1a,A1a_00002,KL123:0.996 ; KL7:0.981 ; KL9:0.974 ; KL110:0....,0,KL7:0.989;KL102:0.569,0,KL151
1,A1a,A1a_00014,KL74:0.974 ; KL70:0.907 ; KL29:0.823 ; KL110:0...,0,KL151:0.708,KL151,KL151
2,A1b,A1b_00048,KL53:0.989 ; KL60:0.989 ; KL34:0.865 ; KL128:0...,0,KL157:0.57,KL157,KL157
3,A1b,A1b_00036,KL123:0.999 ; KL7:0.998 ; KL110:0.981 ; KL9:0....,0,KL7:0.989;KL102:0.569,0,KL157
4,A1c,A1c_00046,KL123:0.996 ; KL7:0.992 ; KL110:0.967 ; KL9:0....,0,KL7:0.989;KL102:0.569,0,KL1
...,...,...,...,...,...,...,...
180,S13a,S13a_00036,KL12:0.979 ; KL70:0.965 ; KL136:0.934 ; KL123:...,0,KL38:0.699,0,"KL149,KL102"
181,S13b,S13b_00058,KL47:0.992 ; KL64:0.913 ; KL28:0.882 ; KL34:0....,0,KL63:0.893,KL63,KL63
182,S13c,S13c_00055,KL12:0.988 ; KL123:0.976 ; KL145:0.972 ; KL136...,0,KL38:0.9,0,"KL149,KL102"
183,S13d,S13d_00057,KL14:0.99 ; KL21:0.975 ; KL13:0.741 ; KL53:0.5...,KL14,KL14:0.976,KL14,KL14


***
### Make the matrices files : 

In [133]:
top_n = 15
labels_tropigat = {}
for kltype in tqdm(dico_prophage_count) : 
    pred_labels , real_labels = [] , []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            matrix = dico_matrices[author]["matrix"]
            for phage in matrix["Phages"].unique() :
                top_predictions = set()
                predictions = tropigat_results[tropigat_results["phage"] == phage]["predictions_tropigat"].values
                for calls in predictions : 
                    predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";")]
                    top_predictions.update(predicted_kltypes[0:top_n-1])
                if kltype in top_predictions : 
                    pred_labels.append(1)
                else : 
                    pred_labels.append(0)
                if kltype in dico_hits[phage] : 
                    real_labels.append(1)
                else :
                    real_labels.append(0)
    labels_tropigat[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}

100%|████████████████████████████████████████| 129/129 [00:00<00:00, 158.93it/s]


In [134]:
labels_tropigat
sorted_dict = dict(sorted(labels_tropigat.items(), key=lambda item: int(item[0].split("KL")[1])))


In [135]:
aucs = []


for kltype in sorted_dict : 
    if len(labels_tropigat[kltype]["real_labels"]) > 1 : 
        if Counter(labels_tropigat[kltype]["y_pred"])[1] > 0 :
            f1 = f1_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
            precision = precision_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
            recall = recall_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
            mcc = matthews_corrcoef(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
            accuracy = accuracy_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
            auc = roc_auc_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
            
            print(kltype , round(f1,5), round(precision,5), round(recall,5), round(accuracy,5), round(mcc,5), round(auc,5), sep = "\t")
            aucs.append(auc)

KL1	0.0	0.0	0.0	0.94318	-0.02865	0.48824
KL2	0.4	0.36364	0.44444	0.89091	0.34274	0.68757
KL3	0.57143	0.5	0.66667	0.86364	0.49951	0.7807
KL5	0.0	0.0	0.0	0.84	-0.05764	0.42857
KL7	0.0	0.0	0.0	0.76	-0.07587	0.38776
KL8	0.0	0.0	0.0	0.9	-0.04213	0.45918
KL10	0.25	0.14286	1.0	0.88	0.35407	0.93878
KL13	0.09091	0.05	0.5	0.77273	0.09924	0.63953
KL14	0.13793	0.07407	1.0	0.71591	0.22922	0.85465
KL16	0.17647	0.09677	1.0	0.74545	0.2673	0.86916
KL17	0.4	0.33333	0.5	0.94	0.37819	0.72917
KL18	0.16667	0.09091	1.0	0.8	0.26899	0.89796
KL20	0.66667	0.5	1.0	0.97368	0.69749	0.98649
KL22	0.0	0.0	0.0	0.86364	-0.0546	0.48718
KL24	0.33333	0.25	0.5	0.90909	0.31053	0.71429
KL25	0.2069	0.11538	1.0	0.73864	0.29011	0.86471
KL27	0.33333	0.2	1.0	0.92	0.42857	0.95918
KL28	0.4	0.25	1.0	0.86364	0.46291	0.92857
KL36	0.0	0.0	0.0	0.9	-0.05157	0.47872
KL38	0.66667	0.5	1.0	0.98	0.69985	0.9898
KL39	0.14286	0.08333	0.5	0.76	0.12427	0.63542
KL43	0.125	0.06667	1.0	0.72	0.21822	0.85714
KL45	0.10526	0.05556	1.0	0.66	0.19048	0.82653

In [93]:
top_n = 15
labels_tropiseq = {}
for kltype in tqdm(dico_prophage_count) : 
    pred_labels , real_labels = [] , []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            matrix = dico_matrices[author]["matrix"]
            for phage in matrix["Phages"].unique() :
                top_predictions = set()
                predictions = tropigat_results[tropigat_results["phage"] == phage]["predictions_seqbased"].values
                for calls in predictions : 
                    predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";") if x != "No hits"]
                    top_predictions.update(predicted_kltypes[0:top_n-1])
                if kltype in top_predictions : 
                    pred_labels.append(1)
                else : 
                    pred_labels.append(0)
                if kltype in dico_hits[phage] : 
                    real_labels.append(1)
                else :
                    real_labels.append(0)
    labels_tropiseq[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}

100%|████████████████████████████████████████| 129/129 [00:00<00:00, 186.42it/s]


In [94]:
aucs_tropiseq = []
for kltype in sorted_dict : 
    if len(labels_tropiseq[kltype]["real_labels"]) > 1 : 
        if Counter(labels_tropiseq[kltype]["y_pred"])[1] > 0 :
            f1 = f1_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            precision = precision_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            recall = recall_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            mcc = matthews_corrcoef(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            accuracy = accuracy_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            auc = roc_auc_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            print(kltype , round(f1,5), round(precision,5), round(recall,5), round(accuracy,5), round(mcc,5), round(auc,5), sep = "\t")
            aucs_tropiseq.append(auc)

KL2	0.66667	0.83333	0.55556	0.95455	0.65858	0.77283
KL3	0.66667	0.66667	0.66667	0.90909	0.61404	0.80702
KL13	0.66667	1.0	0.5	0.98864	0.70303	0.75
KL14	0.66667	1.0	0.5	0.98864	0.70303	0.75
KL16	1.0	1.0	1.0	1.0	1.0	1.0
KL17	0.66667	1.0	0.5	0.98	0.69985	0.75
KL18	0.5	0.33333	1.0	0.96	0.56544	0.97959
KL22	0.16667	0.5	0.1	0.88636	0.18565	0.54359
KL23	1.0	1.0	1.0	1.0	1.0	1.0
KL24	0.44444	0.4	0.5	0.94318	0.41776	0.73214
KL25	0.85714	0.75	1.0	0.98864	0.86092	0.99412
KL27	0.66667	0.5	1.0	0.98	0.69985	0.9898
KL30	0.66667	1.0	0.5	0.98864	0.70303	0.75
KL38	1.0	1.0	1.0	1.0	1.0	1.0
KL39	0.0	0.0	0.0	0.94	-0.02916	0.48958
KL43	0.0	0.0	0.0	0.96	-0.02041	0.4898
KL45	0.66667	0.5	1.0	0.98	0.69985	0.9898
KL51	1.0	1.0	1.0	1.0	1.0	1.0
KL56	1.0	1.0	1.0	1.0	1.0	1.0
KL60	0.0	0.0	0.0	0.94	-0.02916	0.47959
KL63	0.8	0.66667	1.0	0.98864	0.81174	0.99419
KL64	1.0	1.0	1.0	1.0	1.0	1.0
KL66	1.0	1.0	1.0	1.0	1.0	1.0
KL71	0.0	0.0	0.0	0.94	-0.02916	0.48958
KL74	0.66667	0.5	1.0	0.96	0.69222	0.97917
KL102	0.0	0.0	0.0	0.31579	

In [95]:
kltype_see = "KL43"
labels_zip_seq = tuple(zip(labels_tropiseq[kltype_see]["y_pred"], labels_tropiseq[kltype_see]["real_labels"]))
labels_zip_gat = tuple(zip(labels_tropigat[kltype_see]["y_pred"], labels_tropigat[kltype_see]["real_labels"]))

# tuple(zip(labels_zip_seq , labels_zip_gat))

In [68]:
import statistics

mean = statistics.mean(aucs_tropiseq)
print(mean)

0.8093837131614184


In [64]:
tropigat_results

Unnamed: 0,phage,protein_id,predictions_seqbased,predictions_tropigat
0,A1a,A1a_00002,No hits,KL123:0.996 ; KL7:0.981 ; KL9:0.974 ; KL110:0....
1,A1a,A1a_00014,KL151:0.708,KL74:0.974 ; KL70:0.907 ; KL29:0.823 ; KL110:0...
2,A1b,A1b_00048,KL157:0.57,KL53:0.989 ; KL60:0.989 ; KL34:0.865 ; KL13:0....
3,A1b,A1b_00036,No hits,KL123:0.999 ; KL7:0.998 ; KL110:0.981 ; KL9:0....
4,A1c,A1c_00046,No hits,KL123:0.996 ; KL7:0.992 ; KL110:0.967 ; KL9:0....
...,...,...,...,...
180,S13a,S13a_00036,No hits,KL70:0.965 ; KL136:0.934 ; KL123:0.921 ; KL12:...
181,S13b,S13b_00058,KL63:0.893,KL47:0.992 ; KL64:0.913 ; KL28:0.882 ; KL3:0.7...
182,S13c,S13c_00055,KL38:0.9,KL123:0.976 ; KL145:0.972 ; KL38:0.969 ; KL136...
183,S13d,S13d_00057,KL14:0.976,KL14:0.99 ; KL21:0.975 ; KL13:0.917 ; KL62:0.8...
