In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

from TropiGAT_functions import *
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_1908"


In [2]:
# *****************************************************************************
#logging.basicConfig(filename = f"{path_work}/train_nn/GATv2Conv.1608.log",format='%(asctime)s | %(levelname)s: %(message)s', level=logging.NOTSET, filemode='w')

class GNN(torch.nn.Module):
    def __init__(self, edge_type , conv, hidden_channels, heads, dropout): 
        super().__init__()
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
    def forward(self, x_dict, edge_index_dict):
        x = self.hetero_conv(x_dict, edge_index_dict)
        return x

# Classifier, Binary :
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        self.lin1 = torch.nn.Linear(heads*hidden_channels + 127, 512)
        self.lin2 = torch.nn.Linear(512, 1)

    def forward(self, x_dict_A , x_dict_B1, graph_data):
        edge_type = ("B1", "infects", "A")
        edge_feat_A = x_dict_A["A"][graph_data[edge_type].edge_label_index[1]]
        edge_feat_B1 = x_dict_B1["B1"][graph_data[edge_type].edge_label_index[0]]
        features_phage = torch.cat((edge_feat_A ,edge_feat_B1), dim=-1)
        x = self.lin1(features_phage).relu()
        x = self.lin2(x)
        return x.view(-1)

class Model(torch.nn.Module):
    def __init__(self, conv, hidden_channels, heads, dropout):
        super().__init__()
        self.single_layer_model = GNN(("B2", "expressed", "B1") ,conv, hidden_channels,heads,dropout)
        self.EdgeDecoder = EdgeDecoder(hidden_channels,heads)

    def forward(self, graph_data):
        b1_nodes = self.single_layer_model(graph_data.x_dict , graph_data.edge_index_dict)
        a_nodes =  graph_data.x_dict
        out = self.EdgeDecoder(a_nodes ,b1_nodes , graph_data)
        return out


> pre-processing

In [3]:
import pandas as pd 
import os 

path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings.set_index([0], inplace = True)

> run the predictions 

In [4]:
# Run the predictions : 
graph_test = graph_single_Dpo_pred(dpo_embeddings)
r_pred = run_predictions(graph_test,dpo_embeddings)
#formatted_pred = format_predictions(r_pred)


NameError: name 'dico_kltype' is not defined

In [None]:
graph_test

In [None]:
# Print the predictions : 

In [None]:

round_prediction = {}

for ratio in [1,2,3,4,6,7] : 
    clean_results = {} 
    model = Dpo_classifier_models[f"model_ratio_{ratio}"]
    predictions, probabilities = make_predictions(model, pred_data_single)
    ids = get_nodes_id_single(pred_data_single[("B1", "infects", "A")].edge_label_index)
    results = tuple(zip(ids,predictions.numpy(),probabilities.numpy()))
    positive_results = [pred for pred in results if int(pred[1]) == 1]
    for pos_res in positive_results : 
        prot = pos_res[0][0]
        kltype = pos_res[0][1]
        score = pos_res[2]
        a = {}
        a[kltype] = score
        if score > 0.5 : 
            if prot not in clean_results : 
                clean_results[prot] = a
            else :
                clean_results[prot].update(a)
    for prot in clean_results :
        if prot not in round_prediction : 
            round_prediction[prot] = clean_results[prot]
        else :
            for kltype in clean_results[prot] :
                if kltype not in round_prediction[prot] : 
                    round_prediction[prot][kltype] = clean_results[prot][kltype]
                else :
                    round_prediction[prot][kltype] = round_prediction[prot][kltype] + clean_results[prot][kltype]
                    
    #json.dump(clean_results , outfile, cls=CustomEncoder)



In [None]:
import pprint
pp = pprint.PrettyPrinter(width = 250, sort_dicts = True, compact = True)
pp.pprint(formatted_pred)

In [None]:
final_results = {}

for protein,hits in round_prediction.items() : 
    phage = protein.split("__")[0]
    if phage not in final_results : 
        tmp_hits = {}
        for kltype in hits : 
            if kltype in tmp_hits and hits[kltype] > tmp_hits[kltype]:
                tmp_hits[kltype] = hits[kltype]
            elif kltype in tmp_hits and hits[kltype] < tmp_hits[kltype]:
                pass
            elif kltype not in tmp_hits : 
                tmp_hits[kltype] = hits[kltype]
        final_results[phage] = tmp_hits
    else :
        for kltype in hits : 
            if kltype in final_results[phage] and hits[kltype] > final_results[phage][kltype]:
                final_results[phage][kltype] = hits[kltype]
            elif kltype in final_results[phage] and hits[kltype] < final_results[phage][kltype]:
                pass
            elif kltype not in final_results[phage] : 
                final_results[phage][kltype] = hits[kltype]

import pprint
pp = pprint.PrettyPrinter(width = 150, sort_dicts = True, compact = True)
pp.pprint(final_results)

In [None]:
import pprint
pp = pprint.PrettyPrinter(width = 150, sort_dicts = True, compact = True)
pp.pprint(output_dict)

