# Model metrics :

> Open DF : 

In [None]:
import pandas as pd 
import os 

path_seqbased = "/media/concha-eloko/Linux/PPT_clean/RF_2912_models_info"

header_metric = ["KL_type", "Effectifs","MCC","F1","recall","Accuracy","AUC"]
seqbased_df = pd.read_csv(f"{path_seqbased}/RF_report.0.75.2912.tsv", sep = "\t", names = header_metric)

In [None]:
df_metrics_sorted = seqbased_df.sort_values(by='KL_type', key=lambda x: x.str.split("KL").str[1].astype(int))
for col in df_metrics_sorted.columns : 
    if col not in ["KL_type","n_prophages"] : 
        df_metrics_sorted[col] = round(df_metrics_sorted[col] , 4)

> Compute metrics 

In [None]:
g1_group = ["KL2", "KL17", "KL47", "KL64", "KL106", "KL107"]
g2_group = ["KL1", "KL3", "KL14", "KL15", "KL23", "KL24", "KL25", "KL27", "KL51", "KL62", "KL102"]

big_groups = g1_group + g2_group

g3_group = [kltype for kltype in df_metrics_sorted["KL_type"].tolist() if kltype not in big_groups]

In [None]:
import statistics

mcc_g1 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g1_group]
mean_g1 = statistics.mean(mcc_g1)
std_g1 = statistics.stdev(mcc_g1)


mcc_g2 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g2_group]
mean_g2 = statistics.mean(mcc_g2)
std_g2 = statistics.stdev(mcc_g2)


mcc_g3 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g3_group]
mean_g3 = statistics.mean(mcc_g3)
std_g3 = statistics.stdev(mcc_g3)

In [None]:
round(mean_g3, 4) , round(std_g3 , 4)

***
# Get the attention weights :

In [None]:
rsync -avzhe ssh \
/media/concha-eloko/Linux/PPT_clean/TropiGATv2.final_df_v2.filtered.tsv \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn \



In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import TropiGAT_functions 
import TropiGAT_graph
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_2812"

In [2]:
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info.copy()
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))


#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
#DF_info = pd.read_csv(f"{path_work}/train_nn/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)



In [46]:
positive_indices = [index for index,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows() if row["KL_type_LCA"] == "KL47"]


In [40]:
DF_info[DF_info["Phage"] == "GCF_902164905.1__phage1"]

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,1,2,...,1272,1273,1274,1275,1276,1277,1278,1279,1280,prophage_id
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,0.025276,0.053137,...,-0.011464,0.081105,0.012011,0.042917,0.009402,0.093175,-0.080562,0.000897,0.111854,prophage_11309
8786,GCF_902164905.1__phage1,GCF_902164905.1__phage1__26,KL41,GCF_902164905.1,anubis__1158,anubis,MPRNNVPLLAFNRGIISPLALARTDIERLALSAEVQTNWMPRLLGS...,TGTVRITAVNSRTSATGIVLSDLGGTSATADWYEGAFSAKNGFPGA...,0.026319,-0.003804,...,-0.04776,0.030186,0.027976,-0.026192,0.000397,0.044378,-0.103954,-0.161963,0.121896,prophage_11309


In [3]:
def make_ensemble_TropiGAT_attention(path_ensemble) : 
	"""
	This function builds a dictionary with all the models that are part of the TropiGAT predictor
	Input : Path of the models
	Output : Dictionary , attention weights
	# Make a json file with the versions of the GNN corresponding to each KL types
	# Load it
	# Create the correct model instance (TropiGAT_small_module or TropiGAT_big_module)
	"""
	errors = []
	DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
	DF_info_lvl_0 = DF_info.copy()
	df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
	dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))
	dico_ensemble = {}
	for GNN_model in os.listdir(path_ensemble) :
		if GNN_model[-2:] == "pt" :
			KL_type = GNN_model.split(".")[0]
			try :
				if dico_prophage_count[KL_type] >= 125 : 
					model = TropiGAT_models.TropiGAT_big_module_attention(hidden_channels = 1280 , heads = 1)
				else :
					model = TropiGAT_models.TropiGAT_small_module_attention(hidden_channels = 1280 , heads = 1)
				model.load_state_dict(torch.load(f"{path_ensemble}/{GNN_model}"))
				dico_ensemble[KL_type] = model
			except Exception as e :
				a = (KL_type , dico_prophage_count[KL_type], e)
				errors.append(a)
	return dico_ensemble , errors

def build_graph_masking_attention(graph_data_input, dico_prophage_kltype_associated , df_info, KL_type) : 
    # **************************************************************
    # Indexation process  
	graph_data = graph_data_input.clone()
	indexation_nodes_A = df_info["Infected_ancestor"].unique().tolist()  
	indexation_nodes_B1 = df_info["Phage"].unique().tolist()
	indexation_nodes_B2 = df_info["index"].unique().tolist() 
	ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
	ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}
	ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
	ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}
	ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
	ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}
    
	# **************************************************************
	# Make the Y file : 
	B1_labels = df_info.drop_duplicates(subset = ["Phage"], keep = "first")["KL_type_LCA"].apply(lambda x : 1 if x == KL_type else 0).to_list()
	graph_data["B1"].y = torch.tensor(B1_labels)
	# **************************************************************
	# Make mask files :
	# get the positive and negative indices lists :
	positive_indices = [index for index,label in enumerate(B1_labels) if label==1]
	negative_indices = []
	for negative_index,phage in enumerate(df_info["Phage"].unique().tolist()) :
		if KL_type not in dico_prophage_kltype_associated[ID_nodes_B1_r[negative_index]] :
			negative_indices.append(negative_index)
	return graph_data  
    
