# Protein embeddings improve phage-host interaction prediction

**Mark Edward M. Gonzales<sup>1, 2</sup>, Jennifer C. Ureta<sup>1, 2</sup> & Anish M.S. Shrestha<sup>1, 2</sup>**

<sup>1</sup> Bioinformatics Laboratory, Advanced Research Institute for Informatics, Computing and Networking, De La Salle University, Manila, Philippines <br>
<sup>2</sup> Department of Software Technology, College of Computer Studies, De La Salle University, Manila, Philippines 

{mark_gonzales, jennifer.ureta, anish.shrestha}@dlsu.edu.ph

<hr>

## 💡 Phage-Host-Features CSV Files

This notebook assumes that you already have the phage-host-features CSV files (from running [`5. Data Consolidation.ipynb`](https://github.com/bioinfodlsu/phage-host-prediction/blob/main/experiments/5.%20Data%20Consolidation.ipynb)).

Alternatively, you may download the CSV files from [Google Drive](https://drive.google.com/drive/folders/1xNoA6dxkN4jzVNCg_7YNjdPZzl51Jo9M?usp=sharing) and save the downloaded `data` folder inside the `inphared` directory located in the same folder as this notebook. The folder structure should look like this:

`experiments` (parent folder of this notebook) <br> 
↳ `inphared` <br>
&nbsp; &nbsp;↳ `data` <br>
&nbsp; &nbsp;&nbsp; &nbsp; ↳ `rbp.csv` <br>
&nbsp; &nbsp;&nbsp; &nbsp; ↳ `rbp_embeddings_esm.csv` <br>
&nbsp; &nbsp;&nbsp; &nbsp; ↳ ... <br>
↳ `6. Classifier Building & Evaluation.ipynb` (this notebook) <br>

<hr>

## 📁 Output Files

If you would like to skip running this notebook, you may download the trained models from [Google Drive](https://drive.google.com/drive/folders/1U5ugmkhD4LHElYnLj3B8Xt2TcPx-TOjB?usp=sharing). Save the downloaded `models` folder in the same folder as this notebook. The folder structure should look like this:

`experiments` (parent folder of this notebook) <br> 
↳ `inphared` <br>
↳ `models` <br>
&nbsp; &nbsp;↳ `boeckaerts.joblib` <br>
&nbsp; &nbsp;↳ `esm.joblib` <br>
&nbsp; &nbsp;↳ ... <br>
↳ `6. Classifier Building & Evaluation.ipynb` (this notebook) <br>

Intermediate output files (i.e., `temp/feature_importance.pickle` and those saved in `temp/results`) should have already been included when the repository was cloned.

This notebook also generates `rbp_embeddings_boeckaerts.csv`, the phage-host-features CSV file for the state-of-the-art tool by Boeckaerts <i>et al.</i> (2021) with which we compared model performance. It should have already been downloaded as part of the [`data`](https://drive.google.com/drive/folders/1xNoA6dxkN4jzVNCg_7YNjdPZzl51Jo9M?usp=sharing) folder.

<hr>

## Part I: Preliminaries

Import the necessary libraries and modules.

In [1]:
import math
import pickle
import os
import warnings

import pandas as pd
import numpy as np
import sklearn

from ConstantsUtil import ConstantsUtil
from ClassificationUtil import ClassificationUtil
import boeckaerts as RBP_f

%load_ext autoreload
%autoreload 2



In [2]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', 50)

pd.options.mode.chained_assignment = None

with warnings.catch_warnings(): 
    warnings.filterwarnings("ignore", category=sklearn.exceptions.UndefinedMetricWarning)

In [3]:
constants = ConstantsUtil()
util = ClassificationUtil()

<hr>

## Part II: Classifier Building and Evaluation

Train a random forest model with the embeddings of the RBPs as the input and the host as the output. Display the results as well.

In [4]:
models = list(constants.PLM_EMBEDDINGS_CSV.keys())

# len(models) - 1 to exclude benchmark model
for i in range(len(models) - 1):
    save_feature_importance = False
    if models[i] == 'PROTT5':
        save_feature_importance = True
        
    util.classify(models[i], save_feature_importance)

Constructing training and test sets...
Training the model...
Saving evaluation results...
Confidence threshold k: 0.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.3529    0.5217        17
    acinetobacter     0.9385    0.6224    0.7485        98
        aeromonas     0.8485    0.5895    0.6957        95
    agrobacterium     1.0000    0.4167    0.5882        12
     arthrobacter     0.9500    0.7308    0.8261        26
         bacillus     0.7215    0.8382    0.7755       204
      bacteroides     0.9032    0.9032    0.9032        31
    brevundimonas     1.0000    0.6000    0.7500        10
     burkholderia     0.6333    0.4872    0.5507        39
    campylobacter     0.8559    0.8145    0.8347       124
      caulobacter     1.0000    0.8400    0.9130        25
      citrobacter     0.3077    0.2353    0.2667        34
      clostridium     0.5000    0.6667    0.5714        24
      cronobacter     0.5000    0.3333    0.4000      

Confidence threshold k: 30.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2941    0.4545        17
    acinetobacter     1.0000    0.4184    0.5899        98
        aeromonas     0.9714    0.3579    0.5231        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5000    0.6667        26
         bacillus     0.9467    0.6961    0.8023       204
      bacteroides     1.0000    0.7419    0.8519        31
    brevundimonas     1.0000    0.2000    0.3333        10
     burkholderia     1.0000    0.2821    0.4400        39
    campylobacter     0.9394    0.7500    0.8341       124
      caulobacter     1.0000    0.6800    0.8095        25
      citrobacter     0.6000    0.0882    0.1538        34
      clostridium     0.6875    0.4583    0.5500        24
      cronobacter     0.6667    0.1111    0.1905        36
          dickeya     0.9706    0.6346    0.7674        52
     edwardsiella     1.0

Confidence threshold k: 60.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.2449    0.3934        98
        aeromonas     1.0000    0.2000    0.3333        95
    agrobacterium     1.0000    0.2500    0.4000        12
     arthrobacter     1.0000    0.2692    0.4242        26
         bacillus     0.9658    0.5539    0.7040       204
      bacteroides     1.0000    0.6452    0.7843        31
    brevundimonas     1.0000    0.2000    0.3333        10
     burkholderia     1.0000    0.1026    0.1860        39
    campylobacter     0.9737    0.5968    0.7400       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.8000    0.1667    0.2759        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.5000    0.6667        52
     edwardsiella     1.0

Confidence threshold k: 90.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0510    0.0971        98
        aeromonas     1.0000    0.0211    0.0412        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     1.0000    0.1154    0.2069        26
         bacillus     1.0000    0.2206    0.3614       204
      bacteroides     1.0000    0.0645    0.1212        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0256    0.0500        39
    campylobacter     1.0000    0.2661    0.4204       124
      caulobacter     1.0000    0.2800    0.4375        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     1.0000    0.0417    0.0800        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.3654    0.5352        52
     edwardsiella     0.0

Confidence threshold k: 10.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2941    0.4545        17
    acinetobacter     1.0000    0.5714    0.7273        98
        aeromonas     0.9600    0.5053    0.6621        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.6154    0.7619        26
         bacillus     0.8901    0.7941    0.8394       204
      bacteroides     1.0000    0.9032    0.9492        31
    brevundimonas     1.0000    0.4000    0.5714        10
     burkholderia     0.8824    0.3846    0.5357        39
    campylobacter     0.9706    0.7984    0.8761       124
      caulobacter     1.0000    0.7600    0.8636        25
      citrobacter     0.5000    0.2059    0.2917        34
      clostridium     0.4643    0.5417    0.5000        24
      cronobacter     0.6842    0.3611    0.4727        36
          dickeya     0.9744    0.7308    0.8352        52
     edwardsiella     1.0

Confidence threshold k: 40.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2353    0.3810        17
    acinetobacter     1.0000    0.4082    0.5797        98
        aeromonas     0.9630    0.2737    0.4262        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.4231    0.5946        26
         bacillus     0.9504    0.6569    0.7768       204
      bacteroides     1.0000    0.7419    0.8519        31
    brevundimonas     1.0000    0.3000    0.4615        10
     burkholderia     0.9000    0.2308    0.3673        39
    campylobacter     0.9787    0.7419    0.8440       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.6000    0.0882    0.1538        34
      clostridium     0.6875    0.4583    0.5500        24
      cronobacter     0.7500    0.0833    0.1500        36
          dickeya     1.0000    0.5769    0.7317        52
     edwardsiella     1.0

Confidence threshold k: 70.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.1327    0.2342        98
        aeromonas     1.0000    0.1368    0.2407        95
    agrobacterium     1.0000    0.2500    0.4000        12
     arthrobacter     1.0000    0.2692    0.4242        26
         bacillus     0.9798    0.4755    0.6403       204
      bacteroides     1.0000    0.4839    0.6522        31
    brevundimonas     1.0000    0.1000    0.1818        10
     burkholderia     1.0000    0.0513    0.0976        39
    campylobacter     0.9730    0.5806    0.7273       124
      caulobacter     1.0000    0.6000    0.7500        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.8333    0.2083    0.3333        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.4423    0.6133        52
     edwardsiella     0.0

Confidence threshold k: 100.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0102    0.0202        98
        aeromonas     0.0000    0.0000    0.0000        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     0.0000    0.0000    0.0000        26
         bacillus     1.0000    0.0098    0.0194       204
      bacteroides     0.0000    0.0000    0.0000        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     0.0000    0.0000    0.0000        39
    campylobacter     1.0000    0.0484    0.0923       124
      caulobacter     1.0000    0.0400    0.0769        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.0000    0.0000    0.0000        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.2885    0.4478        52
     edwardsiella     0.

Confidence threshold k: 20.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2941    0.4545        17
    acinetobacter     1.0000    0.4592    0.6294        98
        aeromonas     0.9556    0.4526    0.6143        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5385    0.7000        26
         bacillus     0.9394    0.7598    0.8401       204
      bacteroides     1.0000    0.8065    0.8929        31
    brevundimonas     1.0000    0.4000    0.5714        10
     burkholderia     1.0000    0.2564    0.4082        39
    campylobacter     0.9604    0.7823    0.8622       124
      caulobacter     1.0000    0.7200    0.8372        25
      citrobacter     0.5000    0.1765    0.2609        34
      clostridium     0.7222    0.5417    0.6190        24
      cronobacter     0.7692    0.2778    0.4082        36
          dickeya     0.9487    0.7115    0.8132        52
     edwardsiella     1.0

Confidence threshold k: 50.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.1176    0.2105        17
    acinetobacter     1.0000    0.3265    0.4923        98
        aeromonas     0.9600    0.2526    0.4000        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.3462    0.5143        26
         bacillus     0.9615    0.6127    0.7485       204
      bacteroides     1.0000    0.6452    0.7843        31
    brevundimonas     1.0000    0.2000    0.3333        10
     burkholderia     1.0000    0.1538    0.2667        39
    campylobacter     0.9674    0.7177    0.8241       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.5000    0.0588    0.1053        34
      clostridium     0.8333    0.4167    0.5556        24
      cronobacter     0.6667    0.0556    0.1026        36
          dickeya     1.0000    0.5385    0.7000        52
     edwardsiella     1.0

Confidence threshold k: 80.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.1122    0.2018        98
        aeromonas     1.0000    0.1053    0.1905        95
    agrobacterium     1.0000    0.2500    0.4000        12
     arthrobacter     1.0000    0.2308    0.3750        26
         bacillus     0.9747    0.3775    0.5442       204
      bacteroides     1.0000    0.3871    0.5581        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0513    0.0976        39
    campylobacter     1.0000    0.4274    0.5989       124
      caulobacter     1.0000    0.6000    0.7500        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.6667    0.0833    0.1481        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.4423    0.6133        52
     edwardsiella     0.0

Finished
Constructing training and test sets...
Training the model...
Saving evaluation results...
Confidence threshold k: 0.0%
                   precision    recall  f1-score   support

    achromobacter     0.9000    0.5294    0.6667        17
    acinetobacter     0.9565    0.6735    0.7904        98
        aeromonas     0.8267    0.6526    0.7294        95
    agrobacterium     0.8333    0.4167    0.5556        12
     arthrobacter     0.8333    0.7692    0.8000        26
         bacillus     0.8128    0.8725    0.8416       204
      bacteroides     0.9032    0.9032    0.9032        31
    brevundimonas     1.0000    0.6000    0.7500        10
     burkholderia     0.6452    0.5128    0.5714        39
    campylobacter     0.8655    0.8306    0.8477       124
      caulobacter     0.9545    0.8400    0.8936        25
      citrobacter     0.3448    0.2941    0.3175        34
      clostridium     0.5000    0.6667    0.5714        24
      cronobacter     0.5217    0.3333    0.4

Confidence threshold k: 30.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2941    0.4545        17
    acinetobacter     1.0000    0.4898    0.6575        98
        aeromonas     0.9574    0.4737    0.6338        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.6538    0.7907        26
         bacillus     0.9503    0.7500    0.8384       204
      bacteroides     1.0000    0.7419    0.8519        31
    brevundimonas     1.0000    0.5000    0.6667        10
     burkholderia     1.0000    0.3333    0.5000        39
    campylobacter     0.8857    0.7500    0.8122       124
      caulobacter     1.0000    0.7600    0.8636        25
      citrobacter     0.5714    0.1176    0.1951        34
      clostridium     0.7222    0.5417    0.6190        24
      cronobacter     0.8000    0.2222    0.3478        36
          dickeya     1.0000    0.6538    0.7907        52
     edwardsiella     1.0

Confidence threshold k: 60.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.1176    0.2105        17
    acinetobacter     1.0000    0.3163    0.4806        98
        aeromonas     1.0000    0.2316    0.3761        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5000    0.6667        26
         bacillus     0.9638    0.6520    0.7778       204
      bacteroides     1.0000    0.5806    0.7347        31
    brevundimonas     1.0000    0.2000    0.3333        10
     burkholderia     1.0000    0.1282    0.2273        39
    campylobacter     0.9610    0.5968    0.7363       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     1.0000    0.0588    0.1111        34
      clostridium     0.6667    0.1667    0.2667        24
      cronobacter     1.0000    0.0556    0.1053        36
          dickeya     1.0000    0.4808    0.6494        52
     edwardsiella     0.0

Confidence threshold k: 90.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0816    0.1509        98
        aeromonas     1.0000    0.0316    0.0612        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     1.0000    0.1154    0.2069        26
         bacillus     1.0000    0.3137    0.4776       204
      bacteroides     1.0000    0.0323    0.0625        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0513    0.0976        39
    campylobacter     1.0000    0.1774    0.3014       124
      caulobacter     1.0000    0.3200    0.4848        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.6667    0.0833    0.1481        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.3654    0.5352        52
     edwardsiella     0.0

Confidence threshold k: 10.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.3529    0.5217        17
    acinetobacter     1.0000    0.6224    0.7673        98
        aeromonas     0.9592    0.4947    0.6528        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.6923    0.8182        26
         bacillus     0.8895    0.8284    0.8579       204
      bacteroides     1.0000    0.8710    0.9310        31
    brevundimonas     1.0000    0.5000    0.6667        10
     burkholderia     1.0000    0.4103    0.5818        39
    campylobacter     0.9259    0.8065    0.8621       124
      caulobacter     1.0000    0.8400    0.9130        25
      citrobacter     0.5263    0.2941    0.3774        34
      clostridium     0.5600    0.5833    0.5714        24
      cronobacter     0.6471    0.3056    0.4151        36
          dickeya     0.9286    0.7500    0.8298        52
     edwardsiella     1.0

Confidence threshold k: 40.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2941    0.4545        17
    acinetobacter     1.0000    0.4694    0.6389        98
        aeromonas     0.9677    0.3158    0.4762        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5000    0.6667        26
         bacillus     0.9586    0.6814    0.7966       204
      bacteroides     1.0000    0.7097    0.8302        31
    brevundimonas     1.0000    0.3000    0.4615        10
     burkholderia     1.0000    0.2564    0.4082        39
    campylobacter     0.9175    0.7177    0.8054       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.7500    0.0882    0.1579        34
      clostridium     0.7059    0.5000    0.5854        24
      cronobacter     0.6667    0.1111    0.1905        36
          dickeya     1.0000    0.5962    0.7470        52
     edwardsiella     1.0

Confidence threshold k: 70.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.0588    0.1111        17
    acinetobacter     1.0000    0.2143    0.3529        98
        aeromonas     1.0000    0.1579    0.2727        95
    agrobacterium     1.0000    0.2500    0.4000        12
     arthrobacter     1.0000    0.3846    0.5556        26
         bacillus     0.9735    0.5392    0.6940       204
      bacteroides     1.0000    0.5161    0.6809        31
    brevundimonas     1.0000    0.1000    0.1818        10
     burkholderia     1.0000    0.1026    0.1860        39
    campylobacter     0.9697    0.5161    0.6737       124
      caulobacter     1.0000    0.6000    0.7500        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.6667    0.0833    0.1481        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.4423    0.6133        52
     edwardsiella     0.0

Confidence threshold k: 100.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0204    0.0400        98
        aeromonas     0.0000    0.0000    0.0000        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     0.0000    0.0000    0.0000        26
         bacillus     1.0000    0.0294    0.0571       204
      bacteroides     0.0000    0.0000    0.0000        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     0.0000    0.0000    0.0000        39
    campylobacter     1.0000    0.0565    0.1069       124
      caulobacter     0.0000    0.0000    0.0000        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.0000    0.0000    0.0000        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.2885    0.4478        52
     edwardsiella     0.

Confidence threshold k: 20.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.3529    0.5217        17
    acinetobacter     1.0000    0.5510    0.7105        98
        aeromonas     0.9423    0.5158    0.6667        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.6923    0.8182        26
         bacillus     0.9364    0.7941    0.8594       204
      bacteroides     1.0000    0.7742    0.8727        31
    brevundimonas     1.0000    0.4000    0.5714        10
     burkholderia     1.0000    0.3590    0.5283        39
    campylobacter     0.9314    0.7661    0.8407       124
      caulobacter     1.0000    0.7600    0.8636        25
      citrobacter     0.4615    0.1765    0.2553        34
      clostridium     0.6667    0.5833    0.6222        24
      cronobacter     0.7692    0.2778    0.4082        36
          dickeya     0.9737    0.7115    0.8222        52
     edwardsiella     1.0

Confidence threshold k: 50.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.1765    0.3000        17
    acinetobacter     1.0000    0.4082    0.5797        98
        aeromonas     0.9630    0.2737    0.4262        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5385    0.7000        26
         bacillus     0.9638    0.6520    0.7778       204
      bacteroides     1.0000    0.7097    0.8302        31
    brevundimonas     1.0000    0.3000    0.4615        10
     burkholderia     1.0000    0.1795    0.3043        39
    campylobacter     0.9333    0.6774    0.7850       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.6667    0.0588    0.1081        34
      clostridium     0.6667    0.4167    0.5128        24
      cronobacter     0.6000    0.0833    0.1463        36
          dickeya     1.0000    0.5000    0.6667        52
     edwardsiella     1.0

Confidence threshold k: 80.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.0588    0.1111        17
    acinetobacter     1.0000    0.1429    0.2500        98
        aeromonas     1.0000    0.1263    0.2243        95
    agrobacterium     1.0000    0.1667    0.2857        12
     arthrobacter     1.0000    0.2692    0.4242        26
         bacillus     0.9785    0.4461    0.6128       204
      bacteroides     1.0000    0.4516    0.6222        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0769    0.1429        39
    campylobacter     1.0000    0.4274    0.5989       124
      caulobacter     1.0000    0.4400    0.6111        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.6667    0.0833    0.1481        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.4423    0.6133        52
     edwardsiella     0.0

Finished
Constructing training and test sets...
Training the model...
Saving evaluation results...
Confidence threshold k: 0.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.4118    0.5833        17
    acinetobacter     0.8767    0.6531    0.7485        98
        aeromonas     0.9508    0.6105    0.7436        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.6923    0.8182        26
         bacillus     0.7576    0.8578    0.8046       204
      bacteroides     1.0000    0.9355    0.9667        31
    brevundimonas     1.0000    0.5000    0.6667        10
     burkholderia     0.6667    0.4103    0.5079        39
    campylobacter     0.8559    0.7661    0.8085       124
      caulobacter     0.8400    0.8400    0.8400        25
      citrobacter     0.5000    0.2941    0.3704        34
      clostridium     0.4242    0.5833    0.4912        24
      cronobacter     0.6000    0.4167    0.4

Confidence threshold k: 30.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2353    0.3810        17
    acinetobacter     1.0000    0.4184    0.5899        98
        aeromonas     0.9394    0.3263    0.4844        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5000    0.6667        26
         bacillus     0.9346    0.7010    0.8011       204
      bacteroides     1.0000    0.8065    0.8929        31
    brevundimonas     1.0000    0.4000    0.5714        10
     burkholderia     1.0000    0.2821    0.4400        39
    campylobacter     0.9663    0.6935    0.8075       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.6000    0.0882    0.1538        34
      clostridium     0.6923    0.3750    0.4865        24
      cronobacter     0.7778    0.1944    0.3111        36
          dickeya     1.0000    0.6538    0.7907        52
     edwardsiella     1.0

Confidence threshold k: 60.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.2653    0.4194        98
        aeromonas     1.0000    0.1789    0.3036        95
    agrobacterium     1.0000    0.2500    0.4000        12
     arthrobacter     1.0000    0.3077    0.4706        26
         bacillus     0.9750    0.5735    0.7222       204
      bacteroides     1.0000    0.4839    0.6522        31
    brevundimonas     1.0000    0.2000    0.3333        10
     burkholderia     1.0000    0.1282    0.2273        39
    campylobacter     0.9583    0.5565    0.7041       124
      caulobacter     1.0000    0.6400    0.7805        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.7143    0.2083    0.3226        24
      cronobacter     1.0000    0.0556    0.1053        36
          dickeya     1.0000    0.5000    0.6667        52
     edwardsiella     1.0

Confidence threshold k: 90.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0408    0.0784        98
        aeromonas     1.0000    0.0211    0.0412        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     1.0000    0.1154    0.2069        26
         bacillus     1.0000    0.2157    0.3548       204
      bacteroides     1.0000    0.0645    0.1212        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0256    0.0500        39
    campylobacter     0.9375    0.2419    0.3846       124
      caulobacter     1.0000    0.2000    0.3333        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     1.0000    0.0833    0.1538        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.3654    0.5352        52
     edwardsiella     0.0

Display training and test dataset statistics.

In [5]:
plm = 'PROTT5'

# Load data
rbp_embeddings = pd.read_csv(f'{constants.INPHARED}/{constants.DATA}/{constants.PLM_EMBEDDINGS_CSV[plm]}', 
                             low_memory = False)
rbp_embeddings['Modification Date'] = pd.to_datetime(rbp_embeddings['Modification Date'])

# Get only the top 25% hosts
all_counts = rbp_embeddings['Host'].value_counts()
TOP_X_PERCENT = 0.25
top_x = math.floor(all_counts.shape[0] * TOP_X_PERCENT)

top_genus = set()
genus_counts = all_counts.index
for entry in genus_counts[:top_x]:
    top_genus.add(entry)

# Construct the training and test sets
print("Constructing training and test sets...")

rbp_embeddings_top = rbp_embeddings[rbp_embeddings['Host'].isin(top_genus)]

counts, X_train, X_test, y_train, y_test = util.random_train_test_split(rbp_embeddings_top, 'Host',
                                                                        embeddings_size = rbp_embeddings.shape[1] - constants.INPHARED_EXTRA_COLS)

counts_df = pd.DataFrame(counts, columns = ['Genus', f'Train', f'Test', 'Total'])

unknown_hosts_X, unknown_hosts_y = util.get_unknown_hosts(rbp_embeddings[~rbp_embeddings['Host'].isin(top_genus)], 'Host',
                                                          embeddings_size = rbp_embeddings.shape[1] - constants.INPHARED_EXTRA_COLS)

X_test = X_test.append(unknown_hosts_X)
y_test = y_test.append(unknown_hosts_y)

Constructing training and test sets...


In [6]:
counts_df

Unnamed: 0,Genus,Train,Test,Total
0,escherichia,3021,1295,4316
1,salmonella,1474,632,2106
2,synechococcus,1216,521,1737
3,pseudomonas,1196,513,1709
4,vibrio,1079,463,1542
5,klebsiella,926,397,1323
6,erwinia,667,286,953
7,mycobacterium,578,248,826
8,staphylococcus,568,244,812
9,bacillus,475,204,679


In [7]:
print('Training Set:\t', counts_df['Train'].sum())
print('Test Set:\t', len(X_test))

Training Set:	 16636
Test Set:	 8116


Investigate the extent of class imbalance by calculating the percentage of RBPs associated with the top 25% hosts.

In [8]:
counts_df['Total'].sum() / rbp_embeddings.shape[0]

0.9601648351648352

<hr>

## Part III: Benchmarking

Compare the performance of our model with the state-of-the-art phage-host interaction prediction tool by Boeckaerts <i>et al.</i> (2021).

In [9]:
rbp_embeddings = pd.read_csv(f'{constants.INPHARED}/{constants.DATA}/{constants.INPHARED_RBP_DATA}')
rbp_embeddings['Modification Date'] = pd.to_datetime(rbp_embeddings['Modification Date'])

The code in the cell below is taken from https://github.com/dimiboeckaerts/BacteriophageHostPrediction

**Reference**:
Boeckaerts, D., Stock, M., Criel, B., Gerstmans, H., De Baets, B., & Briers, Y. (2021). Predicting bacteriophage hosts based on sequences of annotated receptor-binding proteins. *Scientific Reports, 11*(1), 1467. doi:10.1038/s41598-021-81063-4

In [10]:
dna_list = list(rbp_embeddings.iloc[:,-1])
dna_feats = RBP_f.dna_features(dna_list)

protein_list = list(rbp_embeddings.iloc[:,-2])
protein_feats = RBP_f.protein_features(protein_list)

# protein features: CTD & Z-scale
extra_feats = np.zeros((rbp_embeddings.shape[0], 47))

for i,item in enumerate(protein_list):
    feature_lst = []
    feature_lst  += RBP_f.CTDC(item)
    feature_lst += RBP_f.CTDT(item)
    feature_lst += RBP_f.zscale(item)
    extra_feats[i,:] = feature_lst
    
extra_feats_df = pd.DataFrame(extra_feats, columns=['CTDC1', 'CTDC2', 'CTDC3', 'CTDT1', 'CTDT2', 'CTDT3', 
                        'CTDT4', 'CTDT5', 'CTDT6', 'CTDT7', 'CTDT8', 'CTDT9', 'CTDT10', 'CTDT11', 'CTDT12', 'CTDT13', 
                        'CTDT14', 'CTDT15', 'CTDT16', 'CTDT17', 'CTDT18', 'CTDT19', 'CTDT20', 'CTDT21', 'CTDT22',
                        'CTDT23', 'CTDT24', 'CTDT25', 'CTDT26', 'CTDT27', 'CTDT28', 'CTDT29', 'CTDT30', 'CTDT31', 
                        'CTDT32', 'CTDT33', 'CTDT34', 'CTDT35', 'CTDT36', 'CTDT37', 'CTDT38', 'CTDT39', 'Z1', 'Z2',
                        'Z3', 'Z4', 'Z5'])

# all features, not splitted
features = pd.concat([dna_feats, protein_feats, extra_feats_df], axis=1)
rbp_embeddings = pd.concat([rbp_embeddings, features], axis = 1)
rbp_embeddings.head()

Unnamed: 0,Protein ID,Accession,Description,Classification,Genome Length (bp),Jumbophage,molGC (%),Molecule,Modification Date,Number CDS,Positive Strand (%),Negative Strand (%),Coding Capacity(%),Low Coding Capacity Warning,tRNAs,Host,Lowest Taxa,Genus,Sub-family,Family,Order,Class,Phylum,Kingdom,Realm,...,CTDT20,CTDT21,CTDT22,CTDT23,CTDT24,CTDT25,CTDT26,CTDT27,CTDT28,CTDT29,CTDT30,CTDT31,CTDT32,CTDT33,CTDT34,CTDT35,CTDT36,CTDT37,CTDT38,CTDT39,Z1,Z2,Z3,Z4,Z5
0,BAF36105.1,AB231700,Microcystis virus Ma-LMM01,Microcystis virus Ma-LMM01 Fukuivirus Caudovir...,162109,False,45.953,DNA,2021-07-14,189,34.391534,65.608466,93.542616,,2,microcystis,Fukuivirus,Fukuivirus,Unclassified,Unclassified,Unclassified,Caudoviricetes,Uroviricota,Heunggongvirae,Duplodnaviria,...,0.289209,0.220144,0.271942,0.185612,0.211511,0.208633,0.246043,0.205755,0.264748,0.155396,0.241727,0.155396,0.025899,0.181295,0.208633,0.235971,0.201439,0.292086,0.21295,0.145324,0.162356,-0.329641,-0.20477,-0.290445,0.273951
1,BAF36110.1,AB231700,Microcystis virus Ma-LMM01,Microcystis virus Ma-LMM01 Fukuivirus Caudovir...,162109,False,45.953,DNA,2021-07-14,189,34.391534,65.608466,93.542616,,2,microcystis,Fukuivirus,Fukuivirus,Unclassified,Unclassified,Unclassified,Caudoviricetes,Uroviricota,Heunggongvirae,Duplodnaviria,...,0.266467,0.260479,0.350299,0.134731,0.125749,0.266467,0.212575,0.212575,0.347305,0.101796,0.158683,0.10479,0.002994,0.140719,0.236527,0.182635,0.206587,0.266467,0.239521,0.146707,-0.129463,-0.544507,-0.264687,-0.483254,0.406597
2,BAF36131.1,AB231700,Microcystis virus Ma-LMM01,Microcystis virus Ma-LMM01 Fukuivirus Caudovir...,162109,False,45.953,DNA,2021-07-14,189,34.391534,65.608466,93.542616,,2,microcystis,Fukuivirus,Fukuivirus,Unclassified,Unclassified,Unclassified,Caudoviricetes,Uroviricota,Heunggongvirae,Duplodnaviria,...,0.274648,0.232394,0.380282,0.112676,0.126761,0.246479,0.204225,0.246479,0.401408,0.091549,0.147887,0.084507,0.0,0.183099,0.267606,0.260563,0.21831,0.323944,0.232394,0.112676,0.035315,-0.836783,-0.17,-0.553846,0.415035
3,BAF36132.1,AB231700,Microcystis virus Ma-LMM01,Microcystis virus Ma-LMM01 Fukuivirus Caudovir...,162109,False,45.953,DNA,2021-07-14,189,34.391534,65.608466,93.542616,,2,microcystis,Fukuivirus,Fukuivirus,Unclassified,Unclassified,Unclassified,Caudoviricetes,Uroviricota,Heunggongvirae,Duplodnaviria,...,0.161417,0.307087,0.314961,0.153543,0.173228,0.232283,0.181102,0.228346,0.279528,0.133858,0.192913,0.102362,0.007874,0.102362,0.173228,0.216535,0.259843,0.228346,0.244094,0.173228,-0.017569,-0.509137,-0.090314,-0.423137,0.288863
4,BAF36193.1,AB231700,Microcystis virus Ma-LMM01,Microcystis virus Ma-LMM01 Fukuivirus Caudovir...,162109,False,45.953,DNA,2021-07-14,189,34.391534,65.608466,93.542616,,2,microcystis,Fukuivirus,Fukuivirus,Unclassified,Unclassified,Unclassified,Caudoviricetes,Uroviricota,Heunggongvirae,Duplodnaviria,...,0.262931,0.25431,0.317888,0.15625,0.135776,0.274784,0.21444,0.177802,0.330819,0.123922,0.168103,0.136853,0.018319,0.164871,0.226293,0.230603,0.226293,0.246767,0.262931,0.134698,-0.10099,-0.675307,-0.202917,-0.397621,0.265425


Generate the phage-host-features CSV file.

In [11]:
rbp_embeddings.to_csv(os.path.join(f'{constants.INPHARED}/{constants.DATA}', constants.PLM_EMBEDDINGS_CSV['BOECKAERTS']), 
                      index = False)

In [12]:
rbp_embeddings = pd.read_csv(f'{constants.INPHARED}/{constants.DATA}/{constants.PLM_EMBEDDINGS_CSV["BOECKAERTS"]}',
                             low_memory = False)
features = rbp_embeddings.columns
features = features[constants.INPHARED_EXTRA_COLS:]

Train the model, and display the results.

In [13]:
important_features = util.classify(models[-1], display_feature_importance = True, feature_columns = list(features))

Constructing training and test sets...
Training the model...
Saving evaluation results...
Confidence threshold k: 0.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.3529    0.5217        17
    acinetobacter     0.9839    0.6224    0.7625        98
        aeromonas     0.9444    0.5368    0.6846        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.5000    0.6667        26
         bacillus     0.7155    0.8382    0.7720       204
      bacteroides     1.0000    0.9355    0.9667        31
    brevundimonas     1.0000    0.5000    0.6667        10
     burkholderia     0.7188    0.5897    0.6479        39
    campylobacter     0.8240    0.8306    0.8273       124
      caulobacter     0.8400    0.8400    0.8400        25
      citrobacter     0.5000    0.2647    0.3462        34
      clostridium     0.4000    0.5833    0.4746        24
      cronobacter     0.7368    0.3889    0.5091      

Confidence threshold k: 30.0%
                   precision    recall  f1-score   support

    achromobacter     1.0000    0.2353    0.3810        17
    acinetobacter     1.0000    0.3878    0.5588        98
        aeromonas     0.9375    0.3158    0.4724        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.3462    0.5143        26
         bacillus     0.9500    0.6520    0.7733       204
      bacteroides     1.0000    0.8387    0.9123        31
    brevundimonas     1.0000    0.4000    0.5714        10
     burkholderia     1.0000    0.2051    0.3404        39
    campylobacter     0.9500    0.7661    0.8482       124
      caulobacter     1.0000    0.7200    0.8372        25
      citrobacter     0.6667    0.1765    0.2791        34
      clostridium     0.5882    0.4167    0.4878        24
      cronobacter     0.7143    0.1389    0.2326        36
          dickeya     1.0000    0.6346    0.7765        52
     edwardsiella     1.0

Confidence threshold k: 60.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.1837    0.3103        98
        aeromonas     1.0000    0.2211    0.3621        95
    agrobacterium     1.0000    0.3333    0.5000        12
     arthrobacter     1.0000    0.2308    0.3750        26
         bacillus     0.9691    0.4608    0.6246       204
      bacteroides     1.0000    0.4839    0.6522        31
    brevundimonas     1.0000    0.1000    0.1818        10
     burkholderia     1.0000    0.1026    0.1860        39
    campylobacter     0.9762    0.6613    0.7885       124
      caulobacter     1.0000    0.6000    0.7500        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     0.6250    0.2083    0.3125        24
      cronobacter     1.0000    0.0278    0.0541        36
          dickeya     1.0000    0.4423    0.6133        52
     edwardsiella     0.0

Confidence threshold k: 90.0%
                   precision    recall  f1-score   support

    achromobacter     0.0000    0.0000    0.0000        17
    acinetobacter     1.0000    0.0204    0.0400        98
        aeromonas     1.0000    0.0316    0.0612        95
    agrobacterium     0.0000    0.0000    0.0000        12
     arthrobacter     1.0000    0.0769    0.1429        26
         bacillus     1.0000    0.1127    0.2026       204
      bacteroides     1.0000    0.3548    0.5238        31
    brevundimonas     0.0000    0.0000    0.0000        10
     burkholderia     1.0000    0.0513    0.0976        39
    campylobacter     0.9459    0.2823    0.4348       124
      caulobacter     1.0000    0.2000    0.3333        25
      citrobacter     0.0000    0.0000    0.0000        34
      clostridium     1.0000    0.0833    0.1538        24
      cronobacter     0.0000    0.0000    0.0000        36
          dickeya     1.0000    0.3654    0.5352        52
     edwardsiella     0.0

<hr>

## Part IV: Handpicked Features + Embeddings

### Features with Highest Importance (Global)

Investigate the performance if the ProtT5 embeddings are combined with the sequence properties that registered the highest Gini importance after training the phage-host interaction prediction tool by Boeckaerts <i>et al.</i> (2021).

In [14]:
for i in range(5):
    util.classify_handpicked_embeddings([important_features[i]], important_features[i])

Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...


Investigate the performance if the ProtT5 embeddings are combined with the **<i>n</i> sequence properties** that registered the highest Gini importance after training the phage-host interaction prediction tool by Boeckaerts <i>et al.</i> (2021).

In [15]:
for i in range(1, 6):
    util.classify_handpicked_embeddings(important_features[:i], f'handcrafted_{i}')

Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...


### Protein Features with Highest Importance

Investigate the performance if the ProtT5 embeddings are combined with the **protein** sequence properties that registered the highest Gini importance after training the phage-host interaction prediction tool by Boeckaerts <i>et al.</i> (2021).

In [16]:
important_protein_features = ['K', 'pI', 'Z4', 'CTDC2', 'mol_weight']

In [17]:
for i in range(5):
    util.classify_handpicked_embeddings([important_protein_features[i]], important_protein_features[i])

Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...


Investigate the performance if the ProtT5 embeddings are combined with the **<i>n</i> protein sequence properties** that registered the highest Gini importance after training the phage-host interaction prediction tool by Boeckaerts <i>et al.</i> (2021).

In [18]:
for i in range(1, 6):
    util.classify_handpicked_embeddings(important_protein_features[:i], f'handcrafted_protein_{i}')

Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...
Constructing training and test sets...
Training the model...
Saving evaluation results...


<hr>

## Part V: Results

Load the Pickle files storing the results of evaluating the performance of the models.

In [19]:
model_results = []
for model in models:
    with open(constants.PLM_RESULTS[model], 'rb') as f:
        model_results.append(pickle.load(f))

Declare constant variables for readability of subsequent code.

In [20]:
CLASS = 0
MICRO = 1
MACRO = 2
WEIGHTED = 3

PRECISION = 0
RECALL = 1
F1 = 2

The shaded cells in the subsequent tables correspond to the highest scores (i.e., best performance in terms of the specified evaluation metric).

### Protein Language Models vs Handcrafted

In [21]:
results = []
for model in models:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[models.index(model)][threshold][WEIGHTED][PRECISION] * 100)
        
        result.append(f'{metric}%')
        
    results.append(result)

print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = models)
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
PROTTRANSBERT,63.26%,79.31%,82.93%,83.89%,84.12%,84.52%,84.96%,84.55%,84.20%,83.69%,73.74%
PROTXLNET,62.18%,79.28%,82.92%,83.93%,84.23%,84.61%,85.66%,84.67%,84.06%,82.80%,77.07%
PROTTRANSALBERT,62.63%,80.24%,83.37%,84.26%,84.22%,84.65%,84.91%,84.61%,83.97%,83.03%,76.06%
PROTT5,63.57%,80.47%,82.88%,84.13%,84.38%,84.67%,85.43%,84.98%,84.32%,83.51%,77.23%
ESM,63.32%,80.15%,82.67%,83.89%,84.49%,84.66%,85.31%,84.62%,83.66%,82.68%,76.51%
ESM1B,63.38%,79.80%,83.27%,84.11%,84.29%,84.75%,84.71%,84.99%,84.40%,83.42%,76.89%
SEQVEC,63.30%,80.52%,83.30%,84.31%,84.80%,84.99%,85.65%,84.56%,84.26%,83.83%,73.20%
BOECKAERTS,64.98%,79.65%,82.99%,84.57%,84.97%,85.76%,85.55%,84.88%,84.05%,83.58%,73.09%


In [22]:
results = []
for model in models:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[models.index(model)][threshold][WEIGHTED][RECALL] * 100)
        
        result.append(f'{metric}%')
        
    results.append(result)

print("Weighted Recall")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = models)
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Recall


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
PROTTRANSBERT,69.70%,73.57%,71.66%,68.36%,64.82%,61.27%,56.65%,51.60%,46.16%,39.48%,27.02%
PROTXLNET,68.31%,72.12%,69.99%,67.11%,64.17%,60.46%,56.32%,50.95%,45.50%,38.50%,26.56%
PROTTRANSALBERT,68.97%,73.60%,70.91%,67.41%,64.17%,60.71%,56.67%,51.21%,45.53%,38.90%,26.38%
PROTT5,70.91%,75.85%,73.41%,70.39%,67.05%,63.38%,59.15%,53.72%,48.57%,41.03%,27.16%
ESM,70.54%,75.07%,72.55%,69.68%,66.56%,63.13%,58.33%,53.38%,47.92%,40.69%,27.34%
ESM1B,70.81%,75.20%,73.41%,70.01%,66.68%,62.95%,58.49%,53.30%,47.86%,40.59%,27.24%
SEQVEC,69.20%,73.40%,71.13%,68.04%,64.79%,61.26%,56.73%,51.22%,45.97%,38.87%,25.89%
BOECKAERTS,70.33%,73.95%,71.48%,68.26%,64.21%,60.24%,55.48%,49.84%,44.50%,37.81%,24.84%


In [23]:
results = []
for model in models:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[models.index(model)][threshold][WEIGHTED][F1] * 100)
        
        result.append(f'{metric}%')
        
    results.append(result)

