# PhageHostLearn - v3.4.klebsiella - human in the loop

An AI-based Phage-Host interaction predictor framework with K-loci and receptor-binding proteins at its core. This particular PhageHostLearn is for *Klebsiella pneumoniae* related phages.

This notebook offers the functionality to add new data to the PhageHostLearn framework and retrain the PhageHostLearn prediction models, without having to process all data from scratch. Here, it is assumed that you have completed the initial set-up that is carried out in the `phagehostlearn_training.ipynb`.

**Overview of this notebook**
- Setting A: new validated interactions for the same data
- Setting B or C: new interactions for either new phages or bacteria (= new data)
- Setting D: new interactions for new combinations of phages AND bacteria

**Architecture of the PhageHostLearn framework**: 
- Multi-RBP setting: phages consisting of one or more RBPs (multi-instance)
- K-loci proteins (multi-instance)
- Embeddings for both based on ESM-2 language models and HDC
- Combined XGBoost model (for language embeddings) and Random Forest (for HDC embeddings) to make predictions

In [4]:
import numpy as np
import pandas as pd
from Bio import SeqIO
from os import listdir
from joblib import dump, load
from tqdm.notebook import tqdm
from xgboost import XGBClassifier
import phagehostlearn_utils as phlu
import phagehostlearn_features as phlf
import phagehostlearn_processing as phlp
from sklearn.ensemble import RandomForestClassifier

In [2]:
general_path = '/Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/42_DATA/Valencia_data'
results_path = '/Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/43_RESULTS/models'
data_suffix = 'Valencia'

## Setting A: adding validated interactions for the same data

In this setting, we're not adding new phages or bacterial hosts, but we have tested new interactions for the phages and bacteria that are already present in the dataset. In this scenario, we only need to add those new interactions to our interaction matrix and retrain from there.

#### A.1 Manually add the validated interactions in the .xlsx file with interactions

#### A.2 Reconstruct interaction matrix and feature matrices

In [None]:
interactions_xlsx_path = general_path+'/klebsiella_phage_host_interactions.xlsx'
phlp.process_interactions(general_path, interactions_xlsx_path, data_suffix=data_suffix)

In [None]:
features_esm2, features_hdc, labels, groups_loci, groups_phage = phlf.construct_feature_matrices(general_path, 
                                                                                            data_suffix=data_suffix)

#### A.3 Retrain & save models

In [None]:
cpus=6
labels = np.asarray(labels)
model_suffix = ''

In [None]:
# ESM-2 FEATURES + XGBoost model
imbalance = sum([1 for i in labels if i==1]) / sum([1 for i in labels if i==0])
xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=5,
                    n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
xgb.fit(features_esm2, labels)
xgb.save_model('phagehostlearn_esm2_xgb'+model_suffix+'.json')

In [None]:
# HDC FEATURES + RF model
rf = RandomForestClassifier(n_estimators=1000, max_depth=5, class_weight='balanced', n_jobs=cpus)
rf.fit(features_hdc, labels)
dump(rf, 'phagehostlearn_hdc_rf'+model_suffix+'.joblib')

## Setting B/C: adding new phages or bacteria + interactions

In this setting, we're adding either new phages or bacteria against the known bacteria or phages, respectively. This entails adding the new genomes.fasta files in the respective folders (see `phagehostlearn_training.ipynb`) and manually adding the new rows or columns to the interactions.xlsx file. Alternatively, you can make a new_interactions.xlsx file and combine it with the old interaction matrix in Python.

#### BC.1 Manually add the new phage genomes or bacterial genomes to their designated folders

#### BC.2 Quality check: undetermined nts & identical genomes

In [6]:
# undetermined nts
phage_genomes_path = general_path+'/phages_genomes'
phage_files = listdir(phage_genomes_path)
phage_files.remove('.DS_Store')
undetermined = []
for file in phage_files:
    file_dir = phage_genomes_path+'/'+file
    for record in SeqIO.parse(file_dir, 'fasta'):
        sequence = str(record.seq)
        approved_letters = sequence.count('A') + sequence.count('C') + sequence.count('T') + sequence.count('G')
        if approved_letters != len(sequence):
            undetermined.append(record.id)
print('Genomes with undertermined nts: ', undetermined)

Genomes with undertermined nts:  ['P4a', 'A1m']


In [None]:
# CD-HIT

#### BC.3 Rerun the relevant processing steps with the add=True parameter

If you've added new phage genomes, you'll have to rerun PHANOTATE, constructing protein embeddings and PhageRBPdetect. If you've added new bacterial genomes, you'll have to rerun Kaptive. Afterwards, rerun the processing of the interaction matrix.

In [3]:
# PHANOTATE
phage_genomes_path = general_path+'/phages_genomes'
phanotate_path = '/opt/homebrew/Caskroom/miniforge/base/envs/ML1/bin/phanotate.py'
phlp.phanotate_processing(general_path, phage_genomes_path, phanotate_path, data_suffix=data_suffix, add=True)

