In [17]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import TropiGAT_functions 
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/winning_ensemble_1202"
path_ensemble_opt = f"{path_work}/ficheros_28032023/winning_ensemble_12092024_optimized"

#path_ensemble = f"{path_work}/ficheros_28032023/ensemble_tailored_0612"
dico_regular_head = {kltype.split(".")[0]:{"para_heads": 5} for kltype in os.listdir(path_ensemble)}

In [2]:
dico_regular_head_opt = {
    "KL64": {
        "para_heads": 5,  
        "para_lr": 0.000246, 
        "para_wd": 0.000080,
        "para_dropout": 0.063113}, 
    "KL1": {
        "para_heads": 5,  
        "para_lr":  0.0009415397708661039,
        "para_wd": 1.132790862878068e-06,
        "para_dropout": 0.007657626670776924},
    "KL10": {
        "para_heads": 2,  
        "para_lr": 0.0006633594884735811,
        "para_wd": 3.7430295738223034e-06,
        "para_dropout":0.4493747213067273 }, 
    "KL15": {
        "para_heads": 5,  
        "para_lr": 0.00017766142057218653,
        "para_wd": 5.245213610566463e-05,
        "para_dropout":0.15214512795626994 }, 
    "KL17": {
        "para_heads": 2,  
        "para_lr": 0.0002068133316219641,
        "para_wd": 5.303964308479191e-05,
        "para_dropout":0.4810681327179018 }, 
    "KL19": {
        "para_heads": 5,  
        "para_lr": 0.00028386856144729176,
        "para_wd": 6.667568504410857e-07,
        "para_dropout":0.4460345479421262 }, 
    "KL2": {
        "para_heads": 2,  
        "para_lr": 0.0006115983973072073,
        "para_wd": 3.521041854903662e-06,
        "para_dropout":0.16320044607028428 }, 
    "KL47": {
        "para_heads": 2,  
        "para_lr": 0.0007352151826846244,
        "para_wd": 8.666317429082471e-06,
        "para_dropout":0.1877399746783721 }, 
    "KL74": {
        "para_heads": 1,  
        "para_lr": 0.0004137122657073261,
        "para_wd": 3.5238343953806846e-05,
        "para_dropout":0.3464829958840639 }, 
}


In [3]:
len(dico_regular_head_opt), len(os.listdir(path_ensemble_opt))

(9, 9)

> Make model : 

In [18]:
dico_models, errors_reg = TropiGAT_functions.make_ensemble_TropiGAT_optimized(path_ensemble, dico_regular_head)
dico_models_opt, errors_opt = TropiGAT_functions.make_ensemble_TropiGAT_optimized(path_ensemble_opt, dico_regular_head_opt)

> Ferriol

In [8]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.2406.csv", sep = "," , header = None, index_col = 0)


> Bea

In [12]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

bea_embeddings = pd.read_csv(f"{path_project}/Bea_phages.esm2.embedding.csv", sep = "," , header = None)
bea_embeddings = bea_embeddings.drop([1281] , axis = 1)
bea_embeddings.set_index([0], inplace = True)


> Towndsend

In [11]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

towndsend_embeddings = pd.read_csv(f"{path_project}/Townsed_phages.esm2.embedding.1112.csv", sep = "," , header = None)
towndsend_embeddings = towndsend_embeddings.drop([1281] , axis = 1)
towndsend_embeddings.set_index([0], inplace = True)


***
> run the predictions 

Standard

In [20]:
# Run the predictions ferriol :
ferriol_predictions = {}
for dpo in dpo_embeddings.index : 
    try : 
        graph_dpo = TropiGAT_functions.make_query_graph([dpo_embeddings.loc[dpo].values])
        pred = TropiGAT_functions.run_prediction(graph_dpo, dico_models)
        ferriol_predictions[dpo] = pred
    except Exception as e :
        print(dpo, e)

In [21]:
ferriol_predictions["K17alfa62__cds_66_Dpo_domain"]

{'KL12': 0.7018,
 'KL142': 0.8005,
 'KL9': 0.9277,
 'KL169': 0.5998,
 'KL128': 0.9998,
 'KL46': 0.8162,
 'KL36': 0.7505,
 'KL34': 0.9996,
 'KL52': 0.9946,
 'KL62': 1.0,
 'KL109': 0.9145,
 'KL14': 0.7897}

