In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
from collections import Counter
import logging
import subprocess
from multiprocessing.pool import ThreadPool
import joblib

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, roc_auc_score, matthews_corrcoef

from skopt import BayesSearchCV
from skopt.space import Real, Categorical, Integer

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
PATH_WORK = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023/Seqbased_model"
PATH_MODELS = f"{PATH_WORK}/RF_1309_models"
PATH_TESTING = f"{PATH_WORK}/RF_1309_data"
PATH_MULTI_FASTA = f"{PATH_WORK}/Dpo_domains.2912.multi.fasta"
PATH_TMP_CDHIT = f"{PATH_WORK}/cdhit_clusters_2912"

os.makedirs(PATH_MODELS, exist_ok=True)
os.makedirs(PATH_TESTING, exist_ok=True)

def load_data():
    df_info = pd.read_csv(f"{PATH_WORK}/TropiGATv2.final_df_v2.tsv", sep="\t", header=0)
    df_prophages = df_info.drop_duplicates(subset=["Phage"], keep="first")
    dico_prophage_info = {row["Phage"]: {"prophage_strain": row["prophage_id"], "ancestor": row["Infected_ancestor"]} for _, row in df_prophages.iterrows()}
    return df_info, dico_prophage_info

def get_filtered_prophages(prophage, df_info, dico_prophage_info):
    to_exclude = set()
    to_keep = {prophage}
    df_prophage_group = df_info[(df_info["prophage_id"] == dico_prophage_info[prophage]["prophage_strain"]) & 
                                (df_info["Infected_ancestor"] == dico_prophage_info[prophage]["ancestor"])]
    
    if len(df_prophage_group) > 1:
        depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage]["domain_seq"].values)
        combinations = [depo_set]
        
        for prophage_tmp in df_prophage_group["Phage"].unique():
            if prophage_tmp != prophage:
                tmp_depo_set = set(df_prophage_group[df_prophage_group["Phage"] == prophage_tmp]["domain_seq"].values)
                if tmp_depo_set in combinations:
                    to_exclude.add(prophage_tmp)
                else:
                    to_keep.add(prophage_tmp)
                    combinations.append(tmp_depo_set)
    
    return df_prophage_group, to_exclude, to_keep

def filter_prophages(df_info, dico_prophage_info):
    good_prophages = set()
    excluded_prophages = set()

    for prophage in tqdm(dico_prophage_info.keys()):
        if prophage not in excluded_prophages and prophage not in good_prophages:
            _, excluded_members, kept_members = get_filtered_prophages(prophage, df_info, dico_prophage_info)
            good_prophages.update(kept_members)
            excluded_prophages.update(excluded_members)

    df_info_filtered = df_info[df_info["Phage"].isin(good_prophages)]
    return df_info_filtered[~df_info_filtered["KL_type_LCA"].str.contains("\\|")]

def ultrafilter_prophages(df_info):
    duplicate_prophage = []
    for kltype in df_info["KL_type_LCA"].unique():
        df_kl = df_info[df_info["KL_type_LCA"] == kltype][["Phage", "domain_seq"]]
        set_sets_depo = []
        for _, group in df_kl.groupby("Phage"):
            set_depo = frozenset(group["domain_seq"].values)
            if set_depo in set_sets_depo:
                duplicate_prophage.extend(group["Phage"])
            else:
                set_sets_depo.append(set_depo)
    
    return df_info[~df_info["Phage"].isin(duplicate_prophage)]

