# Model metrics :

> Open DF : 

In [None]:
import pandas as pd 
import os 

path_seqbased = "/media/concha-eloko/Linux/PPT_clean/RF_2912_models_info"

header_metric = ["KL_type", "Effectifs","MCC","F1","recall","Accuracy","AUC"]
seqbased_df = pd.read_csv(f"{path_seqbased}/RF_report.0.75.2912.tsv", sep = "\t", names = header_metric)

In [None]:
df_metrics_sorted = seqbased_df.sort_values(by='KL_type', key=lambda x: x.str.split("KL").str[1].astype(int))
for col in df_metrics_sorted.columns : 
    if col not in ["KL_type","n_prophages"] : 
        df_metrics_sorted[col] = round(df_metrics_sorted[col] , 4)

> Compute metrics 

In [None]:
g1_group = ["KL2", "KL17", "KL47", "KL64", "KL106", "KL107"]
g2_group = ["KL1", "KL3", "KL14", "KL15", "KL23", "KL24", "KL25", "KL27", "KL51", "KL62", "KL102"]

big_groups = g1_group + g2_group

g3_group = [kltype for kltype in df_metrics_sorted["KL_type"].tolist() if kltype not in big_groups]

In [None]:
import statistics

mcc_g1 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g1_group]
mean_g1 = statistics.mean(mcc_g1)
std_g1 = statistics.stdev(mcc_g1)


mcc_g2 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g2_group]
mean_g2 = statistics.mean(mcc_g2)
std_g2 = statistics.stdev(mcc_g2)


mcc_g3 = [float(row["MCC"]) for _,row in df_metrics_sorted.iterrows() if row["KL_type"] in g3_group]
mean_g3 = statistics.mean(mcc_g3)
std_g3 = statistics.stdev(mcc_g3)

In [None]:
round(mean_g3, 4) , round(std_g3 , 4)

***
# Get the attention weights :

In [None]:
rsync -avzhe ssh \
/media/concha-eloko/Linux/PPT_clean/TropiGATv2.final_df_v2.filtered.tsv \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/train_nn \



In [None]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero , HeteroConv , GATv2Conv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score , matthews_corrcoef

import TropiGAT_functions 
import TropiGAT_graph
#from TropiGAT_functions import get_top_n_kltypes ,clean_print 

import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
import logging
from multiprocessing.pool import ThreadPool
warnings.filterwarnings("ignore")

# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
path_ensemble = f"{path_work}/ficheros_28032023/ensemble_2812"

In [None]:
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info.copy()
df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))


#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
#DF_info = pd.read_csv(f"{path_work}/train_nn/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)



In [None]:
DF_info_lvl_0

In [None]:
def make_ensemble_TropiGAT_attention(path_ensemble) : 
	"""
	This function builds a dictionary with all the models that are part of the TropiGAT predictor
	Input : Path of the models
	Output : Dictionary , attention weights
	# Make a json file with the versions of the GNN corresponding to each KL types
	# Load it
	# Create the correct model instance (TropiGAT_small_module or TropiGAT_big_module)
	"""
	errors = []
	DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.filtered.tsv", sep = "\t" ,  header = 0)
	DF_info_lvl_0 = DF_info.copy()
	df_prophages = DF_info_lvl_0.drop_duplicates(subset = ["Phage"])
	dico_prophage_count = dict(Counter(df_prophages["KL_type_LCA"]))
	dico_ensemble = {}
	for GNN_model in os.listdir(path_ensemble) :
		if GNN_model[-2:] == "pt" :
			KL_type = GNN_model.split(".")[0]
			try :
				if dico_prophage_count[KL_type] >= 125 : 
					model = TropiGAT_models.TropiGAT_big_module_attention(hidden_channels = 1280 , heads = 1)
				else :
					model = TropiGAT_models.TropiGAT_small_module_attention(hidden_channels = 1280 , heads = 1)
				model.load_state_dict(torch.load(f"{path_ensemble}/{GNN_model}"))
				dico_ensemble[KL_type] = model
			except Exception as e :
				a = (KL_type , dico_prophage_count[KL_type], e)
				errors.append(a)
	return dico_ensemble , errors
    
