In [2]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 



In [78]:
import TropiGAT_functions
import 

In [6]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"

    # Open the DF
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
dico_prophage_kltype = {row["Phage"]:row["KL_type_LCA"] for _,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows()}

# Ambiguous ones :
# level 0 :
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

# level 1 :
#DF_info_lvl_1 = pd.read_csv(f"{path_work}/TropiGATv2.ambiguity.lvl_1.tsv", sep = "," ,  header = 0)
#DF_info_lvl_1 = DF_info_lvl_1.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

# level 2 :
#DF_info_lvl_2 = pd.read_csv(f"{path_work}/TropiGATv2.ambiguity.lvl_2.tsv", sep = "," ,  header = 0)
#DF_info_lvl_2 = DF_info_lvl_2.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)


In [7]:
DF_info_lvl_0

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_016651625.1__phage29,GCF_016651625.1__phage29__142,KL43,GCF_016651625.1,ppt__1,ppt,MSVPNQTPYNIYTANGLTTVFTYEFYIISASDLRVSINGDVVTSGY...,KDFVNINDYWFPTDGDDFYPALNKALSVSPHVLIPPGKHYLKSTVS...,-0.018416,0.022387,...,0.004437,0.087907,0.015800,0.025778,0.065790,0.034045,-0.070899,0.016068,0.065339,prophage_12186
1,GCF_016651625.1__phage29,GCF_016651625.1__phage29__150,KL43,GCF_016651625.1,anubis__0,anubis,MRANLIKTNFTAGEISPRLMGRVDIARYANGAKIIENAVCVVQGGV...,QAASPGAWTREDTVWTEEFGYPGAVTLYQQRLVLAGSPQYPQTIWW...,0.036016,0.005938,...,-0.037612,0.008772,0.010556,-0.049738,-0.012549,0.092624,-0.136602,-0.191378,0.135658,prophage_12186
2,GCF_016651625.1__phage12,GCF_016651625.1__phage12__59,KL43,GCF_016651625.1,ppt__4,ppt,MSISKRNFLKAVSCAYFFYSFKALTKVNQPIEDYISTKDKNTWPSK...,NTWPSKVHRVEEFYTSTDRDYSDAILRGINYCSLNNCVLFFSDKYK...,0.026004,0.024372,...,-0.026018,0.018206,0.036751,-0.032549,0.064112,0.061520,-0.024423,-0.027998,0.028089,prophage_924
3,GCF_019928025.1__phage0,GCF_019928025.1__phage0__10,KL43,n1471,ppt__4,ppt,MSISKRNFLKAVSCAYFFYSFKALTKVNQPIEDYISTKDKNTWPSK...,NTWPSKVHRVEEFYTSTDRDYSDAILRGINYCSLNNCVLFFSDKYK...,0.026004,0.024372,...,-0.026018,0.018206,0.036751,-0.032549,0.064112,0.061520,-0.024423,-0.027998,0.028089,prophage_2929
4,GCF_004313505.1__phage4,GCF_004313505.1__phage4__113,KL14,GCF_004313505.1,anubis__5,anubis,MSEYDTGNPVPSASMPDAWDNMQSIDKFVNSSDETITTRTGQQLDT...,KAIFDAWLDFGIDWNGNESISLQLQTAVNYVSKLPYGGEIVLRPGV...,-0.023648,0.052674,...,-0.025991,0.068538,-0.051192,0.026481,0.069100,0.017813,-0.103797,0.018961,0.117058,prophage_11091
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11097,GCF_002248635.1__phage4,GCF_002248635.1__phage4__44,KL102,n320,anubis_return__4264,anubis_return,MVSLKGMGSTFRDCTALISLPSGLLDGCINLTSLTLTFSGCTSLAL...,MVSLKGMGSTFRDCTALISLPSGLLDGCINLTSLTLTFSGCTSLAL...,-0.000585,-0.087093,...,0.036749,0.048489,0.020484,0.023950,-0.048109,0.134457,-0.101326,0.088485,0.037368,prophage_3054
11098,GCF_001905235.1__phage21,GCF_001905235.1__phage21__0,KL107,n35403540,anubis_return__4272,anubis_return,MLKHSLAIATCLAFSSSVMGNEANLLYTNTMQFPYKHNADGYMVFD...,VMGNEANLLYTNTMQFPYKHNADGYMVFDIHGKLVVPPEGHFDTLN...,0.076721,0.027635,...,0.042391,-0.004292,-0.004047,-0.011631,-0.026469,0.070159,-0.077212,-0.077950,-0.034630,prophage_313
11099,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656
11100,GCF_900172635.1__phage2,GCF_900172635.1__phage2__1608,KL124,GCF_900172635.1,anubis_return__4287,anubis_return,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,MADLSISVISDQASESNQAGWWHPLDSFQGVEYYGLCKEYGTAGYH...,-0.011089,-0.005328,...,0.034656,0.046130,0.012586,-0.021702,-0.023386,0.105700,-0.099147,-0.057367,0.091427,prophage_12780


