In [21]:
# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from Bio import SeqIO
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import logging
import subprocess
from multiprocessing.pool import ThreadPool
import joblib

# SCikitlearn modules :
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report , roc_auc_score, matthews_corrcoef

# Scipy modules : 
from scipy.stats import fisher_exact

> Import prophage data:

In [39]:
path_work = "/media/concha-eloko/Linux/PPT_clean"

DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df_v2.tsv", sep = "\t" ,  header = 0)
df_prophages = DF_info.drop_duplicates(subset = ["Phage"], keep = "first")
dico_prophage_info = {row["Phage"] : {"prophage_strain" : row["prophage_id"] , "ancestor" : row["Infected_ancestor"]} for _,row in df_prophages.iterrows()}


In [40]:
def get_filtered_prophages(prophage) :
    combinations = []
    to_exclude = set()
    to_keep = set()
    to_keep.add(prophage)
    df_prophage_group = DF_info[(DF_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & (DF_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    if len(df_prophage_group) == 1 : 
        pass
    else :
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        for prophage_tmp in df_prophage_group["Phage"].unique().tolist() :
            if prophage_tmp != prophage :
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if depo_set == tmp_depo_set :
                    to_exclude.add(prophage_tmp)
                else :
                    if tmp_depo_set not in combinations :
                        to_keep.add(prophage_tmp)
                        combinations.append(tmp_depo_set)
                    else :
                        to_exclude.add(prophage_tmp)
    return df_prophage_group , to_exclude , to_keep

good_prophages = set()
excluded_prophages = set()

for prophage, info_prophage in tqdm(dico_prophage_info.items()) :
    if prophage not in excluded_prophages and prophage not in good_prophages:
        _, excluded_members , kept_members = get_filtered_prophages(prophage) 
        good_prophages.update(kept_members)
        excluded_prophages.update(excluded_members)

DF_info_lvl_0_filtered = DF_info[DF_info["Phage"].isin(good_prophages)]
DF_info_lvl_0_final = DF_info_lvl_0_filtered[~DF_info_lvl_0_filtered["KL_type_LCA"].str.contains("\\|")]

DF_info_lvl_0 = DF_info_lvl_0_final.copy()

# useful dictionary :
KLtype_count = Counter(DF_info_lvl_0["KL_type_LCA"])
KLtype_pred = [kltype for kltype in KLtype_count if KLtype_count[kltype] >= 5]

dico_prophage_kltype_associated = {}
for negative_index,phage in tqdm(enumerate(DF_info_lvl_0["Phage"].unique().tolist())) :
    kltypes = set()
    dpos = DF_info_lvl_0[DF_info_lvl_0["Phage"] == phage]["index"]
    for dpo in dpos : 
        tmp_kltypes = DF_info_lvl_0[DF_info_lvl_0["index"] == dpo]["KL_type_LCA"].values
        kltypes.update(tmp_kltypes)
    dico_prophage_kltype_associated[phage] = kltypes

depo_domains_seq = {index: domain_seq for index, domain_seq in zip(DF_info_lvl_0["index"], DF_info_lvl_0['domain_seq'])}

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15981/15981 [00:17<00:00, 931.76it/s]
8892it [00:14, 632.13it/s]


> import the dico

In [None]:
rsync -avzhe ssh \
conchae@garnatxa.srv.cpd:/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/Seqbased_model/dico_cluster.cdhit__0.65.json \
/media/concha-eloko/Linux/PPT_clean/Seqbased_model


In [41]:
import json

path_seqbased = "/media/concha-eloko/Linux/PPT_clean/Seqbased_model"
path_db = f"{path_seqbased}/TropiSeq/TropiSeq_0.65.db"

dico_cluster = json.load(open(f"{path_seqbased}/dico_cluster.cdhit__0.65.json"))
dico_cluster_r = {ref_dpo : key_dpo for key_dpo,list_dpo in dico_cluster.items() for ref_dpo in list_dpo}


> Identify the best model : 

In [27]:
def get_best_performing_model(path_data_object):
    try:
        assert os.path.isfile(path_data_object)
        data_object = joblib.load(path_data_object)
        mcc_values = []
        index_track = []
        for i in data_object:
            y_test = data_object[i]["test_data"][0].values
            predictions = data_object[i]["test_data"][1]
            mcc = matthews_corrcoef(y_test, predictions)
            mcc_values.append(mcc)
            index_track.append(i)
        del data_object 
        max_index = mcc_values.index(max(mcc_values))
        best_model = index_track[max_index]        
        return (best_model, max(mcc_values), path_data_object.split("_LogReg_")[1].split(".")[0])
    except AssertionError:
        raise FileNotFoundError(f"File not found: {path_data_object}")

In [28]:
path_models = "/media/concha-eloko/Linux/PPT_clean/Seqbased_model/LogReg_models"

In [30]:
dico_best_logreg = {}

for model in os.listdir(path_models) :
    best_model = get_best_performing_model(f"{path_models}/{model}")
    dico_best_logreg[best_model[2]] = best_model[0]

    

In [None]:
path_metrics = "/media/concha-eloko/Linux/PPT_clean/ficheros_28032023/review_work/SeqBased_model/metric_files"

threshold = 5
names_metric_col = ["KL_type", "Count", "file_name","mean_mcc"]
final_weighted_dico = {}

for file in os.listdir(path_metrics):
    metric_df = pd.read_csv(f"{path_metrics}/{file}", sep = "\t", names = names_metric_col)
    metric_eval_df = metric_df[metric_df["Count"] > threshold]
    metric_eval_df["c_value"] = metric_eval_df["file_name"].astype(str).apply(lambda x: x.split("_")[0])
    weighted_mcc_dico = {}
    for cluster in metric_eval_df["c_value"].unique().tolist() :
        mcc_sum = 0
        cl_df = metric_eval_df[metric_eval_df["c_value"] == cluster]
        for _,row in cl_df.iterrows() : 
            mcc_sum = mcc_sum + row["mean_mcc"] * row["Count"]
        weighted_mcc = mcc_sum / (sum(cl_df["Count"]))
        weighted_mcc_dico[cluster] = weighted_mcc
    final_weighted_dico[file] = weighted_mcc_dico

In [31]:
import pickle
import os
from joblib import load

path_seqbased = "/media/concha-eloko/Linux/PPT_clean"

models_TropiLR = {}

for lr_model in os.listdir(f"{path_seqbased}/Seqbased_model/LogReg_models") :
    kltype = lr_model.split("_LogReg_")[1].split(".")[0]
    best_model = dico_best_logreg[kltype]
    with open(f"{path_seqbased}/Seqbased_model/LogReg_models/{lr_model}", 'rb') as file:
        models_TropiLR[kltype] = load(file)[best_model]["model"]

TropiLR_results = {}

***
# Approach 1 : Predictions with probability > 0.5

In [33]:
num_arrays = 836
list_of_arrays = [np.zeros(num_arrays) for _ in range(num_arrays)]
for i, arr in enumerate(list_of_arrays):
    arr[i] = 1

In [37]:
for index,array in tqdm(enumerate(list_of_arrays)) :
    cluster_id = "cluster_" + str(index)
    tmp_positif = {}
    for kltype in models_TropiLR :
        pred = models_TropiLR[kltype].predict_proba(np.array(array).reshape(1, -1))
        if pred[0][1] >= 0.5 :
            tmp_positif[kltype] = pred[0][1]
    TropiLR_results[cluster_id] = tmp_positif

import json 
with open("/media/concha-eloko/Linux/PPT_clean/Seqbased_model/prediction_based.labeling.LogReg.json", "w") as outfile :
    json.dump(TropiLR_results, outfile)

836it [00:02, 282.37it/s]


In [38]:
dico_pred = json.load(open("/media/concha-eloko/Linux/PPT_clean/Seqbased_model/prediction_based.labeling.LogReg.json"))
dico_pred_correct_name = {f"Dpo_cdhit_{cluster.split('_')[1]}":hits  for cluster, hits in dico_pred.items()}
dico_pred_correct_name

{'Dpo_cdhit_0': {},
 'Dpo_cdhit_1': {},
 'Dpo_cdhit_2': {},
 'Dpo_cdhit_3': {},
 'Dpo_cdhit_4': {},
 'Dpo_cdhit_5': {},
 'Dpo_cdhit_6': {},
 'Dpo_cdhit_7': {},
 'Dpo_cdhit_8': {},
 'Dpo_cdhit_9': {},
 'Dpo_cdhit_10': {},
 'Dpo_cdhit_11': {},
 'Dpo_cdhit_12': {},
 'Dpo_cdhit_13': {},
 'Dpo_cdhit_14': {'KL14': 0.7400550954866159},
 'Dpo_cdhit_15': {},
 'Dpo_cdhit_16': {},
 'Dpo_cdhit_17': {},
 'Dpo_cdhit_18': {},
 'Dpo_cdhit_19': {},
 'Dpo_cdhit_20': {},
 'Dpo_cdhit_21': {},
 'Dpo_cdhit_22': {},
 'Dpo_cdhit_23': {},
 'Dpo_cdhit_24': {},
 'Dpo_cdhit_25': {},
 'Dpo_cdhit_26': {},
 'Dpo_cdhit_27': {},
 'Dpo_cdhit_28': {},
 'Dpo_cdhit_29': {},
 'Dpo_cdhit_30': {},
 'Dpo_cdhit_31': {},
 'Dpo_cdhit_32': {},
 'Dpo_cdhit_33': {},
 'Dpo_cdhit_34': {},
 'Dpo_cdhit_35': {},
 'Dpo_cdhit_36': {'KL6': 0.5231290342168122},
 'Dpo_cdhit_37': {},
 'Dpo_cdhit_38': {},
 'Dpo_cdhit_39': {},
 'Dpo_cdhit_40': {'KL2': 0.748525864533034},
 'Dpo_cdhit_41': {},
 'Dpo_cdhit_42': {'KL19': 0.5359952594098107},
 'Dpo_

> Write the DF :

In [42]:
DF_info_depo = DF_info_lvl_0.drop_duplicates(subset = ["seq"])
DF_info_depo_clean = DF_info_depo[["index","seq","domain_seq"]]


In [43]:
with open(f"{path_seqbased}/labeling_depo_clusters.pred.LogReg.tsv", "w") as outfile :
    outfile.write(f"index\tseq\tdomain_seq\tdepo_cluster\tTropiLR_KL_types\tTropiLR_scores\n")
    for _,row in DF_info_depo_clean.iterrows() :
        depo_cluster = dico_cluster_r[row["index"]]
        hits_dico = dico_pred_correct_name.get(depo_cluster, "No_association")
        outfile.write(f"{row['index']}\t{row['seq']}\t{row['domain_seq']}\t{depo_cluster}\t")
        if isinstance(hits_dico, dict) :
            targets = ",".join(list(hits_dico.keys()))
            outfile.write(f"{targets}\t")
            hits_info_list = [f"{hit} : {hits_dico[hit]}" for hit in hits_dico]
            outfile.write(f"{','.join(hits_info_list)}\n")
        else :
            outfile.write(f"None\tNone\n")
                

In [44]:
tropiseq_labeled_seq = pd.read_csv(f"{path_seqbased}/labeling_depo_clusters.pred.LogReg.tsv", sep = "\t", header = 0)
tropiseq_labeled_seq_annot = tropiseq_labeled_seq[tropiseq_labeled_seq["TropiLR_KL_types"] != "None"]
tropiseq_labeled_seq_annot

Unnamed: 0,index,seq,domain_seq,depo_cluster,TropiLR_KL_types,TropiLR_scores
0,minibatch__460,MPATPQDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDE...,QDRLYGLTTSVAVKPPVFISVDYDVARFGEQTITSKTPTDERTITT...,Dpo_cdhit_73,,
1,minibatch__1084,MTVSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVT...,VSTQVSRNEYTGNGATTQYDFTFRILDKSHLLVQTMDTSENIVTLT...,Dpo_cdhit_11,,
2,minibatch__1741,MAFNPELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRK...,ELGSSSPEVLLDNAKRLDELTNGPAATVPDRAGEPLDSWRKMQEDN...,Dpo_cdhit_117,KL6,KL6 : 0.5231484010327386
3,minibatch__467,MNRSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEK...,RSRRLLMRGIGYLTLFPLLFLFSKKVSSAPNGLTEKVKNRKIEKDV...,Dpo_cdhit_186,,
4,minibatch__15,MYHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNT...,YHLDNTSGVPEMPEPKEQQSISPRWFGESQEQGGISWPGADWFNTV...,Dpo_cdhit_221,,
...,...,...,...,...,...,...
3912,anubis_return__4216,MMTTLNEHPQWESDIYLIKRSDLVAGGRGGIANMQAQQLANRTAFL...,NRRWFRRFTGNIRAEWSGIHDLSQSSAPVDSYIYRLLLASAVGSPD...,Dpo_cdhit_182,,
3913,anubis_return__4239,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,MNGLNHNALTCSAVPIPPWERSLQTVEAQPYFSVSQASLVLEGIVF...,Dpo_cdhit_562,"KL27,KL62","KL27 : 0.8679564244246406,KL62 : 0.89953390257..."
3914,anubis_return__4260,MRYRFIALALCLLSGSKVAISAGFDCSLANLSPTEKTICSNEYLSG...,ITDSPWLVKKIFSSDSFEGGINLEGMNVSSILTYQEIKNDLYIYIS...,Dpo_cdhit_643,,
3915,anubis_return__4275,MAILITGKSMTRLPESSSWEEEIELITRSERVAGGLDGPANRPLKS...,DAVIRRDLASDKGTSGVGKLGDKPLVAISYYKSKGQSDQDAVQAAF...,Dpo_cdhit_769,,


In [45]:
print(path_seqbased)

/media/concha-eloko/Linux/PPT_clean/Seqbased_model
