In [None]:
# Torch geometric modules
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

# Torch modules
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# SKlearn modules
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
DF_info = pd.read_csv(f"{path_work}/train_nn/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)

DF_info = DF_info.drop_duplicates(subset = ["Protein_name"])

df_prophages = DF_info.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_info = {row["Phage"] : {"prophage_strain" : row["prophage_id"] , "ancestor" : row["Infected_ancestor"]} for _,row in df_prophages.iterrows()}

winner_dico = {'KL1': '4',
 'KL2': '4',
 'KL3': '5',
 'KL5': '3',
 'KL7': '1',
 'KL8': '1',
 'KL9': '5',
 'KL10': '5',
 'KL12': '4',
 'KL13': '2',
 'KL14': '1',
 'KL15': '3',
 'KL16': '3',
 'KL17': '2',
 'KL18': '3',
 'KL19': '5',
 'KL21': '3',
 'KL22': '5',
 'KL23': '5',
 'KL24': '2',
 'KL25': '5',
 'KL27': '4',
 'KL28': '1',
 'KL29': '2',
 'KL30': '2',
 'KL34': '4',
 'KL36': '4',
 'KL38': '3',
 'KL39': '5',
 'KL43': '4',
 'KL45': '3',
 'KL46': '2',
 'KL47': '5',
 'KL48': '5',
 'KL51': '2',
 'KL52': '3',
 'KL53': '4',
 'KL55': '5',
 'KL57': '2',
 'KL60': '5',
 'KL62': '5',
 'KL63': '3',
 'KL64': '2',
 'KL70': '2',
 'KL74': '3',
 'KL81': '2',
 'KL102': '5',
 'KL105': '3',
 'KL106': '3',
 'KL107': '3',
 'KL108': '4',
 'KL110': '1',
 'KL111': '4',
 'KL112': '5',
 'KL114': '3',
 'KL116': '4',
 'KL118': '1',
 'KL122': '3',
 'KL123': '2',
 'KL125': '3',
 'KL127': '2',
 'KL128': '5',
 'KL136': '4',
 'KL145': '2',
 'KL149': '2',
 'KL151': '3',
 'KL169': '3'}

def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 :
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage)
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]
DF_info_lvl_0_final = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0_final.copy()


# Log file :
path_ensemble = f"{path_work}/train_nn/ensemble_0702"
path_training_data = f"{path_work}/train_nn/training_data_1302"

df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

KLtypes = [kltype for kltype in dico_prophage_count if dico_prophage_count[kltype] >= 20]

# *****************************************************************************
# Make graphs :
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)

# *****************************************************************************
def get_seeded_graph(KL_type) :
    seed = winner_dico[KL_type]
    n_prophage = dico_prophage_count[KL_type]
    graph_data_kltype = TropiGAT_graph.build_graph_masking_v2(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, KL_type, 5, 0.7, 0.2, 0.1, seed = seed)
    torch.save(graph_data_kltype, f"{path_training_data}/{KL_type}__{seed}.graph.pt")


if __name__ == '__main__':
    with ThreadPool(5) as p:
        p.map(train_graph, KLtypes)


In [None]:
#!/bin/bash
#BATCH --job-name=get_training_TropiGAT
#SBATCH --qos=short
#SBATCH --ntasks=1 
#SBATCH --cpus-per-task=6 
#SBATCH --mem=20gb 
#SBATCH --time=1-00:00:00 
#SBATCH --output=get_training_TropiGAT__%j.log 

module restore la_base
conda activate torch_geometric

python /home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/script_files/get_training_data.py

> Move the files back 

In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn/training_data_1302 \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023




In [5]:

winner = ['KL1__4',
 'KL2__4',
 'KL3__5',
 'KL5__3',
 'KL7__1',
 'KL8__1',
 'KL9__5',
 'KL10__5',
 'KL12__4',
 'KL13__2',
 'KL14__1',
 'KL15__3',
 'KL16__3',
 'KL17__2',
 'KL18__3',
 'KL19__5',
 'KL21__3',
 'KL22__5',
 'KL23__5',
 'KL24__2',
 'KL25__5',
 'KL27__4',
 'KL28__1',
 'KL29__2',
 'KL30__2',
 'KL34__4',
 'KL36__4',
 'KL38__3',
 'KL39__5',
 'KL43__4',
 'KL45__3',
 'KL46__2',
 'KL47__5',
 'KL48__5',
 'KL51__2',
 'KL52__3',
 'KL53__4',
 'KL55__5',
 'KL57__2',
 'KL60__5',
 'KL62__5',
 'KL63__3',
 'KL64__2',
 'KL70__2',
 'KL74__3',
 'KL81__2',
 'KL102__5',
 'KL105__3',
 'KL106__3',
 'KL107__3',
 'KL108__4',
 'KL110__1',
 'KL111__4',
 'KL112__5',
 'KL114__3',
 'KL116__4',
 'KL118__1',
 'KL122__3',
 'KL123__2',
 'KL125__3',
 'KL127__2',
 'KL128__5',
 'KL136__4',
 'KL145__2',
 'KL149__2',
 'KL151__3',
 'KL169__3']