Processing  3  more phages (add=True)


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Protein embeddings (alternatively run in Google Colab or Kaggle)
phlp.compute_protein_embeddings(general_path, data_suffix=data_suffix, add=True)

In [3]:
# PhageRBPdetect
pfam_path = general_path+'/RBPdetect_phageRBPs.hmm'
hmmer_path = '/Users/Dimi/hmmer-3.3.1'
xgb_path = general_path+'/RBPdetect_xgb_hmm.json'
gene_embeddings_path = general_path+'/phage_protein_embeddings'+data_suffix+'.csv'
phlp.phageRBPdetect(general_path, pfam_path, hmmer_path, xgb_path, gene_embeddings_path, data_suffix=data_suffix)

b'Working...    done.\nPressed and indexed 92 HMMs (92 names and 29 accessions).\nModels pressed into binary file:   /Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/42_DATA/Valencia_data/RBPdetect_phageRBPs.hmm.h3m\nSSI index for binary model file:   /Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/42_DATA/Valencia_data/RBPdetect_phageRBPs.hmm.h3i\nProfiles (MSV part) pressed into:  /Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/42_DATA/Valencia_data/RBPdetect_phageRBPs.hmm.h3f\nProfiles (remainder) pressed into: /Users/dimi/GoogleDrive/PhD/4_PHAGEHOST_LEARNING/42_DATA/Valencia_data/RBPdetect_phageRBPs.hmm.h3p\n'


  0%|          | 0/9727 [00:00<?, ?it/s]

In [None]:
# Kaptive
bact_genomes_path = general_path+'/klebsiella_genomes/fasta_files'
kaptive_database_path = general_path+'/Klebsiella_k_locus_primary_reference.gbk'
phlp.process_bacterial_genomes(general_path, bact_genomes_path, kaptive_database_path, 
                               data_suffix=data_suffix, add=True)

#### BC.4 Manually add the new bacteria (as rows) or phages (as columns) and interactions in the interactions.xlsx Excel sheet

#### BC.5 Reconstruct interaction matrix and feature matrices

In [7]:
interactions_xlsx_path = general_path+'/klebsiella_phage_host_interactions.xlsx'
phlp.process_interactions(general_path, interactions_xlsx_path, data_suffix=data_suffix)

In [4]:
# compute ESM-2 RBP embeddings
phlf.compute_esm2_embeddings_rbp(general_path, data_suffix=data_suffix, add=True)

Processing  8  more sequences (add=True)


100%|█████████████████████████████████████████████| 8/8 [01:34<00:00, 11.86s/it]


In [None]:
# compute ESM-2 loci embeddings
phlf.compute_esm2_embeddings_loci(general_path, data_suffix=data_suffix, add=True)

In [None]:
# compute HDC embeddings
locibase_path = general_path+'/Locibase'+data_suffix+'.json'
rbpbase_path = general_path+'/RBPbase'+data_suffix+'.csv'
phlf.compute_hdc_embedding(general_path, data_suffix, locibase_path, rbpbase_path, mode='train')

In [None]:
rbp_embeddings_path = general_path+'/esm2_embeddings_rbp'+data_suffix+'.csv'
loci_embeddings_path = general_path+'/esm2_embeddings_loci'+data_suffix+'.csv'
hdc_embeddings_path = general_path+'/hdc_features'+data_suffix+'.txt'
features_esm2, features_hdc, labels, groups_loci, groups_phage = phlf.construct_feature_matrices(general_path, 
                                                                            data_suffix, loci_embeddings_path, 
                                                                            rbp_embeddings_path, hdc_embeddings_path)

#### BC.6 Retrain and save the models

In [None]:
cpus = 6
labels = np.asarray(labels)
model_suffix = ''

In [None]:
# ESM-2 FEATURES + XGBoost model
imbalance = sum([1 for i in labels if i==1]) / sum([1 for i in labels if i==0])
xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=5,
                    n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
xgb.fit(features_esm2, labels)
xgb.save_model('phagehostlearn_esm2_xgb'+model_suffix+'.json')

In [None]:
# HDC FEATURES + RF model
rf = RandomForestClassifier(n_estimators=1000, max_depth=5, class_weight='balanced', n_jobs=cpus)
rf.fit(features_hdc, labels)
dump(rf, 'phagehostlearn_hdc_rf'+model_suffix+'.joblib')

## Setting D: adding new phages AND bacteria + interactions

In this setting, you're adding interactions for new combinations of phages AND bacteria. Here, we will incorporate these new interactions from a new .xlsx file containing the interactions, and then rerun all the processing steps with the add=True argument.

#### D.1 Manually add the new phage genomes and bacterial genomes to their designated folders

