In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/ensemble_1908 \
/media/concha-eloko/Linux/PPT_clean/ficheros_28032023


In [5]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_1908"
#graph_data = torch.load(f'{path_work}/graph_file.2607.OHE.pt')

In [2]:
# *****************************************************************************
#logging.basicConfig(filename = f"{path_work}/train_nn/GATv2Conv.1608.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')

class GNN(torch.nn.Module):
    def __init__(self, edge_type , conv, hidden_channels, heads, dropout): 
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)
        return x

# Classifier, Binary :
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.lin1 = torch.nn.Linear(heads*hidden_channels + 127, 512)
        self.lin2 = torch.nn.Linear(512, 1)

    def forward(self, x_dict_A , x_dict_B1, graph_data):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][graph_data[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][graph_data[edge_type].edge_label_index[0]]
        features_phage = torch.cat((edge_feat_A ,edge_feat_B1), dim=-1)
        x = self.lin1(features_phage).relu()
        x = self.lin2(x)
        return x.view(-1)

class Model(torch.nn.Module):
    def __init__(self, conv, hidden_channels, heads, dropout):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") ,conv, hidden_channels,heads,dropout)
        self.EdgeDecoder = EdgeDecoder(hidden_channels,heads)

    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        a_nodes =  graph_data.x_dict
        out = self.EdgeDecoder(a_nodes ,b1_nodes , graph_data)
        return out


In [7]:
# *****************************************************************************
# Relevant functions :
@torch.no_grad()
def make_predictions(model, data):
    model.eval() 
    output = model(data)
    probabilities = torch.sigmoid(output)  # Convert output to probabilities
    predictions = probabilities.round()  # Convert probabilities to class labels
    return predictions, probabilities

# models object : 
Dpo_classifier_models = {}
hidden_channels = 1000
conv = GATv2Conv
heads = 1
dropout = 0.1

ensemble = {i : f"model_ratio_{i}" for i in [1,2,3,4]}
for file in os.listdir(path_ensemble) : 
    if file[-2:] == "pt" and int(file.split(".")[3].split("Neg")[0]) in ensemble :
        ratio = int(file.split(".")[3].split("Neg")[0])
        model = Model(conv, hidden_channels, heads, dropout)
        model.load_state_dict(torch.load(f"{path_ensemble}/{file}"))
        Dpo_classifier_models[ensemble[ratio]] = model



> pre-processing

In [10]:
# *****************************************************************************
path_work = "/media/concha-eloko/Linux/PPT_clean"

    # Open the DF
DF_info = pd.read_csv(f"{path_work}/DF_Dpo.final.2705.tsv", sep = "\t" ,  header = 0 )
    # Open the embeddings
DF_embeddings = pd.read_csv(f"{path_work}/Dpo.2705.embeddings.ultimate.csv", sep = ",", header= None )
DF_embeddings.rename(columns={0: 'index'}, inplace=True)

    # Filter the DF :
DF_info_filtered = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_ToReLabel = DF_info[DF_info["KL_type_LCA"].str.contains("\\|")]
all_data = pd.merge(DF_info_filtered , DF_embeddings , on = "index")

# Mind the over representation of outbreaks :
all_data = all_data.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)

df_kltype = all_data[all_data["KL_type_LCA"] == "KL27"]
df_kltype = df_kltype.drop_duplicates(subset = ["Phage"] , keep = "first").reset_index(drop=True)

indexation_nodes_A = all_data["Infected_ancestor"].unique().tolist()  
indexation_nodes_B1 = all_data["Phage"].unique().tolist() + [f"Dpo_to_predict_{n}" for n in DF_embeddings["index"].unique().tolist()]
indexation_nodes_B2 = DF_embeddings["index"].unique().tolist() 

ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}

ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}

ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}

instances_bacteria = all_data.drop_duplicates(subset = ["KL_type_LCA"] , keep = "first").reset_index(drop=True)
index_interest = []
for ancestor in instances_bacteria.Infected_ancestor :
    index = ID_nodes_A[ancestor]
    index_interest.append(index)


In [11]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/graph_file.2607.OHE.pt')
graph_data

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 7731],
    y=[7731]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 7731],
    y=[7731]
  }
)

In [12]:
def get_nodes_id(B1A_index_file) :
    B1A_index_file = B1A_index_file.numpy()
    B1A_index_file = tuple(zip(B1A_index_file[0],B1A_index_file[1]))
    id_file = [(ID_nodes_B1_r[tup[0]] , ID_nodes_A_r[tup[1]]) for tup in B1A_index_file]
    return id_file

In [13]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings.set_index([0], inplace = True)

# Adding the phage column : 
dpo_embeddings["phage"] = dpo_embeddings.index.map(lambda x: x.split("__")[0])

In [14]:
# Getting the nodes A features : 
instances_bacteria = all_data.drop_duplicates(subset = ["KL_type_LCA"] , keep = "first").reset_index(drop=True)
index_interest = []
for ancestor in instances_bacteria.Infected_ancestor :
    index = ID_nodes_A[ancestor]
    index_interest.append(index)
    
