In [32]:
# Torch geometric modules
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

# Torch modules
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

# SKlearn modules
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
import json

# TropiGAT modules
import TropiGAT_graph
import TropiGAT_models

warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean/reviewed_models/best_models"
DF_info = pd.read_csv(f"/media/concha-eloko/Linux/PPT_clean/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)

DF_info_lvl_0 = DF_info.copy()
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))

# Load best parameters : 
with open(f"/media/concha-eloko/Linux/PPT_clean/trainer_best_parameters/DAG_models_best_para.json", "r") as f:
    best_parameters = json.load(f)
# *****************************************************************************
# Class : 

class TropiGAT_small_module_attention(torch.nn.Module):
    def __init__(self,hidden_channels, heads, edge_type = ("B2", "expressed", "B1") , dropout = 0.2, conv = GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True, return_attention_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 480),
                                           nn.BatchNorm1d(480),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(480 , 1))
        
    def forward(self, graph_data):
        x_B1_dict, weights  = self.conv((graph_data.x_dict["B2"], graph_data.x_dict["B1"]), graph_data.edge_index_dict[("B2", "expressed", "B1")], return_attention_weights=True)
        x = self.linear_layers(x_B1_dict)
        return x.view(-1), weights 


# Functions :
def make_ensemble_TropiGAT_attention_review(path_ensemble, dico_best_para, big_heads = True , UF = False) : 
    """
    This function builds a dictionary with all the models that are part of the TropiGAT predictor
    Input : Path of the models
    Output : Dictionary
    # Make a json file with the versions of the GNN corresponding to each KL types
    # Load it
    # Create the correct model instance  
    """
    errors = []
    dico_ensemble = {}

    if not UF:
        local_dico = dico_best_para["TropiGAT"]
    else:
        local_dico = dico_best_para["TropiGAT_uf"]

    for GNN_model in os.listdir(path_ensemble):
        if GNN_model.endswith(".pt"):
            KL_type = GNN_model.split(".")[0]
            try:
                if big_heads == False :
                    att_heads = 1
                else :
                    att_heads = local_dico[KL_type.split("__")[0]]["att_heads"]
                model = TropiGAT_small_module_attention(
                    hidden_channels=1280,
                    heads=att_heads,
                    dropout=0
                )
                model.load_state_dict(torch.load(f"{path_ensemble}/{GNN_model}"))
                dico_ensemble[KL_type] = model
            except Exception as e:
                a = (KL_type, e)
                errors.append(a)

    return dico_ensemble, errors


@torch.no_grad()
def make_predictions(model, data):
	model.eval() 
	output, weigths = model(data)
	probabilities = torch.sigmoid(output)
	predictions = probabilities.round() 
	return predictions, probabilities , weigths


def run_prediction_attentive(dico_graph, dico_ensemble, KL_type) :
    KL_index = KL_type.split("_")[0]
    dico_predictions = {}
    query_graph = dico_graph[KL_index]
    model = dico_ensemble[KL_type]
    prediction, probabilities, weights = make_predictions(model, query_graph)
    dico_predictions[KL_index] = {"probabilitites" : probabilities, "weights" : weights}
    return dico_predictions

# *****************************************************************************
# Load the Models :
path_ensemble = f"{path_work}/best_models_TropiGAT_UF"


dico_models, errors = make_ensemble_TropiGAT_attention_review(path_ensemble, best_parameters, UF = True)
# *****************************************************************************
# Make graphs :
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
graph_dico = {kltype : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated, DF_info_lvl_0, kltype, 0, 1, 0, 0)
             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}

attention_data = {}
for kltype in dico_models : 
    out_dico = run_prediction_attentive(graph_dico , dico_models, kltype)
    attention_data.update(out_dico)
    

8871it [00:13, 640.35it/s]


In [30]:
for kltype in attention_data :
    att_weights = []
    for tensor in attention_data[kltype]["weights"][1] :
        mean_att = tensor.mean().item()
        att_weights.append(mean_att)
    attention_data[kltype]["probabilitites"] = attention_data[kltype]["probabilitites"].tolist()
    attention_data[kltype]["weights"] = att_weights



In [31]:
with open(f"{path_work}/attention_weights_dico.review.json", "w") as outfile :
    json.dump(attention_data , outfile)


In [44]:
def get_positive_instances(graph_dico, kltype) :
    positive_instances = []
    tensor_y = graph_dico[kltype]["B1"]["y"].tolist()
    for index, value in enumerate(tensor_y) :
        if value == 1 :
            positive_instances.append(index)
    return positive_instances
    

In [58]:
attention_data_raw = {}