def make_cdhit_cluster(threshold):
    cdhit_command = f"cd-hit -i {PATH_MULTI_FASTA} -o {PATH_TMP_CDHIT}/{threshold}.out -c {threshold} -G 0 -aL 0.8"
    subprocess.run(cdhit_command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

def make_cluster_dico(cdhit_out):
    import json
    dico_cluster = {}
    threshold = cdhit_out.split("/")[-1].split(".out")[0]
    cluster_file = f"{cdhit_out}.clstr"
    
    with open(cluster_file) as f:
        for cluster in f.read().split(">Cluster")[1:]:
            id_cluster = f"Dpo_cdhit_{len(dico_cluster)}"
            dico_cluster[id_cluster] = [line.split(">")[1].split(".")[0] for line in cluster.split("\n")[1:-1] if line]
    
    with open(f"{PATH_WORK}/dico_cluster.cdhit__{threshold}.json", "w") as outfile:
        json.dump(dico_cluster, outfile)
    
    return dico_cluster, threshold

def make_DF_binaries(df_info, dico_cluster, threshold):
    df_dpo_prophages = pd.DataFrame(index=df_info.Phage.unique(), columns=dico_cluster.keys())
    
    for phage in df_info.Phage.unique():
        df_phage = set(df_info[df_info["Phage"] == phage]["index"].values)
        df_dpo_prophages.loc[phage] = [bool(set(dpos) & df_phage) for dpos in dico_cluster.values()]
    
    df_dpo_prophages = df_dpo_prophages.astype(int)
    df_dpo_prophages.to_csv(f"{PATH_WORK}/DF_binaries_{threshold}.csv", sep=",", index=True)
    return df_dpo_prophages

def make_DF_kltype(df_info, df, KL_type, dico_cluster, ratio=5, collapse=False):
    positive_phages = df_info[df_info["KL_type_LCA"] == KL_type]["Phage"].unique()
    df_positives = df[df.index.isin(positive_phages)].drop_duplicates()
    
    n_samples = len(df_positives)
    negative_phages = random.sample([phage for phage in df_info["Phage"].unique() 
                                     if KL_type not in dico_prophage_kltype_associated[phage]], 
                                    int(n_samples * ratio))
    
    df_kltype = pd.concat([df_positives, df[df.index.isin(negative_phages)]])
    labels = [1] * n_samples + [0] * (len(df_kltype) - n_samples)
    
    if collapse:
        df_kltype = df_kltype.loc[:, df_kltype.sum() > 0]
    
    return df_kltype, labels

def fit_rf_model_random_search(df_kl, all_labels, KL_type, threshold, n_splits=5, n_iters=100):
    if os.path.isfile(f'{PATH_MODELS}/{threshold}_RF_{KL_type}.full_data.joblib') == False:
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=243)
        data_kltype = {}

        param_grid = {
            'bootstrap': Categorical([True, False]),
            'max_depth': Integer(10, 100),
            'max_features': Categorical(['auto', 'sqrt']),
            'min_samples_leaf': Integer(1, 4),
            'min_samples_split': Integer(2, 10),
            'n_estimators': Integer(200, 800)
        }

        for n, (train_index, test_index) in enumerate(skf.split(df_kl, all_labels)):
            X_train, X_test = df_kl.iloc[train_index], df_kl.iloc[test_index]
            y_train, y_test = pd.Series(all_labels).iloc[train_index], pd.Series(all_labels).iloc[test_index]

            if n == 0:
                rf = RandomForestClassifier(random_state=42)
                bayes_search = BayesSearchCV(rf, param_grid, n_iter=n_iters, cv=4, n_jobs=-1)
                bayes_search.fit(X_train, y_train)
                best_params = bayes_search.best_params_
                best_model = bayes_search.best_estimator_
            else:
                best_model = RandomForestClassifier(random_state=42, **best_params)
                best_model.fit(X_train, y_train)

            predictions = best_model.predict(X_test)
            data_kltype[n] = {
                "best_parameters": best_params,
                "model": best_model,
                "test_data": (y_test, predictions),
                "test_&_model_predictions": (X_test, y_test),
                "iteration": n
            }

        joblib.dump(data_kltype, f'{PATH_MODELS}/{threshold}_RF_{KL_type}.full_data.joblib')
    else :
        pass

def make_prediction_file(path_file):
    dico_cluster, threshold = make_cluster_dico(path_file)
    df_binaries = make_DF_binaries(DF_info_lvl_0, dico_cluster, threshold)
    
    for KL_type, count in KLtype_count.items():
        if count >= 5 and not os.path.isfile(f'{PATH_MODELS}/{threshold}_RF_{KL_type}.full_data.joblib'):
            logging.info(f"Processing KL type: {KL_type}")
            df_kl, all_labels = make_DF_kltype(DF_info_lvl_0, df_binaries, KL_type, dico_cluster, collapse=False)
            fit_rf_model_random_search(df_kl, all_labels, KL_type, threshold)

if __name__ == '__main__':
    DF_info, dico_prophage_info = load_data()
    DF_info_lvl_0 = ultrafilter_prophages(filter_prophages(DF_info, dico_prophage_info))
    KLtype_count = Counter(DF_info_lvl_0["KL_type_LCA"])
    
    dico_prophage_kltype_associated = {
        phage: set(DF_info_lvl_0[DF_info_lvl_0["Phage"] == phage]["KL_type_LCA"].values)
        for phage in DF_info_lvl_0["Phage"].unique()
    }
    
    depo_domains_seq = dict(zip(DF_info_lvl_0["index"], DF_info_lvl_0['domain_seq']))
    with open(PATH_MULTI_FASTA, "w") as outfile:
        for index, seq in depo_domains_seq.items():
            outfile.write(f">{index}\n{seq}\n")
    
    cdhit_thresholds = [0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 0.975]
    for threshold in cdhit_thresholds:
        make_cdhit_cluster(threshold)
    
    cdhit_files = [f"{PATH_TMP_CDHIT}/{file}" for file in os.listdir(PATH_TMP_CDHIT) if file.endswith(".out")]
    
    with ThreadPool(10) as p:
        p.map(make_prediction_file, cdhit_files)