In [43]:
# Torch geometric modules
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

# Torch modules
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# SKlearn modules
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
path_work = "/media/concha-eloko/Linux/PPT_clean"

DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)
DF_info = DF_info.drop_duplicates(subset = ["Protein_name"])
DF_info = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]

df_prophages = DF_info.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_info = {row["Phage"] : {"prophage_strain" : row["prophage_id"] , "ancestor" : row["Infected_ancestor"], "KL_type" : row["KL_type_LCA"]} for _,row in df_prophages.iterrows()}


In [44]:
DF_info

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
4,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
10,GCF_001701985.1__phage2,GCF_001701985.1__phage2__357,KL30,n4988,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_6465
12,GCF_001611095.1__phage5,GCF_001611095.1__phage5__1365,KL30,n49894989,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
15,GCF_902156555.1__phage3,GCF_902156555.1__phage3__511,KL30,GCF_902156555.1,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_1828
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21342,GCF_900506765.1__phage17,GCF_900506765.1__phage17__90,KL149,GCF_900506765.1,anubis_return__4216,anubis_return,MMTTLNEHPQWESDIYLIKRSDLVAGGRGGIANMQAQQLANRTAFL...,NRRWFRRFTGNIRAEWSGIHDLSQSSAPVDSYIYRLLLASAVGSPD...,0.053704,0.083858,...,0.032803,0.109572,0.010032,0.024949,0.094129,0.028693,-0.061396,0.006824,0.046220,prophage_15598
21344,GCF_003255785.1__phage1,GCF_003255785.1__phage1__10,KL127,GCF_003255785.1,anubis_return__4239,anubis_return,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,0.010626,-0.025389,...,0.045372,0.009262,-0.008319,-0.050856,0.034115,0.101663,-0.108278,-0.135629,0.102486,prophage_3577
21346,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,0.073450,0.046651,...,0.035302,0.012151,0.003563,-0.022575,0.014130,0.063376,-0.050646,-0.085156,-0.010849,prophage_6002
21347,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656


In [45]:
def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 :
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage)
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15218/15218 [00:15<00:00, 995.73it/s]


In [13]:
DF_info_lvl_0_filtered

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
4,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
10,GCF_001701985.1__phage2,GCF_001701985.1__phage2__357,KL30,n4988,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_6465
12,GCF_001611095.1__phage5,GCF_001611095.1__phage5__1365,KL30,n49894989,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
15,GCF_902156555.1__phage3,GCF_902156555.1__phage3__511,KL30,GCF_902156555.1,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_1828
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21342,GCF_900506765.1__phage17,GCF_900506765.1__phage17__90,KL149,GCF_900506765.1,anubis_return__4216,anubis_return,MMTTLNEHPQWESDIYLIKRSDLVAGGRGGIANMQAQQLANRTAFL...,NRRWFRRFTGNIRAEWSGIHDLSQSSAPVDSYIYRLLLASAVGSPD...,0.053704,0.083858,...,0.032803,0.109572,0.010032,0.024949,0.094129,0.028693,-0.061396,0.006824,0.046220,prophage_15598
21344,GCF_003255785.1__phage1,GCF_003255785.1__phage1__10,KL127,GCF_003255785.1,anubis_return__4239,anubis_return,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,0.010626,-0.025389,...,0.045372,0.009262,-0.008319,-0.050856,0.034115,0.101663,-0.108278,-0.135629,0.102486,prophage_3577
21346,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,0.073450,0.046651,...,0.035302,0.012151,0.003563,-0.022575,0.014130,0.063376,-0.050646,-0.085156,-0.010849,prophage_6002
21347,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656


***
## Check for duplicates : 

There are duplicate sets of depolymerases, encoded by prophages. We argue that they follow the natural distribution and therefore allow for weighting the importance of depolymerase.

Then what we want is to finetune esm2 models.