@torch.no_grad()
def make_predictions(model, data):
	model.eval() 
	output, weigths = model(data)
	probabilities = torch.sigmoid(output)
	predictions = probabilities.round() 
	return predictions, probabilities , weigths


def run_prediction_attentive(dico_graph, dico_ensemble, KL_type) :
    dico_predictions = {}
    query_graph = dico_graph[KL_type]["graph"]
    model = dico_ensemble[KL_type]
    prediction, probabilities, weights = make_predictions(model, query_graph)
    dico_predictions[KL_type] = {"probabilitites" : probabilities, "weights" : weights}
    return dico_predictions

In [5]:
dico_models, errors = TropiGAT_functions.make_ensemble_TropiGAT_attention(path_ensemble)
#dico_models, errors = make_ensemble_TropiGAT_attention(path_ensemble)


In [None]:
# *****************************************************************************
# Make graphs : 
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
graph_dico = {kltype : {"graph" : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated,DF_info_lvl_0, kltype, 0, 1, 0, 0), 
                        "positive_indices" : [index for index,row in DF_info_lvl_0.drop_duplicates(subset = ["Phage"]).iterrows() if row["KL_type_LCA"] == kltype]}
             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}

8871it [00:18, 468.73it/s]


In [36]:
graph_baseline.x_dict["B2"]

tensor([[ 0.0253,  0.0531,  0.0029,  ..., -0.0806,  0.0009,  0.1119],
        [ 0.0049,  0.0409, -0.0270,  ..., -0.1235,  0.0476,  0.0613],
        [-0.0036, -0.0330, -0.0306,  ..., -0.1246,  0.0117,  0.1479],
        ...,
        [ 0.0735,  0.0467,  0.0106,  ..., -0.0506, -0.0852, -0.0108],
        [ 0.0322,  0.0489, -0.0175,  ..., -0.0601,  0.0237,  0.0865],
        [-0.0111, -0.0053, -0.0120,  ..., -0.0991, -0.0574,  0.0914]])

In [42]:
sum(graph_dico["KL47"]["B1"].y)

tensor(551)

In [37]:

attention_data = {}
for kltype in dico_models : 
    out_dico = run_prediction_attentive(graph_dico , dico_models, kltype)
    attention_data.update(out_dico)


#with open(f"{path_work}/attention_weights_dico.json", "w") as outfile :
#    json.dump(attention_data , outfile)

In [38]:
attention_data["KL128"]

{'probabilitites': tensor([0.9852, 0.1982, 0.8031,  ..., 0.6596, 0.9903, 0.9990]),
 'weights': (tensor([[   0, 2430,    1,  ..., 4032, 4033, 4034],
          [   0,    0,    1,  ..., 8868, 8869, 8870]]),
  tensor([[0.5194],
          [0.4806],
          [0.3267],
          ...,
          [1.0000],
          [1.0000],
          [1.0000]]))}

> Write the data in Json 

In [15]:
def convert_to_serializable(data):
    if isinstance(data, torch.Tensor):
        return data.tolist()  # Convert tensors to lists
    elif isinstance(data, dict):
        return {k: convert_to_serializable(v) for k, v in data.items()}  # Recursively apply to dictionaries
    elif isinstance(data, list):
        return [convert_to_serializable(v) for v in data]  # Apply to each item in the list
    elif isinstance(data, tuple):
        return tuple(convert_to_serializable(v) for v in data)  # Convert items in tuples
    else:
        return data  # Return as is for serializable types

# Apply conversion to the entire attention_data
serializable_attention_data = convert_to_serializable(attention_data)

# Now, serializable_attention_data contains only JSON-serializable items
with open(f"{path_work}/attention_weights_dico.json", "w") as outfile:
    json.dump(serializable_attention_data, outfile)

In [20]:
serializable_attention_data["KL47"]

{'probabilitites': [0.20913609862327576,
  0.24970440566539764,
  0.10275932401418686,
  0.24970440566539764,
  0.10275932401418686,
  0.10275932401418686,
  0.24970440566539764,
  0.18879470229148865,
  0.021939411759376526,
  0.021939411759376526,
  0.021939411759376526,
  0.021939411759376526,
  0.021939411759376526,
  0.021939411759376526,
  0.07345058768987656,
  0.07345058768987656,
  0.07345058768987656,
  0.07345058768987656,
  0.024546796455979347,
  0.020584257319569588,
  0.020584257319569588,
  0.026085859164595604,
  0.026085859164595604,
  0.026085859164595604,
  0.06736677885055542,
  0.026572464033961296,
  0.026572464033961296,
  0.026572464033961296,
  0.026572464033961296,
  0.026572464033961296,
  0.7907968163490295,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,
  0.09482117742300034,


In [22]:
import statistics
kltype = "KL47"

mean_KLtype = statistics.mean(serializable_attention_data[kltype]["probabilitites"])

mean_KLtype


0.2596760921555255

In [23]:
len(serializable_attention_data[kltype]["probabilitites"])

8871