@torch.no_grad()
def make_predictions(model, data):
	model.eval() 
	output, weigths = model(data)
	probabilities = torch.sigmoid(output)
	predictions = probabilities.round() 
	return predictions, probabilities , weigths


def run_prediction_attentive(dico_graph, dico_ensemble, KL_type) :
    dico_predictions = {}
    query_graph = dico_graph[KL_type]["graph"]
    model = dico_ensemble[KL_type]
    prediction, probabilities, weights = make_predictions(model, query_graph)
    dico_predictions[KL_type] = {"probabilitites" : probabilities, "weights" : weights}
    return dico_predictions

In [None]:
dico_models, errors = TropiGAT_functions.make_ensemble_TropiGAT_attention(path_ensemble)
#dico_models, errors = make_ensemble_TropiGAT_attention(path_ensemble)


In [None]:
%% time
# *****************************************************************************
# Make graphs : 
graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
graph_dico = {kltype : {"graph" : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated,DF_info_lvl_0, kltype, 0, 1, 0, 0), 
                        "positive_indices" : [index for index,kltype_ in enumerate(DF_info_lvl_0.drop_duplicates(subset = ["Phage"])["KL_type_LCA"].tolist()) if kltype_ == kltype]}
             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}

In [None]:
attention_data = {}
for kltype in dico_models : 
    out_dico = run_prediction_attentive(graph_dico , dico_models, kltype)
    attention_data.update(out_dico)

#with open(f"{path_work}/attention_weights_dico.json", "w") as outfile :
#    json.dump(attention_data , outfile)

In [None]:
attention_data_final = {}

for kltype in tqdm(attention_data) : 
    prophage_indices = graph_dico[kltype]["positive_indices"]
    tmp_dico = {}
    for _,prophage_index in enumerate(prophage_indices) :
        # If prediction is positive : 
        if attention_data[kltype]["probabilitites"][prophage_index] > 0.5 :
            prob = attention_data[kltype]["probabilitites"][prophage_index]
            real_prophage_name = DF_info["Phage"].unique().tolist()[prophage_index] 
            tmp_dpos = []
           # Look for the edges involving the prophage :
            for index_edge, prophage_edge in enumerate(attention_data[kltype]["weights"][0][1]) :
                if prophage_edge == prophage_index :
                    # Check the value of the attention coeff on the edge involving the prophage :
                    att_coeff = attention_data[kltype]["weights"][1][index_edge]
                    if att_coeff > 0.5 : 
                        # Get the seq of the depo : 
                        depo_index = attention_data[kltype]["weights"][0][0][index_edge]
                        real_depo_index = DF_info["index"].unique().tolist()[depo_index] 
                        seq = DF_info[DF_info["index"] == real_depo_index]["seq"].values[0]
                        # Pack the data :
                        a = (real_depo_index,seq,att_coeff, prob)
                        tmp_dpos.append(a)
        tmp_dico[real_prophage_name] = a
    attention_data_final[kltype] = tmp_dico
                        
                    
    

In [None]:
attention_data_raw = {}

for kltype in tqdm(attention_data) : 
    prophage_indices = graph_dico[kltype]["positive_indices"]
    tmp_dico = {}
    for _,prophage_index in enumerate(prophage_indices) :
        # If prediction is positive : 
        if attention_data[kltype]["probabilitites"][prophage_index] > 0.5 :
            prob = attention_data[kltype]["probabilitites"][prophage_index]
            real_prophage_name = DF_info["Phage"].unique().tolist()[prophage_index] 
            tmp_dpos = []
           # Look for the edges involving the prophage :
            for index_edge, prophage_edge in enumerate(attention_data[kltype]["weights"][0][1]) :
                if prophage_edge == prophage_index :
                    # Check the value of the attention coeff on the edge involving the prophage :
                    att_coeff = attention_data[kltype]["weights"][1][index_edge].float()
                    #if att_coeff > 0.5 : 
                    # Get the seq of the depo : 
                    depo_index = attention_data[kltype]["weights"][0][0][index_edge]
                    real_depo_index = DF_info["index"].unique().tolist()[depo_index] 
                    seq = DF_info[DF_info["index"] == real_depo_index]["seq"].values[0]
                    # Pack the data :
                    a = (real_depo_index,seq,att_coeff, prob)
                    tmp_dpos.append(a)
        tmp_dico[real_prophage_name] = tmp_dpos
    attention_data_raw[kltype] = tmp_dico

