In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/ensemble_2809 \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/ensemble_2809_log_files \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/77_strains_phage_project/CAD5239776.1.fasta.esm_out \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023


In [78]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import TropiGAT_models  
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_2809"


In [75]:
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))



In [91]:
pp.pprint(dico_prophage_count)

{'KL1': 166,
 'KL10': 101,
 'KL101': 1,
 'KL102': 269,
 'KL103': 12,
 'KL104': 5,
 'KL105': 123,
 'KL106': 463,
 'KL107': 1066,
 'KL108': 28,
 'KL109': 15,
 'KL11': 6,
 'KL110': 62,
 'KL111': 69,
 'KL112': 69,
 'KL113': 4,
 'KL114': 22,
 'KL115': 3,
 'KL116': 27,
 'KL117': 18,
 'KL118': 26,
 'KL119': 5,
 'KL12': 40,
 'KL120': 1,
 'KL121': 1,
 'KL122': 35,
 'KL123': 35,
 'KL124': 15,
 'KL125': 30,
 'KL126': 7,
 'KL127': 36,
 'KL128': 20,
 'KL13': 71,
 'KL130': 1,
 'KL131': 3,
 'KL132': 3,
 'KL134': 4,
 'KL136': 32,
 'KL137': 6,
 'KL139': 11,
 'KL14': 123,
 'KL140': 13,
 'KL141': 8,
 'KL142': 15,
 'KL143': 11,
 'KL144': 1,
 'KL145': 28,
 'KL146': 6,
 'KL147': 3,
 'KL148': 2,
 'KL149': 64,
 'KL15': 214,
 'KL150': 3,
 'KL151': 45,
 'KL152': 10,
 'KL153': 17,
 'KL154': 1,
 'KL155': 12,
 'KL157': 11,
 'KL158': 6,
 'KL159': 13,
 'KL16': 22,
 'KL162': 4,
 'KL163': 5,
 'KL164': 7,
 'KL165': 1,
 'KL166': 8,
 'KL169': 29,
 'KL17': 461,
 'KL170': 2,
 'KL18': 23,
 'KL19': 79,
 'KL2': 364,
 'KL20': 

In [76]:
def make_query_graph(embeddings) :
    """
    This function builds the query graph for the ensemble model.
    Input : A list of the ESM2 embeddings of the depolymerase 
    Output : The query graph
    """
    query_graph = HeteroData()
    query_graph["B1"].x = torch.empty((1, 0))
    query_graph["B2"].x = torch.tensor(embeddings , dtype=torch.float)
    edge_index_B2_B1 = torch.tensor([[0,0]] , dtype=torch.long)
    query_graph['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    
    return query_graph
    
def make_ensemble_TropiGAT(path_ensemble) : 
    """
    This function builds a dictionary with all the models that are part of the TropiGAT predictor
    Input : Path of the models
    Output : Dictionary

    # Make a json file with the versions of the GNN corresponding to each KL types
    # Load it
    # Create the correct model instance (TropiGAT_small_module or TropiGAT_big_module)
    """
    dico_ensemble = {}
    for GNN_model in os.listdir(path_ensemble) :
        if GNN_model[-2:] == "pt" : 
            KL_type = GNN_model.split(".")[0]
            if dico_prophage_count[KL_type] >= 125 : 
                model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280 , heads = 1)
            else :
                model = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280 , heads = 1)
            model.load_state_dict(torch.load(f"{path_ensemble}/{GNN_model}"))
            dico_ensemble[KL_type] = model
        
    return dico_ensemble

@torch.no_grad()
def make_predictions(model, data):
    """
    This generic function run the prediction of a binary model
    Inputs : The model, the query data
    Ouput : the prediction and associated probability
    """
    model.eval() 
    output = model(data)
    probabilities = torch.sigmoid(output)
    predictions = probabilities.round() 
    
    return predictions, round(probabilities.item() , 4) 
        
def run_prediction(query_graph, dico_ensemble) :
    dico_predictions = {}
    for KL_type in dico_ensemble :
        model = dico_ensemble[KL_type]
        prediction, probabilities = make_predictions(model, query_graph)
        if int(prediction) == 1 :
            dico_predictions[KL_type] = probabilities
        else :
            continue

    return dico_predictions

def format_predictions(predictions, sep = "__") : 
    final_results = {}
    for protein,hits in predictions.items() : 
        phage = protein.split(sep)[0]
        if phage not in final_results : 
            tmp_hits = {}
            for kltype in hits : 
                if kltype in tmp_hits and hits[kltype] > tmp_hits[kltype]:
                    tmp_hits[kltype] = hits[kltype]
                elif kltype in tmp_hits and hits[kltype] < tmp_hits[kltype]:
                    pass
                elif kltype not in tmp_hits : 
                    tmp_hits[kltype] = hits[kltype]
            final_results[phage] = tmp_hits
        else :
            for kltype in hits : 
                if kltype in final_results[phage] and hits[kltype] > final_results[phage][kltype]:
                    final_results[phage][kltype] = hits[kltype]
                elif kltype in final_results[phage] and hits[kltype] < final_results[phage][kltype]:
                    pass
                elif kltype not in final_results[phage] : 
                    final_results[phage][kltype] = hits[kltype]
    return final_results
        