def clean_print(dico) :
	""" 
	Inputs : a dico
	Outputs : pretty printed dico
	"""
	import pprint
	pp = pprint.PrettyPrinter(width = 150, sort_dicts = False, compact = True)
	out = pp.pprint(dico)
	return out 
    
winner_dico = {item.split("__")[0]: item.split("__")[1] for _,item in enumerate(winner)}
clean_print(winner_dico)

{'KL1': '4',
 'KL2': '4',
 'KL3': '5',
 'KL5': '3',
 'KL7': '1',
 'KL8': '1',
 'KL9': '5',
 'KL10': '5',
 'KL12': '4',
 'KL13': '2',
 'KL14': '1',
 'KL15': '3',
 'KL16': '3',
 'KL17': '2',
 'KL18': '3',
 'KL19': '5',
 'KL21': '3',
 'KL22': '5',
 'KL23': '5',
 'KL24': '2',
 'KL25': '5',
 'KL27': '4',
 'KL28': '1',
 'KL29': '2',
 'KL30': '2',
 'KL34': '4',
 'KL36': '4',
 'KL38': '3',
 'KL39': '5',
 'KL43': '4',
 'KL45': '3',
 'KL46': '2',
 'KL47': '5',
 'KL48': '5',
 'KL51': '2',
 'KL52': '3',
 'KL53': '4',
 'KL55': '5',
 'KL57': '2',
 'KL60': '5',
 'KL62': '5',
 'KL63': '3',
 'KL64': '2',
 'KL70': '2',
 'KL74': '3',
 'KL81': '2',
 'KL102': '5',
 'KL105': '3',
 'KL106': '3',
 'KL107': '3',
 'KL108': '4',
 'KL110': '1',
 'KL111': '4',
 'KL112': '5',
 'KL114': '3',
 'KL116': '4',
 'KL118': '1',
 'KL122': '3',
 'KL123': '2',
 'KL125': '3',
 'KL127': '2',
 'KL128': '5',
 'KL136': '4',
 'KL145': '2',
 'KL149': '2',
 'KL151': '3',
 'KL169': '3'}


In [None]:
winner_dico = {'KL1': '4',
 'KL2': '4',
 'KL3': '5',
 'KL5': '3',
 'KL7': '1',
 'KL8': '1',
 'KL9': '5',
 'KL10': '5',
 'KL12': '4',
 'KL13': '2',
 'KL14': '1',
 'KL15': '3',
 'KL16': '3',
 'KL17': '2',
 'KL18': '3',
 'KL19': '5',
 'KL21': '3',
 'KL22': '5',
 'KL23': '5',
 'KL24': '2',
 'KL25': '5',
 'KL27': '4',
 'KL28': '1',
 'KL29': '2',
 'KL30': '2',
 'KL34': '4',
 'KL36': '4',
 'KL38': '3',
 'KL39': '5',
 'KL43': '4',
 'KL45': '3',
 'KL46': '2',
 'KL47': '5',
 'KL48': '5',
 'KL51': '2',
 'KL52': '3',
 'KL53': '4',
 'KL55': '5',
 'KL57': '2',
 'KL60': '5',
 'KL62': '5',
 'KL63': '3',
 'KL64': '2',
 'KL70': '2',
 'KL74': '3',
 'KL81': '2',
 'KL102': '5',
 'KL105': '3',
 'KL106': '3',
 'KL107': '3',
 'KL108': '4',
 'KL110': '1',
 'KL111': '4',
 'KL112': '5',
 'KL114': '3',
 'KL116': '4',
 'KL118': '1',
 'KL122': '3',
 'KL123': '2',
 'KL125': '3',
 'KL127': '2',
 'KL128': '5',
 'KL136': '4',
 'KL145': '2',
 'KL149': '2',
 'KL151': '3',
 'KL169': '3'}

