# PhageHostLearn.*klebsiella* - inference

This notebook offers complete functionality to make predictions for new bacteria, phages or both, using a trained PhageHostLearn prediction model for Klebsiella phage-host interactions.

**Overview of this notebook**
1. Initial set-up
2. Processing phage genomes and bacterial genomes into RBPs and K-locus proteins, respectively
3. Computing feature representations based on ESM-2.
4. Predicting new interactions and ranking

**Architecture of the PhageHostLearn framework**: 
- Multi-RBP setting: phages consisting of one or more RBPs (multi-instance)
- K-loci proteins (multi-instance) 
- Embeddings for both based on the ESM-2 language model.
- An XGBoost model on top of language embeddings to make predictions

## 1. Initial set-up

In [1]:
# data paths
path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/data'
phages_path = path + '/phages_genomes'
bacteria_path = path + '/bacteria_genomes'
pfam_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/code/RBPdetect_phageRBPs.hmm'
xgb_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/code/RBPdetect_xgb_hmm.json'
kaptive_db_path = path + '/Klebsiella_k_locus_primary_reference.gbk'
suffix = 'inference'

hmmer_path = path + '/hmmer-3.4'


## 2. Data processing

The data processing of PhageHostLearn consists of four consecutive steps: (1) phage gene calling with PHANOTATE, (2) phage protein embedding with bio_embeddings, (3) phage RBP detection and (4) bacterial genome processing with Kaptive.

Expected outputs: (1) an RBPbase.csv file with detected RBPs, (2) a Locibase.json file with detected K-loci proteins.

In [6]:
import phagehostlearn_processing as phlp
import time
import os

In [3]:
# run Phanotate
phanotate_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/.venv/bin/phanotate.py'
phlp.phanotate_processing(path, phages_path, phanotate_path, data_suffix=suffix, num_phages=2)

Number of phage files: 105
Processing only the first  2  phages


Processing phage genomes: 100%|██████████| 2/2 [00:16<00:00,  8.19s/it]

Completed PHANOTATE
Number of phage genes: 145





In [4]:
print('All phage genomes processed')
time.sleep(1)

All phage genomes processed


In [7]:
# Check if RBPbase already exists (if so, we can skip steps 2.2 and 2.3)
rbpbase_path = os.path.join(path, f'RBPbase{suffix}.csv')
rbpbase_path_fallback = os.path.join(path, 'RBPbase.csv')

if os.path.exists(rbpbase_path) or os.path.exists(rbpbase_path_fallback):
    print('RBPbase already exists - skipping protein embeddings computation.')
else:
    # Check for existing embeddings first
    embeddings_path = os.path.join(path, f'phage_protein_embeddings{suffix}.csv')
    embeddings_path_fallback = os.path.join(path, 'phage_protein_embeddings.csv')
    
    if os.path.exists(embeddings_path) or os.path.exists(embeddings_path_fallback):
        print('Embedding file already exists. Skipping computation.')
    else:
        # run PTB embeddings (can be done faster in the cloud, see PTB_embeddings.ipynb)
        phlp.compute_protein_embeddings(path, data_suffix=suffix)

RBPbase already exists - skipping protein embeddings computation.


In [8]:
# Check if RBPbase already exists
if os.path.exists(rbpbase_path) or os.path.exists(rbpbase_path_fallback):
    print('RBPbase already exists - skipping PhageRBPdetect.')
else:
    # run PhageRBPdetect
    gene_embeddings_file = os.path.join(path, f'phage_protein_embeddings{suffix}.csv')
    gene_embeddings_file_fallback = os.path.join(path, 'phage_protein_embeddings.csv')
    if not os.path.exists(gene_embeddings_file) and os.path.exists(gene_embeddings_file_fallback):
        gene_embeddings_file = gene_embeddings_file_fallback
    phlp.phageRBPdetect(path, pfam_path, hmmer_path, xgb_path, gene_embeddings_file, data_suffix=suffix)

RBPbase already exists - skipping PhageRBPdetect.


In [9]:
# Check if Locibase already exists
locibase_path = os.path.join(path, f'Locibase{suffix}.json')
locibase_path_fallback = os.path.join(path, 'Locibase.json')

if os.path.exists(locibase_path) or os.path.exists(locibase_path_fallback):
    print('Locibase already exists - skipping Kaptive processing.')
else:
    # run Kaptive
    phlp.process_bacterial_genomes(path, bacteria_path, kaptive_db_path, data_suffix=suffix)

Locibase already exists - skipping Kaptive processing.


## 3. Feature construction

Starts from the RBPbase.csv and the Locibase.json in the path. If the ESM-2 embeddings take too long, you might opt to do this step in the cloud or on a high-performance computer. Expected outputs: (1) a .csv file with RBP embeddings, (2) a .csv file with loci embeddings. The last function outputs the following Python objects: ESM-2 feature matrix and groups_bact. If the ESM-2 embeddings take too long, you might opt to do this step in the cloud or on a high-performance computer.

In [10]:
import phagehostlearn_features as phlf

In [11]:
# ESM-2 features for RBPs
rbp_embeddings_path = os.path.join(path, f'esm2_embeddings_rbp{suffix}.csv')
rbp_embeddings_path_fallback = os.path.join(path, 'esm2_embeddings_rbp.csv')