print("Weighted F1")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = models)
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted F1


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
PROTTRANSBERT,65.00%,74.39%,73.61%,71.07%,68.05%,64.87%,60.58%,55.27%,49.29%,41.07%,24.59%
PROTXLNET,63.71%,73.31%,72.18%,69.93%,67.44%,64.11%,60.18%,54.54%,48.46%,40.20%,23.92%
PROTTRANSALBERT,64.31%,74.57%,73.01%,70.34%,67.42%,64.32%,60.45%,54.82%,48.51%,40.57%,23.80%
PROTT5,66.10%,76.51%,74.96%,72.78%,70.03%,66.78%,62.95%,57.51%,51.98%,43.05%,24.82%
ESM,65.72%,75.77%,74.24%,72.12%,69.60%,66.60%,62.21%,57.13%,51.18%,42.70%,25.18%
ESM1B,65.92%,75.85%,75.07%,72.46%,69.73%,66.40%,62.27%,57.16%,51.29%,42.66%,25.04%
SEQVEC,64.56%,74.43%,73.14%,70.76%,68.10%,64.97%,60.64%,54.90%,49.12%,40.52%,23.20%
BOECKAERTS,65.64%,74.47%,73.33%,70.84%,67.39%,63.89%,59.35%,53.41%,47.20%,38.97%,21.61%