In [46]:
duplicate_prophage = []
dico_kltype_duplica = {}
for kltype in DF_info_lvl_0_filtered["KL_type_LCA"].unique():
    df_kl = DF_info_lvl_0_filtered[DF_info_lvl_0_filtered["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
    prophages_tmp_list = df_kl["Phage"].unique().tolist()
    set_sets_depo = []
    duplicated = {}  
    for prophage_tmp in prophages_tmp_list: 
        set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
        for past_set in set_sets_depo:
            if past_set == set_depo:
                duplicated[past_set] = duplicated.get(past_set, 0) + 1
                duplicate_prophage.append(prophage_tmp)
                break
        else:
            set_sets_depo.append(set_depo)
            duplicated[set_depo] = 1
    dico_kltype_duplica[kltype] = duplicated

DF_info_lvl_0_final

In [47]:
for kltype in dico_kltype_duplica :
    print(kltype, dico_kltype_duplica[kltype].values(), len(dico_kltype_duplica[kltype].values()), sep = "\n")
    print("\n")

KL41
dict_values([1, 1, 1, 1, 1, 1, 1, 1, 2])
9


KL30
dict_values([3, 3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 7, 3, 1, 1, 1, 2, 1, 1, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 3, 1])
63


KL6
dict_values([1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1])
11


KL19
dict_values([6, 1, 2, 1, 4, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 12, 2, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
45


KL25
dict_values([4, 1, 2, 1, 1, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 1, 5, 5, 1, 1, 1, 3, 10, 1, 2, 7, 10, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 4, 1, 1, 1, 4, 2, 1, 1, 2, 1, 1, 2, 3, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 14, 1, 3, 7, 1, 1, 1, 1, 1, 1, 1, 23, 1, 3, 1, 3, 1, 1])
141


KL123
dict_values([2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [48]:
DF_info_lvl_0_ultra_filtered = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["Phage"].isin(duplicate_prophage)]
DF_info_lvl_0_ultra_filtered

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
4,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_5
10,GCF_001701985.1__phage2,GCF_001701985.1__phage2__357,KL30,n4988,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,0.004905,0.040896,...,-0.040657,0.087288,0.022292,0.024434,0.025246,0.083449,-0.123537,0.047648,0.061250,prophage_6465
21,GCF_900622625.1__phage2,GCF_900622625.1__phage2__2892,KL6,GCF_900622625.1,minibatch__1741,minibatch,MAFNPELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRK...,ELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRKMQEDN...,-0.003624,-0.032987,...,-0.075218,-0.010726,0.008995,-0.001741,-0.103979,0.119088,-0.124593,0.011745,0.147883,prophage_4098
22,GCF_011044795.1__phage17,GCF_011044795.1__phage17__11,KL19,80.7/1001331,minibatch__467,minibatch,MNRSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEK...,RSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEKDV...,0.038219,0.037305,...,0.010699,0.013225,0.038260,-0.001471,0.040612,0.066368,-0.078655,0.031434,0.080821,prophage_4997
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21342,GCF_900506765.1__phage17,GCF_900506765.1__phage17__90,KL149,GCF_900506765.1,anubis_return__4216,anubis_return,MMTTLNEHPQWESDIYLIKRSDLVAGGRGGIANMQAQQLANRTAFL...,NRRWFRRFTGNIRAEWSGIHDLSQSSAPVDSYIYRLLLASAVGSPD...,0.053704,0.083858,...,0.032803,0.109572,0.010032,0.024949,0.094129,0.028693,-0.061396,0.006824,0.046220,prophage_15598
21344,GCF_003255785.1__phage1,GCF_003255785.1__phage1__10,KL127,GCF_003255785.1,anubis_return__4239,anubis_return,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,0.010626,-0.025389,...,0.045372,0.009262,-0.008319,-0.050856,0.034115,0.101663,-0.108278,-0.135629,0.102486,prophage_3577
21346,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,0.073450,0.046651,...,0.035302,0.012151,0.003563,-0.022575,0.014130,0.063376,-0.050646,-0.085156,-0.010849,prophage_6002
21347,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,0.032196,0.048856,...,-0.016331,0.084711,0.056063,0.001793,0.073958,0.090169,-0.060105,0.023726,0.086452,prophage_12656


In [50]:
DF_info_lvl_0_ultra_filtered[DF_info_lvl_0_ultra_filtered["KL_type_LCA"] == "KL25"].drop_duplicates(subset = ["Phage"])

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
28,GCF_019096335.1__phage21,GCF_019096335.1__phage21__173,KL25,n12421242,minibatch__15,minibatch,MYHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNT...,YHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNTV...,0.008351,0.003703,...,0.019730,0.102699,0.003058,0.044005,0.042251,0.042038,-0.098792,0.009710,0.120689,prophage_8486
233,GCF_900502245.1__phage6,GCF_900502245.1__phage6__383,KL25,n5220,minibatch__771,minibatch,MANIEKLGSSSPEVLLKNATNLDKLVNGRESESLPDRFGVLRKTWH...,IEKLGSSSPEVLLKNATNLDKLVNGRESESLPDRFGVLRKTWHGME...,0.023976,0.031226,...,-0.068705,0.049401,0.043436,0.039093,-0.013564,0.048744,-0.172154,0.071909,0.124039,prophage_806
308,GCF_003861575.1__phage3,GCF_003861575.1__phage3__452,KL25,n16531653,minibatch__946,minibatch,MVENDTSSVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNK...,SVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNKTTLTIDV...,0.021638,0.017694,...,-0.025857,0.033861,0.014796,0.026119,-0.022077,0.088668,-0.115063,0.022781,0.078048,prophage_926
310,GCF_900501985.1__phage18,GCF_900501985.1__phage18__125,KL25,n1650,minibatch__946,minibatch,MVENDTSSVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNK...,SVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNKTTLTIDV...,0.021638,0.017694,...,-0.025857,0.033861,0.014796,0.026119,-0.022077,0.088668,-0.115063,0.022781,0.078048,prophage_926
311,GCF_001463555.1__phage10,GCF_001463555.1__phage10__3,KL25,n1642,minibatch__946,minibatch,MVENDTSSVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNK...,SVEYQLSTSTGPFSIPFYFIENGHIVAELYTQNGDDFNKTTLTIDV...,0.021638,0.017694,...,-0.025857,0.033861,0.014796,0.026119,-0.022077,0.088668,-0.115063,0.022781,0.078048,prophage_926
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18570,GCF_021135995.1__phage14,GCF_021135995.1__phage14__58,KL25,n52215221,anubis__145,anubis,MTANYPASILPPNATAVERAIDRASAAALARLPVYLIRWVKDPDSC...,VGAFDDLPNIQKCTSIFRGCSSLTELPEGLFARFTGATDFSAAFYG...,0.003956,-0.036481,...,0.048554,0.023865,0.013784,0.008874,-0.075574,0.094992,-0.049823,0.057132,0.015287,prophage_1832
18956,GCF_001456135.1__phage22,GCF_001456135.1__phage22__149,KL25,n2288,anubis__1038,anubis,MDIIDKVFQQEDFSRQDLSDSRFRRCRFYQCDFSHCQLQDASFEDC...,MDIIDKVFQQEDFSRQDLSDSRFRRCRFYQCDFSHCQLQDASFEDC...,-0.023045,0.004479,...,0.012213,0.047369,0.019383,0.048012,0.008921,0.047179,-0.008280,0.001435,0.028852,prophage_4461
18961,GCF_900511055.1__phage8,GCF_900511055.1__phage8__19,KL25,n16151615,anubis__1050,anubis,MTANYPASILPPNATAVERAIDRASAAALERLPVYLIRWVKDPDSC...,LMAIRPGAFDDLPNVNNCKNIFTNCSSLAGIPASLFSRMKIEDFSD...,-0.006649,-0.064026,...,0.065778,0.032808,-0.003482,0.002955,-0.080683,0.106188,-0.045880,0.030399,0.008590,prophage_59
20332,GCF_021135995.1__phage4,GCF_021135995.1__phage4__16,KL25,n52225222,anubis_return__374,anubis_return,MNALNHNALTCSAVPIPPWERSLQTVEAQPYFNVSQASLVLEGIVF...,MNALNHNALTCSAVPIPPWERSLQTVEAQPYFNVSQASLVLEGIVF...,0.010255,-0.019249,...,0.033402,0.006714,-0.010221,-0.048146,0.033665,0.106182,-0.113253,-0.126760,0.098701,prophage_2032


> On the server: 

In [None]:
duplicate_prophage = []
dico_kltype_duplica = {}
for kltype in DF_info_lvl_0_final["KL_type_LCA"].unique():
    df_kl = DF_info_lvl_0_final[DF_info_lvl_0_final["KL_type_LCA"] == kltype][["Phage", "Protein_name", "KL_type_LCA", "Infected_ancestor", "index", "seq", "domain_seq"]]
    prophages_tmp_list = df_kl["Phage"].unique().tolist()
    set_sets_depo = []
    duplicated = {}  
    for prophage_tmp in prophages_tmp_list: 
        set_depo = frozenset(df_kl[df_kl["Phage"] == prophage_tmp]["domain_seq"].values)
        for past_set in set_sets_depo:
            if past_set == set_depo:
                duplicated[past_set] = duplicated.get(past_set, 0) + 1
                duplicate_prophage.append(prophage_tmp)
                break
        else:
            set_sets_depo.append(set_depo)
            duplicated[set_depo] = 1
    dico_kltype_duplica[kltype] = duplicated
    

DF_info_lvl_0_final_ultrafiltered = DF_info_lvl_0_final[~DF_info_lvl_0_final["Phage"].isin(duplicate_prophage)]
    


***
### Next :

In [32]:
dico_kltype_duplica["KL64"].values()

dict_values([5, 1, 2, 1, 2, 7, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 9, 6, 1, 84, 3, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 85, 8, 38, 1, 1, 1, 83, 2, 1, 1, 13, 5, 29, 1, 1, 2, 2, 9, 1, 8, 26, 3, 2, 1, 124, 1, 2, 1, 1, 1, 9, 6, 12, 9, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 3, 5, 1, 14, 1, 1, 1, 1, 7, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 81, 8, 3, 3, 1, 11, 2, 1, 1, 1])

In [26]:
test_list = [set([1,2]), set([1,2,3]), set([2,1])]
test_list[0] == test_list[2]
set(test_list)

TypeError: unhashable type: 'set'

In [None]:
def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 :
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage)
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]
DF_info_lvl_0_final = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["KL_type_LCA"].str.contains("\\|")]

DF_info_lvl_0 = DF_info_lvl_0_final.copy()



# Log file :
path_ensemble = f"{path_work}/train_nn/ensemble_0702"

df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]

# *****************************************************************************
# Make graphs :
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
#graph_dico = {kltype : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated,DF_info_lvl_0, kltype, 5, 1, 0, 0)
#             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}




# *****************************************************************************
def train_graph(KL_type) :
    for seed in range(1,6) :
        with open(f"{path_work}/train_nn/ensemble_0702_log_files/{KL_type}__{seed}__node_classification.0702.log" , "w") as log_outfile :
            n_prophage = dico_prophage_count[KL_type]
            graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = seed)
            if n_prophage <= 125 : 
                model = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280, heads = 1)
                n = "small"
            else : 
                model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
                n = "big"
            model(graph_data_kltype)
            optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001 , weight_decay= 0.000001)
            scheduler = ReduceLROnPlateau(optimizer, 'min')
            criterion = torch.nn.BCEWithLogitsLoss()
            early_stopping = TropiGAT_models.EarlyStopping(patience=60, verbose=True, path=f"{path_ensemble}/{KL_type}__{seed}.TropiGATv2.0702.pt", metric='MCC')
            try : 
                for epoch in range(200):
                    train_loss = TropiGAT_models.train(model, graph_data_kltype, optimizer,criterion)
                    if epoch % 5 == 0:
                        # Get all metrics
                        test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_kltype,criterion, graph_data_kltype["B1"].test_mask)
                        info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}\tAccuracy: {metrics[4]}\n'
                        info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss},F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
                        log_outfile.write(info_training_concise)
                        scheduler.step(test_loss)
                        early_stopping(metrics[3], model, epoch)
                        if early_stopping.early_stop:
                            log_outfile.write(f"Early stopping at epoch = {epoch}\n")
                            break
                else :
                    torch.save(model, f"{path_ensemble}/{KL_type}__{seed}.TropiGATv2.0702.pt")
                # The final eval :
                if n == "small" : 
                    model_final = TropiGAT_models.TropiGAT_small_module(hidden_channels = 1280, heads = 1)
                else :
                    model_final = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280, heads = 1)
                model_final.load_state_dict(torch.load(f"{path_ensemble}/{KL_type}__{seed}.TropiGATv2.0702.pt"))
                eval_loss, metrics = TropiGAT_models.evaluate(model_final, graph_data_kltype, criterion,graph_data_kltype["B1"].eval_mask)
                with open(f"{path_ensemble}/Metric_Report.0702.tsv", "a+") as metric_outfile :
                    metric_outfile.write(f"{KL_type}__{seed}\t{n_prophage}\t{metrics[0]}\t{metrics[1]}\t{metrics[2]}\t{metrics[3]}\t{metrics[4]}\t{metrics[5]}\n")
                info_eval = f'Epoch: {epoch}, F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]},Accuracy: {metrics[4]}, AUC: {metrics[5]}'
                log_outfile.write(f"Final evaluation ...\n{info_eval}")
            except Exception as e :
                log_outfile.write(f"***Issue here : {e}")
                with open(f"{path_ensemble}/Metric_Report.0702.tsv", "a+") as metric_outfile :
                    n_prophage = dico_prophage_count[KL_type]
                    metric_outfile.write(f"{KL_type}__{seed}\t{n_prophage}\t***Issue***\n")

