In [2]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

#from TropiGAT_functions import *
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_1908"

In [None]:
path_project = "/media/concha-eloko/Linux/77_strains_phage_project"
path_Dpo_domain_org = "/media/concha-eloko/Linux/depolymerase_building/clean_77_phages_depo"

dpo_embeddings = pd.read_csv(f"{path_project}/rbp_work/Dpo_domains_77.esm2.embedding.csv", sep = "," , header = None)
dpo_embeddings = dpo_embeddings.drop([1281] , axis = 1)
dpo_embeddings.set_index([0], inplace = True)

In [None]:
def make_query_graph(embeddings) :
    """
    This function builds the query graph for the ensemble model.
    Input : A list of the ESM2 embeddings of the depolymerase 
    Output : The query graph
    """
    query_graph = HeteroData()
    query_graph["B1"].x = torch.empty((n_dpos, 0))
    query_graph["B2"].x = torch.tensor(embeddings , dtype=torch.float)
    edge_index_B2_B1 = torch.tensor([0,0] , dtype=torch.long)
    query_graph['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
    
    return query_graph
    
def make_ensemble_TropiGAT(path_ensemble) : 
    """
    This function builds a dictionary with all the models that are part of the TropiGAT predictor
    Input : Path of the models
    Output : Dictionary

    # Make a json file with the versions of the GNN corresponding to each KL types
    # Load it
    # Create the correct model instance (TropiGAT_small_module or TropiGAT_big_module)
    """
    dico_ensemble = {}
    for GNN_model in os.listdir(path_ensemble) :
        if GNN_model[-2:] == "pt" : 
            KL_type = GNN_model.split(".")[0]
            model = TropiGAT_models.TropiGAT_big_module(hidden_channels = 1280 , heads = 1)
            model.load_state_dict(torch.load(f"{path_ensemble}/{GNN_model}"))
            dico_ensemble[KL_type] = model
        
    return dico_ensemble

@torch.no_grad()
def make_predictions(model, data):
    """
    This generic function run the prediction of a binary model
    Inputs : The model, the query data
    Ouput : the prediction and associated probability
    """
    model.eval() 
    output = model(data)
    probabilities = torch.sigmoid(output)  
    predictions = probabilities.round() 
    
    return predictions, probabilities
        
def run_prediction(query_graph, dico_ensemble) :
    dico_predictions = {}
    for KL_type in dico_ensemble :
        model = dico_ensemble[KL_type]
        prediction, probabilities = make_predictions(model, graph)
        if int(prediction) == 1 :
            dico_predictions[KL_type] = probabilities
        else :
            continue

    return dico_predictions
        

In [7]:
path_work = "/media/concha-eloko/Linux/PPT_clean"
graph_data = torch.load(f'{path_work}/Tropi_graph.lvl_1.1909.pt')
graph_data

HeteroData(
  [1mA[0m={ x=[4585, 127] },
  [1mB1[0m={ x=[7640, 0] },
  [1mB2[0m={ x=[3449, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 7707],
    y=[7707]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 9626],
    y=[9626]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 7707],
    y=[7707]
  }
)

In [9]:
graph_data[("B2", "expressed", "B1")].edge_index

tensor([[   0,    1,    2,  ..., 3446, 3447, 3448],
        [   0,    0,    1,  ..., 7637, 7638, 7639]])