### Global: ProtT5 + Handcrafted (Individually)

In [24]:
model_results = []
for model in important_features[:5]:
    with open(f'{constants.TEMP_RESULTS}/prott5_{model}.pickle', 'rb') as f:
        model_results.append(pickle.load(f))

In [25]:
results = []
for model in important_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_features[:5].index(model)][threshold][WEIGHTED][PRECISION] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
A_freq,63.80%,80.24%,83.29%,84.03%,84.42%,84.74%,85.86%,85.05%,84.48%,83.47%,76.52%
GC,63.94%,80.30%,83.21%,84.05%,84.50%,84.84%,85.89%,84.93%,84.22%,84.10%,75.47%
C_freq,64.08%,80.61%,83.05%,84.20%,84.49%,84.69%,85.44%,84.80%,84.43%,83.79%,77.00%
TTA,63.83%,80.22%,83.17%,84.06%,84.57%,84.73%,85.46%,84.81%,84.34%,83.60%,77.03%
TTA_b,63.56%,80.65%,83.30%,83.86%,84.47%,84.77%,85.69%,84.78%,84.89%,84.01%,75.16%


In [26]:
results = []
for model in important_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_features[:5].index(model)][threshold][WEIGHTED][RECALL] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Recall")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Recall


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
A_freq,71.39%,76.27%,74.05%,70.93%,67.76%,64.11%,59.71%,54.51%,49.14%,41.29%,27.16%
GC,71.48%,76.50%,74.30%,71.17%,67.89%,64.42%,60.20%,54.77%,49.15%,41.74%,27.70%
C_freq,71.51%,76.38%,74.32%,71.18%,68.05%,64.58%,59.91%,54.62%,48.98%,41.38%,27.24%
TTA,71.30%,75.86%,73.55%,70.58%,67.36%,63.85%,59.51%,53.77%,48.39%,41.01%,27.23%
TTA_b,71.11%,76.21%,73.77%,70.65%,67.11%,63.79%,59.41%,53.78%,48.32%,40.99%,27.07%


