In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import HeteroConv , GATv2Conv 
#from torch_geometric.utils import negative_sampling
#from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 



In [1]:
import TropiGAT_models
import TropiGAT_graph

In [9]:
def build_graph_baseline(df_info) : 
    # **************************************************************
    # initialize the graph
    graph_data = HeteroData()
    # Indexation process  
    indexation_nodes_A = df_info["Infected_ancestor"].unique().tolist()  
    indexation_nodes_B1 = df_info["Phage"].unique().tolist()
    indexation_nodes_B2 = df_info["index"].unique().tolist() 
    ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
    ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    # **************************************************************
    # Make the node feature file : 
    OHE = OneHotEncoder(sparse=False)
    one_hot_encoded = OHE.fit_transform(df_info[["KL_type_LCA"]])
    label_mapping = {label: one_hot_encoded[i] for i, label in enumerate(OHE.categories_[0])}
    embeddings_columns = [str(i) for i in range(1, 1281)]
    node_feature_A = torch.tensor([label_mapping[df_info[df_info["Infected_ancestor"] == ID_nodes_A_r[i]]["KL_type_LCA"].values[0]] for i in range(0,len(ID_nodes_A_r))], dtype=torch.float)
    node_feature_B1 = torch.zeros((len(ID_nodes_B1), 0), dtype=torch.float)
    node_feature_B2 = torch.tensor([df_info[df_info["index"] == ID_nodes_B2_r[i]][embeddings_columns].values[0].tolist() for i in range(0,len(ID_nodes_B2_r))] , dtype=torch.float)
    # feed the graph
    graph_data["A"].x = node_feature_A
    graph_data["B1"].x = node_feature_B1
    graph_data["B2"].x = node_feature_B2
    # **************************************************************
    # Make edge file
    # Node B1 (prophage) - Node A (bacteria) :
    edge_index_B1_A = []
    track_B1_A = set()
    for _, row in df_info.iterrows() :
        pair = [ID_nodes_B1[row["Phage"]], ID_nodes_A[row["Infected_ancestor"]]]
        if tuple(pair) not in track_B1_A : 
            track_B1_A.add(tuple(pair))
            edge_index_B1_A.append(pair)
        else :
            continue
    edge_index_B1_A = torch.tensor(edge_index_B1_A , dtype=torch.long)
    # Node A (bacteria) - Node B1 (prophage) :
    edge_index_A_B1 = []
    track_A_B1 = set()
    for _, row in df_info.iterrows() :
        pair = [ID_nodes_A[row["Infected_ancestor"]] , ID_nodes_B1[row["Phage"]]]
        if tuple(pair) not in track_A_B1 :
            track_A_B1.add(tuple(pair))
            edge_index_A_B1.append(pair)
    edge_index_A_B1 = torch.tensor(edge_index_A_B1 , dtype=torch.long)
    # Node B2 (depolymerase) - Node B1 (prophage) :
    edge_index_B2_B1 = []
    for phage in df_info.Phage.unique() :
        all_data_phage = df_info[df_info["Phage"] == phage]
        for _, row in all_data_phage.iterrows() :
            edge_index_B2_B1.append([ID_nodes_B2[row["index"]], ID_nodes_B1[row["Phage"]]])
    edge_index_B2_B1 = torch.tensor(edge_index_B2_B1 , dtype=torch.long)
    # feed the graph
    graph_data['B1', 'infects', 'A'].edge_index = edge_index_B1_A.t().contiguous()
    graph_data['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    # That one is optional  
    graph_data['A', 'harbors', 'B1'].edge_index = edge_index_A_B1.t().contiguous()
    dico_prophage_kltype_associated = {}
    for negative_index,phage in tqdm(enumerate(df_info["Phage"].unique().tolist())) :
        kltypes = set()
        dpos = df_info[df_info["Phage"] == phage]["index"]
        for dpo in dpos : 
            tmp_kltypes = df_info[df_info["index"] == dpo]["KL_type_LCA"].values
            kltypes.update(tmp_kltypes)
        dico_prophage_kltype_associated[phage] = kltypes
    return graph_data , dico_prophage_kltype_associated


def build_graph_masking(graph_data, dico_prophage_kltype_associated , df_info, KL_type, ratio , f_train, f_test, f_eval) : 
    # **************************************************************
    # Indexation process  
    indexation_nodes_A = df_info["Infected_ancestor"].unique().tolist()  
    indexation_nodes_B1 = df_info["Phage"].unique().tolist()
    indexation_nodes_B2 = df_info["index"].unique().tolist() 
    ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
    ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    # **************************************************************
    # Make the Y file : 
    B1_labels = df_info.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"].apply(lambda x : 1 if x == KL_type else 0).to_list()
    graph_data["B1"].y = torch.tensor(B1_labels)
    # **************************************************************
    # Make mask files :
    # get the positive and negative indices lists :
    positive_indices = [index for index,label in enumerate(B1_labels) if label==1]
    negative_indices = []
    for negative_index,phage in enumerate(df_info["Phage"].unique().tolist()) :
        if KL_type not in dico_prophage_kltype_associated[ID_nodes_B1_r[negative_index]] :
            negative_indices.append(negative_index)
    # make the train, test, val lists : 
    n_samples = len(positive_indices)
    #train_indices, test_indices, val_indices = [],[],[]
    # make train : 
    train_pos = random.sample(positive_indices, int(f_train*n_samples))
    train_neg = random.sample(negative_indices, int(f_train*n_samples*ratio))
    train_indices = train_pos + train_neg
    train_mask = [1 if n in train_indices else 0 for n in range(0,len(B1_labels))]
    # make test : 
    pool_positives_test = list(set(positive_indices) - set(train_pos))
    pool_negatives_test = list(set(negative_indices) - set(train_neg))
    test_pos = random.sample(pool_positives_test, int(f_test*n_samples))
    test_neg = random.sample(pool_negatives_test, int(f_test*n_samples*ratio))
    test_indices = test_pos + test_neg
    test_mask = [1 if n in test_indices else 0 for n in range(0,len(B1_labels))]
    # make eval
    pool_positives_eval = list(set(positive_indices) - set(train_pos) - set(test_pos))
    pool_negatives_eval = list(set(negative_indices) - set(train_neg) - set(test_neg))
    eval_pos = random.sample(pool_positives_eval, int(f_eval*n_samples))
    eval_neg = random.sample(pool_negatives_eval, int(f_eval*n_samples*ratio))
    eval_indices = eval_pos + eval_neg
    eval_mask = [1 if n in eval_indices else 0 for n in range(0,len(B1_labels))]
    # Transfer data to graph :
    graph_data["B1"].train_mask = torch.tensor(train_mask)
    graph_data["B1"].test_mask = torch.tensor(test_mask)
    graph_data["B1"].eval_mask = torch.tensor(eval_mask)

    return graph_data


In [9]:
import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings

import TropiGAT_functions
warnings.filterwarnings("ignore") 

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"

    # Open the DF
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)