> Build model :

In [79]:
dico_models = make_ensemble_TropiGAT(path_ensemble)

In [80]:
path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings.set_index([0], inplace = True)

In [85]:
dpo_embeddings

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10,...,1271,1272,1273,1274,1275,1276,1277,1278,1279,1280
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
K15PH90__cds_55_Dpo_domain,-0.028760,0.046677,-0.010773,0.028452,-0.090442,0.027041,0.004249,-0.083708,0.022172,0.090119,...,-0.032166,-0.012386,0.079159,0.012298,0.027317,0.037254,0.069599,-0.097522,0.067495,0.062502
K80PH1317b__cds_54_Dpo_domain,0.007689,0.036850,-0.006928,-0.056424,-0.090723,0.018707,0.014913,-0.070090,0.073792,0.055322,...,0.017004,-0.000657,0.059184,-0.006782,0.023955,0.035585,0.048035,-0.081247,0.043776,0.118674
K64PH164C4__cds_24_Dpo_domain,0.015762,0.062429,-0.003427,-0.003609,-0.101109,0.028121,0.004342,-0.096114,0.062562,0.027864,...,0.042829,0.029579,0.095159,0.024894,0.002837,0.046701,0.062497,-0.084956,0.027426,0.051051
K5lambda5__cds_196_Dpo_domain,0.040111,0.046436,-0.012045,-0.043877,-0.100054,-0.028328,0.028640,-0.047144,0.065727,0.047312,...,0.017347,0.008084,0.096149,-0.031008,0.040423,0.082593,0.050161,-0.105612,0.023642,0.081104
K11PH164C1__cds_46_Dpo_domain,0.017319,0.077582,-0.001212,-0.030026,-0.070916,-0.011639,0.006673,-0.078486,0.072836,0.046921,...,0.005777,-0.032808,0.099658,-0.028466,0.020794,0.082009,0.030658,-0.091195,0.047744,0.105303
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
K61PH164C1__cds_10_Dpo_domain,0.038221,0.054935,0.010196,-0.022653,-0.098234,0.015615,0.016274,-0.050646,0.050069,0.042166,...,0.005035,0.002913,0.103508,-0.006524,0.026563,0.044091,0.034147,-0.095066,0.054782,0.054492
K60PH164C1__cds_96_Dpo_domain,0.031801,0.045489,-0.020759,-0.023341,-0.087468,0.019555,-0.035902,-0.042261,0.063424,0.040513,...,0.017682,0.015435,0.078165,0.023697,0.016046,0.060408,0.054158,-0.090769,0.066702,0.084330
K63PH128__cds_22_Dpo_domain,0.012132,0.044426,-0.010644,-0.017182,-0.043079,-0.016834,-0.010763,-0.082545,0.046603,0.070069,...,0.050513,0.032082,0.108994,0.016834,-0.003678,0.051204,0.083290,-0.103134,0.023146,0.039394
K74PH129C2__cds_52_Dpo_domain,-0.001554,0.026739,0.023782,-0.063908,-0.059218,-0.031451,0.004133,-0.061047,0.099100,0.021165,...,0.019009,0.054984,0.081528,-0.015591,0.010591,0.056652,0.030724,-0.028994,-0.001747,0.080028


In [88]:
predictions = {}
for dpo in dpo_embeddings.index : 
    graph_dpo = make_query_graph([dpo_embeddings.loc[dpo].values])
    pred = run_prediction(graph_dpo,dico_models)
    predictions[dpo] = pred
    

In [90]:
final_results = format_predictions(predictions)

import pprint
pp = pprint.PrettyPrinter(width = 150, sort_dicts = True, compact = True)
pp.pprint(final_results)

{'K10PH82C1': {'KL10': 0.5707,
               'KL110': 0.9517,
               'KL118': 0.5704,
               'KL12': 0.9234,
               'KL127': 0.8499,
               'KL128': 0.9694,
               'KL136': 0.7698,
               'KL2': 0.6289,
               'KL25': 0.6036,
               'KL3': 0.8678,
               'KL34': 0.9265,
               'KL45': 0.7,
               'KL46': 0.7351,
               'KL48': 0.8131,
               'KL53': 0.9169,
               'KL60': 0.9817,
               'KL62': 0.9791,
               'KL63': 0.5017,
               'KL74': 0.796},
 'K11PH164C1': {'KL105': 0.5814,
                'KL111': 0.7759,
                'KL127': 0.9122,
                'KL145': 0.9378,
                'KL24': 0.5383,
                'KL3': 0.9619,
                'KL39': 0.5719,
                'KL45': 0.6045,
                'KL46': 0.7452,
                'KL57': 0.5124,
                'KL60': 0.8816,
                'KL62': 0.8212,
                'KL64': 