tensor_interest = [graph_data.x_dict["A"][i] for _,i in enumerate(index_interest)]
stacked_tensor = torch.stack(tensor_interest)
dico_kltype = {tuple(graph_data.x_dict["A"][i].numpy()) : all_data[all_data["Infected_ancestor"] == ID_nodes_A_r[i]]["KL_type_LCA"].values[0]  for _,i in enumerate(index_interest)}

def graph_single_Dpo_pred(df_embeddings) : 
    pred_data_single = HeteroData()
    # Defining the nodes :
    l_dpos = len(df_embeddings)
    pred_data_single["A"].x = stacked_tensor
    pred_data_single["B1"].x = torch.empty((l_dpos, 0))
    pred_data_single["B2"].x = torch.tensor(df_embeddings.iloc[:, :1280].values , dtype=torch.float)
    # Defining the edge_file :
    edge_index_B2_B1 = torch.tensor([[i , i] for i in range(l_dpos)] , dtype=torch.long)
    pred_data_single['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    edge_index_B1_A = torch.tensor([[i,j] for i in range(l_dpos) for j in range(len(pred_data_single["A"].x))] , dtype=torch.long)
    pred_data_single['B1', 'infects', 'A'].edge_label_index = edge_index_B1_A.t().contiguous()
    return pred_data_single

In [16]:
# Building the graph prediction : 
pred_data_single = HeteroData()

# Defining the nodes :
l_dpos = len(dpo_embeddings)
pred_data_single["A"].x = stacked_tensor
pred_data_single["B1"].x = torch.empty((l_dpos, 0))
pred_data_single["B2"].x = torch.tensor(dpo_embeddings.iloc[:, :1280].values , dtype=torch.float)

# Defining the edge_file :
edge_index_B2_B1 = torch.tensor([[i , i] for i in range(l_dpos)] , dtype=torch.long)
pred_data_single['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
edge_index_B1_A = torch.tensor([[i,j] for i in range(l_dpos) for j in range(len(pred_data_single["A"].x))] , dtype=torch.long)
pred_data_single['B1', 'infects', 'A'].edge_label_index = edge_index_B1_A.t().contiguous()

import json

class CustomEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.float32):
            return float(obj)
        return json.JSONEncoder.default(self, obj)
        
def get_nodes_id_single(B1A_index_file) :
    B1A_index_file = B1A_index_file.numpy()
    B1A_index_file = tuple(zip(B1A_index_file[0],B1A_index_file[1]))
    id_file = [(dpo_embeddings.index[tup[0]] , dico_kltype[tuple(tensor_interest[tup[1]].numpy())]) for tup in B1A_index_file]
    return id_file

> run the predictions 

In [None]:

round_prediction = {}

with open(f"{path_work}/77_Dpos.PPT_pred.1408.json", "w") as outfile :
    clean_results = {}
    predictions, probabilities = make_predictions(model, pred_data_single)
    ids = get_nodes_id_single(pred_data_single[("B1", "infects", "A")].edge_label_index)
    results = tuple(zip(ids,predictions.numpy(),probabilities.numpy()))
    positive_results = [pred for pred in results if int(pred[1]) == 1]
    for pos_res in positive_results : 
        prot = pos_res[0][0]
        kltype = pos_res[0][1]
        score = pos_res[2]
        a = {}
        a[kltype] = score
        if score > 0.0 : 
            #a = {"KLtype" : kltype , "Score" : score}
            if prot not in clean_results : 
                clean_results[prot] = a
            else :
                clean_results[prot].update(a)
    json.dump(clean_results , outfile, cls=CustomEncoder)

import pprint
pp = pprint.PrettyPrinter(width = 250, sort_dicts = True, compact = True)
pp.pprint(clean_results)

In [None]:
final_results = {}

for protein,hits in clean_results.items() : 
    phage = protein.split("__")[0]
    if phage not in final_results : 
        tmp_hits = {}
        for kltype in hits : 
            if kltype in tmp_hits and hits[kltype] > tmp_hits[kltype]:
                tmp_hits[kltype] = hits[kltype]
            elif kltype in tmp_hits and hits[kltype] < tmp_hits[kltype]:
                pass
            elif kltype not in tmp_hits : 
                tmp_hits[kltype] = hits[kltype]
        final_results[phage] = tmp_hits
    else :
        for kltype in hits : 
            if kltype in final_results[phage] and hits[kltype] > final_results[phage][kltype]:
                final_results[phage][kltype] = hits[kltype]
            elif kltype in final_results[phage] and hits[kltype] < final_results[phage][kltype]:
                pass
            elif kltype not in final_results[phage] : 
                final_results[phage][kltype] = hits[kltype]

import pprint
pp = pprint.PrettyPrinter(width = 150, sort_dicts = True, compact = True)
pp.pprint(final_results)

In [None]:
kltype_interest = ["KL112", "KL17", "KL2", "KL24", "KL27","KL64"]

for kltype in instances_bacteria.KL_type_LCA : 
    if kltype in kltype_interest :
        print(kltype)
        for phage in final_results : 
            if kltype in final_results[phage] : 
                print(phage , final_results[phage][kltype])
        print("\n")