#### D.2 Make a new .xlsx file for the new interactions with the bactrerial accession in the first column (each row a new bacterium) and the phage names in the first row (each column a new phage). Store the .xlsx file in the general folder.

#### D.3 Rerun all the processing steps with add=True

In [None]:
# PHANOTATE
phage_genomes_path = general_path+'/phages_genomes'
phanotate_path = '/opt/homebrew/Caskroom/miniforge/base/envs/ML1/bin/phanotate.py'
phlp.phanotate_processing(general_path, phage_genomes_path, phanotate_path, data_suffix=data_suffix, add=True)

In [None]:
# Protein embeddings (alternatively run in Google Colab or Kaggle)
phlp.compute_protein_embeddings(general_path, data_suffix=data_suffix, add=True)

In [None]:
# PhageRBPdetect
pfam_path = general_path+'/RBPdetect_phageRBPs.hmm'
hmmer_path = '/Users/Dimi/hmmer-3.3.1'
xgb_path = general_path+'/RBPdetect_xgb_hmm.json'
gene_embeddings_path = general_path+'/phage_protein_embeddings'+data_suffix+'.csv'
phlp.phageRBPdetect(general_path, pfam_path, hmmer_path, xgb_path, gene_embeddings_path, data_suffix=data_suffix)

In [None]:
# Kaptive
bact_genomes_path = general_path+'/klebsiella_genomes/fasta_files'
kaptive_database_path = general_path+'/Klebsiella_k_locus_primary_reference.gbk'
phlp.process_bacterial_genomes(general_path, bact_genomes_path, kaptive_database_path, 
                               data_suffix=data_suffix, add=True)

#### D.4 Integrate the interactions

In [None]:
original_interactions_xlsx_path = general_path+'/klebsiella_phage_host_interactions.xlsx'
new_interactions_xlsx_path = ...
interactions_path = general_path+'/phage_host_interactions'+data_suffix
phlp.process_interactions(general_path, original_interactions_xlsx_path, data_suffix=data_suffix)
phlp.add_to_database(interactions_path+'.csv', new_interactions_xlsx_path, interactions_path)

#### D.5 Reconstruct features

In [None]:
# compute ESM-2 RBP embeddings
phlf.compute_esm2_embeddings_rbp(general_path, data_suffix=data_suffix, add=True)

In [None]:
# compute ESM-2 loci embeddings
phlf.compute_esm2_embeddings_loci(general_path, data_suffix=data_suffix, add=True)

In [None]:
# compute HDC embeddings
locibase_path = general_path+'/Locibase'+data_suffix+'.json'
rbpbase_path = general_path+'/RBPbase'+data_suffix+'.csv'
phlf.compute_hdc_embedding(general_path, data_suffix, locibase_path, rbpbase_path, mode='train')

In [None]:
features_esm2, features_hdc, labels, groups_loci, groups_phage = phlf.construct_feature_matrices(general_path, 
                                                                                            data_suffix=data_suffix)

#### D.6 Retrain and save models

In [None]:
cpus = 6
labels = np.asarray(labels)
model_suffix = ''

In [None]:
# ESM-2 FEATURES + XGBoost model
imbalance = sum([1 for i in labels if i==1]) / sum([1 for i in labels if i==0])
xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=5,
                    n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
xgb.fit(features_esm2, labels)
xgb.save_model('phagehostlearn_esm2_xgb'+model_suffix+'.json')

In [None]:
# HDC FEATURES + RF model
rf = RandomForestClassifier(n_estimators=1000, max_depth=5, class_weight='balanced', n_jobs=cpus)
rf.fit(features_hdc, labels)
dump(rf, 'phagehostlearn_hdc_rf'+model_suffix+'.joblib')

## Evaluation: LOGOCV

In [None]:
logo = LeaveOneGroupOut()
cpus = 6
scores_lan = []
scores_hdc = []
label_list = []
labels = np.asarray(labels)
pbar = tqdm(total=len(set(groups_loci)))
for train_index, test_index in logo.split(features_esm2, labels, groups_loci):
    # get the training and test data
    Xlan_train, Xlan_test = features_esm2[train_index], features_esm2[test_index]
    Xhdc_train, Xhdc_test = features_hdc[train_index], features_hdc[test_index]
    y_train, y_test = labels[train_index], labels[test_index]
    imbalance = sum([1 for i in y_train if i==1]) / sum([1 for i in y_train if i==0])

    ## ESM-2 EMBEDDINGS: XGBoost model
    xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=7,
                        n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
    xgb.fit(Xlan_train, y_train)
    score_xgb = xgb.predict_proba(Xlan_test)[:,1]
    scores_lan.append(score_xgb)
    
    ## HDC EMBEDDINGS: RF model
    rf = RandomForestClassifier(n_estimators=1250, max_depth=3, class_weight='balanced', n_jobs=cpus)
    rf.fit(Xhdc_train, y_train)
    score_rf = rf.predict_proba(Xhdc_test)[:,1]
    scores_hdc.append(score_rf)
    
    # save labels for later
    label_list.append(y_test)
    
    # pbar update
    pbar.update(1)