if os.path.exists(rbp_embeddings_path):
    print(f'RBP embeddings file already exists at: {rbp_embeddings_path}')
    print('Skipping ESM-2 embeddings computation for RBPs.')
elif os.path.exists(rbp_embeddings_path_fallback):
    print(f'Using existing RBP embeddings file (without suffix): {rbp_embeddings_path_fallback}')
    print('Skipping ESM-2 embeddings computation for RBPs.')
else:
    phlf.compute_esm2_embeddings_rbp(path, data_suffix=suffix)

Using existing RBP embeddings file (without suffix): /Users/eliottvalette/Documents/Clones/PhageHostLearn/data/esm2_embeddings_rbp.csv
Skipping ESM-2 embeddings computation for RBPs.


In [12]:
# ESM-2 features for loci
loci_embeddings_path_check = os.path.join(path, f'esm2_embeddings_loci{suffix}.csv')
loci_embeddings_path_fallback = os.path.join(path, 'esm2_embeddings_loci.csv')

if os.path.exists(loci_embeddings_path_check):
    print(f'Loci embeddings file already exists at: {loci_embeddings_path_check}')
    print('Skipping ESM-2 embeddings computation for loci.')
elif os.path.exists(loci_embeddings_path_fallback):
    print(f'Using existing loci embeddings file (without suffix): {loci_embeddings_path_fallback}')
    print('Skipping ESM-2 embeddings computation for loci.')
else:
    phlf.compute_esm2_embeddings_loci(path, data_suffix=suffix)

Using existing loci embeddings file (without suffix): /Users/eliottvalette/Documents/Clones/PhageHostLearn/data/esm2_embeddings_loci.csv
Skipping ESM-2 embeddings computation for loci.


In [13]:
# Construct feature matrices
# Use the paths already defined above, with fallback if suffixed files don't exist
if not os.path.exists(rbp_embeddings_path):
    rbp_embeddings_path = rbp_embeddings_path_fallback
if not os.path.exists(loci_embeddings_path_check):
    loci_embeddings_path = loci_embeddings_path_fallback
else:
    loci_embeddings_path = loci_embeddings_path_check

features_esm2, groups_bact = phlf.construct_feature_matrices(path, suffix, loci_embeddings_path, rbp_embeddings_path, mode='test')

Dimensions match? True


## 4. Predict and rank new interactions

What we want is to make predictions per bacterium for all of the phages, and then use the prediction scores to rank the potential phages per bacterium.

In [14]:
# load the needed libraries
import pickle
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
%matplotlib inline

In [15]:
# Load the XGBoost model and make predictions
xgb = XGBClassifier()
xgb.load_model('phagehostlearn_esm2_xgb.json')
scores_xgb = xgb.predict_proba(features_esm2)[:,1]



In [16]:
# save prediction scores in an interaction matrix
groups_bact = np.asarray(groups_bact)
loci_embeddings = pd.read_csv(loci_embeddings_path)
rbp_embeddings = pd.read_csv(rbp_embeddings_path)
bacteria = list(loci_embeddings['accession'])
phages = list(set(rbp_embeddings['phage_ID']))

score_matrix = np.zeros((len(bacteria), len(phages)))
for i, group in enumerate(list(set(groups_bact))):
    #scores_this_group = scores[groups_bact == group]
    scores_this_group = scores_xgb[groups_bact == group]
    score_matrix[i, :] = scores_this_group
results = pd.DataFrame(score_matrix, index=bacteria, columns=phages)
results.to_csv(path+'/prediction_results'+suffix+'.csv', index=False)

In [17]:
# rank the phages per bacterium
ranked = {}
for group in list(set(groups_bact)):
    scores_this_group = scores_xgb[groups_bact == group]
    ranked_phages = [(x, y) for y, x in sorted(zip(scores_this_group, phages), reverse=True)]
    ranked[bacteria[group]] = ranked_phages

# save results
with open(path+'/ranked_results'+suffix+'.pickle', 'wb') as f:
    pickle.dump(ranked, f)

## 5. Read & interpret results

In [18]:
# read results
with open(path+'/ranked_results'+suffix+'.pickle', 'rb') as f:
    ranked_results = pickle.load(f)

In [19]:
# print top phages per bacterium
top =  5
scores = np.zeros((len(ranked_results.keys()), top))
for i, acc in enumerate(ranked_results.keys()):
    topscores = [round(y, 3) for (x,y) in ranked_results[acc]][:top]
    scores[i,:] = topscores
pd.DataFrame(scores, index=list(ranked_results.keys()))

Unnamed: 0,0,1,2,3,4
KP_HGUA02_071,1.000,1.000,0.999,0.998,0.975
205KP-HG,0.065,0.065,0.021,0.004,0.003
52KP-HG,0.001,0.000,0.000,0.000,0.000
Kpcas042,1.000,1.000,1.000,0.985,0.961
HGUA4_08,0.999,0.095,0.024,0.007,0.003
...,...,...,...,...,...
HGV2C_28,0.999,0.234,0.088,0.075,0.022
HGV2C_36_contigs,1.000,0.990,0.809,0.015,0.007
NTUH,1.000,0.999,0.998,0.947,0.087
CU451,1.000,1.000,1.000,0.501,0.313