In [None]:
attention_data_raw["KL47"]["GCF_020405285.1__phage9"]

In [None]:
with open(f"{path_work}/attention_weights_dpos.raw.tsv", "w") as outfile:
    outfile.write(f"KL_type\tPhage\tdpo_index\tattention_coefficient\tprobability\tseq\n")
    for kltype in attention_data_raw : 
        for prophage in attention_data_raw[kltype] : 
            for dpo in attention_data_raw[kltype][prophage] :
                outfile.write(f"{kltype}\t{prophage}\t{dpo[0]}\t{float(dpo[2])}\t{dpo[3]}\t{dpo[1]}\n")

> Write the dico attention_weights_dico.raw.json: 

In [None]:
def convert_to_serializable(data):
    if isinstance(data, torch.Tensor):
        return data.tolist()  # Convert tensors to lists
    elif isinstance(data, dict):
        return {k: convert_to_serializable(v) for k, v in data.items()}  # Recursively apply to dictionaries
    elif isinstance(data, list):
        return [convert_to_serializable(v) for v in data]  # Apply to each item in the list
    elif isinstance(data, tuple):
        return tuple(convert_to_serializable(v) for v in data)  # Convert items in tuples
    else:
        return data  # Return as is for serializable types

# Apply conversion to the entire attention_data
serializable_attention_data = convert_to_serializable(attention_data_final)

with open(f"{path_work}/attention_weights_dico.raw.json", "w") as outfile:
    json.dump(serializable_attention_data, outfile)

In [None]:
with open(f"{path_work}/attention_weights_dpos.tsv", "w") as outfile:
    outfile.write(f"KL_type\tPhage\tdpo_index\tattention_coefficient\tseq\n")
    for kltype in attention_data_final : 
        tmp_seq = set()
        for prophage in attention_data_final[kltype] : 
            for dpo in attention_data_final[kltype][prophage] :
                print(len(dpo))
                if dpo[1] not in tmp_seq :
                    tmp_seq.add(dpo[1])
                    outfile.write(f"{kltype}\t{prophage}\t{dpo[0]}\t{dpo[2]}\t{dpo[1]}\n")

In [None]:
attention_data["KL128"]

len(attention_data["KL128"]["probabilitites"]) , len(attention_data["KL128"]["weights"][1])

attention_data["KL128"]

> Write the data in Json 

In [None]:
def convert_to_serializable(data):
    if isinstance(data, torch.Tensor):
        return data.tolist()  # Convert tensors to lists
    elif isinstance(data, dict):
        return {k: convert_to_serializable(v) for k, v in data.items()}  # Recursively apply to dictionaries
    elif isinstance(data, list):
        return [convert_to_serializable(v) for v in data]  # Apply to each item in the list
    elif isinstance(data, tuple):
        return tuple(convert_to_serializable(v) for v in data)  # Convert items in tuples
    else:
        return data  # Return as is for serializable types

# Apply conversion to the entire attention_data
serializable_attention_data = convert_to_serializable(attention_data)

# Now, serializable_attention_data contains only JSON-serializable items
with open(f"{path_work}/attention_weights_dico.json", "w") as outfile:
    json.dump(serializable_attention_data, outfile)

In [None]:
serializable_attention_data["KL47"]

In [None]:
import statistics
kltype = "KL47"

mean_KLtype = statistics.mean(serializable_attention_data[kltype]["probabilitites"])

mean_KLtype


In [None]:
len(serializable_attention_data[kltype]["probabilitites"])

***
# Work on the attention coefficient 

In [1]:
import os
import pandas as pd 
from collections import Counter

path_work = "/media/concha-eloko/Linux/PPT_clean"

df_coeff = pd.read_csv(f"{path_work}/attention_weights_dpos.raw.tsv" , sep = "\t", header= 0)

***
### Fun with plotting the attention coeff :

In [None]:
df_coeff

In [None]:
kltype = "KL106"

test_coeff = df_coeff[df_coeff["KL_type"] == kltype]["attention_coefficient"].tolist()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