In [81]:
%%time
dico_prophage_kltype_associated = {}
for negative_index,phage in tqdm(enumerate(DF_info_lvl_0["Phage"].unique().tolist())) :
    kltypes = set()
    dpos = DF_info_lvl_0[DF_info_lvl_0["Phage"] == phage]["index"]
    for dpo in dpos : 
        tmp_kltypes = DF_info_lvl_0[DF_info_lvl_0["index"] == dpo]["KL_type_LCA"].values
        kltypes.update(tmp_kltypes)
    dico_prophage_kltype_associated[phage] = kltypes

8815it [00:17, 500.81it/s]

CPU times: user 18 s, sys: 116 ms, total: 18.1 s
Wall time: 17.6 s





In [None]:
dico_prophage_kltype = {row["Phage"]:row["KL_type_LCA"] for _,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows()}

In [82]:
import pprint
clean_print(dico_prophage_kltype_associated)

{'GCF_000019565.1__phage12': {'KL30'},
 'GCF_000240185.1__phage5': {'KL10', 'KL103', 'KL105', 'KL106', 'KL107', 'KL108', 'KL111', 'KL112', 'KL125', 'KL14', 'KL15', 'KL155', 'KL19', 'KL2',
                             'KL21', 'KL23', 'KL24', 'KL27', 'KL28', 'KL36', 'KL47', 'KL52', 'KL55', 'KL64', 'KL74'},
 'GCF_000240185.1__phage6': {'KL103', 'KL105', 'KL106', 'KL107', 'KL108', 'KL111', 'KL112', 'KL15', 'KL155', 'KL19', 'KL2', 'KL21', 'KL24', 'KL27',
                             'KL28', 'KL36', 'KL38', 'KL46', 'KL47', 'KL52', 'KL64'},
 'GCF_000276705.2__phage3': {'KL145'},
 'GCF_000281335.1__phage14': {'KL10', 'KL103', 'KL105', 'KL106', 'KL107', 'KL108', 'KL111', 'KL112', 'KL125', 'KL14', 'KL15', 'KL155', 'KL19', 'KL2',
                              'KL21', 'KL23', 'KL24', 'KL27', 'KL28', 'KL36', 'KL47', 'KL52', 'KL55', 'KL64', 'KL74'},
 'GCF_000281335.1__phage16': {'KL15', 'KL47', 'KL107', 'KL106', 'KL27', 'KL36', 'KL24', 'KL21', 'KL64', 'KL14'},
 'GCF_000281335.1__phage25': {'KL107'},

In [14]:
dico_prophage_kltype = {row["Phage"]:row["KL_type_LCA"] for _,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows()}




In [99]:
def build_graph_baseline(df_info) : 
    # **************************************************************
    # initialize the graph
    graph_data = HeteroData()
    # Indexation process  
    indexation_nodes_A = df_info["Infected_ancestor"].unique().tolist()  
    indexation_nodes_B1 = df_info["Phage"].unique().tolist()
    indexation_nodes_B2 = df_info["index"].unique().tolist() 
    ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
    ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    # **************************************************************
    # Make the node feature file : 
    OHE = OneHotEncoder(sparse=False)
    one_hot_encoded = OHE.fit_transform(df_info[["KL_type_LCA"]])
    label_mapping = {label: one_hot_encoded[i] for i, label in enumerate(OHE.categories_[0])}
    embeddings_columns = [str(i) for i in range(1, 1281)]
    node_feature_A = torch.tensor([label_mapping[df_info[df_info["Infected_ancestor"] == ID_nodes_A_r[i]]["KL_type_LCA"].values[0]] for i in range(0,len(ID_nodes_A_r))], dtype=torch.float)
    node_feature_B1 = torch.zeros((len(ID_nodes_B1), 0), dtype=torch.float)
    node_feature_B2 = torch.tensor([df_info[df_info["index"] == ID_nodes_B2_r[i]][embeddings_columns].values[0].tolist() for i in range(0,len(ID_nodes_B2_r))] , dtype=torch.float)
    # feed the graph
    graph_data["A"].x = node_feature_A
    graph_data["B1"].x = node_feature_B1
    graph_data["B2"].x = node_feature_B2
    # **************************************************************
    # Make edge file
    # Node B1 (prophage) - Node A (bacteria) :
    edge_index_B1_A = []
    track_B1_A = set()
    for _, row in df_info.iterrows() :
        pair = [ID_nodes_B1[row["Phage"]], ID_nodes_A[row["Infected_ancestor"]]]
        if tuple(pair) not in track_B1_A : 
            track_B1_A.add(tuple(pair))
            edge_index_B1_A.append(pair)
        else :
            continue
    edge_index_B1_A = torch.tensor(edge_index_B1_A , dtype=torch.long)
    # Node A (bacteria) - Node B1 (prophage) :
    edge_index_A_B1 = []
    track_A_B1 = set()
    for _, row in df_info.iterrows() :
        pair = [ID_nodes_A[row["Infected_ancestor"]] , ID_nodes_B1[row["Phage"]]]
        if tuple(pair) not in track_A_B1 :
            track_A_B1.add(tuple(pair))
            edge_index_A_B1.append(pair)
    edge_index_A_B1 = torch.tensor(edge_index_A_B1 , dtype=torch.long)
    # Node B2 (depolymerase) - Node B1 (prophage) :
    edge_index_B2_B1 = []
    for phage in df_info.Phage.unique() :
        all_data_phage = df_info[df_info["Phage"] == phage]
        for _, row in all_data_phage.iterrows() :
            edge_index_B2_B1.append([ID_nodes_B2[row["index"]], ID_nodes_B1[row["Phage"]]])
    edge_index_B2_B1 = torch.tensor(edge_index_B2_B1 , dtype=torch.long)
    # feed the graph
    graph_data['B1', 'infects', 'A'].edge_index = edge_index_B1_A.t().contiguous()
    graph_data['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    # That one is optional  
    graph_data['A', 'harbors', 'B1'].edge_index = edge_index_A_B1.t().contiguous()
    
    return graph_data 



In [100]:
graph_lvl_0 = build_graph_baseline(DF_info_lvl_0)

In [24]:
graph_lvl_0 

HeteroData(
  [1mA[0m={ x=[5069, 129] },
  [1mB1[0m={ x=[8815, 0] },
  [1mB2[0m={ x=[4105, 1280] },
  [1m(B1, infects, A)[0m={ edge_index=[2, 8815] },
  [1m(B2, expressed, B1)[0m={ edge_index=[2, 11102] },
  [1m(A, harbors, B1)[0m={ edge_index=[2, 8815] }
)

In [71]:
indexation_nodes_A = DF_info_lvl_0["Infected_ancestor"].unique().tolist()  
indexation_nodes_B1 = DF_info_lvl_0["Phage"].unique().tolist()
indexation_nodes_B2 = DF_info_lvl_0["index"].unique().tolist() 
ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}

In [None]:
DF_info_lvl_0_phages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"], keep = "first")

In [54]:
%%time
y = DF_info_lvl_0.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"].apply(lambda x : 1 if x == "KL64" else 0).to_list()
y_torch = torch.tensor(y)

CPU times: user 19.4 ms, sys: 20 ms, total: 39.4 ms
Wall time: 37.9 ms


In [101]:
def build_graph_masking(graph_data , df_info, KL_type, ratio , f_train, f_test, f_eval) : 
    # **************************************************************
    # Indexation process  
    indexation_nodes_A = df_info["Infected_ancestor"].unique().tolist()  
    indexation_nodes_B1 = df_info["Phage"].unique().tolist()
    indexation_nodes_B2 = df_info["index"].unique().tolist() 
    ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
    ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
    ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
    ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    # **************************************************************
    # Make the Y file : 
    B1_labels = df_info.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"].apply(lambda x : 1 if x == KL_type else 0).to_list()
    graph_data["B1"].y = torch.tensor(B1_labels)
    # **************************************************************
    # Make mask files :
    # get the positive and negative indices lists :
    positive_indices = [index for index,label in enumerate(B1_labels) if label==1]
    negative_indices = []
    for negative_index,phage in enumerate(DF_info_lvl_0["Phage"].unique().tolist()) :
        if KL_type not in dico_prophage_kltype_associated[ID_nodes_B1_r[negative_index]] :
            negative_indices.append(negative_index)
    # make the train, test, val lists : 
    n_samples = len(positive_indices)
    #train_indices, test_indices, val_indices = [],[],[]
    # make train : 
    train_pos = random.sample(positive_indices, int(f_train*n_samples))
    train_neg = random.sample(negative_indices, int(f_train*n_samples*ratio))
    train_indices = train_pos + train_neg
    train_mask = [1 if n in train_indices else 0 for n in range(0,len(B1_labels))]
    # make test : 
    pool_positives_test = list(set(positive_indices) - set(train_pos))
    pool_negatives_test = list(set(negative_indices) - set(train_neg))
    test_pos = random.sample(pool_positives_test, int(f_test*n_samples))
    test_neg = random.sample(pool_negatives_test, int(f_test*n_samples*ratio))
    test_indices = test_pos + test_neg
    test_mask = [1 if n in test_indices else 0 for n in range(0,len(B1_labels))]
    # make eval
    pool_positives_eval = list(set(positive_indices) - set(train_pos) - set(test_pos))
    pool_negatives_eval = list(set(negative_indices) - set(train_neg) - set(test_neg))
    eval_pos = random.sample(pool_positives_eval, int(f_eval*n_samples))
    eval_neg = random.sample(pool_negatives_eval, int(f_eval*n_samples*ratio))
    eval_indices = eval_pos + eval_neg
    eval_mask = [1 if n in eval_indices else 0 for n in range(0,len(B1_labels))]
    # Transfer data to graph :
    graph_data["B1"].train_mask = torch.tensor(train_mask)
    graph_data["B1"].test_mask = torch.tensor(test_mask)
    graph_data["B1"].eval_mask = torch.tensor(eval_mask)

    return graph_data




In [103]:
graph_lvl_0_KL64 = build_graph_masking(graph_lvl_0 ,DF_info_lvl_0, "KL64", 2, 0.8,0.1,0.1 )

graph_lvl_0_KL64

HeteroData(
  [1mA[0m={ x=[5069, 129] },
  [1mB1[0m={
    x=[8815, 0],
    y=[8815],
    train_mask=[8815],
    test_mask=[8815],
    eval_mask=[8815]
  },
  [1mB2[0m={ x=[4105, 1280] },
  [1m(B1, infects, A)[0m={ edge_index=[2, 8815] },
  [1m(B2, expressed, B1)[0m={ edge_index=[2, 11102] },
  [1m(A, harbors, B1)[0m={ edge_index=[2, 8815] }
)

In [None]:
def build_graph_masking(graph_data , df_info, KL_type, ratio , f_train, f_test, f_eval) : 


In [86]:
negative_indices = []
for negative_index,phage in enumerate(DF_info_lvl_0["Phage"].unique().tolist()) :
    if "KL64" not in dico_prophage_kltype_associated[ID_nodes_B1_r[negative_index]] :
        negative_indices.append(negative_index)
len(negative_indices)

6386

In [69]:
y = DF_info_lvl_0.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"].apply(lambda x : 1 if x == "KL64" else 0).to_list()
positive_indices = [index for index,label in enumerate(y) if label==1]

positive_indices

[379,
 381,
 384,
 390,
 392,
 393,
 396,
 401,
 402,
 404,
 411,
 412,
 413,
 419,
 423,
 427,
 428,
 429,
 432,
 454,
 458,
 461,
 466,
 481,
 486,
 492,
 497,
 498,
 500,
 501,
 503,
 504,
 508,
 512,
 515,
 516,
 519,
 521,
 522,
 525,
 528,
 536,
 541,
 543,
 547,
 548,
 549,
 552,
 557,
 560,
 562,
 563,
 566,
 568,
 569,
 572,
 574,
 578,
 582,
 586,
 588,
 589,
 590,
 600,
 610,
 613,
 614,
 615,
 616,
 619,
 623,
 625,
 626,
 627,
 636,
 641,
 644,
 646,
 651,
 653,
 654,
 662,
 663,
 666,
 668,
 673,
 984,
 989,
 990,
 991,
 992,
 993,
 994,
 996,
 997,
 998,
 999,
 1000,
 1001,
 1002,
 1003,
 1004,
 1005,
 1006,
 1009,
 1010,
 1011,
 1013,
 1015,
 1016,
 1020,
 1022,
 1025,
 1027,
 1028,
 1030,
 1031,
 1032,
 1033,
 1034,
 1040,
 1041,
 1043,
 1045,
 1046,
 1047,
 1048,
 1050,
 1051,
 1052,
 1054,
 1055,
 1057,
 1062,
 1063,
 1064,
 1065,
 1066,
 1067,
 1069,
 1070,
 1072,
 1073,
 1074,
 1076,
 1077,
 1079,
 1083,
 1084,
 1086,
 1089,
 1094,
 1096,
 1099,
 1100,
 1101,
 1102

In [96]:
n_samples = len(positive_indices)
train_pos = random.sample(positive_indices, int(0.8*n_samples))
train_neg = random.sample(negative_indices, int(0.8*n_samples*2))
len(train_neg) , len(train_pos)

train_indices = train_neg + train_pos
train_mask = [1 if n in train_indices else 0 for n in range(0,10000)]

train_mask

[0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,


In [75]:
DF_info_lvl_0[DF_info_lvl_0["Phage"] == "GCF_003037915.1__phage4"]

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
636,GCF_003037915.1__phage4,GCF_003037915.1__phage4__93,KL64,n2575,anubis__13,anubis,MSEYDTGNPVPSASMPDAWDNMQSIDKFVNSSEETITTRTGEQLDT...,ALSNEVEIYRNGNRDNPRDRVLYREFSRIGRNGALTERIVKDIPTG...,-0.017392,0.117156,...,0.001901,0.111698,-0.000487,0.01816,0.046598,0.024505,-0.104585,0.059359,0.046143,prophage_14


In [65]:
#graph_lvl_0_KL64 = build_graph_y_files(graph_lvl_0, DF_info_lvl_0, "KL64")
graph_lvl_0_KL64

HeteroData(
  [1mA[0m={ x=[5069, 129] },
  [1mB1[0m={
    x=[8815, 0],
    y=[8815]
  },
  [1mB2[0m={ x=[4105, 1280] },
  [1m(B1, infects, A)[0m={ edge_index=[2, 8815] },
  [1m(B2, expressed, B1)[0m={ edge_index=[2, 11102] },
  [1m(A, harbors, B1)[0m={ edge_index=[2, 8815] }
)

In [64]:
dico_count = dict(Counter(DF_info_lvl_0.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"]))
TropiGAT_functions.clean_print(dico_count)

{'KL1': 166,
 'KL10': 101,
 'KL101': 1,
 'KL102': 269,
 'KL103': 12,
 'KL104': 5,
 'KL105': 123,
 'KL106': 463,
 'KL107': 1066,
 'KL108': 28,
 'KL109': 15,
 'KL11': 6,
 'KL110': 62,
 'KL111': 69,
 'KL112': 69,
 'KL113': 4,
 'KL114': 22,
 'KL115': 3,
 'KL116': 27,
 'KL117': 18,
 'KL118': 26,
 'KL119': 5,
 'KL12': 40,
 'KL120': 1,
 'KL121': 1,
 'KL122': 35,
 'KL123': 35,
 'KL124': 15,
 'KL125': 30,
 'KL126': 7,
 'KL127': 36,
 'KL128': 20,
 'KL13': 71,
 'KL130': 1,
 'KL131': 3,
 'KL132': 3,
 'KL134': 4,
 'KL136': 32,
 'KL137': 6,
 'KL139': 11,
 'KL14': 123,
 'KL140': 13,
 'KL141': 8,
 'KL142': 15,
 'KL143': 11,
 'KL144': 1,
 'KL145': 28,
 'KL146': 6,
 'KL147': 3,
 'KL148': 2,
 'KL149': 64,
 'KL15': 214,
 'KL150': 3,
 'KL151': 45,
 'KL152': 10,
 'KL153': 17,
 'KL154': 1,
 'KL155': 12,
 'KL157': 11,
 'KL158': 6,
 'KL159': 13,
 'KL16': 22,
 'KL162': 4,
 'KL163': 5,
 'KL164': 7,
 'KL165': 1,
 'KL166': 8,
 'KL169': 29,
 'KL17': 461,
 'KL170': 2,
 'KL18': 23,
 'KL19': 79,
 'KL2': 364,
 'KL20': 

In [None]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['paper'].train_mask
    loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)