In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import TropiGAT_functions 
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_3112_SAGE"
#path_ensemble = f"{path_work}/ficheros_28032023/ensemble_tailored_0612"


> Make model : 

In [2]:
dico_models_ = TropiGAT_functions.make_ensemble_TropiSAGE(path_ensemble)
dico_models = dico_models_[0]

In [3]:
len(dico_models) , dico_models

(66,
 {'KL136': TropiGAT_small_sage_module(
    (conv): SAGEConv((-1, -1), 1280, aggr=mean)
    (hetero_conv): HeteroConv(num_relations=1)
    (linear_layers): Sequential(
      (0): Linear(in_features=1280, out_features=1280, bias=True)
      (1): BatchNorm1d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Dropout(p=0.2, inplace=False)
      (4): Linear(in_features=1280, out_features=480, bias=True)
      (5): BatchNorm1d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): LeakyReLU(negative_slope=0.01)
      (7): Dropout(p=0.2, inplace=False)
      (8): Linear(in_features=480, out_features=1, bias=True)
    )
  ),
  'KL116': TropiGAT_small_sage_module(
    (conv): SAGEConv((-1, -1), 1280, aggr=mean)
    (hetero_conv): HeteroConv(num_relations=1)
    (linear_layers): Sequential(
      (0): Linear(in_features=1280, out_features=1280, bias=True)
      (1): BatchNorm1d(1280, eps=1e-05, 

> Ferriol

In [4]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.1512.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings.set_index([0], inplace = True)
dpo_embeddings.index = [x if x.count("__cds") > 0 else x.replace("_cds", "__cds") for x in dpo_embeddings.index]
dpo_embeddings.index = [x.replace("__CDS","__cds") if x.count("__CDS") > 0 else x for x in dpo_embeddings.index]

> Bea

In [5]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

bea_embeddings = pd.read_csv(f"{path_project}/Bea_phages.esm2.embedding.csv", sep = "," , header = None)
bea_embeddings = bea_embeddings.drop([1281] , axis = 1)
bea_embeddings.set_index([0], inplace = True)


> Towndsend

In [6]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

towndsend_embeddings = pd.read_csv(f"{path_project}/Townsed_phages.esm2.embedding.1112.csv", sep = "," , header = None)
towndsend_embeddings = towndsend_embeddings.drop([1281] , axis = 1)
towndsend_embeddings.set_index([0], inplace = True)


>Others old

In [None]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

others_embeddings = pd.read_csv(f"{path_project}/Others_phages.esm2.embedding.csv", sep = "," , header = None)
others_embeddings = others_embeddings.drop([1281] , axis = 1)
others_embeddings.set_index([0], inplace = True)

namesother_df = pd.read_csv(f"{path_project}/Others/index_others.tsv", sep = "\t" , names = ["index_phage", "index_prot","prot_name"])
new_index = []
for _,index in enumerate(others_embeddings.index) : 
    i_phage = int(index.split("__")[0])
    i_prot = int(index.split("__")[1])
    prot = namesother_df[(namesother_df["index_phage"] == i_phage) & (namesother_df["index_prot"] == i_prot)]["prot_name"].values[0]
    new_index.append(prot)

others_embeddings.index = new_index
others_embeddings.index.name = 0

> Others 

In [7]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean/in_vitro"

others_embeddings = pd.read_csv(f"{path_project}/Others_all.esm2.embedding.csv", sep = "," , header = None)
others_embeddings.set_index([0], inplace = True)

***
> run the predictions 

In [8]:
# Run the predictions ferriol :
ferriol_predictions = {}
for dpo in dpo_embeddings.index : 
    graph_dpo = TropiGAT_functions.make_query_graph([dpo_embeddings.loc[dpo].values])
    pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
    ferriol_predictions[dpo] = pred

In [9]:
ferriol_predictions

{'K15PH90__cds_55_Dpo_domain': {'KL136': 0.998,
  'KL116': 0.7466,
  'KL9': 0.7634,
  'KL13': 0.9998,
  'KL39': 0.9816,
  'KL27': 0.5414,
  'KL21': 0.7751,
  'KL81': 0.5072,
  'KL15': 0.9996,
  'KL3': 0.8273,
  'KL38': 0.9017,
  'KL55': 0.5635,
  'KL43': 0.9927,
  'KL34': 0.9354,
  'KL48': 0.759,
  'KL8': 0.9777,
  'KL128': 0.533,
  'KL62': 0.9299,
  'KL14': 0.8073,
  'KL70': 0.9989,
  'KL22': 0.7735,
  'KL52': 0.8518,
  'KL12': 0.5162,
  'KL18': 0.8459,
  'KL122': 0.9852,
  'KL16': 0.7517,
  'KL46': 0.5115,
  'KL125': 0.5202,
  'KL29': 0.5218,
  'KL7': 0.8498},
 'K80PH1317b__cds_54_Dpo_domain': {'KL63': 0.7917,
  'KL112': 0.6195,
  'KL9': 0.5027,
  'KL13': 0.9231,
  'KL19': 0.9521,
  'KL108': 0.9591,
  'KL39': 0.9127,
  'KL36': 0.5201,
  'KL30': 0.5641,
  'KL23': 0.9132,
  'KL81': 0.9272,
  'KL60': 0.903,
  'KL128': 0.9196,
  'KL70': 0.9752,
  'KL52': 0.7926,
  'KL57': 0.6212,
  'KL18': 0.7743,
  'KL151': 0.9722,
  'KL28': 0.6049,
  'KL74': 0.6675,
  'KL29': 0.6114,
  'KL114': 0.6775}

In [10]:
# format the results : 
ferriol_pred_formated = TropiGAT_functions.format_predictions(ferriol_predictions , sep = "__")
TropiGAT_functions.clean_print(ferriol_pred_formated)

{'K10PH82C1': {'KL1': 0.5725,
               'KL102': 0.9813,
               'KL105': 0.7636,
               'KL108': 0.9754,
               'KL110': 0.97,
               'KL114': 0.8724,
               'KL116': 0.7557,
               'KL12': 0.8195,
               'KL122': 0.6229,
               'KL123': 1.0,
               'KL125': 0.9529,
               'KL127': 0.9253,
               'KL128': 0.9903,
               'KL13': 0.9998,
               'KL136': 0.9932,
               'KL14': 0.8753,
               'KL145': 0.5768,
               'KL149': 0.7221,
               'KL15': 0.8412,
               'KL151': 0.8708,
               'KL16': 0.7488,
               'KL169': 0.79,
               'KL18': 0.889,
               'KL19': 0.975,
               'KL2': 0.9496,
               'KL21': 0.5883,
               'KL22': 0.5143,
               'KL23': 0.7528,
               'KL24': 0.9942,
               'KL25': 0.5081,
               'KL27': 0.9951,
               'KL29': 0.9986,
   

***

In [11]:
# Run the predictions Bea :
bea_predictions = {}
for dpo in bea_embeddings.index : 
    graph_dpo = TropiGAT_functions.make_query_graph([bea_embeddings.loc[dpo].values])
    pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
    bea_predictions[dpo] = pred

In [12]:
# format the results : 
bea_pred_formated = TropiGAT_functions.format_predictions(bea_predictions , sep = "_")
TropiGAT_functions.clean_print(bea_pred_formated)

{'A1a': {'KL102': 0.9484,
         'KL105': 0.9427,
         'KL108': 0.9525,
         'KL110': 0.7318,
         'KL114': 0.6595,
         'KL116': 0.9695,
         'KL12': 0.5511,
         'KL123': 1.0,
         'KL125': 0.9132,
         'KL128': 0.9435,
         'KL13': 0.9882,
         'KL14': 0.9881,
         'KL151': 0.9859,
         'KL16': 0.7754,
         'KL169': 0.7015,
         'KL19': 0.9996,
         'KL2': 0.988,
         'KL21': 0.5839,
         'KL23': 0.871,
         'KL24': 0.5377,
         'KL25': 0.7727,
         'KL27': 0.9992,
         'KL29': 0.9921,
         'KL3': 0.8865,
         'KL30': 0.8726,
         'KL36': 0.6254,
         'KL39': 0.5858,
         'KL45': 0.9786,
         'KL48': 0.7095,
         'KL5': 0.9782,
         'KL51': 0.8713,
         'KL52': 0.987,
         'KL55': 0.9821,
         'KL57': 0.5729,
         'KL7': 0.9288,
         'KL70': 0.998,
         'KL74': 0.996,
         'KL8': 0.5404,
         'KL81': 0.5274,
         'KL9': 0.8794},
 '

In [13]:
# Run the predictions Towndsend :
towndsend_predictions = {}
for dpo in towndsend_embeddings.index : 
    graph_dpo = TropiGAT_functions.make_query_graph([towndsend_embeddings.loc[dpo].values])
    pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
    towndsend_predictions[dpo] = pred

In [14]:
# format the results : 
towndsend_pred_formated = TropiGAT_functions.format_predictions(towndsend_predictions , sep = "_")
TropiGAT_functions.clean_print(towndsend_pred_formated)

{'BLCJPOBP': {'KL102': 0.9639,
              'KL105': 0.8514,
              'KL108': 0.9688,
              'KL110': 0.7847,
              'KL114': 0.614,
              'KL12': 0.7794,
              'KL123': 1.0,
              'KL125': 0.914,
              'KL127': 0.6878,
              'KL128': 0.9953,
              'KL13': 0.9946,
              'KL136': 0.9005,
              'KL14': 0.9877,
              'KL151': 0.9865,
              'KL16': 0.8097,
              'KL169': 0.8471,
              'KL18': 0.7734,
              'KL19': 0.9996,
              'KL23': 0.881,
              'KL25': 0.8724,
              'KL27': 0.9996,
              'KL29': 0.9433,
              'KL3': 0.6783,
              'KL30': 0.903,
              'KL34': 0.7991,
              'KL36': 0.5605,
              'KL39': 0.9945,
              'KL43': 0.6329,
              'KL45': 0.7762,
              'KL48': 0.9376,
              'KL51': 0.8723,
              'KL55': 0.9808,
              'KL57': 0.6442,
      

***

In [None]:
others_embeddings.info()

In [None]:
others_embeddings.index

In [15]:
# Run the predictions Others :
other_predictions = {}
for index, dpo in enumerate(others_embeddings.index) :
    if dpo not in ["MN781108.1_prot_QGZ15323.1_262"] :
        try : 
            graph_dpo = TropiGAT_functions.make_query_graph([others_embeddings.loc[dpo].values])
            pred = TropiGAT_functions.run_prediction(graph_dpo,dico_models)
            other_predictions[dpo] = pred
        except Exception as e :
            print(e, dpo)

running_mean should contain 1 elements not 1280 ON146449.1_prot_UPW35150.1_13
running_mean should contain 1 elements not 1280 ON146449.1_prot_UPW35138.1_1
running_mean should contain 1 elements not 1280 ON146449.1_prot_UPW35150.1_13
running_mean should contain 1 elements not 1280 ON146449.1_prot_UPW35138.1_1


In [16]:
# format the results : 
others_pred_formated = TropiGAT_functions.format_predictions(other_predictions , sep = "_prot_")
TropiGAT_functions.clean_print(other_predictions)

{'AB716666.1_prot_BAP15736.1_24': {'KL102': 0.9546,
                                   'KL108': 0.9017,
                                   'KL114': 0.6542,
                                   'KL123': 1.0,
                                   'KL125': 0.9712,
                                   'KL128': 0.9464,
                                   'KL13': 0.959,
                                   'KL14': 0.9986,
                                   'KL16': 0.8834,
                                   'KL169': 0.9017,
                                   'KL19': 0.9986,
                                   'KL25': 0.7771,
                                   'KL27': 0.9997,
                                   'KL3': 0.9092,
                                   'KL30': 0.9858,
                                   'KL36': 0.5253,
                                   'KL38': 0.5397,
                                   'KL39': 0.8535,
                                   'KL45': 0.7187,
                             

***
# Write the results : 

> Others : 

In [17]:
predictions = [other_predictions]

with open("/media/concha-eloko/Linux/PPT_clean/TropiGAT.Others.all.results.SAGE_0201.tsv", "w") as outfile:
    for prediction in predictions:
        prediction_sorted = dict(sorted(prediction.items()))
        for prot in prediction_sorted:
            if prediction_sorted[prot] == "No hits" or len(prediction_sorted[prot]) == 0:
                outfile.write(f"{prot}\tNo hits\n")
            else:
                outfile.write(f"{prot}\t")
                hits = [f"{kltype}:{round(score, 3)}" for kltype, score in prediction_sorted[prot].items()]
                sorted_hits = " ; ".join(sorted(hits, key=lambda x: float(x.split(":")[1]), reverse=True))
                outfile.write(sorted_hits)
                outfile.write("\n")

> Predictions : 

In [18]:
predictions = [ferriol_predictions , bea_predictions , towndsend_predictions]

with open("/media/concha-eloko/Linux/PPT_clean/TropiGAT.results.SAGE_0201.tsv", "w") as outfile:
    for prediction in predictions:
        for prot in prediction:
            if prediction[prot] == "No hits" or len(prediction[prot]) == 0:
                outfile.write(f"{prot}\tNo hits\n")
            else:
                outfile.write(f"{prot}\t")
                hits = [f"{kltype}:{round(score, 3)}" for kltype, score in prediction[prot].items()]
                sorted_hits = " ; ".join(sorted(hits, key=lambda x: float(x.split(":")[1]), reverse=True))
                outfile.write(sorted_hits)
                outfile.write("\n")

> Parse TropiGAT and Seqbased results :

In [22]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/PPT_clean"

tropigat_results = pd.read_csv(f"{path_project}/TropiGAT.results.SAGE_0201.tsv", sep = "\t", names = ["protein", "predictions_tropisage"])
seqbased_results = pd.read_csv(f"{path_project}/Seqbased_model.0101.results.tsv", sep = "\t", names = ["protein", "predictions_seqbased"])


In [23]:
tropigat_results["protein_id"] = tropigat_results["protein"].apply(lambda x : x.split("_Dpo")[0])
seqbased_results["protein_id"] = seqbased_results["protein"].apply(lambda x : x.split("_A_")[0])

merged_df = tropigat_results.merge(seqbased_results, on='protein_id', how='inner')
merged_df["phage"] = merged_df["protein_id"].apply(lambda x : x.split("__")[0] if x.count("__")>0 else x.split("_")[0])

merged_df_sorted = merged_df.sort_values(by='phage', ascending=True)
merged_df_sorted

Unnamed: 0,protein_x,predictions_tropisage,protein_id,protein_y,predictions_seqbased,phage
120,A1a_00002,KL19:1.0 ; KL123:1.0 ; KL27:0.999 ; KL14:0.988...,A1a_00002,A1a_00002,KL102: 0.691,A1a
105,A1a_00014,KL70:0.998 ; KL74:0.996 ; KL29:0.992 ; KL13:0....,A1a_00014,A1a_00014,KL151: 0.698,A1a
106,A1b_00048,KL38:0.999 ; KL39:0.993 ; KL128:0.985 ; KL48:0...,A1b_00048,A1b_00048,KL157: 0.729,A1b
98,A1b_00036,KL123:1.0 ; KL27:0.999 ; KL14:0.999 ; KL74:0.9...,A1b_00036,A1b_00036,KL102: 0.691,A1b
129,A1c_00046,KL27:1.0 ; KL123:1.0 ; KL19:0.999 ; KL74:0.998...,A1c_00046,A1c_00046,KL102: 0.691,A1c
...,...,...,...,...,...,...
128,S13a_00036,KL27:0.996 ; KL38:0.995 ; KL57:0.974 ; KL3:0.9...,S13a_00036,S13a_00036,KL38: 0.822,S13a
90,S13b_00058,KL47:0.992 ; KL23:0.99 ; KL112:0.979 ; KL70:0....,S13b_00058,S13b_00058,KL63: 0.867,S13b
122,S13c_00055,KL12:0.999 ; KL38:0.998 ; KL27:0.992 ; KL57:0....,S13c_00055,S13c_00055,No hits,S13c
121,S13d_00057,KL21:1.0 ; KL14:0.999 ; KL74:0.979 ; KL38:0.95...,S13d_00057,S13d_00057,KL14: 0.736,S13d


In [24]:
final_df = merged_df_sorted[["phage","protein_id", "predictions_seqbased", "predictions_tropisage"]]

final_df

Unnamed: 0,phage,protein_id,predictions_seqbased,predictions_tropisage
120,A1a,A1a_00002,KL102: 0.691,KL19:1.0 ; KL123:1.0 ; KL27:0.999 ; KL14:0.988...
105,A1a,A1a_00014,KL151: 0.698,KL70:0.998 ; KL74:0.996 ; KL29:0.992 ; KL13:0....
106,A1b,A1b_00048,KL157: 0.729,KL38:0.999 ; KL39:0.993 ; KL128:0.985 ; KL48:0...
98,A1b,A1b_00036,KL102: 0.691,KL123:1.0 ; KL27:0.999 ; KL14:0.999 ; KL74:0.9...
129,A1c,A1c_00046,KL102: 0.691,KL27:1.0 ; KL123:1.0 ; KL19:0.999 ; KL74:0.998...
...,...,...,...,...
128,S13a,S13a_00036,KL38: 0.822,KL27:0.996 ; KL38:0.995 ; KL57:0.974 ; KL3:0.9...
90,S13b,S13b_00058,KL63: 0.867,KL47:0.992 ; KL23:0.99 ; KL112:0.979 ; KL70:0....
122,S13c,S13c_00055,No hits,KL12:0.999 ; KL38:0.998 ; KL27:0.992 ; KL57:0....
121,S13d,S13d_00057,KL14: 0.736,KL21:1.0 ; KL14:0.999 ; KL74:0.979 ; KL38:0.95...


In [25]:
final_df.to_csv(f"{path_project}/PPT_results.SAGE_0201.tsv", sep = "\t", header = True, index = False)