In [27]:
results = []
for model in important_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_features[:5].index(model)][threshold][WEIGHTED][F1] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted F1")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted F1


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
A_freq,66.41%,76.78%,75.58%,73.23%,70.63%,67.47%,63.49%,58.36%,52.56%,43.37%,24.93%
GC,66.49%,76.92%,75.76%,73.42%,70.75%,67.74%,63.91%,58.60%,52.60%,43.99%,25.57%
C_freq,66.63%,76.93%,75.77%,73.43%,70.83%,67.87%,63.57%,58.47%,52.44%,43.48%,25.09%
TTA,66.39%,76.50%,75.15%,72.99%,70.33%,67.26%,63.25%,57.60%,51.75%,43.01%,24.97%
TTA_b,66.19%,76.81%,75.33%,72.99%,70.07%,67.18%,63.22%,57.56%,51.70%,43.05%,24.63%


### Protein Features Only: ProtT5 + Handcrafted (Individually)

In [28]:
model_results = []
for model in important_protein_features[:5]:
    with open(f'{constants.TEMP_RESULTS}/prott5_{model}.pickle', 'rb') as f:
        model_results.append(pickle.load(f))

In [29]:
results = []
for model in important_protein_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_protein_features[:5].index(model)][threshold][WEIGHTED][PRECISION] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_protein_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
K,63.62%,80.09%,83.08%,84.27%,84.53%,84.62%,85.54%,84.87%,84.24%,83.98%,78.06%
pI,63.82%,80.51%,83.13%,84.28%,84.55%,84.46%,85.28%,85.34%,84.40%,83.71%,77.09%
Z4,63.69%,80.09%,83.18%,83.99%,84.33%,84.60%,85.37%,84.77%,84.32%,83.79%,77.44%
CTDC2,63.76%,79.91%,83.14%,84.07%,84.28%,84.62%,84.95%,84.96%,84.34%,83.47%,76.82%
mol_weight,63.84%,80.20%,83.24%,84.14%,84.45%,84.46%,85.17%,85.06%,84.33%,83.26%,75.83%


