## 1. Initial set-up

In [None]:
import os
project_root = os.path.dirname(os.getcwd())
general_path = os.path.join(project_root, 'data')
results_path = os.path.join(project_root, 'results')
data_suffix = ''

In [None]:
# Utilities originally defined in phagehostlearn_processing.py
import json
import os
import subprocess
import time
from os import listdir

import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SearchIO import HmmerIO
from tqdm.notebook import tqdm
from xgboost import XGBClassifier
from bio_embeddings.embed import ProtTransBertBFDEmbedder

## 2. Data processing

The data processing of PhageHostLearn consists of five consecutive steps: (1) phage gene calling with PHANOTATE, (2) phage protein embedding with bio_embeddings, (3) phage RBP detection, (4) bacterial genome processing with Kaptive and (5) processing the interaction matrix.

Expected outputs: (1) an RBPbase.csv file with detected RBPs, (2) a Locibase.json file with detected K-loci proteins, (3) a .csv file of processed interaction data.

In [None]:
def hmmpress_python(hmm_path, pfam_file):
    """Press a profiles database, necessary to do scanning."""
    cd_str = 'cd ' + hmm_path
    press_str = 'hmmpress ' + pfam_file
    command = cd_str + '; ' + press_str
    press_process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    press_out, press_err = press_process.communicate()
    return press_out, press_err