def plot_histogram(data):
    """
    Plot a histogram of the provided data with fixed bins between 0 and 1, with a step of 0.10.
    Args:
    data (list): A list of floats between 0 and 1.
    """
    # Define the bin edges for the histogram
    bin_edges = np.arange(0, 1.1, 0.0250)  # Bins from 0 to 1 with a step of 0.10
    # Create the histogram using seaborn
    sns.histplot(data, bins=bin_edges, kde=False, color='blue', edgecolor='black')
    # Adding labels and title for clarity
    plt.xlabel('Value')
    plt.ylabel('Count')
    plt.title('Distribution of Floats in Windows of 0.10')
    # Show the plot
    plt.show()

In [None]:
plot_histogram(test_coeff)

***
### Seek pairs of KL types :

In [None]:
df_coeff

In [None]:
phage = "GCF_900502235.1__phage19"
ceoff = 0.7300820350646973

df_coeff[df_coeff["attention_coefficient"] == ceoff]

In [4]:
df_coeff_clean = df_coeff[(df_coeff["probability"] > 0.8) & (df_coeff["attention_coefficient"] > 0.5)]

In [None]:
Counter(df_coeff_clean["attention_coefficient"])

In [9]:
from itertools import combinations

def generate_pairs(input_set):
    """
    Generate all possible pairs of elements from the input set.
    Args:
    input_set (set): A set of elements.
    Returns:
    list of sets: A list containing sets of all possible pairs.
    """
    return [set(pair) for pair in combinations(input_set, 2)]

def count_sets(list_of_sets):
    """
    Count occurrences of unique sets in a list of sets.
    Args:
    list_of_sets (list of sets): A list containing sets.
    Returns:
    Counter: A Counter object with counts of each unique set (converted to tuple).
    """
    # Convert each set to a tuple
    tuples = [tuple(sorted(s)) for s in list_of_sets]
    # Use Counter to count occurrences
    return Counter(tuples)

In [10]:
plural_seq = []
combinations_kltypes = []
for seq in df_coeff_clean['seq'].unique().tolist() :
    df_seq = df_coeff_clean[df_coeff_clean["seq"] == seq]["KL_type"]
    kltypes = set(df_seq.values)
    len_group = len(kltypes)
    if len_group > 1 :
        if len_group == 2 :
            combinations_kltypes.append(kltypes)
        else :
            combinations_kltypes = combinations_kltypes + generate_pairs(kltypes)
        plural_seq.append(seq)


In [11]:
count_sets(combinations_kltypes)

