In [1]:
import pandas as pd
import sys
sys.path.append('../')
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

## SignalP 5.0 Benchmark set (Eukarya Subset)

Evaluate performance of custom models on the SignalP 5.0 Benchmark set, to see if it is comparable.  
__Important:__ This is just an initial check for orders of magnitude. Models still need to pass proper evaluation either on a val set or crossvalidation.

In [2]:
from train_scripts.utils.signalp_dataset import PartitionThreeLineFastaDataset
ds = PartitionThreeLineFastaDataset('../data/signalp_5_data/benchmark_set.fasta', kingdom_id=['EUKARYA'])
dl = torch.utils.data.DataLoader(ds, collate_fn=ds.collate_fn, batch_size=20)

#### To be solved
- Why do benchmark set seqs have a partition number?

#### AWDLSTM-CRF Eukarya performance

In [3]:
from models.sp_tagging_awd_lstm import ProteinAWDLSTMSequenceTaggingCRF
model = ProteinAWDLSTMSequenceTaggingCRF.from_pretrained('../model_checkpoints/sp_prediction_models/best_mlp_tagging_model/')

In [4]:
losses, global_probs, probs, pos_preds, global_targets, targets = [], [], [] ,[], [], []
model.eval()
for i, b in tqdm(enumerate(dl), total=len(dl)):
    target = b[1]
    global_target = (target == 0).any(axis =1) *1
    with torch.no_grad():
        loss, global_prob, prob, pos_pred = model(b[0], targets= target, global_targets=global_target)
    losses.append(loss)
    global_probs.append(global_prob)
    probs.append(prob)
    pos_preds.append(pos_pred)
    targets.append(b[1])
    global_targets.append(global_target)

global_targets = torch.cat(global_targets)
global_probs = torch.cat(global_probs)

HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))




In [92]:
#NOTE as long as this is binary (=eukarya model), argmax is ok. Otherwise not so sure.
from sklearn.metrics import matthews_corrcoef, accuracy_score, precision_score
mcc_detection = matthews_corrcoef(global_targets.numpy(), np.argmax(global_probs.numpy(), axis =1))
print(f'MCC for SP detection : {mcc_detection:.3f}')

MCC for SP detection : 0.902


In [53]:
from analysis_utils import tagged_seq_to_cs
true_cs = tagged_seq_to_cs(torch.cat(targets).numpy())
pred_cs = tagged_seq_to_cs(torch.cat(pos_preds).numpy())
#replace nan with -1 to make metric functions work
true_cs[np.isnan(true_cs)] = -1
pred_cs[np.isnan(pred_cs)] = -1

In [97]:
mcc_cs = matthews_corrcoef(true_cs[~np.isnan(true_cs)], pred_cs[~np.isnan(true_cs)])
#precision = precision_score(true_cs[~np.isnan(true_cs)], pred_cs[~np.isnan(true_cs)])
accuracy = accuracy_score(true_cs[~np.isnan(true_cs)], pred_cs[~np.isnan(true_cs)])
print(f'MCC for CS tagging (no tolerance window) : {mcc_cs:.3f}')
print(f'Accuracy: {accuracy:.3f}')

MCC for CS tagging (no tolerance window) : 0.687
Accuracy: 0.981


In [None]:
%%script false
#generate .fasta format for SignalP web server
with open('euk_benchmark.fasta', 'w') as f:
    for x,y  in zip(ds.identifiers[5000:], ds.sequences[5000:]):
        f.write(x)
        f.write('\n')
        f.write(y)
        f.write('\n')

### SignalP webserver results on same dataset

In [87]:
df = pd.read_csv('signalp_results_benchmark_set_euk.tsv', sep ='\t')
mcc_detection = matthews_corrcoef(global_targets.numpy(), df['Prediction'].astype('category').cat.codes)

print(f'MCC for SP detection : {mcc_detection:.3f}')

MCC for SP detection : 0.920


In [88]:
cs_pos = df['CS Position'].str.extract(r'([0-9]{1,2})-')
cs_pos[cs_pos.isna()] = -1
cs_pos = cs_pos.values.astype(int)

In [98]:
mcc_cs = matthews_corrcoef(true_cs[~np.isnan(true_cs)], cs_pos[~np.isnan(true_cs)])
accuracy = accuracy_score(true_cs[~np.isnan(true_cs)], cs_pos[~np.isnan(true_cs)])
print(f'MCC for CS tagging (no tolerance window) : {mcc_cs:.3f}')
print(f'Accuracy: {accuracy:.3f}')

MCC for CS tagging (no tolerance window) : 0.840
Accuracy: 0.991


### Summary
- SignalP outperforms AWDLSTM-CRF.
- SignalP seems to perform better than stated in the supplementary material.
- Clarification needed for nested cross-validation? What is the final model trained on? If trained on all, this would explain the weakness of AWDLSTM-CRF a bit. Model has seen 40% less data than SignalP, so has a harder time predicting those homology groups.

- __AWDLSTM-CRF performance still looks competitive. Together with the better performance seen on plasmodium, worthy of further pursuing.__

______________________

### Test Split (Famsa partition 0) ==20% test

In [68]:
ds = PartitionThreeLineFastaDataset('../data/signalp_5_data/famsa_225_partitions/partition_0.0.fasta')
dl = torch.utils.data.DataLoader(ds, collate_fn=ds.collate_fn, batch_size=20)

In [31]:
losses, global_probs, probs, pos_preds, global_targets, targets = [], [], [] ,[], [], []
model.eval()
for i, b in tqdm(enumerate(dl), total=len(dl)):
    target = b[1]
    global_target = (target == 0).any(axis =1) *1
    with torch.no_grad():
        loss, global_prob, prob, pos_pred = model(b[0], targets= target, global_targets=global_target)
    losses.append(loss)
    global_probs.append(global_prob)
    probs.append(prob)
    pos_preds.append(pos_pred)
    targets.append(b[1])
    global_targets.append(global_target)


HBox(children=(FloatProgress(value=0.0, max=248.0), HTML(value='')))




In [32]:
global_targets = torch.cat(global_targets)
global_probs = torch.cat(global_probs)
pos_preds = torch.cat(pos_preds)
targets = torch.cat(targets)
matthews_corrcoef(global_targets.numpy(), np.argmax(global_probs.numpy(), axis =1))

0.9218481802331078

In [51]:
#SignalP performance on partition 0
signalp_preds = pd.read_csv('partition_0_signalp.txt', sep ='\t')['Prediction'].astype('category').cat.codes.values
matthews_corrcoef(global_targets.numpy(), signalp_preds)

0.9786515298497976