***
### Import modules : 

In [10]:
import os 
import pandas as pd 
from tqdm import tqdm
from collections import Counter
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, matthews_corrcoef


# Measure the performances formally :


## Opening the DF

> Open the results files : 

In [48]:
path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
# PPT_results.matrices.1512.tsv
# PPT_results.classic_0101.bit50.tsv
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.classic_1001.bit75.tsv", header = 0, sep = "\t")
df_folds = pd.read_csv(f"{path_project}/in_vitro/dpos_folds.all_matrices.tsv", header = 0, sep = "\t")
tropigat_results.protein_id.tolist()

['A1a_00002',
 'A1a_00014',
 'A1b_00048',
 'A1b_00036',
 'A1c_00046',
 'A1c_00034',
 'A1d_00013',
 'A1d_00009',
 'A1e_00024',
 'A1e_00012',
 'A1f_00024',
 'A1f_00012',
 'A1g_00045',
 'A1g_00057',
 'A1h_00021',
 'A1h_00009',
 'A1h_00013',
 'A1i_00041',
 'A1i_00037',
 'A1i_00049',
 'A1j_00040',
 'A1j_00049',
 'A1j_00002',
 'A1k_00018',
 'A1k_00014',
 'A1l_00058',
 'A1l_00005',
 'A1m_00045',
 'A1m_00049',
 'A1n_00050',
 'A1o_00045',
 'A1o_00041',
 'A1p_00055',
 'A1q_00023',
 'A1q_00010',
 'A1q_00019',
 'A1r_00013',
 'A1r_00009',
 'A2a_00010',
 'A2a_b_00022',
 'A2a_b_00036',
 'A2a_00049',
 'A2b_00022',
 'A2b_00008',
 'A3a_00002',
 'A3a_00045',
 'A3b_00021',
 'A3b_00016',
 'A3c_00044',
 'A3c_00039',
 'A3c_00045',
 'A3d_00041',
 'A3d_00042',
 'A3d_00036',
 'BLCJPOBP_00041',
 'D7b_00043',
 'D7c_00007',
 'DIMCIIMF_00039',
 'DIMCIIMF_00240',
 'DJLANJJD_00238',
 'EHPPICDA_00095',
 'EKPIEFBL_00113',
 'EKPIEFBL_00177',
 'EONHMLJF_00087',
 'FADJDIKG_00083',
 'FADJDIKG_00019',
 'GCLPFEGH_00240',
 'H

In [52]:
tropigat_results[tropigat_results["phage"] == "K17alfa62"]

Unnamed: 0,phage,protein_id,predictions_seqbased,predictions_tropigat
89,K17alfa62,K17alfa62__cds_66,No hits,KL62:1.0 ; KL43:0.999 ; KL29:0.986 ; KL52:0.93...
90,K17alfa62,K17alfa62__cds_64,KL17: 0.611,KL17:1.0 ; KL128:0.998 ; KL29:0.912 ; KL70:0.8...


***
### Read the matrices :

In [13]:
path_finetuning = "/media/concha-eloko/Linux/PPT_clean/in_vitro/fine_tuning"

bea_df = pd.read_csv(f"{path_finetuning}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

ferriol_df = pd.read_csv(f"{path_finetuning}/ferriol_fine_tuning.df", sep = "\t", header = 0)
ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

towndsend_df = pd.read_csv(f"{path_finetuning}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}



> TropiGATv2 DF : 

In [64]:
from collections import Counter
DF_info = pd.read_csv(f"{path_project}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
#DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
#DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

DF_info_lvl_0 = DF_info.copy()
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]



In [14]:
# targets dico : 
dico_hits = {}
for author in dico_matrices :
    matrix = dico_matrices[author]["matrix"]
    for _, row in matrix.iterrows() : 
        for phage in matrix["Phages"].unique() : 
            all_targets = set()
            targets = matrix[matrix["Phages"] == phage]["Target"].values
            for calls in targets : 
                actual_targets = [x.strip() for x in calls.split(",")]
                all_targets.update(actual_targets)
            dico_hits[phage] = all_targets

In [15]:
dico_hits

{'K15PH90': {'KL15'},
 'K64PH164C4': {'KL64'},
 'K5lambda5': {'KL5'},
 'K11PH164C1': {'KL11', 'KL57'},
 'K57lambda1_2': {'KL57', 'KL68'},
 'K60PH164C1': {'KL18', 'KL60'},
 'K13PH07C1S': {'KL13'},
 'K2PH164C1': {'KL2'},
 'K41P2': {'KL39', 'KL41'},
 'K71PH129C1': {'KL71'},
 'K56PH164C1': {'KL56'},
 'K38PH09C2': {'KL38'},
 'K22PH164C1': {'KL22', 'KL37'},
 'K58PH129C2': {'KL58'},
 'K44PH129C1': {'KL37', 'KL44'},
 'K17alfa62': {'KL17', 'KL62'},
 'K24PH164C1': {'KL24'},
 'K37PH164C1': {'KL2', 'KL37'},
 'K35PH164C3': {'KL30', 'KL35', 'KL36'},
 'K46PH129': {'KL46'},
 'K25PH129C1': {'KL25'},
 'K48PH164C1': {'KL48'},
 'K2PH164C2': {'KL2'},
 'K23PH08C2': {'KL23', 'KL54', 'KL58'},
 'K2alfa62': {'KL2'},
 'K10PH82C1': {'KL10'},
 'K65PH164': {'KL27',
  'KL29',
  'KL32',
  'KL35',
  'KL36',
  'KL50',
  'KL59',
  'KL62',
  'KL65',
  'KL69',
  'KL7',
  'KL70',
  'KL72',
  'KL79',
  'KL80'},
 'K26PH128C1': {'KL26', 'KL74'},
 'K39PH122C2': {'KL36', 'KL39'},
 'K16PH164C3': {'KL16'},
 'K74PH129C2': {'KL74'}

***
### Make the raw results file : 

> Old 

In [None]:
top_n = 15

path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
# classic_0101
# SAGE_0201
# PPT_results.classic_0101.bit50.tsv
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.classic_1001.bit50.tsv", header = 0, sep = "\t")

with open(f"{path_project}/raw_metrics.classic_1001.bit50.tsv", "w") as outfile :
    outfile.write(f"Phage\tProtein\tTropiGAT_predictions\tTropiGAT_good_calls\tTropiSeq_predictions\tTropiSeq_good_calls\tTargets\n")
    for _, row in tropigat_results.iterrows() : 
        targets = dico_hits[row["phage"]]
        outfile.write(f"{row['phage']}\t{row['protein_id']}\t")
        # TropiGAT part : 
        # write the pred
        top_n_predictions = ";".join([x for x in row["predictions_tropigat"].split(";")][0:top_n-1])
        outfile.write(top_n_predictions + "\t")
        # check the calls
        tropigat_pred = [x.split(":")[0].strip() for x in row["predictions_tropigat"].split(";")]
        top_KLtypes_pred = set(tropigat_pred[0: top_n-1])
        good_calls = top_KLtypes_pred.intersection(targets)
        if len(good_calls) > 0 : 
            outfile.write(",".join(list(good_calls)) + "\t")
        else : 
            outfile.write("0" + "\t")
        # TropiSeq part : 
        # write the pred
        outfile.write(row["predictions_seqbased"] + "\t")
        if row["predictions_seqbased"] != "No hits" : 
            tropiseq_pred = [x.split(":")[0].strip() for x in row["predictions_seqbased"].split(";")]
            top_predictions = set(tropiseq_pred[0: top_n-1])
            good_calls = top_predictions.intersection(targets)
            if len(good_calls) > 0 : 
                outfile.write(",".join(list(good_calls)) + "\t")
            else :
                outfile.write("0" + "\t")
        else :
            outfile.write("0\t")
        target_clean = ",".join(list(targets))
        outfile.write(target_clean + "\n")

    

In [57]:
top_n = 15

path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
# classic_0101
# SAGE_0201
# PPT_results.classic_0101.bit50.tsv
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.classic_1001.bit75.tsv", header = 0, sep = "\t")

with open(f"{path_project}/raw_metrics.classic_1001.bit75.tsv", "w") as outfile :
    outfile.write(f"Phage\tProtein\tFolds\tTropiGAT_predictions\tTropiGAT_good_calls\tTropiSeq_predictions\tTropiSeq_good_calls\tTargets\n")
    for _, row in tropigat_results.iterrows() : 
        if row["phage"] in dico_hits :
            targets = dico_hits[row["phage"]]
            prot_id = row['protein_id'].replace("__cds", "_cds")
            try : 
                #targets = dico_hits[row["phage"]]
                #prot_id = row['protein_id'].replace("__cds", "_cds")
                fold = df_folds[df_folds["protein_id"] == prot_id]["Fold"].values[0]
            except Exception as e :
                fold = "unknown"
            outfile.write(f"{row['phage']}\t{row['protein_id']}\t{fold}\t")
            # TropiGAT part : 
            # write the pred
            top_n_predictions = ";".join([x for x in row["predictions_tropigat"].split(";")][0:top_n-1])
            outfile.write(top_n_predictions + "\t")
            # check the calls
            tropigat_pred = [x.split(":")[0].strip() for x in row["predictions_tropigat"].split(";")]
            top_KLtypes_pred = set(tropigat_pred[0: top_n-1])
            good_calls = top_KLtypes_pred.intersection(targets)
            if len(good_calls) > 0 : 
                outfile.write(",".join(list(good_calls)) + "\t")
            else : 
                outfile.write("0" + "\t")
            # TropiSeq part : 
            # write the pred
            outfile.write(row["predictions_seqbased"] + "\t")
            if row["predictions_seqbased"] != "No hits" : 
                tropiseq_pred = [x.split(":")[0].strip() for x in row["predictions_seqbased"].split(";")]
                top_predictions = set(tropiseq_pred[0: top_n-1])
                good_calls = top_predictions.intersection(targets)
                if len(good_calls) > 0 : 
                    outfile.write(",".join(list(good_calls)) + "\t")
                else :
                    outfile.write("0" + "\t")
            else :
                outfile.write("0\t")
            target_clean = ",".join(list(targets))
            outfile.write(target_clean + "\n")


    

***
## Working on the final matrices file :

In [30]:
import os 
import pandas as pd 

path_project = "/media/concha-eloko/Linux/PPT_clean"


raw_df = pd.read_csv(f"{path_project}/raw_metrics.classic_1001.bit75.tsv" , sep = "\t", header = 0)
# raw_df[raw_df["Phage"] == "K17alfa62"]

tropigat_results = raw_df.copy()
tropigat_results

Unnamed: 0,Phage,Protein,Folds,TropiGAT_predictions,TropiGAT_good_calls,TropiSeq_predictions,TropiSeq_good_calls,Targets
0,A1a,A1a_00002,6-bladed beta-propeller,KL111:0.983 ; KL123:0.982 ; KL45:0.973 ; KL24:...,0,KL102: 0.691,0,KL151
1,A1a,A1a_00014,right-handed beta-helix,KL128:0.987 ; KL29:0.979 ; KL70:0.958 ; KL24:0...,0,KL151: 0.698,KL151,KL151
2,A1b,A1b_00048,right-handed beta-helix,KL46:0.994 ; KL128:0.991 ; KL149:0.951 ; KL74:...,0,KL157: 0.729,KL157,KL157
3,A1b,A1b_00036,6-bladed beta-propeller,KL123:0.998 ; KL111:0.983 ; KL128:0.982 ; KL45...,0,KL102: 0.691,0,KL157
4,A1c,A1c_00046,6-bladed beta-propeller,KL123:0.994 ; KL24:0.982 ; KL45:0.953 ; KL111:...,0,KL102: 0.691,0,KL1
...,...,...,...,...,...,...,...,...
225,S13a,S13a_00036,right-handed beta-helix,KL60:0.97 ; KL46:0.96 ; KL27:0.95 ; KL23:0.922...,0,KL38: 0.822,0,"KL102,KL149"
226,S13b,S13b_00058,right-handed beta-helix,KL47:0.994 ; KL81:0.932 ; KL74:0.917 ; KL28:0....,0,KL63: 0.867,KL63,KL63
227,S13c,S13c_00055,right-handed beta-helix,KL12:0.992 ; KL38:0.98 ; KL128:0.974 ; KL27:0....,0,No hits,0,"KL102,KL149"
228,S13d,S13d_00057,right-handed beta-helix,KL14:0.986 ; KL128:0.968 ; KL46:0.965 ; KL15:0...,KL14,KL14: 0.736,KL14,KL14


In [11]:
from collections import Counter
DF_info = pd.read_csv(f"{path_project}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
#DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
#DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

DF_info_lvl_0 = DF_info.copy()
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]


In [None]:
path_finetuning = "/media/concha-eloko/Linux/PPT_clean/in_vitro/fine_tuning"

bea_df = pd.read_csv(f"{path_finetuning}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

ferriol_df = pd.read_csv(f"{path_finetuning}/ferriol_fine_tuning.df", sep = "\t", header = 0)
ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

towndsend_df = pd.read_csv(f"{path_finetuning}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}



> check the folds :

In [6]:
from collections import Counter 

folds = dict(Counter(raw_df["Folds"]))
folds



{'6-bladed beta-propeller': 74,
 'right-handed beta-helix': 137,
 'triple-helix': 8,
 'unknown': 11}

raw_df[raw_df["Folds"] == "unknown"]

The 11 unknown are actually triple helix depolymerase

> Check what is the label when they predict the same thing : 

In [26]:
good_inter = 0
bad_inter = 0
inter_count = 0

for _, row in raw_df.iterrows() : 
    tropigat_pred = set([pred.split(":")[0].strip() for pred in row["TropiGAT_predictions"].split(";")])
    tropiseq_pred = set([pred.split(":")[0].strip() for pred in row["TropiSeq_predictions"].split(";") if row["TropiSeq_predictions"] != "No hits"])
    inter = tropigat_pred.intersection(tropiseq_pred)
    if len(inter) > 0 :
        inter_count += 1 
        targets = set(row["Targets"].split(","))
        inter_tar = inter.intersection(targets)
        if len(inter_tar) > 0 :
            good_inter += 1
        else :
            bad_inter += 1 


In [27]:
good_inter , bad_inter , inter_count

(22, 95, 117)

In [50]:
20/82*100

24.390243902439025

***
## Make the matrices files : 

In [22]:
from tqdm import tqdm

top_n = 15
labels_tropigat = {}
count_kltypes = {}

for kltype in tqdm(dico_prophage_count) : 
    n = 0
    pred_labels , real_labels = [] , []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            matrix = dico_matrices[author]["matrix"]
            for phage in matrix["Phages"].unique() :
                top_predictions = set()
                predictions = tropigat_results[tropigat_results["Phage"] == phage]["TropiGAT_predictions"].values
                for calls in predictions : 
                    predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";")]
                    top_predictions.update(predicted_kltypes[0:top_n-1])
                if kltype in top_predictions : 
                    pred_labels.append(1)
                else : 
                    pred_labels.append(0)
                if kltype in dico_hits[phage] : 
                    real_labels.append(1)
                    n += 1 
                else :
                    real_labels.append(0)
    labels_tropigat[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}
    count_kltypes[kltype] = n 

labels_tropigat
sorted_dict = dict(sorted(labels_tropigat.items(), key=lambda item: int(item[0].split("KL")[1])))


100%|████████████████████████████████████████| 128/128 [00:01<00:00, 102.78it/s]


In [44]:
def make_labels(pred_df, predictor = "tropigat" , top_n = 15) :
    dico_pred = {"tropigat" : "TropiGAT_predictions" ,
                 "tropiseq" : "TropiSeq_predictions"}
    col = dico_pred[predictor]
    labels_tropigat = {}
    count_kltypes = {}
    for kltype in tqdm(dico_prophage_count) : 
        n = 0
        pred_labels , real_labels = [] , []
        for author in dico_matrices :
            if kltype in dico_matrices[author]["pool"] : 
                matrix = dico_matrices[author]["matrix"]
                for phage in matrix["Phages"].unique() :
                    top_predictions = set()
                    predictions = tropigat_results[pred_df["Phage"] == phage][col].values
                    for calls in predictions : 
                        predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";")]
                        top_predictions.update(predicted_kltypes[0:top_n-1])
                    if kltype in top_predictions : 
                        pred_labels.append(1)
                    else : 
                        pred_labels.append(0)
                    if kltype in dico_hits[phage] : 
                        real_labels.append(1)
                        n += 1 
                    else :
                        real_labels.append(0)
        labels_tropigat[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}
        sorted_dict = dict(sorted(labels_tropigat.items(), key=lambda item: int(item[0].split("KL")[1])))
        count_kltypes[kltype] = n 
    return sorted_dict


def decript_dic (sorted_dict) :
    for kltype in sorted_dict : 
        if len(labels_tropigat[kltype]["real_labels"]) > 1 : 
            if Counter(labels_tropigat[kltype]["y_pred"])[1] > 0 :
                f1 = f1_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                precision = precision_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                recall = recall_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"], average='binary')
                mcc = matthews_corrcoef(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                accuracy = accuracy_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                auc = roc_auc_score(labels_tropigat[kltype]["real_labels"], labels_tropigat[kltype]["y_pred"])
                print(kltype ,count_kltypes[kltype], round(f1,5),round(accuracy,5), round(recall,5),round(precision,5),round(auc,5), round(mcc,5), sep = "\t")
                #aucs.append(auc)

### Work on TropiGAT : 

> Full predictions : 

In [45]:
sorted_dic_tropigat = make_labels(tropigat_results)
decript_dic(sorted_dic_tropigat)

100%|████████████████████████████████████████| 128/128 [00:01<00:00, 116.49it/s]


KL2	9	0.5	0.92063	0.55556	0.45455	0.75214	0.46008
KL3	3	0.44444	0.77273	0.66667	0.33333	0.72807	0.35148
KL5	1	0.0	0.9697	0.0	0.0	0.49231	-0.01538
KL7	2	0.0	0.43939	0.0	0.0	0.22656	-0.18784
KL8	1	0.0	0.66667	0.0	0.0	0.33846	-0.08473
KL9	1	0.0	0.75758	0.0	0.0	0.38462	-0.06727
KL10	1	0.22222	0.89394	1.0	0.125	0.94615	0.33397
KL12	1	0.0	0.42424	0.0	0.0	0.21538	-0.1401
KL13	3	0.0	0.84615	0.0	0.0	0.43564	-0.06514
KL14	3	0.05405	0.32692	0.66667	0.02817	0.49175	-0.00593
KL15	1	0.04444	0.34848	1.0	0.02273	0.66923	0.08771
KL16	3	0.22222	0.88889	0.66667	0.13333	0.78049	0.26409
KL17	2	0.10256	0.4697	1.0	0.05405	0.72656	0.1565
KL18	1	0.11111	0.75758	1.0	0.05882	0.87692	0.21058
KL19	1	0.11111	0.75758	1.0	0.05882	0.87692	0.21058
KL21	1	0.18182	0.86364	1.0	0.1	0.93077	0.29352
KL22	10	0.21875	0.51923	0.7	0.12963	0.6	0.11801
KL23	1	0.04651	0.37879	1.0	0.02381	0.68462	0.09376
KL24	5	0.08451	0.375	0.6	0.04545	0.48182	-0.01615
KL25	3	0.26087	0.83654	1.0	0.15	0.91584	0.3532
KL27	4	0.09091	0.39394	0.5	0.05	0

> without propeller : 

In [38]:
tropigat_results_helix = tropigat_results[(~tropigat_results["Folds"].isin(["6-bladed beta-propeller"])) & (tropigat_results["Targets"].str.count("pass") == 0)]
tropigat_results_helix

Unnamed: 0,Phage,Protein,Folds,TropiGAT_predictions,TropiGAT_good_calls,TropiSeq_predictions,TropiSeq_good_calls,Targets
1,A1a,A1a_00014,right-handed beta-helix,KL128:0.987 ; KL29:0.979 ; KL70:0.958 ; KL24:0...,0,KL151: 0.698,KL151,KL151
2,A1b,A1b_00048,right-handed beta-helix,KL46:0.994 ; KL128:0.991 ; KL149:0.951 ; KL74:...,0,KL157: 0.729,KL157,KL157
5,A1c,A1c_00034,right-handed beta-helix,KL29:0.994 ; KL128:0.984 ; KL118:0.869 ; KL43:...,0,No hits,0,KL1
7,A1d,A1d_00009,right-handed beta-helix,KL28:0.988 ; KL43:0.972 ; KL74:0.969 ; KL3:0.9...,0,KL112: 0.966,0,KL20
13,A1g,A1g_00057,right-handed beta-helix,KL128:0.999 ; KL43:0.997 ; KL45:0.993 ; KL29:0...,KL16,KL16: 0.854,KL16,KL16
...,...,...,...,...,...,...,...,...
225,S13a,S13a_00036,right-handed beta-helix,KL60:0.97 ; KL46:0.96 ; KL27:0.95 ; KL23:0.922...,0,KL38: 0.822,0,"KL102,KL149"
226,S13b,S13b_00058,right-handed beta-helix,KL47:0.994 ; KL81:0.932 ; KL74:0.917 ; KL28:0....,0,KL63: 0.867,KL63,KL63
227,S13c,S13c_00055,right-handed beta-helix,KL12:0.992 ; KL38:0.98 ; KL128:0.974 ; KL27:0....,0,No hits,0,"KL102,KL149"
228,S13d,S13d_00057,right-handed beta-helix,KL14:0.986 ; KL128:0.968 ; KL46:0.965 ; KL15:0...,KL14,KL14: 0.736,KL14,KL14


In [None]:
from tqdm import tqdm

top_n = 15
labels_tropigat = {}
count_kltypes = {}

for kltype in tqdm(dico_prophage_count) : 
    n = 0
    pred_labels , real_labels = [] , []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            matrix = dico_matrices[author]["matrix"]
            for phage in matrix["Phages"].unique() :
                top_predictions = set()
                predictions = tropigat_results[tropigat_results["Phage"] == phage]["TropiGAT_predictions"].values
                for calls in predictions : 
                    predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";")]
                    top_predictions.update(predicted_kltypes[0:top_n-1])
                if kltype in top_predictions : 
                    pred_labels.append(1)
                else : 
                    pred_labels.append(0)
                if kltype in dico_hits[phage] : 
                    real_labels.append(1)
                    n += 1 
                else :
                    real_labels.append(0)
    labels_tropigat[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}
    count_kltypes[kltype] = n 

labels_tropigat
sorted_dict = dict(sorted(labels_tropigat.items(), key=lambda item: int(item[0].split("KL")[1])))


> Work on TropiSeq :

In [31]:
top_n = 15
labels_tropiseq = {}
for kltype in tqdm(dico_prophage_count) : 
    pred_labels , real_labels = [] , []
    for author in dico_matrices :
        if kltype in dico_matrices[author]["pool"] : 
            matrix = dico_matrices[author]["matrix"]
            for phage in matrix["Phages"].unique() :
                top_predictions = set()
                predictions = tropigat_results[tropigat_results["Phage"] == phage]["TropiSeq_predictions"].values
                for calls in predictions : 
                    predicted_kltypes = [x.split(":")[0].strip() for x in calls.split(";") if x != "No hits"]
                    top_predictions.update(predicted_kltypes[0:top_n-1])
                if kltype in top_predictions : 
                    pred_labels.append(1)
                else : 
                    pred_labels.append(0)
                if kltype in dico_hits[phage] : 
                    real_labels.append(1)
                else :
                    real_labels.append(0)
    labels_tropiseq[kltype] = {"y_pred" : pred_labels, "real_labels" : real_labels}

100%|████████████████████████████████████████| 128/128 [00:01<00:00, 116.32it/s]


In [32]:
aucs_tropiseq = []
for kltype in sorted_dict : 
    if len(labels_tropiseq[kltype]["real_labels"]) > 1 : 
        if Counter(labels_tropiseq[kltype]["y_pred"])[1] > 0 :
            f1 = f1_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            precision = precision_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            recall = recall_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"], average='binary')
            mcc = matthews_corrcoef(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            accuracy = accuracy_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            auc = roc_auc_score(labels_tropiseq[kltype]["real_labels"], labels_tropiseq[kltype]["y_pred"])
            print(kltype ,count_kltypes[kltype], round(f1,5),round(accuracy,5), round(recall,5),round(precision,5),round(auc,5), round(mcc,5), sep = "\t")
            aucs_tropiseq.append(auc)

KL2	9	0.42857	0.93651	0.33333	0.6	0.65812	0.41721
KL3	3	0.66667	0.90909	0.66667	0.66667	0.80702	0.61404
KL4	2	0.66667	0.98864	0.5	1.0	0.75	0.70303
KL9	1	0.66667	0.98485	1.0	0.5	0.99231	0.70165
KL13	3	0.5	0.98077	0.33333	1.0	0.66667	0.57172
KL14	3	0.5	0.98077	0.33333	1.0	0.66667	0.57172
KL15	1	1.0	1.0	1.0	1.0	1.0	1.0
KL16	3	1.0	1.0	1.0	1.0	1.0	1.0
KL17	2	0.66667	0.98485	0.5	1.0	0.75	0.70165
KL18	1	0.0	0.95455	0.0	0.0	0.48462	-0.02193
KL20	3	0.0	0.96154	0.0	0.0	0.49505	-0.01698
KL21	1	0.5	0.9697	1.0	0.33333	0.98462	0.5684
KL23	1	0.66667	0.98485	1.0	0.5	0.99231	0.70165
KL24	5	0.4	0.94231	0.4	0.4	0.68485	0.3697
KL25	3	0.85714	0.99038	1.0	0.75	0.99505	0.86173
KL27	4	0.4	0.95455	0.25	1.0	0.625	0.48833
KL28	10	0.125	0.88889	0.1	0.16667	0.52845	0.07222
KL30	3	0.33333	0.96154	0.33333	0.33333	0.65677	0.31353
KL31	1	0.0	0.9697	0.0	0.0	0.49231	-0.01538
KL35	7	0.57143	0.93182	0.57143	0.57143	0.7672	0.53439
KL39	3	0.0	0.92424	0.0	0.0	0.48413	-0.03858
KL43	2	0.0	0.95455	0.0	0.0	0.49219	-0.02193
KL48	

In [95]:
kltype_see = "KL43"
labels_zip_seq = tuple(zip(labels_tropiseq[kltype_see]["y_pred"], labels_tropiseq[kltype_see]["real_labels"]))
labels_zip_gat = tuple(zip(labels_tropigat[kltype_see]["y_pred"], labels_tropigat[kltype_see]["real_labels"]))

# tuple(zip(labels_zip_seq , labels_zip_gat))

In [73]:
import statistics

mean = statistics.mean(aucs_tropiseq)
print(mean)

0.7271412423345761


***
# Good calls : 

In [76]:
# tailored_bit50
# 1512
raw_df = pd.read_csv(f"{path_project}/raw_metrics.classic_1001.bit75.tsv" , sep = "\t", header = 0)
raw_df

Unnamed: 0,Phage,Protein,Folds,TropiGAT_predictions,TropiGAT_good_calls,TropiSeq_predictions,TropiSeq_good_calls,Targets
0,A1a,A1a_00002,6-bladed beta-propeller,KL111:0.983 ; KL123:0.982 ; KL45:0.973 ; KL24:...,0,KL102: 0.691,0,KL151
1,A1a,A1a_00014,right-handed beta-helix,KL128:0.987 ; KL29:0.979 ; KL70:0.958 ; KL24:0...,0,KL151: 0.698,KL151,KL151
2,A1b,A1b_00048,right-handed beta-helix,KL46:0.994 ; KL128:0.991 ; KL149:0.951 ; KL74:...,0,KL157: 0.729,KL157,KL157
3,A1b,A1b_00036,6-bladed beta-propeller,KL123:0.998 ; KL111:0.983 ; KL128:0.982 ; KL45...,0,KL102: 0.691,0,KL157
4,A1c,A1c_00046,6-bladed beta-propeller,KL123:0.994 ; KL24:0.982 ; KL45:0.953 ; KL111:...,0,KL102: 0.691,0,KL1
...,...,...,...,...,...,...,...,...
225,S13a,S13a_00036,right-handed beta-helix,KL60:0.97 ; KL46:0.96 ; KL27:0.95 ; KL23:0.922...,0,KL38: 0.822,0,"KL102,KL149"
226,S13b,S13b_00058,right-handed beta-helix,KL47:0.994 ; KL81:0.932 ; KL74:0.917 ; KL28:0....,0,KL63: 0.867,KL63,KL63
227,S13c,S13c_00055,right-handed beta-helix,KL12:0.992 ; KL38:0.98 ; KL128:0.974 ; KL27:0....,0,No hits,0,"KL102,KL149"
228,S13d,S13d_00057,right-handed beta-helix,KL14:0.986 ; KL128:0.968 ; KL46:0.965 ; KL15:0...,KL14,KL14: 0.736,KL14,KL14


In [77]:

total_calls = 0
TropiGAT_good_calls = []
TropiSeq_good_calls = []
top_n = 15

for phage in raw_df["Phage"].unique() : 
    tmp_df = raw_df[raw_df["Phage"] == phage]
    targets_set = set(tmp_df["Targets"].tolist()[0].split(","))
    n_targets = len(targets_set)
    if "".join((list(targets_set)[0])).count("pass") == 0 : 
        total_calls += n_targets
        # TropiGAT calls : 
        call_tropigat = set()
        for calls in tmp_df["TropiGAT_predictions"] :
            top_n_predictions = [x.split(":")[0] for x in calls.split(";")][0:top_n-1]
            call_tropigat.update(top_n_predictions)
        good_calls_tropigat = call_tropigat.intersection(targets_set)
        if len(good_calls_tropigat) > 0 :
            for kltype in good_calls_tropigat : 
                a = (phage , kltype)
                TropiGAT_good_calls.append(a)
        # TropiSeq calls : 
        call_tropiseq = set()
        for calls in tmp_df["TropiSeq_predictions"] :
            top_n_predictions = [x.split(":")[0] for x in calls.split(";")][0:top_n-1]
            call_tropiseq.update(top_n_predictions)
        good_calls_tropiseq = call_tropiseq.intersection(targets_set)
        if len(good_calls_tropiseq) > 0 :
            for kltype in good_calls_tropiseq : 
                a = (phage , kltype)
                TropiSeq_good_calls.append(a)
        




In [78]:
len(TropiGAT_good_calls) , len(TropiSeq_good_calls)


(22, 40)

In [79]:
len(TropiGAT_good_calls) , len(TropiSeq_good_calls)


tropigat_list = [call[1] for call in TropiGAT_good_calls]
tropiseq_list = [call[1] for call in TropiSeq_good_calls]

Counter(tropiseq_list)

TropiGAT_good_calls

[('D7c', 'KL28'),
 ('K15PH90', 'KL15'),
 ('K17alfa61', 'KL17'),
 ('K17alfa62', 'KL62'),
 ('K17alfa62', 'KL17'),
 ('K19PH14C4P1', 'KL19'),
 ('K24PH164C1', 'KL24'),
 ('K26PH128C1', 'KL74'),
 ('K27PH129C1', 'KL27'),
 ('K38PH09C2', 'KL38'),
 ('K43PH164C1', 'KL43'),
 ('K54lambda1_1_1', 'KL24'),
 ('K57lambda1_2', 'KL57'),
 ('K60PH164C1', 'KL18'),
 ('K60PH164C1', 'KL60'),
 ('K74PH129C2', 'KL74'),
 ('NBNDMPCG', 'KL2'),
 ('OPBIHMGG', 'KL3'),
 ('PFOEGONH', 'KL3'),
 ('PP187', 'KL110'),
 ('S11a', 'KL25'),
 ('S13d', 'KL14')]

In [80]:
all_calls = set(TropiGAT_good_calls).union(set(TropiSeq_good_calls))
len(all_calls)

50

In [81]:
total_calls

213

In [82]:
all_calls

{('A1a', 'KL151'),
 ('A1b', 'KL157'),
 ('A1g', 'KL16'),
 ('A1h', 'KL13'),
 ('A1h', 'KL2'),
 ('A3b', 'KL30'),
 ('D7b', 'KL140'),
 ('D7c', 'KL28'),
 ('K11PH164C1', 'KL57'),
 ('K15PH90', 'KL15'),
 ('K16PH164C3', 'KL16'),
 ('K17alfa61', 'KL17'),
 ('K17alfa62', 'KL17'),
 ('K17alfa62', 'KL62'),
 ('K19PH14C4P1', 'KL19'),
 ('K21lambda1', 'KL21'),
 ('K23PH08C2', 'KL23'),
 ('K24PH164C1', 'KL24'),
 ('K25PH129C1', 'KL25'),
 ('K26PH128C1', 'KL74'),
 ('K27PH129C1', 'KL27'),
 ('K2PH164C2', 'KL2'),
 ('K35PH164C3', 'KL35'),
 ('K38PH09C2', 'KL38'),
 ('K43PH164C1', 'KL43'),
 ('K4PH164', 'KL4'),
 ('K51PH129C1', 'KL51'),
 ('K54lambda1_1_1', 'KL24'),
 ('K57lambda1_2', 'KL57'),
 ('K58PH129C2', 'KL58'),
 ('K60PH164C1', 'KL18'),
 ('K60PH164C1', 'KL60'),
 ('K63PH128', 'KL63'),
 ('K64PH164C4', 'KL64'),
 ('K74PH129C2', 'KL74'),
 ('K9PH25C2', 'KL9'),
 ('KBDEFBCI', 'KL35'),
 ('NBNDMPCG', 'KL2'),
 ('NBNDMPCG', 'KL35'),
 ('NJHLHPIG', 'KL35'),
 ('OPBIHMGG', 'KL3'),
 ('P4a', 'KL140'),
 ('P4b', 'KL140'),
 ('PFOEGONH', '

In [83]:
50/213*100

23.474178403755868

In [None]:
# Adjust the number of correct calls : remove the KLtypes for which there are no models ?
# Do the analysis for the new dpos 
# Check the folds
# Repeat the predictions with the final tailored

# Fine tune the final tailored
# Repeat the predictions 


In [84]:
TropiGAT_good_calls

[('D7c', 'KL28'),
 ('K15PH90', 'KL15'),
 ('K17alfa61', 'KL17'),
 ('K17alfa62', 'KL62'),
 ('K17alfa62', 'KL17'),
 ('K19PH14C4P1', 'KL19'),
 ('K24PH164C1', 'KL24'),
 ('K26PH128C1', 'KL74'),
 ('K27PH129C1', 'KL27'),
 ('K38PH09C2', 'KL38'),
 ('K43PH164C1', 'KL43'),
 ('K54lambda1_1_1', 'KL24'),
 ('K57lambda1_2', 'KL57'),
 ('K60PH164C1', 'KL18'),
 ('K60PH164C1', 'KL60'),
 ('K74PH129C2', 'KL74'),
 ('NBNDMPCG', 'KL2'),
 ('OPBIHMGG', 'KL3'),
 ('PFOEGONH', 'KL3'),
 ('PP187', 'KL110'),
 ('S11a', 'KL25'),
 ('S13d', 'KL14')]

> At least one positive prediction for each phage : 

In [85]:
towndsend_df

Unnamed: 0,Phages,Protein,Target
0,BLCJPOBP,BLCJPOBP__00041,KL2
1,DIMCIIMF,DIMCIIMF__00240,KL28
2,DIMCIIMF,DIMCIIMF__00039,KL28
3,DJLANJJD,DJLANJJD__00238,pass (baseplate)
4,EHPPICDA,EHPPICDA__00095,pass (baseplate)
5,EKPIEFBL,EKPIEFBL__00177,pass (baseplate)
6,EKPIEFBL,EKPIEFBL__00113,pass (baseplate)
7,EONHMLJF,EONHMLJF__00087,pass (KL107 no confidence)
8,FADJDIKG,FADJDIKG__00083,pass (baseplate)
9,FADJDIKG,FADJDIKG__00019,pass (baseplate)


In [86]:
raw_df

Unnamed: 0,Phage,Protein,Folds,TropiGAT_predictions,TropiGAT_good_calls,TropiSeq_predictions,TropiSeq_good_calls,Targets
0,A1a,A1a_00002,6-bladed beta-propeller,KL111:0.983 ; KL123:0.982 ; KL45:0.973 ; KL24:...,0,KL102: 0.691,0,KL151
1,A1a,A1a_00014,right-handed beta-helix,KL128:0.987 ; KL29:0.979 ; KL70:0.958 ; KL24:0...,0,KL151: 0.698,KL151,KL151
2,A1b,A1b_00048,right-handed beta-helix,KL46:0.994 ; KL128:0.991 ; KL149:0.951 ; KL74:...,0,KL157: 0.729,KL157,KL157
3,A1b,A1b_00036,6-bladed beta-propeller,KL123:0.998 ; KL111:0.983 ; KL128:0.982 ; KL45...,0,KL102: 0.691,0,KL157
4,A1c,A1c_00046,6-bladed beta-propeller,KL123:0.994 ; KL24:0.982 ; KL45:0.953 ; KL111:...,0,KL102: 0.691,0,KL1
...,...,...,...,...,...,...,...,...
225,S13a,S13a_00036,right-handed beta-helix,KL60:0.97 ; KL46:0.96 ; KL27:0.95 ; KL23:0.922...,0,KL38: 0.822,0,"KL102,KL149"
226,S13b,S13b_00058,right-handed beta-helix,KL47:0.994 ; KL81:0.932 ; KL74:0.917 ; KL28:0....,0,KL63: 0.867,KL63,KL63
227,S13c,S13c_00055,right-handed beta-helix,KL12:0.992 ; KL38:0.98 ; KL128:0.974 ; KL27:0....,0,No hits,0,"KL102,KL149"
228,S13d,S13d_00057,right-handed beta-helix,KL14:0.986 ; KL128:0.968 ; KL46:0.965 ; KL15:0...,KL14,KL14: 0.736,KL14,KL14


In [87]:
total_calls = 0
TropiGAT_good_calls = []
TropiSeq_good_calls = []
top_n = 15

for phage in raw_df["Phage"].unique() : 
    if phage in ferriol_df["Phages"].tolist() or towndsend_df["Phages"].tolist():
        kltype = "ok"
        tmp_df = raw_df[raw_df["Phage"] == phage]
        targets_set = set(tmp_df["Targets"].tolist()[0].split(","))
        n_targets = len(targets_set)
        if "".join((list(targets_set)[0])).count("pass") == 0 : 
            total_calls += 1
            # TropiGAT calls : 
            call_tropigat = set()
            for calls in tmp_df["TropiGAT_predictions"] :
                top_n_predictions = [x.split(":")[0] for x in calls.split(";")][0:top_n-1]
                call_tropigat.update(top_n_predictions)
            good_calls_tropigat = call_tropigat.intersection(targets_set)
            if len(good_calls_tropigat) > 0 :
                a = (phage , kltype)
                TropiGAT_good_calls.append(a)
            # TropiSeq calls : 
            call_tropiseq = set()
            for calls in tmp_df["TropiSeq_predictions"] :
                top_n_predictions = [x.split(":")[0] for x in calls.split(";")][0:top_n-1]
                call_tropiseq.update(top_n_predictions)
            good_calls_tropiseq = call_tropiseq.intersection(targets_set)
            if len(good_calls_tropiseq) > 0 :
                a = (phage , kltype)
                TropiSeq_good_calls.append(a)

In [88]:
len(TropiGAT_good_calls) , len(TropiSeq_good_calls)

(20, 38)

In [89]:
total_calls

112

In [90]:
all_calls = set(TropiGAT_good_calls).union(set(TropiSeq_good_calls))
len(all_calls)

45

In [91]:
TropiSeq_good_calls

[('A1a', 'ok'),
 ('A1b', 'ok'),
 ('A1g', 'ok'),
 ('A1h', 'ok'),
 ('A3b', 'ok'),
 ('D7b', 'ok'),
 ('D7c', 'ok'),
 ('K11PH164C1', 'ok'),
 ('K15PH90', 'ok'),
 ('K16PH164C3', 'ok'),
 ('K17alfa62', 'ok'),
 ('K21lambda1', 'ok'),
 ('K23PH08C2', 'ok'),
 ('K24PH164C1', 'ok'),
 ('K25PH129C1', 'ok'),
 ('K27PH129C1', 'ok'),
 ('K2PH164C2', 'ok'),
 ('K35PH164C3', 'ok'),
 ('K4PH164', 'ok'),
 ('K51PH129C1', 'ok'),
 ('K54lambda1_1_1', 'ok'),
 ('K57lambda1_2', 'ok'),
 ('K58PH129C2', 'ok'),
 ('K63PH128', 'ok'),
 ('K64PH164C4', 'ok'),
 ('K9PH25C2', 'ok'),
 ('KBDEFBCI', 'ok'),
 ('NBNDMPCG', 'ok'),
 ('NJHLHPIG', 'ok'),
 ('OPBIHMGG', 'ok'),
 ('P4a', 'ok'),
 ('P4b', 'ok'),
 ('PFOEGONH', 'ok'),
 ('PP187', 'ok'),
 ('S10a', 'ok'),
 ('S11a', 'ok'),
 ('S13b', 'ok'),
 ('S13d', 'ok')]