if __name__ == '__main__':
    with ThreadPool(5) as p:
        p.map(train_graph, KLtypes)


> Move the files back 

In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/ensemble_2812 \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/ensemble_2812_log_files \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023

In [None]:

import os
import pandas as pd
from sklearn.model_selection import StratifiedKFold

def train_graph(KL_type):
    n_splits = 3
    n_prophage = dico_prophage_count[KL_type]
    graph_data_kltype = graph_dico[KL_type]

    # Create a directory to store the models for each cross-validation step
    model_dir = f"{path_ensemble}/{KL_type}/models"
    os.makedirs(model_dir, exist_ok=True)

    # Split the data into n folds using StratifiedKFold
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_idx = 1
    for train_indices, val_indices in skf.split(graph_data_kltype["B1"].x, graph_data_kltype["B1"].y):
        print(f"Fold {fold_idx}:")
        print("Train indices:", train_indices)
        print("Validation indices:", val_indices)
        
        # Split the graph data into train and validation sets
        graph_data_train = graph_data_kltype[train_indices]
        graph_data_val = graph_data_kltype[val_indices]
        
        # Train the model for this fold
        model = TropiGAT_models.TropiGAT_small_module(hidden_channels=1280, heads=1)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.000001)
        scheduler = ReduceLROnPlateau(optimizer, 'min')
        criterion = torch.nn.BCEWithLogitsLoss()
        early_stopping = TropiGAT_models.EarlyStopping(patience=60, verbose=True, path=f"{model_dir}/fold{fold_idx}.pt", metric='MCC')
        train_loss = TropiGAT_models.train(model, graph_data_train, optimizer, criterion)
        
        # Evaluate the model on the validation set and save the metrics to a file
        test_loss, metrics = TropiGAT_models.evaluate(model, graph_data_val, criterion, graph_data_val["B1"].eval_mask)
        metrics_df = pd.DataFrame(metrics, index=[fold_idx])
        if os.path.exists(f"{model_dir}/metrics.csv"):
            metrics_df.to_csv(f"{model_dir}/metrics.csv", mode='a', header=False, index=True)
        else:
            metrics_df.to_csv(f"{model_dir}/metrics.csv", index=True)
        
        # Save the trained model for this fold
        torch.save(model, f"{model_dir}/fold{fold_idx}.pt")
        
        fold_idx += 1