# PhageHostLearn.*klebsiella* - inference

This notebook offers complete functionality to make predictions for new bacteria, phages or both, using a trained PhageHostLearn prediction model for Klebsiella phage-host interactions.

**Overview of this notebook**
1. Initial set-up
2. Processing phage genomes and bacterial genomes into RBPs and K-locus proteins, respectively
3. Computing feature representations based on ESM-2.
4. Predicting new interactions and ranking

**Architecture of the PhageHostLearn framework**: 
- Multi-RBP setting: phages consisting of one or more RBPs (multi-instance)
- K-loci proteins (multi-instance) 
- Embeddings for both based on the ESM-2 language model.
- An XGBoost model on top of language embeddings to make predictions

## 1. Initial set-up

In [1]:
# data paths
path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/data'
phages_path = path + '/phages_genomes'
bacteria_path = path + '/bacteria_genomes'
pfam_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/code/RBPdetect_phageRBPs.hmm'
xgb_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/code/RBPdetect_xgb_hmm.json'
kaptive_db_path = path + '/Klebsiella_k_locus_primary_reference.gbk'
suffix = 'inference'

hmmer_path = path + '/hmmer-3.4'


## 2. Data processing

The data processing of PhageHostLearn consists of four consecutive steps: (1) phage gene calling with PHANOTATE, (2) phage protein embedding with bio_embeddings, (3) phage RBP detection and (4) bacterial genome processing with Kaptive.

Expected outputs: (1) an RBPbase.csv file with detected RBPs, (2) a Locibase.json file with detected K-loci proteins.

In [2]:
import phagehostlearn_processing as phlp
import time

In [3]:
# run Phanotate
phanotate_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/.venv/bin/phanotate.py'
phlp.phanotate_processing(path, phages_path, phanotate_path, data_suffix=suffix, num_phages=2)

Number of phage files: 105
Processing only the first  2  phages


Processing phage genomes: 100%|██████████| 2/2 [00:14<00:00,  7.20s/it]

Completed PHANOTATE
Number of phage genes: 145





In [4]:
print('All phage genomes processed')
time.sleep(1)

All phage genomes processed


KeyboardInterrupt: 

In [None]:
# run PTB embeddings (can be done faster in the cloud, see PTB_embeddings.ipynb)
phlp.compute_protein_embeddings(path, data_suffix=suffix)

Starting compute_protein_embeddings...
Python version: 3.9.18 (main, Nov 12 2025, 09:03:55) 
[Clang 16.0.0 (clang-1600.0.26.6)]
Available memory: 1.18 GB
Total memory: 8.00 GB
The ProtTransBertBFD model requires at least 3-4 GB of free RAM.
Recommendations:
1. Restart the kernel to free memory from previous cells
2. Close other applications using memory
3. Use num_genes parameter to process fewer genes at a time
4. Consider using the cloud notebook (compute_embeddings_cloud.ipynb)
5. Consider setting num_genes=10 or num_genes=20 to test with fewer genes
Loading genebase from: /Users/eliottvalette/Documents/Clones/PhageHostLearn/data/phage_genesinference.csv
Number of phage genes: 145
Importing ProtTransBertBFDEmbedder...
If the kernel crashes here, you may need to:
1. Free up memory by restarting the kernel and closing other applications
2. Use the cloud notebook (compute_embeddings_cloud.ipynb) instead
3. Process fewer genes at a time
Memory before import: 1.18 GB available
Starting i

In [None]:
# run PhageRBPdetect
gene_embeddings_file = path+'/phage_protein_embeddings'+suffix+'.csv'
phlp.phageRBPdetect(path, pfam_path, hmmer_path, xgb_path, gene_embeddings_file, data_suffix=suffix)

In [None]:
# run Kaptive
phlp.process_bacterial_genomes(path, bacteria_path, kaptive_db_path, data_suffix=suffix)

## 3. Feature construction

Starts from the RBPbase.csv and the Locibase.json in the path. If the ESM-2 embeddings take too long, you might opt to do this step in the cloud or on a high-performance computer. Expected outputs: (1) a .csv file with RBP embeddings, (2) a .csv file with loci embeddings. The last function outputs the following Python objects: ESM-2 feature matrix and groups_bact. If the ESM-2 embeddings take too long, you might opt to do this step in the cloud or on a high-performance computer.

In [None]:
import phagehostlearn_features as phlf

In [None]:
# ESM-2 features for RBPs
phlf.compute_esm2_embeddings_rbp(path, data_suffix=suffix)

In [None]:
# ESM-2 features for loci
phlf.compute_esm2_embeddings_loci(path, data_suffix=suffix)

In [None]:
# Construct feature matrices
rbp_embeddings_path = path+'/esm2_embeddings_rbp'+suffix+'.csv'
loci_embeddings_path = path+'/esm2_embeddings_loci'+suffix+'.csv'
features_esm2, groups_bact = phlf.construct_feature_matrices(path, suffix, loci_embeddings_path, rbp_embeddings_path, mode='test')

## 4. Predict and rank new interactions

What we want is to make predictions per bacterium for all of the phages, and then use the prediction scores to rank the potential phages per bacterium.

In [None]:
# load the needed libraries
import pickle
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
%matplotlib inline

In [None]:
# Load the XGBoost model and make predictions
xgb = XGBClassifier()
xgb.load_model('phagehostlearn_esm2_xgb.json')
scores_xgb = xgb.predict_proba(features_esm2)[:,1]

In [None]:
# save prediction scores in an interaction matrix
groups_bact = np.asarray(groups_bact)
loci_embeddings = pd.read_csv(loci_embeddings_path)
rbp_embeddings = pd.read_csv(rbp_embeddings_path)
bacteria = list(loci_embeddings['accession'])
phages = list(set(rbp_embeddings['phage_ID']))

score_matrix = np.zeros((len(bacteria), len(phages)))
for i, group in enumerate(list(set(groups_bact))):
    #scores_this_group = scores[groups_bact == group]
    scores_this_group = scores_xgb[groups_bact == group]
    score_matrix[i, :] = scores_this_group
results = pd.DataFrame(score_matrix, index=bacteria, columns=phages)
results.to_csv(path+'/prediction_results'+suffix+'.csv', index=False)

In [None]:
# rank the phages per bacterium
ranked = {}
for group in list(set(groups_bact)):
    scores_this_group = scores_xgb[groups_bact == group]
    ranked_phages = [(x, y) for y, x in sorted(zip(scores_this_group, phages), reverse=True)]
    ranked[bacteria[group]] = ranked_phages

# save results
with open(path+'/ranked_results'+suffix+'.pickle', 'wb') as f:
    pickle.dump(ranked, f)

## 5. Read & interpret results

In [None]:
# read results
with open(path+'/ranked_results'+suffix+'.pickle', 'rb') as f:
    ranked_results = pickle.load(f)

In [None]:
# print top phages per bacterium
top =  5
scores = np.zeros((len(ranked_results.keys()), top))
for i, acc in enumerate(ranked_results.keys()):
    topscores = [round(y, 3) for (x,y) in ranked_results[acc]][:top]
    scores[i,:] = topscores
pd.DataFrame(scores, index=list(ranked_results.keys()))