In [1]:
# Ground modules
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from Bio import SeqIO
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import logging
import subprocess
from multiprocessing.pool import ThreadPool
import joblib

# SCikitlearn modules :
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report , roc_auc_score

# Scipy modules : 
from scipy.stats import fisher_exact

***
# Make function that :
### A : blastp from a Dpo seq
### B : read the results and spot the hits
### C : Build a vector from the presence abscence
### D : Make prediction
***

> 77 phages

In [2]:
import json

path_seqbased = "/media/concha-eloko/Linux/PPT_clean/Seqbased_model"
path_db = f"{path_seqbased}/TropiSeq/TropiSeq_0.75.db"

dico_cluster = json.load(open(f"{path_seqbased}/dico_cluster.cdhit__0.75.json"))
dico_cluster_r = {ref_dpo : key_dpo for key_dpo,list_dpo in dico_cluster.items() for ref_dpo in list_dpo}


In [5]:
len(dico_cluster)

883

In [29]:
num_arrays = 883
list_of_arrays = [np.zeros(num_arrays) for _ in range(num_arrays)]
for i, arr in enumerate(list_of_arrays):
    arr[i] = 1

***
# Make predictions

In [31]:
import pickle
import os
from joblib import load

path_seqbased = "/media/concha-eloko/Linux/PPT_clean"

models_TropiSeq = {}

for rf_model in os.listdir(f"{path_seqbased}/selected_RF_3112") :
        kltype = rf_model.split("_RF_")[1].split(".")[0]
        with open(f"{path_seqbased}/selected_RF_3112/{rf_model}", 'rb') as file:
            models_TropiSeq[kltype] = load(file)

TropiSeq_results = {}


> Make the predictions : 

In [33]:
# Others part : 
for index,array in tqdm(enumerate(list_of_arrays)) :
    cluster_id = "cluster_" + str(index)
    tmp_positif = {}
    for kltype in models_TropiSeq :
        pred = models_TropiSeq[kltype].predict_proba(np.array(array).reshape(1, -1))
        if pred[0][1] >= 0.5 :
            tmp_positif[kltype] = pred[0][1]
    TropiSeq_results[cluster_id] = tmp_positif


883it [28:00,  1.90s/it]


In [35]:
import json 

with open("/media/concha-eloko/Linux/PPT_clean/Seqbased_model/cluster_KLtypes.json", "w") as outfile :
    json.dump(TropiSeq_results, outfile)

In [37]:
from collections import Counter
lengths = [len(TropiSeq_results[cluster]) for cluster in TropiSeq_results]


Counter({1: 382,
         0: 361,
         2: 76,
         3: 26,
         4: 11,
         5: 8,
         7: 5,
         6: 4,
         11: 3,
         8: 3,
         9: 2,
         12: 1,
         10: 1})

In [43]:
from itertools import combinations
pairs_list = []
associations_tropiseq = [set(kl for kl in TropiSeq_results[cluster]) for cluster in TropiSeq_results if len(TropiSeq_results[cluster])>0]


for s in associations_tropiseq:
    # Convert set to list for compatibility with combinations
    elements = list(s)
    pairs = combinations(elements, 2)
    pairs_list.extend(pairs)

# Convert pairs_list to a list of tuples
#pairs_list = list(pairs_list)




In [45]:
Counter(pairs_list)

Counter({('KL47', 'KL64'): 5,
         ('KL30', 'KL125'): 4,
         ('KL51', 'KL81'): 3,
         ('KL123', 'KL43'): 3,
         ('KL21', 'KL64'): 3,
         ('KL24', 'KL28'): 3,
         ('KL105', 'KL15'): 3,
         ('KL36', 'KL106'): 3,
         ('KL36', 'KL15'): 3,
         ('KL36', 'KL107'): 3,
         ('KL24', 'KL15'): 3,
         ('KL106', 'KL15'): 3,
         ('KL106', 'KL107'): 3,
         ('KL15', 'KL107'): 3,
         ('KL15', 'KL64'): 3,
         ('KL107', 'KL64'): 3,
         ('KL107', 'KL106'): 3,
         ('KL74', 'KL26'): 3,
         ('KL116', 'KL30'): 3,
         ('KL116', 'KL125'): 3,
         ('KL8', 'KL22'): 3,
         ('KL107', 'KL15'): 3,
         ('KL2', 'KL122'): 3,
         ('KL2', 'KL64'): 3,
         ('KL13', 'KL2'): 3,
         ('KL5', 'KL30'): 3,
         ('KL51', 'KL2'): 3,
         ('KL21', 'KL47'): 2,
         ('KL24', 'KL112'): 2,
         ('KL112', 'KL39'): 2,
         ('KL8', 'KL1'): 2,
         ('KL31', 'KL14'): 2,
         ('KL48', 'KL9'): 2,