def single_hmmscan_python(hmm_path, pfam_file, fasta_file):
    """Run hmmscan for a given FASTA file of one (or multiple) sequences."""
    cd_str = 'cd ' + hmm_path
    cd_process = subprocess.Popen(cd_str, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    cd_process.communicate()

    scan_str = 'hmmscan ' + pfam_file + ' ' + fasta_file + ' > hmmscan_out.txt'
    scan_process = subprocess.Popen(scan_str, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    scan_process.communicate()

    with open('hmmscan_out.txt') as results_handle:
        scan_res = HmmerIO.Hmmer3TextParser(results_handle)
    os.remove('hmmscan_out.txt')
    return scan_res


def hmmscan_python(hmm_path, pfam_file, sequences_file, threshold=18):
    """Scan sequences for domains using hmmscan."""
    domains = []
    scores = []
    biases = []
    ranges = []
    for sequence in SeqIO.parse(sequences_file, 'fasta'):
        with open('single_sequence.fasta', 'w') as temp_fasta:
            temp_fasta.write('>' + sequence.id + '\n' + str(sequence.seq) + '\n')

        scan_res = single_hmmscan_python(hmm_path, pfam_file, 'single_sequence.fasta')
        for line in scan_res:
            try:
                for hit in line.hits:
                    hsp = hit._items[0]
                    aln_start = hsp.query_range[0]
                    aln_stop = hsp.query_range[1]
                    if (hit.bitscore >= threshold) and (hit.id not in domains):
                        domains.append(hit.id)
                        scores.append(hit.bitscore)
                        biases.append(hit.bias)
                        ranges.append((aln_start, aln_stop))
            except IndexError:
                pass
    os.remove('single_sequence.fasta')
    return domains, scores, biases, ranges


def gene_domain_scan(hmmpath, pfam_file, gene_hits, threshold=18):
    """Run hmmscan on translated gene hits."""
    with open('protein_hits.fasta', 'w') as hits_fasta:
        for i, gene_hit in enumerate(gene_hits):
            protein_sequence = str(Seq(gene_hit).translate())[:-1]
            hits_fasta.write('>' + str(i) + '_proteindomain_hit\n' + protein_sequence + '\n')
    domains, scores, biases, ranges = hmmscan_python(hmmpath, pfam_file, 'protein_hits.fasta', threshold)
    os.remove('protein_hits.fasta')
    return domains, scores, biases, ranges


def kaptive_python(database_path, file_path, output_path):
    """Wrapper for the Kaptive command-line call."""
    command = 'python kaptive.py -a ' + file_path + ' -k ' + database_path + ' -o ' + output_path + '/ --no_table'
    ssprocess = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    ssprocess.communicate()
    return


def xlsx_database_to_csv(xlsx_file, save_path, index_col=0, header=0, export=True):
    """Convert an XLSX interaction matrix to CSV."""
    interactions_matrix = pd.read_excel(xlsx_file, index_col=index_col, header=header)
    if export:
        interactions_matrix.to_csv(save_path + '.csv')
        return
    return interactions_matrix


#### 2.1 PHANOTATE

In [None]:
def phanotate_processing(general_path, phage_genomes_path, phanotate_path, data_suffix='', add=False, test=False, num_phages=None):
    """Run PHANOTATE on each phage genome and build the gene database."""
    phage_files = listdir(phage_genomes_path)
    print('Number of phage files:', len(phage_files))
    if '.DS_Store' in phage_files:
        phage_files.remove('.DS_Store')
    if add:
        rbp_base = pd.read_csv(general_path + '/RBPbase' + data_suffix + '.csv')
        phage_ids = list(set(rbp_base['phage_ID']))
        phage_files = [x for x in phage_files if x.split('.fasta')[0] not in phage_ids]
        print('Processing ', len(phage_files), ' more phages (add=True)')
    if num_phages is not None:
        print('Processing only the first ', num_phages, ' phages')
        phage_files = phage_files[:num_phages]
    bar = tqdm(total=len(phage_files), position=0, leave=True, desc='Processing phage genomes')
    name_list = []
    gene_list = []
    gene_ids = []

    for file in phage_files:
        count = 1
        file_dir = phage_genomes_path + '/' + file
        raw_str = phanotate_path + ' ' + file_dir
        process = subprocess.Popen(raw_str, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        stdout, _ = process.communicate()
        stdout_text = stdout.decode('utf-8', errors='ignore')
        if process.returncode != 0:
            raise RuntimeError(
                f"PHANOTATE command failed for {file_dir} with exit code {process.returncode}. "
                f"Command output:\n{stdout_text}"
            )
        std_splits = stdout.split(sep=b'\n')[2:]
        if not any(split.strip() for split in std_splits):
            raise ValueError(
                f"PHANOTATE did not return ORF predictions for {file_dir}. "
                f"Command output:\n{stdout_text}"
            )

        temp_tab_path = os.path.join(general_path, 'phage_results.tsv')
        with open(temp_tab_path, 'wb') as temp_tab:
            for split in std_splits:
                split = split.replace(b',', b'')
                temp_tab.write(split + b'\n')
        try:
            results_orfs = pd.read_csv(temp_tab_path, sep='\t', lineterminator='\n', index_col=False)
        except pd.errors.EmptyDataError as exc:
            with open(temp_tab_path, 'r', encoding='utf-8', errors='ignore') as temp_in:
                temp_preview = temp_in.read()
            raise ValueError(
                f"PHANOTATE output for {file_dir} produced an empty or invalid TSV file.\n"
                f"Command output:\n{stdout_text}\n"
                f"Temporary TSV content:\n{temp_preview}"
            ) from exc

        name = file.split('.fasta')[0]
        sequence = str(SeqIO.read(file_dir, 'fasta').seq)
        for j, strand in enumerate(results_orfs['FRAME']):
            start = results_orfs['#START'][j]
            stop = results_orfs['STOP'][j]
            if strand == '+':
                gene = sequence[start - 1:stop]
            else:
                sequence_part = sequence[stop - 1:start]
                gene = str(Seq(sequence_part).reverse_complement())
            name_list.append(name)
            gene_list.append(gene)
            gene_ids.append(name + '_gp' + str(count))
            count += 1
        bar.update(1)
    bar.close()

    if not test and os.path.exists(os.path.join(general_path, 'phage_results.tsv')):
        os.remove(os.path.join(general_path, 'phage_results.tsv'))

    genebase = pd.DataFrame(list(zip(name_list, gene_ids, gene_list)), columns=['phage_ID', 'gene_ID', 'gene_sequence'])
    if add:
        old_genebase = pd.read_csv(general_path + '/phage_genes' + data_suffix + '.csv')
        genebase = pd.concat([old_genebase, genebase], axis=0)
    genebase.to_csv(general_path + '/phage_genes' + data_suffix + '.csv', index=False)
    return

In [None]:
phage_genomes_path = general_path+'/phages_genomes'
phanotate_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/.venv/bin/phanotate.py'
phanotate_processing(general_path, phage_genomes_path, phanotate_path, data_suffix=data_suffix, num_phages=5)

#### 2.2 Protein embeddings

In [None]:
def compute_protein_embeddings(general_path, data_suffix='', add=False, num_genes=None):
    """Compute protein embeddings using ProtTransBertBFD."""
    genebase = pd.read_csv(general_path + '/phage_genes' + data_suffix + '.csv')
    if num_genes is not None:
        print('Processing only the first ', num_genes, ' phage genes')
        genebase = genebase.head(num_genes)
    print('Number of phage genes:', len(genebase))
    time_start = time.time()
    embedder = ProtTransBertBFDEmbedder()
    time_end = time.time()
    print('Time taken to initialize embedder:', time_end - time_start)
    print('Embedder initialized')
    if add:
        print('Adding new protein embeddings')
        old_embeddings_df = pd.read_csv(general_path + '/phage_protein_embeddings' + data_suffix + '.csv')
        protein_ids = list(old_embeddings_df['ID'])
        sequences = []
        names = []
        for i, sequence in enumerate(genebase['gene_sequence']):
            if genebase['gene_ID'][i] not in protein_ids:
                sequences.append(str(Seq(sequence).translate())[:-1])
                names.append(genebase['gene_ID'][i])
    else:
        print('Computing protein embeddings for all phage genes')
        names = list(genebase['gene_ID'])
        print('Number of protein sequences to embed:', len(names))
        sequences = [str(Seq(sequence).translate())[:-1] for sequence in genebase['gene_sequence']]

    print('Number of protein sequences to embed:', len(sequences))
    protein_embeddings = []
    progress_bar = tqdm(sequences, desc='Computing protein embeddings', unit='protein')
    for protein_sequence in progress_bar:
        reduced_embedding = embedder.reduce_per_protein(embedder.embed(protein_sequence))
        protein_embeddings.append(reduced_embedding)
    embeddings_df = pd.concat([pd.DataFrame({'ID': names}), pd.DataFrame(protein_embeddings)], axis=1)
    if add:
        embeddings_df = pd.DataFrame(np.vstack([old_embeddings_df, embeddings_df]), columns=old_embeddings_df.columns)
    embeddings_df.to_csv(general_path + '/phage_protein_embeddings' + data_suffix + '.csv', index=False)
    return

In [None]:
compute_protein_embeddings(general_path, data_suffix=data_suffix, num_genes=5)

#### 2.3 PhageRBPdetect

In [None]:

def phageRBPdetect(general_path, pfam_path, hmmer_path, xgb_path, gene_embeddings_path, data_suffix=''):
    """Detect receptor-binding proteins using PhageRBPdetect."""
    genebase = pd.read_csv(general_path + '/phage_genes' + data_suffix + '.csv')
    new_blocks = ['Phage_T7_tail', 'Tail_spike_N', 'Prophage_tail', 'BppU_N', 'Mtd_N', 'Head_binding', 'DUF3751',
                  'End_N_terminal', 'phage_tail_N', 'Prophage_tailD1', 'DUF2163', 'Phage_fiber_2', 'unknown_N0',
                  'unknown_N1', 'unknown_N2', 'unknown_N3', 'unknown_N4', 'unknown_N6', 'unknown_N10', 'unknown_N11',
                  'unknown_N12', 'unknown_N13', 'unknown_N17', 'unknown_N19', 'unknown_N23', 'unknown_N24',
                  'unknown_N26', 'unknown_N29', 'unknown_N36', 'unknown_N45', 'unknown_N48', 'unknown_N49',
                  'unknown_N53', 'unknown_N57', 'unknown_N60', 'unknown_N61', 'unknown_N65', 'unknown_N73',
                  'unknown_N82', 'unknown_N83', 'unknown_N101', 'unknown_N114', 'unknown_N119', 'unknown_N122',
                  'unknown_N163', 'unknown_N174', 'unknown_N192', 'unknown_N200', 'unknown_N206', 'unknown_N208',
                  'Lipase_GDSL_2', 'Pectate_lyase_3', 'gp37_C', 'Beta_helix', 'Gp58', 'End_beta_propel',
                  'End_tail_spike', 'End_beta_barrel', 'PhageP22-tail', 'Phage_spike_2', 'gp12-short_mid', 'Collar',
                  'unknown_C2', 'unknown_C3', 'unknown_C8', 'unknown_C15', 'unknown_C35', 'unknown_C54', 'unknown_C76',
                  'unknown_C100', 'unknown_C105', 'unknown_C112', 'unknown_C123', 'unknown_C179', 'unknown_C201',
                  'unknown_C203', 'unknown_C228', 'unknown_C234', 'unknown_C242', 'unknown_C258', 'unknown_C262',
                  'unknown_C267', 'unknown_C268', 'unknown_C274', 'unknown_C286', 'unknown_C292', 'unknown_C294',
                  'Peptidase_S74', 'Phage_fiber_C', 'S_tail_recep_bd', 'CBM_4_9', 'DUF1983', 'DUF3672']

    output, err = hmmpress_python(hmmer_path, pfam_path)
    print(output)

    phage_genes = genebase['gene_sequence']
    hmm_scores = {item: [0] * len(phage_genes) for item in new_blocks}
    bar = tqdm(total=len(phage_genes), position=0, leave=True, desc='Scanning phage genes')
    for i, sequence in enumerate(phage_genes):
        hits, scores, biases, ranges = gene_domain_scan(hmmer_path, pfam_path, [sequence])
        for j, dom in enumerate(hits):
            hmm_scores[dom][i] = scores[j]
        bar.update(1)
    bar.close()
    hmm_scores_array = np.asarray(pd.DataFrame(hmm_scores))

    embeddings_df = pd.read_csv(gene_embeddings_path)
    embeddings = np.asarray(embeddings_df.iloc[:, 1:])
    features = np.concatenate((embeddings, hmm_scores_array), axis=1)

    xgb_saved = XGBClassifier()
    xgb_saved.load_model(xgb_path)

    score_xgb = xgb_saved.predict_proba(features)[:, 1]
    preds_xgb = (score_xgb > 0.5) * 1

    rbp_base = {'phage_ID': [], 'protein_ID': [], 'protein_sequence': [], 'dna_sequence': [], 'xgb_score': []}
    for i, dna_sequence in enumerate(genebase['gene_sequence']):
        if preds_xgb[i] == 1:
            rbp_base['phage_ID'].append(genebase['phage_ID'][i])
            rbp_base['protein_ID'].append(genebase['gene_ID'][i])
            rbp_base['protein_sequence'].append(str(Seq(dna_sequence).translate())[:-1])
            rbp_base['dna_sequence'].append(dna_sequence)
            rbp_base['xgb_score'].append(score_xgb[i])
    rbp_base_df = pd.DataFrame(rbp_base)
    to_delete = [i for i, protein_seq in enumerate(rbp_base_df['protein_sequence']) if (len(protein_seq) < 200 or len(protein_seq) > 1500)]
    rbp_base_df = rbp_base_df.drop(to_delete).reset_index(drop=True)
    rbp_base_df.to_csv(general_path + '/RBPbase' + data_suffix + '.csv', index=False)
    return


In [None]:
pfam_path = general_path+'/RBPdetect_phageRBPs.hmm'
hmmer_path = '/Users/Dimi/hmmer-3.3.1'
xgb_path = general_path+'/RBPdetect_xgb_hmm.json'
gene_embeddings_path = general_path+'/phage_protein_embeddings'+data_suffix+'.csv'
phageRBPdetect(general_path, pfam_path, hmmer_path, xgb_path, gene_embeddings_path, data_suffix=data_suffix)

#### 2.4 Kaptive

In [None]:
def process_bacterial_genomes(general_path, bact_genomes_path, database_path, data_suffix='', add=False):
    """Process bacterial genomes with Kaptive to extract K-locus proteins."""
    fastas = listdir(bact_genomes_path)
    try:
        fastas.remove('.DS_Store')
    except ValueError:
        pass
    if add:
        with open(general_path + '/Locibase' + data_suffix + '.json') as dict_file:
            old_locibase = json.load(dict_file)
        loci_accessions = list(old_locibase.keys())
        fastas = [x for x in fastas if x.split('.fasta')[0] not in loci_accessions]
        print('Processing ', len(fastas), ' more bacteria (add=True)')
    accessions = [file.split('.fasta')[0] for file in fastas]
    serotypes = []
    loci_results = {}
    pbar = tqdm(total=len(fastas), desc='Processing bacterial genomes')
    with open(general_path + '/kaptive_results_all_loci.fasta', 'w') as big_fasta:
        for i, file in enumerate(fastas):
            file_path = bact_genomes_path + '/' + file
            kaptive_python(database_path, file_path, general_path)

            results = json.load(open(general_path + '/kaptive_results.json'))
            serotypes.append(results[0]['Best match']['Type'])
            for gene in results[0]['Locus genes']:
                try:
                    protein = gene['tblastn result']['Protein sequence']
                    protein = protein.replace('-', '').replace('*', '')
                except KeyError:
                    protein = gene['Reference']['Protein sequence']
                loci_results.setdefault(accessions[i], []).append(protein[:-1])

            loci_sequence = ''
            for record in SeqIO.parse(general_path + '/kaptive_results_' + file, 'fasta'):
                loci_sequence += str(record.seq)
            big_fasta.write('>' + accessions[i] + '\n' + loci_sequence + '\n')

            for extension in ['.ndb', '.not', '.ntf', '.nto']:
                os.remove(file_path + extension)
            os.remove(general_path + '/kaptive_results.json')
            os.remove(general_path + '/kaptive_results_' + file)
            pbar.update(1)
    pbar.close()

    sero_df = pd.DataFrame(serotypes, columns=['sero'])
    if add:
        loci_results = {**old_locibase, **loci_results}
        old_seros = pd.read_csv(general_path + '/serotypes' + data_suffix + '.csv')
        sero_df = pd.concat([old_seros, sero_df], axis=0)
    sero_df.to_csv(general_path + '/serotypes' + data_suffix + '.csv', index=False)
    with open(general_path + '/Locibase' + data_suffix + '.json', 'w') as dict_file:
        json.dump(loci_results, dict_file)
    return

In [None]:
bact_genomes_path = general_path+'/klebsiella_genomes/fasta_files'
kaptive_database_path = general_path+'/Klebsiella_k_locus_primary_reference.gbk'
process_bacterial_genomes(general_path, bact_genomes_path, kaptive_database_path, data_suffix=data_suffix)

#### 2.5 Process the interaction matrix

In [None]:
def process_interactions(general_path, interactions_xlsx_path, data_suffix=''):
    """Process the interaction matrix and export it to CSV."""
    output = general_path + '/phage_host_interactions' + data_suffix
    xlsx_database_to_csv(interactions_xlsx_path, output)
    return

In [None]:
interactions_xlsx_path = general_path+'/klebsiella_phage_host_interactions.xlsx'
process_interactions(general_path, interactions_xlsx_path, data_suffix=data_suffix)

If you want to combine separate data sources of interactions, you can use the code block below.

In [None]:
phage_genomes_path = general_path+'/phages_genomes'
phanotate_path = '/Users/eliottvalette/Documents/Clones/PhageHostLearn/.venv/bin/phanotate.py'
phanotate_processing(general_path, phage_genomes_path, phanotate_path, data_suffix=data_suffix, num_phages=5)

In [None]:
output = general_path+'/phage_host_interactions'+data_suffix
new_file = general_path+'/klebsiella_interactions_part2.xlsx' # part 2
add_to_database(output+'.csv', new_file, output)

## 3. Feature construction

Starts from the RBPbase.csv and the Locibase.json files that should be stored in the general_path. If you wish to reproduce our analyses, you can download these files from our [Zenodo repository](https://doi.org/10.5281/zenodo.8095914).

Expected outputs: (1) a .csv file with RBP embeddings, (2) a .csv file with loci embeddings. The last function outputs the following Python objects: ESM-2 feature matrix, labels, groups_loci and groups_phage (for evaluation). If the ESM-2 embeddings take too long, you might opt to do this step in the cloud or on a high-performance computer.

If you're retraining a model with the same data but new validated interactions, you can simply run the `construct_feature_matrices` function to construct updated feature matrices and labels and train models anew.

In [None]:
import phagehostlearn_features as phlf

In [None]:
phlf.compute_esm2_embeddings_rbp(general_path, data_suffix=data_suffix)

In [None]:
phlf.compute_esm2_embeddings_loci(general_path, data_suffix=data_suffix)

In [None]:
rbp_embeddings_path = general_path+'/esm2_embeddings_rbp'+data_suffix+'.csv'
loci_embeddings_path = general_path+'/esm2_embeddings_loci'+data_suffix+'.csv'

In [None]:
features_esm2, labels, groups_loci, groups_phage = phlf.construct_feature_matrices(general_path, 
                                                                            data_suffix, loci_embeddings_path, 
                                                                            rbp_embeddings_path)

## 4. Training and evaluating models

In [None]:
import random
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import phagehostlearn_utils as phlu
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from xgboost import XGBClassifier
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.metrics import roc_auc_score, auc, precision_recall_curve, roc_curve
%matplotlib inline

#### 4.1 Training both models and saving them for later use

In [None]:
cpus=6
labels = np.asarray(labels)

In [None]:
# ESM-2 FEATURES + XGBoost model
imbalance = sum([1 for i in labels if i==1]) / sum([1 for i in labels if i==0])
xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=7,
                    n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
xgb.fit(features_esm2, labels)
xgb.save_model('phagehostlearn_vbeta.json')

#### 4.2 LOGOCV with the combined model

In [None]:
# if we want to set a threshold for grouping
matrix = np.loadtxt(general_path+'/all_loci_score_matrix.txt', delimiter='\t')
threshold = 0.995
threshold_str='995'
group_i = 0
new_groups = [np.nan] * len(groups_loci)
for i in range(matrix.shape[0]):
    cluster = np.where(matrix[i,:] >= threshold)[0]
    oldgroups_i = [k for k, x in enumerate(groups_loci) if x in cluster]
    if np.isnan(new_groups[groups_loci.index(i)]):
        for ogi in oldgroups_i:
            new_groups[ogi] = group_i
        group_i += 1
groups_loci = new_groups
print('Number of unique groups: ', len(set(groups_loci)))

In [None]:
logo = LeaveOneGroupOut()
cpus = 6
scores_lan = []
label_list = []
labels = np.asarray(labels)
pbar = tqdm(total=len(set(groups_loci)))
for train_index, test_index in logo.split(features_esm2, labels, groups_loci):
    #print(test_index)
    # get the training and test data
    Xlan_train, Xlan_test = features_esm2[train_index], features_esm2[test_index]
    y_train, y_test = labels[train_index], labels[test_index]
    imbalance = sum([1 for i in y_train if i==1]) / sum([1 for i in y_train if i==0])

    ## ESM-2 EMBEDDINGS: XGBoost model
    xgb = XGBClassifier(scale_pos_weight=1/imbalance, learning_rate=0.3, n_estimators=250, max_depth=7,
                        n_jobs=cpus, eval_metric='logloss', use_label_encoder=False)
    xgb.fit(Xlan_train, y_train)
    score_xgb = xgb.predict_proba(Xlan_test)[:,1]
    scores_lan.append(score_xgb)
    
    # save labels for later
    label_list.append(y_test)
    
    # pbar update
    pbar.update(1)
pbar.close()

In [None]:
# save results
logo_results = {'labels': label_list, 'scores_language': scores_lan}   
with open(results_path+'/v3.4/combined_logocv_results_v34_'+threshold_str+'.pickle', 'wb') as f:
    pickle.dump(logo_results, f)

## 5. Results interpretation

In [None]:
# read results
with open(results_path+'/v3.4/combined_logocv_results_v34_'+threshold_str+'.pickle', 'rb') as f:
    logo_results = pickle.load(f)
scores_lan = logo_results['scores_language']
label_list = logo_results['labels']

# compute performance
rqueries_lan = []
for i in range(len(set(groups_loci))):
    score_lan = scores_lan[i]
    y_test = label_list[i]
    try:
            roc_auc = roc_auc_score(y_test, score_lan)
            ranked_lan = [x for _, x in sorted(zip(score_lan, y_test), reverse=True)]
            rqueries_lan.append(ranked_lan)
    except:
        pass

#### ROC AUC curve

In [None]:
# results, ROC AUC 
labels = np.concatenate(label_list).ravel()
scoreslr = np.concatenate(scores_lan).ravel()

fig, ax = plt.subplots(figsize=(10,8))
fpr, tpr, thrs = roc_curve(labels, scoreslr)
rauclr = round(auc(fpr, tpr), 3)
ax.plot(fpr, tpr, c='#124559', linewidth=2.5, label='ESM-2 + XGBoost (AUC= '+str(rauclr)+')')
ax.set_xlabel('False positive rate', size=24)
ax.set_ylabel('True positive rate', size=24)
ax.legend(loc=4, prop={'size': 20})
ax.grid(True, linestyle=':')
ax.yaxis.set_tick_params(labelsize = 14)
ax.xaxis.set_tick_params(labelsize = 14)
fig.savefig(results_path+'/vbeta/logocv_rocauc.png', dpi=400)
fig.savefig(results_path+'/vbeta/logocv_rocauc_svg.svg', format='svg', dpi=400)

#### Hit ratio against microbiologist approach

In [None]:
import pandas as pd
import numpy as np
import json
import random

# prep the data
interactions1 = general_path+'/klebsiella_phage_host_interactions.xlsx'
interactions2 = general_path+'/klebsiella_interactions_part2.xlsx' # for part 1 NO SUGGESTIONS POSSIBLE -> ALL UNIQUE K-TYPES
matrix1 = pd.read_excel(interactions1, index_col=0, header=0)
matrix2 = pd.read_excel(interactions2, index_col=0, header=0)
locipath = general_path+'/LocibaseValencia.json'
seros = pd.read_csv(general_path+'/serotypesValencia.csv')
with open(locipath) as f:
    locibase = json.load(f)

# do the informed approach
hits = {i: 0 for i in range(1, 51)}
total = 0
# --------------------
# MATRIX 1
# --------------------
loci_serotype = {}
for i, accession in enumerate(locibase.keys()):
    loci_serotype[accession] = seros['sero'][i]
    
# phages sorted by broad-spec
sorted_phages = matrix1.sum().sort_values(ascending=False).index.tolist()

# delete keys not in this matrix (only suggestions within the matrix)
rownames = list(matrix1.index.values)
no_genome = ['K2', 'K21', 'K23', 'K27', 'K28', 'K40', 'K45', 'K48', 'K52', 'K53', 'K67', 'K69', 'K70', 'K71', 'K72']
rownames = [str(i) for i in rownames if i not in no_genome]
for key in list(loci_serotype.keys()):
    if key not in rownames:
        del loci_serotype[key]
        
# iterate over all accessions in matrix1
for i, accession in enumerate(rownames):
    # only compute hit ratio when we can find something
    if sum(matrix1.loc[accession]) > 0:
        # get the serotype
        serotype = loci_serotype[str(accession)]
        # search other bacteria with the same serotype
        same_serotype = [key for key, value in loci_serotype.items() if value == serotype]
        same_serotype.remove(str(accession))
        # get phage suggestions: columnnames of corresponding bacteria in matrix1 with value = 1
        phage_suggestions = []
        for j, acc in enumerate(same_serotype):
            if acc in ['132', '779', '806', '228', '245', '406', '1210', '1446', '1468', '1572', '2164']:
                acc = int(acc)
            colnames = matrix1.columns[matrix1.loc[acc] == 1].tolist()
            phage_suggestions.append(colnames)
        # flatten the list
        phage_suggestions = list(set([item for sublist in phage_suggestions for item in sublist]))
        # sort the list based: most narrow phages first!
        phage_suggestions.sort(key=lambda x: matrix1[x].sum(), reverse=True)
        
        total += 1
        for k in range(1, 51):
            # approach 1: if we dont have enough suggestions, pick extra at random from total pool available
            # approach 2: now, we supplement them with the sorted phages by broad-spectrum, not random!
            if k > len(phage_suggestions):
                sample_pool = [sugg for sugg in sorted_phages if sugg not in phage_suggestions]
                to_pick = k-len(phage_suggestions)
                if len(sample_pool) < to_pick:
                    phage_suggestions = phage_suggestions + sample_pool
                else:
                    phage_suggestions = phage_suggestions + sample_pool[:to_pick]

            #suggested = random.sample(phage_suggestions, k)
            if any([matrix1.loc[accession, sugg] == 1 for sugg in phage_suggestions]):
                hits[k] += 1
                
# --------------------
# MATRIX 2
# --------------------
hits2 = {i: 0 for i in range(1, 51)}
total2 = 0
loci_serotype = {}
for i, accession in enumerate(locibase.keys()):
    loci_serotype[accession] = seros['sero'][i]
    
sorted_phages = matrix2.sum().sort_values(ascending=False).index.tolist()

# delete keys not in this matrix (only suggestions within the matrix)
rownames = list(matrix2.index.values)
rownames = [str(i) for i in rownames]
for key in list(loci_serotype.keys()):
    if key not in rownames:
        del loci_serotype[key]

# iterate over all accessions in matrix2
for i, accession in enumerate(matrix2.index.values):
    # only compute hit ratio when we can find something
    if sum(matrix2.loc[accession]) > 0:
        # get the serotype
        serotype = loci_serotype[str(accession)]
        # search other bacteria with the same serotype
        same_serotype = [key for key, value in loci_serotype.items() if value == serotype]
        same_serotype.remove(str(accession))
        # get phage suggestions: columnnames of corresponding bacteria in matrix2 with value = 1
        phage_suggestions = []
        for j, acc in enumerate(same_serotype):
            if acc in ['132', '779', '806', '228', '245', '406', '1210', '1446', '1468', '1572', '2164']:
                acc = int(acc)
            colnames = matrix2.columns[matrix2.loc[acc] == 1].tolist()
            phage_suggestions.append(colnames)
        # flatten the list
        phage_suggestions = list(set([item for sublist in phage_suggestions for item in sublist]))
        # sort the list based: most narrow phages first!
        phage_suggestions.sort(key=lambda x: matrix2[x].sum(), reverse=True)

        total += 1
        total2 += 1
        for k in range(1, 51):
            # if we dont have enough suggestions, pick extra at random from the total pool
            if k > len(phage_suggestions):
                sample_pool = [sugg for sugg in sorted_phages if sugg not in phage_suggestions]
                to_pick = k-len(phage_suggestions)
                if len(sample_pool) < to_pick:
                    phage_suggestions = phage_suggestions + sample_pool
                else:
                    phage_suggestions = phage_suggestions + sample_pool[:to_pick]
            
            if any([matrix2.loc[accession, sugg] == 1 for sugg in phage_suggestions]):
                hits[k] += 1
                hits2[k] += 1

informed_hitratio = {k: v/total for k, v in hits.items()}
informed_hitratio2 = {k: v/total2 for k, v in hits2.items()}

In [None]:
# results, hit ratios @ K
ks = np.linspace(1, 50, 50)
hits_lan = [phlu.hitratio(rqueries_lan, int(k)) for k in ks]
fig, ax = plt.subplots(figsize=(10,8))
ax.plot(ks, hits_lan, c='#124559', linewidth=2.5, label='ESM-2 + XGBoost')
#ax.plot(ks, hits_ens, c='#124559', linewidth=2.5, label='Combined model')
#ax.plot(ks, hits_random, c='#81B29A', linewidth=2.5, ls=':', label='Random guess')
#ax.plot(ks, list(informed_hitratio.values()), c='#E15554', linewidth=2.5, ls='-.', label='Informed microbiologist')
#ax.plot(ks, list(informed_hitratio2.values()), c='#E15554', linewidth=2.5, ls=':', label='Informed guess (Bea only)')
ax.set_xlabel('$\it{k}$', size=24)
ax.set_ylabel('Hit ratio @ $\it{k}$', size=24)
ax.set_ylim(0.1, 1)
ax.legend(loc=4, prop={'size': 24})
ax.grid(True, linestyle=':')
ax.yaxis.set_tick_params(labelsize = 14)
ax.xaxis.set_tick_params(labelsize = 14)
fig.savefig(results_path+'/vbeta/logocv_hitratio_informed.png', dpi=400)
fig.savefig(results_path+'/vbeta/logocv_hitratio_informed_svg.svg', format='svg', dpi=400)

#### Performance per K-type

https://medium.com/@curryrowan/simplified-logistic-regression-classification-with-categorical-variables-in-python-1ce50c4b137

In [None]:
# read results
with open(results_path+'/v3.4/combined_logocv_results_v34_100.pickle', 'rb') as f:
    logo_results = pickle.load(f)
scores_lan = logo_results['scores_language']
label_list = logo_results['labels']

# read K-types
seros = pd.read_csv(general_path+'/serotypesValencia.csv')

In [None]:
# mean hit ratio per K-type
unique_seros = list(set(seros['sero']))
performance_ktypes = {}
labelcount_ktypes = {}
for unique in unique_seros:
    indices = seros['sero'] == unique
    subscores_lan = [val for is_good, val in zip(indices, scores_lan) if is_good]
    sublabels = [val for is_good, val in zip(indices, label_list) if is_good]
    labelcount_ktypes[unique] = [sum(i) for i in sublabels]
    rqueries_lan = []
    for i in range(len(subscores_lan)):
        score_lan = subscores_lan[i]
        y_test = sublabels[i]
        if sum(y_test) > 0:
            ranked_lan = [x for _, x in sorted(zip(score_lan, y_test), reverse=True)]
            rqueries_lan.append(ranked_lan)
    if len(rqueries_lan) > 0:
        hr_lan = round(phlu.hitratio(rqueries_lan, 10), 3)
        performance_ktypes[unique] = [('HR_XGB', hr_lan)]
    #else:
    #    performance_ktypes[unique] = [('MAR_XGB', np.nan), ('MAR_HDC', np.nan), ('MAR_COMBINED', np.nan)]

In [None]:
performance_hr_xgb = []
for ktype in performance_ktypes:
    performance_hr_xgb.append(performance_ktypes[ktype][0][1])
sortedpairs = [(x,y) for y, x in sorted(zip(performance_hr_xgb, list(performance_ktypes.keys())), reverse=True)]
fig, ax = plt.subplots(figsize=(16,6))
ax.hist(performance_hr_xgb, bins=25, color='#124559')
#sns.barplot(x=[score for (key, score) in sortedpairs], y=[key for (key, score) in sortedpairs], ax=ax, palette='magma')
ax.set_xlabel('Mean top-10 hit ratio', size=22)
ax.set_ylabel('Number of K-types', size=22)
ax.yaxis.set_tick_params(labelsize = 14)
ax.xaxis.set_tick_params(labelsize = 14)
fig.tight_layout()
fig.savefig(results_path+'/vbeta/histogram_ktypes_svg.svg', format='svg', dpi=400)

#### Hit ratio per K-type versus number of pos labels

In [None]:
top10 = [x[0] for x in sortedpairs if x[1] == 1] # all with HR == 1
bottom10 = [x[0] for x in sortedpairs if x[1] == 0] # all with HR == 0
middle = [x[0] for x in sortedpairs if (x[1] != 0 and x[1] != 1)]
countst10 = []
countsb10 = []
countsmid = []
for key in labelcount_ktypes.keys():
    if key in top10:
        countst10.append(labelcount_ktypes[key])
    elif key in bottom10:
        countsb10.append(labelcount_ktypes[key])
    elif key in middle:
        countsmid.append(labelcount_ktypes[key])
countst10 = [i for x in countst10 for i in x]
countsb10 = [i for x in countsb10 for i in x]
countsmid = [i for x in countsmid for i in x]

In [None]:
countlist = [countst10, countsmid, countsb10]
binlist = [15, 15, 15]

for i, count in enumerate(countlist):
    fig, ax = plt.subplots(figsize=(8, 8))
    sns.histplot(count, ax=ax, color='#221150', bins=binlist[i])
    ax.set_xlim(0, 10)
    ax.set_xlabel('Number of confirmed interactions', size=22)
    ax.set_ylabel('Number of bacteria', size=22)
    ax.yaxis.set_tick_params(labelsize = 14)
    ax.xaxis.set_tick_params(labelsize = 14)
    fig.savefig(results_path+'/v3.4/ktypecounts_svg'+str(i)+'.svg', format='svg', dpi=400)