In [23]:
ferriol_predictions

{'K15PH90__cds_54_Dpo_domain': {'KL124': 0.9131,
  'KL6': 0.9956,
  'KL57': 1.0,
  'KL74': 0.9844,
  'KL1': 0.9777,
  'KL45': 0.9611,
  'KL125': 0.9949,
  'KL43': 0.9348,
  'KL103': 1.0,
  'KL38': 0.9987,
  'KL116': 0.9998,
  'KL8': 0.9999,
  'KL145': 0.918,
  'KL39': 0.9242,
  'KL149': 0.9844,
  'KL112': 0.7082,
  'KL63': 0.936,
  'KL60': 0.9993,
  'KL155': 0.9997,
  'KL106': 1.0,
  'KL21': 0.7232,
  'KL153': 0.79,
  'KL31': 0.8904,
  'KL109': 0.9813,
  'KL14': 0.9096},
 'K7PH164C4__cds_20_Dpo_domain': {'KL111': 0.5666,
  'KL30': 0.8581,
  'KL142': 0.9961,
  'KL9': 0.9396,
  'KL124': 0.9993,
  'KL57': 1.0,
  'KL74': 1.0,
  'KL47': 0.9999,
  'KL117': 0.9352,
  'KL108': 0.9996,
  'KL125': 1.0,
  'KL128': 0.9663,
  'KL43': 0.8053,
  'KL136': 0.6705,
  'KL123': 0.9995,
  'KL46': 0.9962,
  'KL36': 1.0,
  'KL103': 1.0,
  'KL116': 1.0,
  'KL8': 0.9977,
  'KL34': 0.5395,
  'KL149': 1.0,
  'KL71': 0.9957,
  'KL2': 0.9926,
  'KL112': 0.7514,
  'KL63': 0.9991,
  'KL155': 0.5407,
  'KL106': 1.0,


In [22]:
# format the results : 
ferriol_pred_formated = TropiGAT_functions.format_predictions(ferriol_predictions , sep = "__")
TropiGAT_functions.clean_print(ferriol_pred_formated)

{'K10PH82C1': {'KL102': 0.9998,
               'KL103': 0.9136,
               'KL108': 1.0,
               'KL109': 0.992,
               'KL110': 0.8868,
               'KL111': 1.0,
               'KL112': 0.9991,
               'KL114': 0.9142,
               'KL116': 0.9976,
               'KL117': 0.9985,
               'KL118': 0.5297,
               'KL12': 0.9999,
               'KL122': 0.9986,
               'KL123': 0.9997,
               'KL124': 0.9996,
               'KL125': 0.9548,
               'KL127': 1.0,
               'KL128': 1.0,
               'KL13': 0.9998,
               'KL136': 0.9614,
               'KL14': 0.9971,
               'KL140': 0.9744,
               'KL142': 0.7919,
               'KL149': 0.9941,
               'KL15': 0.6394,
               'KL151': 0.9449,
               'KL153': 0.9983,
               'KL16': 0.9963,
               'KL169': 0.9862,
               'KL17': 0.971,
               'KL18': 0.9991,
               'KL19': 0.9939

Optimized

***

In [24]:
# Run the predictions Bea :
bea_predictions = {}
for dpo in bea_embeddings.index : 
    graph_dpo = TropiGAT_functions.make_query_graph([bea_embeddings.loc[dpo].values])
    pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
    bea_predictions[dpo] = pred

In [None]:
# format the results : 
bea_pred_formated = TropiGAT_functions.format_predictions(bea_predictions , sep = "_")
TropiGAT_functions.clean_print(bea_pred_formated)

In [25]:
# Run the predictions Towndsend :
towndsend_predictions = {}
for dpo in towndsend_embeddings.index : 
    graph_dpo = TropiGAT_functions.make_query_graph([towndsend_embeddings.loc[dpo].values])
    pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
    towndsend_predictions[dpo] = pred

In [None]:
# format the results : 
towndsend_pred_formated = TropiGAT_functions.format_predictions(towndsend_predictions , sep = "_")
TropiGAT_functions.clean_print(towndsend_pred_formated)

***

In [49]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

others_embeddings = pd.read_csv(f"{path_project}/Others_all.esm2.embedding.csv", sep = "," , header = None)
others_embeddings.set_index([0], inplace = True)

In [50]:
others_embeddings.info()

<class 'pandas.core.frame.DataFrame'>
Index: 56 entries, NC_025418.1_prot_YP_009098385.1_34 to ON146449.1_prot_UPW35138.1_1
Columns: 1280 entries, 1 to 1280
dtypes: float64(1280)
memory usage: 560.4+ KB


In [51]:
others_embeddings.index

Index(['NC_025418.1_prot_YP_009098385.1_34',
       'NC_029099.1_prot_YP_009226011.1_50', 'MT966873.1_prot_QOV05502.1_43',
       'NC_013649.2_prot_YP_003347651.1_57', 'AB716666.1_prot_BAP15736.1_24',
       'MK903728.1_prot_QDF14645.1_43', 'NC_031246.1_prot_YP_009302756.1_52',
       'MZ826764.1_prot_UCR74083.1_31', 'MW655991.1_prot_QUU29414.1_2',
       'AB897757.1_prot_BAQ02839.1_59', 'NC_031246.1_prot_YP_009302745.1_41',
       'AB897757.1_prot_BAQ02841.1_61', 'MZ826764.1_prot_UCR74085.1_33',
       'NC_025418.1_prot_YP_009098375.1_24', 'AB897757.1_prot_BAQ02844.1_64',
       'MK903728.1_prot_QDF14639.1_37', 'AB897757.1_prot_BAQ02843.1_63',
       'AB897757.1_prot_BAQ02838.1_58', 'MT542697.1_prot_QKY78353.1_44',
       'MZ826764.1_prot_UCR74084.1_32', 'ON146449.1_prot_UPW35150.1_13',
       'MK903728.1_prot_QDF14644.1_42', 'AB716666.1_prot_BAP15746.1_34',
       'MT542697.1_prot_QKY78347.1_38', 'MZ826764.1_prot_UCR74082.1_30',
       'MT966873.1_prot_QOV05496.1_37', 'MN781108.1_pro

In [52]:
# Run the predictions Others :
other_predictions = {}
for index, dpo in enumerate(others_embeddings.index) :
    if dpo not in ["MN781108.1_prot_QGZ15323.1_262"] :
        try : 
            graph_dpo = TropiGAT_functions.make_query_graph([others_embeddings.loc[dpo].values])
            pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
            other_predictions[dpo] = pred
        except Exception as e :
            print(e, dpo)

 ON146449.1_prot_UPW35150.1_13
 ON146449.1_prot_UPW35138.1_1
 ON146449.1_prot_UPW35150.1_13
 ON146449.1_prot_UPW35138.1_1


In [53]:
other_predictions

{'NC_025418.1_prot_YP_009098385.1_34': {'KL142': 0.5892,
  'KL19': 0.9123,
  'KL6': 0.5754,
  'KL29': 0.8795,
  'KL43': 0.9364,
  'KL145': 0.6923,
  'KL140': 0.992,
  'KL34': 0.9997,
  'KL13': 0.8376,
  'KL52': 0.9813,
  'KL21': 0.7839,
  'KL62': 0.9994,
  'KL109': 0.897,
  'KL22': 0.5955,
  'KL14': 0.9941},
 'NC_029099.1_prot_YP_009226011.1_50': {'KL23': 0.522,
  'KL142': 0.9751,
  'KL9': 0.7554,
  'KL15': 0.935,
  'KL19': 0.9125,
  'KL6': 0.7027,
  'KL18': 0.8889,
  'KL57': 0.8899,
  'KL74': 0.7166,
  'KL157': 0.907,
  'KL47': 0.9537,
  'KL108': 0.9897,
  'KL43': 0.9706,
  'KL36': 0.8661,
  'KL103': 0.9961,
  'KL28': 0.9786,
  'KL8': 0.9747,
  'KL34': 0.9999,
  'KL149': 0.9734,
  'KL71': 0.6871,
  'KL63': 0.9063,
  'KL155': 0.9998,
  'KL70': 0.8962,
  'KL21': 0.9459,
  'KL26': 0.9751,
  'KL109': 0.9989},
 'MT966873.1_prot_QOV05502.1_43': {'KL20': 0.6084,
  'KL6': 0.8573,
  'KL18': 0.8924,
  'KL57': 0.9902,
  'KL29': 0.7542,
  'KL74': 0.7665,
  'KL1': 0.9778,
  'KL157': 0.9305,
  'KL4

In [None]:
# format the results : 
others_pred_formated = TropiGAT_functions.format_predictions(other_predictions , sep = "_prot_")
TropiGAT_functions.clean_print(other_predictions)

***
# Write the results : 

> Others : 

In [54]:
predictions = [other_predictions]

with open("/media/concha-eloko/Linux/PPT_clean/TropiGAT.Others.all.results.standard.tsv", "w") as outfile:
    for prediction in predictions:
        prediction_sorted = dict(sorted(prediction.items()))
        for prot in prediction_sorted:
            if prediction_sorted[prot] == "No hits" or len(prediction_sorted[prot]) == 0:
                outfile.write(f"{prot}\tNo hits\n")
            else:
                outfile.write(f"{prot}\t")
                hits = [f"{kltype}:{round(score, 3)}" for kltype, score in prediction_sorted[prot].items()]
                sorted_hits = " ; ".join(sorted(hits, key=lambda x: float(x.split(":")[1]), reverse=True))
                outfile.write(sorted_hits)
                outfile.write("\n")

In [55]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean"

tropigat_others = pd.read_csv("/media/concha-eloko/Linux/PPT_clean/TropiGAT.Others.all.results.standard.tsv", sep = "\t", names = ["protein", "predicitons"])
tropiseq_others = pd.read_csv("/media/concha-eloko/Linux/PPT_clean/Seqbased_model.results.bit75.2406.Others.tsv", sep = "\t", names = ["protein", "predicitons"])
merged_df = tropigat_others.merge(tropiseq_others, on='protein', how='inner')

info_others_df = pd.read_csv(f"{path_project}/in_vitro/other_naming_KL.tsv", sep = "\t", header = 0)
info_others_df["protein"] = info_others_df["Proteins"].apply(lambda x : x if x.count("(")==0 else x.split(" ")[0])


In [56]:
merged_df = tropigat_others.merge(tropiseq_others, on='protein', how='inner')
merged_df_sorted = merged_df.sort_values(by='protein', ascending=True)


In [57]:
merged_df_sorted["Target"] = merged_df_sorted["protein"].apply(lambda x : info_others_df[info_others_df["protein"] == x]["Target KLtype"].values[0])
merged_df_sorted["Phage"] = merged_df_sorted["protein"].apply(lambda x : info_others_df[info_others_df["protein"] == x]["Phage"].values[0])

merged_df_sorted.to_csv(f"{path_project}/Other_predictions.standard.raw.tsv", sep ="\t", index = False, header = True)

In [58]:
merged_df_sorted

Unnamed: 0,protein,predicitons_x,predicitons_y,Target,Phage
0,AB716666.1_prot_BAP15736.1_24,KL56:1.0 ; KL4:1.0 ; KL123:1.0 ; KL7:1.0 ; KL1...,KL102:0.737,K1,Klebsiella phage NTUH-K2044-K1-1
1,AB716666.1_prot_BAP15746.1_34,KL34:1.0 ; KL62:0.999 ; KL52:0.994 ; KL140:0.9...,No_hits,K1,Klebsiella phage NTUH-K2044-K1-1
2,AB897757.1_prot_BAQ02835.1_55,KL52:0.998 ; KL145:0.997 ; KL70:0.992 ; KL12:0...,No_associations,K11*,Klebsiella Phage ΦK64-1
3,AB897757.1_prot_BAQ02836.1_56,KL128:1.0 ; KL12:0.998 ; KL70:0.997 ; KL47:0.9...,KL70:0.888 ; KL21:0.674,KN4†,Klebsiella Phage ΦK64-1
4,AB897757.1_prot_BAQ02837.1_57,KL128:1.0 ; KL26:1.0 ; KL151:0.998 ; KL103:0.9...,KL123:0.711,K21*,Klebsiella Phage ΦK64-1
5,AB897757.1_prot_BAQ02838.1_58,KL128:1.0 ; KL151:0.999 ; KL109:0.999 ; KL26:0...,No_hits,KN5*,Klebsiella Phage ΦK64-1
6,AB897757.1_prot_BAQ02839.1_59,KL6:0.997 ; KL26:0.995 ; KL29:0.992 ; KL8:0.99...,No_associations,K25†,Klebsiella Phage ΦK64-1
7,AB897757.1_prot_BAQ02840.1_60,KL128:1.0 ; KL36:1.0 ; KL26:1.0 ; KL151:0.999 ...,KL35:0.66,K35†,Klebsiella Phage ΦK64-1
8,AB897757.1_prot_BAQ02841.1_61,KL81:0.999 ; KL63:0.998 ; KL14:0.996 ; KL109:0...,No_hits,K1*,Klebsiella Phage ΦK64-1
9,AB897757.1_prot_BAQ02842.1_62,KL128:1.0 ; KL103:1.0 ; KL64:1.0 ; KL155:1.0 ;...,KL64:0.904,K64*,Klebsiella Phage ΦK64-1


> Predictions : 

In [26]:
predictions = [ferriol_predictions , bea_predictions , towndsend_predictions]

with open("/media/concha-eloko/Linux/PPT_clean/TropiGAT.results.standard.1309.tsv", "w") as outfile:
    for prediction in predictions:
        for prot in prediction:
            if prediction[prot] == "No hits" or len(prediction[prot]) == 0:
                outfile.write(f"{prot}\tNo hits\n")
            else:
                outfile.write(f"{prot}\t")
                hits = [f"{kltype}:{round(score, 3)}" for kltype, score in prediction[prot].items()]
                sorted_hits = " ; ".join(sorted(hits, key=lambda x: float(x.split(":")[1]), reverse=True))
                outfile.write(sorted_hits)
                outfile.write("\n")

> Parse TropiGAT and Seqbased results :

In [38]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean"

tropigat_results = pd.read_csv(f"{path_project}/TropiGAT.results.standard.1309.tsv", sep = "\t", names = ["protein", "predictions_tropigat"])

# Seqbased_model.results.bit50.0101.tsv
# Seqbased_model.0101.results.tsv
# Seqbased_model.1001.results.tsv
seqbased_results = pd.read_csv(f"{path_project}/Seqbased_model.results.bit75.2406.tsv", sep = "\t", names = ["protein", "predictions_seqbased"])


In [39]:
seqbased_results

Unnamed: 0,protein,predictions_seqbased
0,K10PH82C1__cds_49,No_hits
1,K13PH07C1L__cds_10,KL3:0.511
2,K13PH07C1L__cds_11,KL13:0.527
3,K13PH07C1L__cds_12,No_associations
4,K15PH90__cds_54,No_hits
...,...,...
255,A2a_b_00022,No_associations
256,A2a_b_00036,KL102:0.737
257,A1i_00037,KL102:0.737
258,A1i_00041,KL48:0.568


In [31]:
tropigat_results

Unnamed: 0,protein,predictions_tropigat
0,K15PH90__cds_54_Dpo_domain,KL57:1.0 ; KL103:1.0 ; KL116:1.0 ; KL8:1.0 ; K...
1,K7PH164C4__cds_20_Dpo_domain,KL57:1.0 ; KL74:1.0 ; KL47:1.0 ; KL108:1.0 ; K...
2,K32PH164C1__cds_20_Dpo_domain,KL74:1.0 ; KL46:1.0 ; KL36:1.0 ; KL103:1.0 ; K...
3,K18PH07C1__cds_245_Dpo_domain,KL63:1.0 ; KL6:0.999 ; KL52:0.999 ; KL39:0.994...
4,K13PH07C1L__cds_11_Dpo_domain,KL36:1.0 ; KL13:1.0 ; KL31:0.999 ; KL46:0.992 ...
...,...,...
254,NBNDMPCG_00163,KL36:0.999 ; KL2:0.993 ; KL71:0.99 ; KL18:0.98...
255,NJHLHPIG_00061,KL46:1.0 ; KL18:0.999 ; KL128:0.999 ; KL67:0.9...
256,HIIECEMK_00054,KL45:0.999 ; KL6:0.996 ; KL2:0.996 ; KL71:0.99...
257,PP187_gp237,KL103:1.0 ; KL63:1.0 ; KL109:1.0 ; KL71:0.998 ...


In [40]:
tropigat_results["protein_id"] = tropigat_results["protein"].apply(lambda x : x.split("_Dpo")[0])
seqbased_results["protein_id"] = seqbased_results["protein"].apply(lambda x : x.split("_A")[0] if x.count("_A")>0 else "__".join(x.split(",")[0].split(" ")) if x.count(",")>0 else x)

merged_df = tropigat_results.merge(seqbased_results, on='protein_id', how='inner')
merged_df["phage"] = merged_df["protein_id"].apply(lambda x : x.split("__")[0] if x.count("__")>0 else x.split("_")[0])

merged_df_sorted = merged_df.sort_values(by='phage', ascending=True)
merged_df_sorted

Unnamed: 0,protein_x,predictions_tropigat,protein_id,protein_y,predictions_seqbased,phage
129,A1a_00014,KL70:0.997 ; KL117:0.995 ; KL34:0.986 ; KL31:0...,A1a_00014,A1a_00014,KL151:0.599,A1a
144,A1a_00002,KL4:1.0 ; KL123:1.0 ; KL14:1.0 ; KL56:0.999 ; ...,A1a_00002,A1a_00002,KL102:0.737,A1a
130,A1b_00048,KL149:1.0 ; KL155:1.0 ; KL128:0.999 ; KL28:0.9...,A1b_00048,A1b_00048,No_associations,A1b
122,A1b_00036,KL56:1.0 ; KL4:1.0 ; KL123:1.0 ; KL7:1.0 ; KL1...,A1b_00036,A1b_00036,KL102:0.737,A1b
153,A1c_00046,KL56:1.0 ; KL4:1.0 ; KL123:1.0 ; KL7:1.0 ; KL1...,A1c_00046,A1c_00046,KL102:0.737,A1c
...,...,...,...,...,...,...
152,S13a_00036,KL70:1.0 ; KL12:0.997 ; KL149:0.995 ; KL56:0.9...,S13a_00036,S13a_00036,No_associations,S13a
114,S13b_00058,KL47:1.0 ; KL103:1.0 ; KL34:1.0 ; KL149:1.0 ; ...,S13b_00058,S13b_00058,KL63:0.641,S13b
146,S13c_00055,KL70:1.0 ; KL38:0.996 ; KL12:0.993 ; KL56:0.99...,S13c_00055,S13c_00055,No_associations,S13c
145,S13d_00057,KL155:1.0 ; KL14:0.998 ; KL149:0.991 ; KL64:0....,S13d_00057,S13d_00057,KL14:0.951,S13d


In [41]:
final_df = merged_df_sorted[["phage","protein_id", "predictions_seqbased", "predictions_tropigat"]]



In [42]:
final_df.to_csv(f"{path_project}/PPT_results.standard.1309.bit75.tsv", sep = "\t", header = True, index = False)

In [43]:
final_df

Unnamed: 0,phage,protein_id,predictions_seqbased,predictions_tropigat
129,A1a,A1a_00014,KL151:0.599,KL70:0.997 ; KL117:0.995 ; KL34:0.986 ; KL31:0...
144,A1a,A1a_00002,KL102:0.737,KL4:1.0 ; KL123:1.0 ; KL14:1.0 ; KL56:0.999 ; ...
130,A1b,A1b_00048,No_associations,KL149:1.0 ; KL155:1.0 ; KL128:0.999 ; KL28:0.9...
122,A1b,A1b_00036,KL102:0.737,KL56:1.0 ; KL4:1.0 ; KL123:1.0 ; KL7:1.0 ; KL1...
153,A1c,A1c_00046,KL102:0.737,KL56:1.0 ; KL4:1.0 ; KL123:1.0 ; KL7:1.0 ; KL1...
...,...,...,...,...
152,S13a,S13a_00036,No_associations,KL70:1.0 ; KL12:0.997 ; KL149:0.995 ; KL56:0.9...
114,S13b,S13b_00058,KL63:0.641,KL47:1.0 ; KL103:1.0 ; KL34:1.0 ; KL149:1.0 ; ...
146,S13c,S13c_00055,No_associations,KL70:1.0 ; KL38:0.996 ; KL12:0.993 ; KL56:0.99...
145,S13d,S13d_00057,KL14:0.951,KL155:1.0 ; KL14:0.998 ; KL149:0.991 ; KL64:0....


> Make Raw file:

In [46]:
df_folds = pd.read_csv(f"{path_project}/in_vitro/dpos_folds.all_matrices.tsv", header = 0, sep = "\t")


In [47]:
path_finetuning = "/media/concha-eloko/Linux/PPT_clean/in_vitro/fine_tuning"

bea_df = pd.read_csv(f"{path_finetuning}/bea_fine_tuning.df", sep = "\t", header = 0)
bea_df["Protein"] = bea_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_bea = set([kltype.strip() for kltypes in bea_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

ferriol_df = pd.read_csv(f"{path_finetuning}/ferriol_fine_tuning.df", sep = "\t", header = 0)
ferriol_df["Target"] = ferriol_df["Target"].apply(lambda x : x.replace("K", "KL"))
pool_ferriol = set([kltype.strip() for kltypes in ferriol_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

towndsend_df = pd.read_csv(f"{path_finetuning}/towndsend_fine_tuning.df", sep = "\t", header = 0)
towndsend_df["Protein"] = towndsend_df["Protein"].apply(lambda x : x.replace("_", "__"))
pool_towndsend = set([kltype.strip() for kltypes in towndsend_df["Target"] for kltype in kltypes.split(",") if kltype.count("wzi") == 0 if kltype.count("pass") == 0])

dico_matrices = {"ferriol" : {"matrix" : ferriol_df, "pool" : pool_ferriol}, 
                 "bea" : {"matrix": bea_df, "pool" : pool_bea}, 
                 "towndsend" : {"matrix" : towndsend_df, "pool" : pool_towndsend}}

# targets dico : 
dico_hits = {}
for author in dico_matrices :
    matrix = dico_matrices[author]["matrix"]
    for _, row in matrix.iterrows() : 
        for phage in matrix["Phages"].unique() : 
            all_targets = set()
            targets = matrix[matrix["Phages"] == phage]["Target"].values
            for calls in targets : 
                actual_targets = [x.strip() for x in calls.split(",")]
                all_targets.update(actual_targets)
            dico_hits[phage] = all_targets



In [48]:
top_n = 40

path_project = "/media/concha-eloko/Linux/PPT_clean"

# Classic version : 
# PPT_results.matrices.tailored.tsv : Tailored version
# PPT_results.classic_1112.tsv : Classic version 
# PPT_results.matrices.tailored_bit50.tsv : tailored and bit50
# classic_0101
# SAGE_0201
# PPT_results.classic_0101.bit50.tsv
tropigat_results = pd.read_csv(f"{path_project}/PPT_results.standard.1309.bit75.tsv", header = 0, sep = "\t")

with open(f"{path_project}/raw_metrics.standard.bit75.top40.detailed.tsv", "w") as outfile :
    outfile.write(f"Phage\tProtein\tFolds\tTropiGAT_predictions\tTropiGAT_good_calls\tTropiSeq_predictions\tTropiSeq_good_calls\tTargets\n")
    for _, row in tropigat_results.iterrows() : 
        if row["phage"] in dico_hits :
            targets = dico_hits[row["phage"]]
            prot_id = row['protein_id'].replace("__cds", "_cds")
            try : 
                #targets = dico_hits[row["phage"]]
                #prot_id = row['protein_id'].replace("__cds", "_cds")
                fold = df_folds[df_folds["protein_id"] == prot_id]["Fold"].values[0]
            except Exception as e :
                fold = "unknown"
            outfile.write(f"{row['phage']}\t{row['protein_id']}\t{fold}\t")
            # TropiGAT part : 
            # write the pred
            top_n_predictions = ";".join([x for x in row["predictions_tropigat"].split(";")][0:top_n-1])
            outfile.write(top_n_predictions + "\t")
            # check the calls
            tropigat_pred = [x.split(":")[0].strip() for x in row["predictions_tropigat"].split(";")]
            top_KLtypes_pred = set(tropigat_pred[0: top_n-1])
            good_calls = top_KLtypes_pred.intersection(targets)
            if len(good_calls) > 0 : 
                outfile.write(",".join(list(good_calls)) + "\t")
            else : 
                outfile.write("0" + "\t")
            # TropiSeq part : 
            # write the pred
            outfile.write(row["predictions_seqbased"] + "\t")
            if row["predictions_seqbased"] != "No hits" and row["predictions_seqbased"] != "No predictions" : 
                tropiseq_pred = [x.split(":")[0].strip() for x in row["predictions_seqbased"].split(";")]
                top_predictions = set(tropiseq_pred[0: top_n-1])
                good_calls = top_predictions.intersection(targets)
                if len(good_calls) > 0 : 
                    outfile.write(",".join(list(good_calls)) + "\t")
                else :
                    outfile.write("0" + "\t")
            else :
                outfile.write("0\t")
            target_clean = ",".join(list(targets))
            outfile.write(target_clean + "\n")