pbar.close()

In [None]:
# save results
logo_results = {'labels': label_list, 'scores_language': scores_lan, 'scores_hdc': scores_hdc}   
with open(results_path+'/v3.4/combined_logocv_results_v34celia.pickle', 'wb') as f:
    pickle.dump(logo_results, f)
    
# read results
#with open(results_path+'/v3.3/combined_logocv_results.pickle', 'rb') as f:
#    logo_results = pickle.load(f)
#scores_lan = logo_results['scores_language']
#scores_hdc = logo_results['scores_hdc']
#labels = logo_results['labels']

# compute performance
rocaucs = []
praucs = []
rqueries_lan = []
rqueries_hdc = []
rqueries_ens = []
for i in range(len(set(groups_loci))):
    score_lan = scores_lan[i]
    score_hdc = scores_hdc[i]
    score_ens = [phlu.uninorm(score_lan[j], score_hdc[j]) for j in range(len(score_lan))]
    y_test = label_list[i]
    try:
            roc_auc = roc_auc_score(y_test, score_ens)
            rocaucs.append(roc_auc)
            precision, recall, thresholds = precision_recall_curve(y_test, score_ens)
            praucs.append(round(auc(recall, precision), 3))
            ranked_lan = [x for _, x in sorted(zip(score_lan, y_test), reverse=True)]
            ranked_hdc = [x for _, x in sorted(zip(score_hdc, y_test), reverse=True)]
            ranked_ens = [x for _, x in sorted(zip(score_ens, y_test), reverse=True)]
            rqueries_lan.append(ranked_lan)
            rqueries_hdc.append(ranked_hdc)
            rqueries_ens.append(ranked_ens)
    except:
        pass

In [None]:
# results, part 1
print('Mean average recall @ K:')
print('ESM-2 + XGBoost: ', phlu.marecallatk(rqueries_lan, 50))
print('HDC + Random Forest: ', phlu.marecallatk(rqueries_hdc, 50))
print('Combined model: ', phlu.marecallatk(rqueries_ens, 50))

In [None]:
# results, hit ratios @ K
ks = np.linspace(1, 50, 50)
hits_lan = [phlu.hitratio(rqueries_lan, int(k)) for k in ks]
hits_hdc = [phlu.hitratio(rqueries_hdc, int(k)) for k in ks]
hits_ens = [phlu.hitratio(rqueries_ens, int(k)) for k in ks]
fig, ax = plt.subplots(figsize=(10,7))
ax.plot(ks, hits_lan, c='#E15554', linewidth=2.5, ls=':', label='ESM-2 + XGBoost')
ax.plot(ks, hits_hdc, c='#E15554', linewidth=2.5, ls='-.', label='HDC + Random Forest')
ax.plot(ks, hits_ens, c='#124559', linewidth=2.5, label='Combined model')
ax.set_xlabel('$\it{K}$', size=12)
ax.set_ylabel('Hit ratio @ $\it{K}$', size=12)
ax.set_ylim(0.1, 1)
ax.legend(loc=4, prop={'size': 12})
ax.grid(True)
fig.savefig(results_path+'/v3.4/logocv_hitratio_v34celia.png', dpi=400)
fig.savefig(results_path+'/v3.4/logocv_hitratio_svg_v34celia.svg', format='svg', dpi=400)

In [None]:
# results, recalls @ K
ks = np.linspace(1, 50, 50)
recalls_lan = [phlu.recallatk(rqueries_lan, int(k)) for k in ks]
recalls_hdc = [phlu.recallatk(rqueries_hdc, int(k)) for k in ks]
recalls_ens = [phlu.recallatk(rqueries_ens, int(k)) for k in ks]
fig, ax = plt.subplots(figsize=(10,7))
ax.plot(ks, recalls_lan, c='#E15554', linewidth=2.5, ls=':', label='ESM-2 + XGBoost')
ax.plot(ks, recalls_hdc, c='#E15554', linewidth=2.5, ls='-.', label='HDC + Random Forest')
ax.plot(ks, recalls_ens, c='#124559', linewidth=2.5, label='Combined model')
ax.set_xlabel('$\it{K}$', size=12)
ax.set_ylabel('Recall @ $\it{K}$', size=12)
ax.set_ylim(0.1, 1)
ax.legend(loc=4, prop={'size': 12})
ax.grid(True)
fig.savefig(results_path+'/v3.4/logocv_recall_v34celia.png', dpi=400)
fig.savefig(results_path+'/v3.4/logocv_recall_svg_v34celia.svg', format='svg', dpi=400)