In [30]:
results = []
for model in important_protein_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_protein_features[:5].index(model)][threshold][WEIGHTED][RECALL] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_protein_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
K,70.92%,75.73%,73.58%,70.68%,67.20%,63.64%,59.22%,53.67%,48.63%,41.13%,27.23%
pI,71.00%,75.91%,73.41%,70.45%,67.04%,63.39%,59.13%,53.55%,48.44%,40.93%,27.51%
Z4,70.81%,75.64%,73.35%,70.59%,67.42%,63.54%,59.09%,53.63%,48.18%,41.01%,27.23%
CTDC2,70.87%,75.68%,73.51%,70.49%,67.18%,63.71%,59.01%,53.82%,48.36%,40.88%,27.58%
mol_weight,70.92%,75.76%,73.36%,70.52%,67.26%,63.27%,59.22%,53.87%,48.46%,40.86%,27.34%


In [31]:
results = []
for model in important_protein_features[:5]:
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[important_protein_features[:5].index(model)][threshold][WEIGHTED][F1] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted F1")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = important_protein_features[:5])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted F1


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
K,66.10%,76.32%,75.15%,73.03%,70.18%,66.96%,63.00%,57.49%,52.00%,43.13%,25.07%
pI,66.24%,76.57%,75.02%,72.86%,70.02%,66.76%,62.90%,57.36%,51.76%,43.01%,25.27%
Z4,66.03%,76.32%,74.95%,72.92%,70.31%,66.91%,62.86%,57.40%,51.48%,43.06%,24.97%
CTDC2,66.10%,76.26%,75.12%,72.84%,70.14%,67.09%,62.78%,57.65%,51.64%,42.91%,25.42%
mol_weight,66.16%,76.35%,75.02%,72.94%,70.26%,66.66%,62.98%,57.74%,51.79%,42.87%,25.15%