for kltype in tqdm(attention_data) : 
    prophage_indices = get_positive_instances(graph_dico, kltype)
    tmp_dico = {}
    for _,prophage_index in enumerate(prophage_indices) :
        # If prediction is positive : 
        if attention_data[kltype]["probabilitites"][prophage_index] > 0.5 :
            prob = attention_data[kltype]["probabilitites"][prophage_index]
            real_prophage_name = DF_info["Phage"].unique().tolist()[prophage_index] 
            tmp_dpos = []
           # Look for the edges involving the prophage :
            for index_edge, prophage_edge in enumerate(attention_data[kltype]["weights"][0][1]) :
                if prophage_edge == prophage_index :
                    # Check the value of the attention coeff on the edge involving the prophage :
                    att_coeff = attention_data[kltype]["weights"][1][index_edge].float()
                    #if att_coeff > 0.5 : 
                    # Get the seq of the depo : 
                    depo_index = attention_data[kltype]["weights"][0][0][index_edge]
                    real_depo_index = DF_info["index"].unique().tolist()[depo_index] 
                    seq = DF_info[DF_info["index"] == real_depo_index]["seq"].values[0]
                    domain_seq = DF_info[DF_info["index"] == real_depo_index]["domain_seq"].values[0]
                    # Pack the data :
                    a = (real_depo_index,seq,att_coeff, prob,domain_seq)
                    tmp_dpos.append(a)
            tmp_dico[real_prophage_name] = tmp_dpos
    attention_data_raw[kltype] = tmp_dico

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 85/85 [04:25<00:00,  3.12s/it]


In [79]:
with open(f"{path_work}/attention_weights_dpos.review.raw.tsv", "w") as outfile:
    outfile.write(f"KL_type\tPhage\tdpo_index\tattention_coefficient\tprobability\tdomain seq\tseq\n")
    for kltype in attention_data_raw :
        for prophage in attention_data_raw[kltype] : 
            for dpo in attention_data_raw[kltype][prophage] :
                outfile.write(f"{kltype}\t{prophage}\t{dpo[0]}\t{float(dpo[2].mean().item())}\t{dpo[3]}\t{dpo[4]}\t{dpo[1]}\n")


In [81]:
df_coeff = pd.read_csv(f"{path_work}/attention_weights_dpos.review.raw.tsv" , sep = "\t", header= 0)
df_coeff

Unnamed: 0,KL_type,Phage,dpo_index,attention_coefficient,probability,domain seq,seq
0,KL4,GCF_014218685.1__phage13,ppt__780,1.000000,0.986413,MSAKFSPSLMCMDLTQFKEQITAMNKKADFYHVDIMDGNYVRNITL...,MSAKFSPSLMCMDLTQFKEQITAMNKKADFYHVDIMDGNYVRNITL...
1,KL4,GCF_900451495.1__phage28,ppt__2295,0.566855,0.988421,TVNDLANSVYQTSVSRITDHGIGFANWPQGKAVTFNNNLYVGYNYA...,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTLDTSESIVT...
2,KL4,GCF_900451495.1__phage28,anubis__118,0.433145,0.988421,GKLFYMEQKAVDSVGRWETDKDIGIGDECRYQENFYRCVDGGSNGT...,MAYSLVQPSLAGGEISPSLYGRIDLEKYQTSLRRCRNFIVRQSGGI...
3,KL4,GCF_001598715.1__phage7,anubis__118,1.000000,0.994869,GKLFYMEQKAVDSVGRWETDKDIGIGDECRYQENFYRCVDGGSNGT...,MAYSLVQPSLAGGEISPSLYGRIDLEKYQTSLRRCRNFIVRQSGGI...
4,KL4,GCF_900451495.1__phage29,ppt__2301,0.389807,0.991711,YQTSVSRITDHGIGFANWPQGKAVKFNNNLYVGYNYATAHGSVVQD...,MSVPNQIPYNIYTANGLTTVFTYQFYIISASDLEVSINGSVVASGY...
...,...,...,...,...,...,...,...
9369,KL102,GCF_900500995.1__phage8,anubis__1637,1.000000,0.999374,LGLDYPNEYYLQDFSGDTDIEWIQNAMDWVHDAGGGWLILSSDYVK...,MAEVPLPTPTDNAVPSTDIRDAVYAGAMLDKVVTSTDLKYTDRLGV...
9370,KL102,GCF_001583485.1__phage2,anubis_return__379,1.000000,0.645472,LLLPAAALAARERVEVLQNQLEHPWALAFLPDDRGILMTLRGGELR...,MRQTITLMIALTALLLPAAALAARERVEVLQNQLEHPWALAFLPDD...
9371,KL102,GCF_900501135.1__phage2,anubis_return__506,1.000000,0.881703,ANEMFDFHSDPIRVVFAGVTFYDCAHLIMVNGRSLLSVRPPQEVSS...,MANTNDHGLPRTIPEGVKREIRQRCGFGCVICGLGFYDYEHFAPDF...
9372,KL102,GCF_019336245.1__phage12,anubis_return__506,1.000000,0.881703,ANEMFDFHSDPIRVVFAGVTFYDCAHLIMVNGRSLLSVRPPQEVSS...,MANTNDHGLPRTIPEGVKREIRQRCGFGCVICGLGFYDYEHFAPDF...


In [82]:
df_coeff[df_coeff["Phage"] == "GCF_900451495.1__phage29"]

Unnamed: 0,KL_type,Phage,dpo_index,attention_coefficient,probability,domain seq,seq
4,KL4,GCF_900451495.1__phage29,ppt__2301,0.389807,0.991711,YQTSVSRITDHGIGFANWPQGKAVKFNNNLYVGYNYATAHGSVVQD...,MSVPNQIPYNIYTANGLTTVFTYQFYIISASDLEVSINGSVVASGY...
5,KL4,GCF_900451495.1__phage29,anubis__448,0.610193,0.991711,TQAASPGAWTREDSVWTDEFGYPGAVTLYQQRLVLAGSPQYPQTIW...,MRANLIKTNFTAGEISPRLMGRVDIDRYANGAKTLENSMVVVQGGV...