Counter({('KL47', 'KL64'): 7,
         ('KL107', 'KL15'): 7,
         ('KL106', 'KL24'): 6,
         ('KL106', 'KL107'): 6,
         ('KL106', 'KL15'): 6,
         ('KL15', 'KL24'): 6,
         ('KL106', 'KL64'): 5,
         ('KL106', 'KL47'): 5,
         ('KL24', 'KL64'): 4,
         ('KL24', 'KL47'): 4,
         ('KL107', 'KL24'): 4,
         ('KL2', 'KL24'): 4,
         ('KL15', 'KL36'): 4,
         ('KL106', 'KL36'): 4,
         ('KL24', 'KL28'): 4,
         ('KL27', 'KL64'): 3,
         ('KL106', 'KL27'): 3,
         ('KL27', 'KL47'): 3,
         ('KL107', 'KL47'): 3,
         ('KL14', 'KL64'): 3,
         ('KL2', 'KL51'): 3,
         ('KL106', 'KL14'): 3,
         ('KL24', 'KL51'): 3,
         ('KL112', 'KL24'): 3,
         ('KL106', 'KL23'): 3,
         ('KL108', 'KL36'): 3,
         ('KL36', 'KL47'): 3,
         ('KL24', 'KL36'): 3,
         ('KL107', 'KL36'): 3,
         ('KL108', 'KL15'): 3,
         ('KL106', 'KL108'): 3,
         ('KL108', 'KL47'): 3,
         ('KL108', 'KL

In [12]:
len(plural_seq)

100

***
### Check the fold of the assigned dpos


In [2]:
DF_info_light_folded = pd.read_csv(f"{path_work}/TropiGAT_v2.light_folded.tsv", sep = "\t", header = 0)


In [4]:
DF_info_light_folded.drop_duplicates(subset = ["seq"])

Unnamed: 0,Phage,Protein_name,KL_type_LCA,Infected_ancestor,index,Dataset,seq,domain_seq,prophage_id,Fold
0,GCF_902164905.1__phage1,GCF_902164905.1__phage1__34,KL41,GCF_902164905.1,minibatch__460,minibatch,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,prophage_11309,right-handed beta-helix
1,GCF_017310305.1__phage5,GCF_017310305.1__phage5__1353,KL30,n4996,minibatch__1084,minibatch,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,prophage_5,right-handed beta-helix
7,GCF_900622625.1__phage2,GCF_900622625.1__phage2__2892,KL6,GCF_900622625.1,minibatch__1741,minibatch,MAFNPELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRK...,ELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRKMQEDN...,prophage_4098,TIM beta/alpha-barrel
8,GCF_011044795.1__phage17,GCF_011044795.1__phage17__11,KL19,80.7/1001331,minibatch__467,minibatch,MNRSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEK...,RSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEKDV...,prophage_4997,right-handed beta-helix
14,GCF_019096335.1__phage21,GCF_019096335.1__phage21__173,KL25,n12421242,minibatch__15,minibatch,MYHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNT...,YHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNTV...,prophage_8486,right-handed beta-helix
...,...,...,...,...,...,...,...,...,...,...
11462,GCF_900506765.1__phage17,GCF_900506765.1__phage17__90,KL149,GCF_900506765.1,anubis_return__4216,anubis_return,MMTTLNEHPQWESDIYLIKRSDLVAGGRGGIANMQAQQLANRTAFL...,NRRWFRRFTGNIRAEWSGIHDLSQSSAPVDSYIYRLLLASAVGSPD...,prophage_15598,right-handed beta-helix
11463,GCF_003255785.1__phage1,GCF_003255785.1__phage1__10,KL127,GCF_003255785.1,anubis_return__4239,anubis_return,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,prophage_3577,6-bladed beta-propeller
11464,GCF_002186895.1__phage9,GCF_002186895.1__phage9__5,KL57,GCF_002186895.1,anubis_return__4260,anubis_return,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,prophage_6002,6-bladed beta-propeller
11465,GCF_004312845.1__phage3,GCF_004312845.1__phage3__38,KL9,GCF_004312845.1,anubis_return__4275,anubis_return,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,prophage_12656,right-handed beta-helix


In [18]:
assgined_folds = []

for _, seq in enumerate(df_coeff["seq"].unique().tolist()) : 
    fold = DF_info_light_folded[DF_info_light_folded["seq"] == seq]["Fold"].values[0]
    assgined_folds.append(fold)


In [19]:
Counter(assgined_folds)

Counter({'right-handed beta-helix': 2314,
         '6-bladed beta-propeller': 526,
         'TIM beta/alpha-barrel': 202,
         'Alpha/Beta hydrolase fold': 71,
         'alpha/alpha toroid': 23,
         'triple-helix': 17,
         'unknown': 1})

In [23]:
Counter(DF_info_light_folded.drop_duplicates(subset = ["seq"])["Fold"])

Counter({'right-handed beta-helix': 2722,
         '6-bladed beta-propeller': 714,
         'TIM beta/alpha-barrel': 294,
         'Alpha/Beta hydrolase fold': 114,
         'triple-helix': 32,
         'alpha/alpha toroid': 29,
         'unknown': 3})

In [25]:
fold_plural = []

for _, seq in enumerate(plural_seq) : 
    fold = DF_info_light_folded[DF_info_light_folded["seq"] == seq]["Fold"].values[0]
    fold_plural.append(fold)

In [26]:
Counter(fold_plural)

Counter({'right-handed beta-helix': 46,
         '6-bladed beta-propeller': 23,
         'TIM beta/alpha-barrel': 14,
         'Alpha/Beta hydrolase fold': 12,
         'triple-helix': 2,
         'alpha/alpha toroid': 2,
         'unknown': 1})

In [27]:
714/3933*100

18.154080854309687

In [28]:
23/100*100

23.0

In [29]:
len(fold_plural)

100