### Global: ProtT5 + Handcrafted (Top n)

In [32]:
model_results = []
for i in range(1, 6):
    with open(f'{constants.TEMP_RESULTS}/prott5_handcrafted_{i}.pickle', 'rb') as f:
        model_results.append(pickle.load(f))

In [33]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][PRECISION] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,63.80%,80.24%,83.29%,84.03%,84.42%,84.74%,85.86%,85.05%,84.48%,83.47%,76.52%
2,64.17%,80.40%,83.04%,84.33%,84.82%,84.86%,85.80%,85.01%,84.59%,84.14%,76.98%
3,64.42%,80.47%,83.11%,84.38%,84.53%,84.97%,85.99%,85.18%,84.44%,84.31%,78.18%
4,64.21%,80.49%,83.00%,84.46%,84.78%,84.94%,85.46%,85.44%,84.41%,84.26%,74.61%
5,64.55%,80.41%,82.87%,84.32%,84.70%,85.02%,85.31%,85.58%,84.65%,84.11%,77.88%


In [34]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][RECALL] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Recall")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Recall


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,71.39%,76.27%,74.05%,70.93%,67.76%,64.11%,59.71%,54.51%,49.14%,41.29%,27.16%
2,71.83%,76.54%,74.56%,71.64%,68.20%,64.85%,60.55%,54.95%,49.22%,41.51%,27.37%
3,71.85%,76.50%,74.42%,71.66%,68.21%,64.77%,60.45%,55.01%,49.37%,41.93%,27.39%
4,71.77%,76.74%,74.47%,71.61%,68.48%,65.08%,60.31%,55.04%,49.25%,41.67%,26.81%
5,72.09%,76.58%,74.38%,71.82%,68.37%,65.07%,60.33%,54.87%,49.33%,41.94%,27.13%