# Ambiguous ones :
# level 0 :
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0_filter1 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)
#dico_prophage_kltype = {row["Phage"]:row["KL_type_LCA"] for _,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows()}

# level 1 :
#DF_info_lvl_1 = pd.read_csv(f"{path_work}/TropiGATv2.ambiguity.lvl_1.tsv", sep = "," ,  header = 0)
#DF_info_lvl_1 = DF_info_lvl_1.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

# level 2 :
#DF_info_lvl_2 = pd.read_csv(f"{path_work}/TropiGATv2.ambiguity.lvl_2.tsv", sep = "," ,  header = 0)
#DF_info_lvl_2 = DF_info_lvl_2.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)


In [8]:
DF_info

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
1,GCF_015910145.1__phage5,GCF_015910145.1__phage5__1351,KL122|KL106,n4984,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
2,GCF_900502315.1__phage13,GCF_900502315.1__phage13__356,KL122|KL106,n4984,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
3,GCF_004803085.1__phage3,GCF_004803085.1__phage3__24,KL122|KL106,n4984,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
4,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21345,GCF_900407145.1__phage13,GCF_900407145.1__phage13__48,KL105,GCF_900407145.1,anubis_return__4247,anubis_return,MAEVPLPTPTQAPVPSTDIRNAVFAGAKLDEEVTGTGEFYTDRLGV...,NIASYDVTWFGAVASDDTATYTAANTVSIQNALNAAEKAGLAAVWF...,-0.002267,0.042625,...,0.013049,0.112762,0.002557,0.024400,0.066527,0.052364,-0.092847,0.054733,0.077084,prophage_8459
21346,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,0.073450,0.046651,...,0.035302,0.012151,0.003563,-0.022575,0.014130,0.063376,-0.050646,-0.085156,-0.010849,prophage_6002
21347,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656
21348,GCF_900172635.1__phage2,GCF_900172635.1__phage2__1608,KL124,GCF_900172635.1,anubis_return__4287,anubis_return,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,-0.011089,-0.005328,...,0.034656,0.046130,0.012586,-0.021702,-0.023386,0.105700,-0.099147,-0.057367,0.091427,prophage_12780