In [35]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][F1] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted F1")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted F1


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,66.41%,76.78%,75.58%,73.23%,70.63%,67.47%,63.49%,58.36%,52.56%,43.37%,24.93%
2,66.79%,76.97%,75.95%,73.87%,71.04%,68.08%,64.26%,58.76%,52.70%,43.75%,25.13%
3,66.82%,76.91%,75.82%,73.84%,71.02%,68.07%,64.17%,58.83%,52.81%,44.20%,25.33%
4,66.72%,77.14%,75.81%,73.83%,71.30%,68.34%,64.03%,58.83%,52.67%,43.85%,24.44%
5,67.03%,77.04%,75.79%,73.98%,71.17%,68.33%,64.01%,58.68%,52.77%,44.21%,24.80%


### Protein Features Only: ProtT5 + Handcrafted (Top n)

In [36]:
model_results = []
for i in range(1, 6):
    with open(f'{constants.TEMP_RESULTS}/prott5_handcrafted_protein_{i}.pickle', 'rb') as f:
        model_results.append(pickle.load(f))

In [37]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][PRECISION] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Precision")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Precision


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,63.62%,80.09%,83.08%,84.27%,84.53%,84.62%,85.54%,84.87%,84.24%,83.98%,78.06%
2,63.57%,80.56%,83.14%,84.02%,84.49%,84.78%,85.53%,84.88%,84.29%,83.24%,75.05%
3,63.82%,80.40%,83.27%,84.00%,84.30%,84.79%,85.19%,85.32%,84.12%,83.90%,78.62%
4,63.67%,80.24%,83.21%,84.28%,84.50%,84.67%,85.18%,84.98%,83.99%,83.86%,77.62%
5,63.68%,80.34%,83.25%,83.93%,84.33%,84.40%,85.42%,85.29%,84.51%,83.83%,75.81%


In [38]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][RECALL] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted Recall")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted Recall


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,70.92%,75.73%,73.58%,70.68%,67.20%,63.64%,59.22%,53.67%,48.63%,41.13%,27.23%
2,70.85%,76.12%,73.56%,70.39%,67.26%,63.44%,59.23%,53.82%,48.47%,40.91%,27.37%
3,70.95%,75.79%,73.63%,70.50%,66.94%,63.60%,59.15%,53.63%,48.10%,41.03%,27.44%
4,70.92%,75.79%,73.57%,70.69%,67.41%,63.58%,59.19%,53.84%,48.34%,41.26%,27.53%
5,70.81%,75.63%,73.60%,70.50%,67.07%,63.48%,58.99%,54.00%,48.51%,40.89%,27.11%


In [39]:
results = []
for i in range(1, 6):
    result = []
    for threshold in range(0, 11):
        metric = "{:.2f}".format(model_results[i - 1][threshold][WEIGHTED][F1] * 100)
        result.append(f'{metric}%')

    results.append(result)
        
print("Weighted F1")
results_df = pd.DataFrame(results, columns = [str(_) + '%' for _ in range(0, 101, 10)], index = [i for i in range(1, 6)])
results_df.style.highlight_max(color = 'lightgreen', axis = 0)

Weighted F1


Unnamed: 0,0%,10%,20%,30%,40%,50%,60%,70%,80%,90%,100%
1,66.10%,76.32%,75.15%,73.03%,70.18%,66.96%,63.00%,57.49%,52.00%,43.13%,25.07%
2,66.07%,76.69%,75.16%,72.78%,70.24%,66.85%,62.96%,57.57%,51.79%,42.87%,25.17%
3,66.18%,76.44%,75.23%,72.86%,69.93%,67.02%,62.95%,57.47%,51.38%,43.09%,25.19%
4,66.11%,76.40%,75.18%,73.06%,70.36%,66.95%,62.92%,57.63%,51.67%,43.36%,25.34%
5,66.05%,76.29%,75.16%,72.83%,70.01%,66.86%,62.79%,57.82%,51.83%,42.86%,24.80%