In [10]:
DF_info_lvl_0_filter1

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
1,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
2,GCF_001701985.1__phage2,GCF_001701985.1__phage2__357,KL30,n4988,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_6465
3,GCF_001611095.1__phage5,GCF_001611095.1__phage5__1365,KL30,n49894989,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
4,GCF_902156555.1__phage3,GCF_902156555.1__phage3__511,KL30,GCF_902156555.1,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_1828
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11937,GCF_900407145.1__phage13,GCF_900407145.1__phage13__48,KL105,GCF_900407145.1,anubis_return__4247,anubis_return,MAEVPLPTPTQAPVPSTDIRNAVFAGAKLDEEVTGTGEFYTDRLGV...,NIASYDVTWFGAVASDDTATYTAANTVSIQNALNAAEKAGLAAVWF...,-0.002267,0.042625,...,0.013049,0.112762,0.002557,0.024400,0.066527,0.052364,-0.092847,0.054733,0.077084,prophage_8459
11938,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,0.073450,0.046651,...,0.035302,0.012151,0.003563,-0.022575,0.014130,0.063376,-0.050646,-0.085156,-0.010849,prophage_6002
11939,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656
11940,GCF_900172635.1__phage2,GCF_900172635.1__phage2__1608,KL124,GCF_900172635.1,anubis_return__4287,anubis_return,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,-0.011089,-0.005328,...,0.034656,0.046130,0.012586,-0.021702,-0.023386,0.105700,-0.099147,-0.057367,0.091427,prophage_12780


In [2]:
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]


In [10]:
sorted_dico_prophage_count = dict(sorted(dico_prophage_count.items(), key=lambda x: float(x[0].split("KL")[1])  ,reverse=False))
sorted_dico_prophage_count

{'KL1': 166,
 'KL2': 364,
 'KL3': 127,
 'KL4': 11,
 'KL5': 23,
 'KL6': 8,
 'KL7': 36,
 'KL8': 22,
 'KL9': 23,
 'KL10': 101,
 'KL11': 6,
 'KL12': 40,
 'KL13': 71,
 'KL14': 123,
 'KL15': 214,
 'KL16': 22,
 'KL17': 461,
 'KL18': 23,
 'KL19': 79,
 'KL20': 29,
 'KL21': 77,
 'KL22': 88,
 'KL23': 117,
 'KL24': 276,
 'KL25': 281,
 'KL26': 16,
 'KL27': 142,
 'KL28': 113,
 'KL29': 36,
 'KL30': 101,
 'KL31': 13,
 'KL33': 1,
 'KL34': 20,
 'KL35': 19,
 'KL36': 89,
 'KL37': 7,
 'KL38': 92,
 'KL39': 25,
 'KL40': 1,
 'KL41': 12,
 'KL42': 10,
 'KL43': 49,
 'KL45': 63,
 'KL46': 81,
 'KL47': 582,
 'KL48': 27,
 'KL49': 3,
 'KL51': 128,
 'KL52': 40,
 'KL53': 48,
 'KL54': 13,
 'KL55': 25,
 'KL56': 13,
 'KL57': 57,
 'KL58': 7,
 'KL59': 1,
 'KL60': 69,
 'KL61': 6,
 'KL62': 126,
 'KL63': 59,
 'KL64': 886,
 'KL66': 10,
 'KL67': 13,
 'KL70': 30,
 'KL71': 16,
 'KL74': 78,
 'KL81': 43,
 'KL82': 5,
 'KL101': 1,
 'KL102': 269,
 'KL103': 12,
 'KL104': 5,
 'KL105': 123,
 'KL106': 463,
 'KL107': 1066,
 'KL108': 28,
 'K

In [11]:
len(sorted_dico_prophage_count)

129

In [6]:
%%time
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)

8815it [00:20, 427.60it/s]

CPU times: user 41.3 s, sys: 176 ms, total: 41.4 s
Wall time: 41.8 s





In [16]:
%%time
graph_KL64 = TropiGAT_graph.build_graph_masking(graph_baseline , 
                                 dico_prophage_kltype_associated, 
                                 DF_info_lvl_0, 
                                 "KL64",
                                 2, 0.80, 0.1,0.1)

CPU times: user 225 ms, sys: 36.4 ms, total: 261 ms
Wall time: 261 ms


In [24]:
graph_KL64["B1"]

{'x': tensor([], size=(8815, 0)), 'y': tensor([0, 0, 0,  ..., 0, 0, 0]), 'train_mask': tensor([False, False, False,  ...,  True,  True, False]), 'test_mask': tensor([False, False, False,  ..., False, False, False]), 'eval_mask': tensor([False, False, False,  ..., False, False, False])}

In [26]:
%%time
graph_KL37 = build_graph_masking(graph_baseline , 
                                 dico_prophage_kltype_associated, 
                                 DF_info_lvl_0, 
                                 "KL37",
                                 2, 0.8, 0.1,0.1)

CPU times: user 83.8 ms, sys: 32.5 ms, total: 116 ms
Wall time: 529 ms


In [28]:
Counter(graph_KL64["B1"].y.numpy())

Counter({0: 7929, 1: 886})

In [31]:
len(graph_KL64["B1"].y[graph_KL64["B1"].train_mask])

2125

In [32]:
len(graph_KL64["B1"].y)

8815