In [1]:
import numpy as np
import pandas as pd

from collections import Counter

from sklearn.model_selection import train_test_split

from datasets import Dataset, load_metric
from transformers import (
    AutoTokenizer, 
    EsmForSequenceClassification, 
    TrainingArguments, 
    Trainer
)

In [2]:
MODEL_PRETRAINED = "esm2_t6_8M_UR50D" # Path to the pre-trained ESM model
BATCH_SIZE = 10 # Training batch size

In [3]:
def read_fasta(path: str, min_length=200, max_length=1024):
    """
    Reads a fasta file and returns a list of sequences, and a list of sequence identifiers. 
    Sequences are filtered out if shorter than `min_length`, and truncated if longer than `max_length`.
    """
    sequences, ids = [], []
    sequence = ""
    with open(path, "r") as fasta:
        for line in fasta:
            if line.startswith(">"):
                if len(sequence) >= min_length:
                    ids.append(line[1:].strip())
                    sequences.append(sequence[:max_length])
                sequence = ""
                continue
            sequence += line.strip()
    if len(sequence) >= min_length:
        ids.append(line[1:].strip())
        sequences.append(sequence[:max_length])
    
    return sequences, ids

# Data Loading and preparation

We use a dataset derived from SerratusL for training: 

## Importing and formatting data

The data used as input for the ESM based classifier is the set of raw sequences, truncated to a max length of 1024.  
The labels are `1` for RDRP sequences and `0` for Non-RDRP. 

In [4]:
df = pd.read_csv("datasets/eval/serratusL-negdepleted.csv")

In [7]:
pos_sequences = df[df["type"] == "rdrp"]["seq"].values.tolist()
pos_labels = [1 for _ in pos_sequences]

In [8]:
neg_sequences = df[df["type"] != "rdrp"]["seq"].values.tolist()
neg_labels = [0 for _ in neg_sequences]

In [10]:
sequences = pos_sequences + neg_sequences
labels = pos_labels + neg_labels

## Making datasets 
The data is then split into a training and a testing set, the sequences in each set are then tokenized using the included pretrained Tokenizer.

In [11]:
# Splitting dataset
train_sequences, test_sequences, train_labels, test_labels = (
    train_test_split(sequences, labels, test_size=0.25, shuffle=True)
)

In [12]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PRETRAINED)

In [13]:
# Tokenize inputs
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [14]:
# Build datasets
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

# Model training

The sequence classifier is trained on the created dataset. Since the training data is imbalanced *(5k pos for 10k neg)* we use the F1 score to train model.  
For evaluation we compute several metrics, to get an overall view of how the model performs:  
 - F1
 - Accuracy
 - Precision
 - Recall

In [15]:
metric_paths = ["./metrics/f1", "./metrics/accuracy", "./metrics/precision", "./metrics/recall"]
metrics = [load_metric(path) for path in metric_paths]

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    scores = dict()
    for metric in metrics:
        scores.update(
            metric.compute(predictions=predictions, references=labels)
        )

    return scores

  metrics = [load_metric(path) for path in metric_paths]


In [16]:
num_labels = 2 # Binary classifier
model = EsmForSequenceClassification.from_pretrained(
    MODEL_PRETRAINED, num_labels=num_labels
)

Some weights of the model checkpoint at esm2_t6_8M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at esm2_t6_8M_UR50D and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bi

In [17]:
trainer_args = TrainingArguments(
    f"{MODEL_PRETRAINED}-finetuned-serratusL",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

In [18]:
trainer = Trainer(
    model,
    trainer_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

***** Running training *****
  Num examples = 32976
  Num Epochs = 3
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 9894
  Number of trainable parameters = 7840642


Epoch,Training Loss,Validation Loss,F1,Accuracy,Precision,Recall
1,0.1229,0.119739,0.967083,0.968341,0.949304,0.985541
2,0.0749,0.11224,0.972338,0.973435,0.955858,0.989397
3,0.0526,0.115589,0.972746,0.973981,0.961749,0.983998


***** Running Evaluation *****
  Num examples = 10992
  Batch size = 10
Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-3298
Configuration saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-3298/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-3298/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-3298/tokenizer_config.json
Special tokens file saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-3298/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 10992
  Batch size = 10
Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-6596
Configuration saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-6596/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-6596/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-serratusL/checkpoint-6596/tokenizer_config.json
Special tokens

TrainOutput(global_step=9894, training_loss=0.10503221353377047, metrics={'train_runtime': 516.1023, 'train_samples_per_second': 191.683, 'train_steps_per_second': 19.171, 'total_flos': 1135465496038080.0, 'train_loss': 0.10503221353377047, 'epoch': 3.0})

In [20]:
trainer.save_model()
trainer.save_state()

Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-serratusL
Configuration saved in esm2_t6_8M_UR50D-finetuned-serratusL/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-serratusL/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-serratusL/tokenizer_config.json
Special tokens file saved in esm2_t6_8M_UR50D-finetuned-serratusL/special_tokens_map.json


# Model Evaluation

the model is then evaluated on a number of datasets:  
 - wolf2018 positive sequences (`datasets/wolf2018.fa`)
 - tricky decoys (`datasets/rdrp_decoy.10k.fa`)
 - Other SerratusL negative sequences (`datasets/eval/other-negs.csv`)
 - ~~CFDL sequences (`datasets/eval/CFDL-sample.fa`)~~
 - Palmcore decoys (`datasets/eval/palmcores.fa`)
 

In [21]:
eval_tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_PRETRAINED}-finetuned-serratusL")
eval_model = EsmForSequenceClassification.from_pretrained(f"{MODEL_PRETRAINED}-finetuned-serratusL", num_labels=2)
trainer = Trainer(
    eval_model,
    trainer_args,
    tokenizer=eval_tokenizer,
    compute_metrics=compute_metrics,
)

loading file vocab.txt
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json
loading configuration file esm2_t6_8M_UR50D-finetuned-serratusL/config.json
Model config EsmConfig {
  "_name_or_path": "esm2_t6_8M_UR50D",
  "architectures": [
    "EsmForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "esmfold_config": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 320,
  "initializer_range": 0.02,
  "intermediate_size": 1280,
  "is_folding_model": false,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "problem_type": "single_label_classification",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.24.0",
  "use_cache": true,

In [26]:
def compute_specificity(predictions, labels):
    counts = Counter(zip(labels, predictions))
    
    TP, TN = counts[(1,1)], counts[(0,0)]
    FP, FN = counts[(0,1)], counts[(1,0)]
    
    if TN + FP == 0: return 0 
    
    return TN / (TN + FP)

### Wolf2018 classes

In [41]:
sequences, ids = read_fasta("datasets/wolf2018.fa")
labels = [1 for _ in sequences]
tokens = eval_tokenizer(sequences)

wolf2018_dataset = Dataset.from_dict(tokens)
pred_probas = trainer.predict(wolf2018_dataset)
wolf2018_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

wolf2018_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

print(wolf2018_metrics)

predictions = np.argmax(pred_probas.predictions, axis=1)
couples = [(p,t) for p,t in zip(predictions, labels)]
tp_idx = [i for i, (p,t) in enumerate(couples) if p==t]
fn_idx = [i for i, (p,t) in enumerate(couples) if p!=t]

wolf2018_fns = [ids[x] for x in fn_idx]
wolf2018_fns

***** Running Prediction *****
  Num examples = 4627
  Batch size = 10


{'f1': 0.9821821381434228, 'accuracy': 0.964988113248325, 'precision': 1.0, 'recall': 0.964988113248325, 'specificity': 0}


['1123592258',
 '501421962',
 '489297600',
 '919371179',
 '946694699',
 '1062933037',
 '923750907',
 '1025017799',
 '390353251',
 'APG77107.1|Changjiang_narna_like_virus_5',
 'APG77159.1|Hubei_narna_like_virus_20',
 'YP_009330065.1|Hubei_narna_like_virus_18',
 'YP_009336688.1|Sanxia_water_strider_virus_13',
 'YP_009329842.1|Hubei_narna_like_virus_24',
 'YP_009336767.1|Hubei_narna_like_virus_23',
 'AIF33766.2|Heterobasidion_mitovirus_1',
 'YP_009009144.1|Sclerotinia_sclerotiorum_mitovirus_6',
 'AHI43534.1|Fusarium_circinatum_mitovirus_2_1',
 'YP_009259483.1|Cronartium_ribicola_mitovirus_4',
 'AGW51760.1|Mitovirus_AEF_2013',
 'AHX84135.1|Sclerotinia_sclerotiorum_mitovirus_7',
 'AHF48631.1|Sclerotinia_sclerotiorum_mitovirus_15',
 'AHX84129.1|Sclerotinia_sclerotiorum_mitovirus_2',
 'ALM62243.1|Soybean_leaf_associated_mitovirus_3',
 'AAA66950.1|Southern_bean_mosaic_virus',
 'YP_004869651.2|Soybean_yellow_common_mosaic_virus',
 'YP_009140472.1|Cymbidium_chlorotic_mosaic_virus',
 'NP_941957.2

### Tricky decoy 

In [39]:
sequences, ids = read_fasta("datasets/rdrp_decoy.10k.fa")
labels = [0 for _ in sequences]
tokens = eval_tokenizer(sequences)

tricky_dataset = Dataset.from_dict(tokens)
pred_probas = trainer.predict(tricky_dataset)
tricky_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

tricky_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

print(tricky_metrics)

predictions = np.argmax(pred_probas.predictions, axis=1)
decoys = [(p,t) for p,t in zip(predictions, labels) if t == 0]
tn_idx = [i for i, (p,t) in enumerate(decoys) if p==t]
fp_idx = [i for i, (p,t) in enumerate(decoys) if p!=t]

tricky_fps = [ids[x] for x in fp_idx]
tricky_fps

***** Running Prediction *****
  Num examples = 7366
  Batch size = 10


{'f1': 0.0, 'accuracy': 0.9933478142818355, 'precision': 0.0, 'recall': 0.0, 'specificity': 0.9933478142818355}


  _warn_prf(average, modifier, msg_start, len(result))


['WP_025353893_1 peptidase M24 family protein [Kutzneria albida]',
 'WP_043739048_1 transketolase [Thioalkalivibrio nitratireducens]',
 'WP_013762692_1 aspartate aminotransferase family protein [Haliscomenobacter hydrossis]',
 'WP_011771135_1 50S ribosomal protein L13 [Psychromonas ingrahamii]',
 'WP_011414694_1 nucleoside-diphosphate kinase [Erythrobacter litoralis]',
 'WP_099339856_1 D-amino-acid transaminase [Candidatus Fonsibacter ubiquis]',
 'YP_003613465_1 galactokinase [Enterobacter cloacae subsp. cloacae ATCC 13047]',
 'WP_014261837_1 glutamine synthetase [Filifactor alocis]',
 'WP_015935410_1 MULTISPECIES: type I glutamate--ammonia ligase [Anaeromyxobacter]',
 'WP_011769510_1 ribonucleotide-diphosphate reductase subunit beta [Psychromonas ingrahamii]',
 'WP_012909532_1 vitamin B12-dependent ribonucleotide reductase [Pirellula staleyi]',
 'WP_066635279_1 glucose-6-phosphate dehydrogenase [Serinicoccus sp. JLT9]',
 'WP_081598702_1 citrate synthase/methylcitrate synthase [Tistrel

### SerratusL other negs

In [42]:
df = pd.read_csv("datasets/eval/other-negs.csv", header=None)
df.columns = ["seq", "type"]
labels = [1 if label == "rdrp" else 0 for label in df["type"].values]
tokens = eval_tokenizer(df["seq"].values.tolist())

serratusOther_dataset = Dataset.from_dict(tokens)

pred_probas = trainer.predict(serratusOther_dataset)
serratusOther_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

serratusOther_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

serratusOther_metrics

***** Running Prediction *****
  Num examples = 10000
  Batch size = 10


  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.0,
 'accuracy': 0.9891,
 'precision': 0.0,
 'recall': 0.0,
 'specificity': 0.9891}

### Palmcore decoys

In [44]:
seqs, ids = read_fasta("datasets/eval/palmcores.fa")
labels = [0 if "decoy" in id else 1 for id in ids]

filtered = [
    seq for (seq, label) in zip(seqs, labels)
    if label == 0
]
labels = [0 for _ in filtered]

tokens = eval_tokenizer(filtered)

palmcore_dataset = Dataset.from_dict(tokens)

pred_probas = trainer.predict(palmcore_dataset)
palmcore_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

palmcore_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

print(palmcore_metrics)

predictions = np.argmax(pred_probas.predictions, axis=1)
decoys = [(p,t) for p,t in zip(predictions, labels) if t == 0]
tn_idx = [i for i, (p,t) in enumerate(decoys) if p==t]
fp_idx = [i for i, (p,t) in enumerate(decoys) if p!=t]

palmcore_fps = [ids[x] for x in fp_idx]
palmcore_fps

***** Running Prediction *****
  Num examples = 433
  Batch size = 10


{'f1': 0.0, 'accuracy': 0.9930715935334873, 'precision': 0.0, 'recall': 0.0, 'specificity': 0.9930715935334873}


  _warn_prf(average, modifier, msg_start, len(result))


['decoy.2DY4_B', 'decoy.A0A0K0F8E6_STRVS', 'decoy.A0A4W6CRS7_LATCA']

### Joined metrics

In [45]:
metrics_df = pd.DataFrame({
    "wolf2018 (+)": wolf2018_metrics,
    "rdrp_decoy.10k (-)": tricky_metrics,
    "serratusL_other_negs (-)": serratusOther_metrics,
    "palmcore_decoys (-)": palmcore_metrics,
})

metrics_df.transpose()[["recall", "specificity"]]

Unnamed: 0,recall,specificity
wolf2018 (+),0.964988,0.0
rdrp_decoy.10k (-),0.0,0.993348
serratusL_other_negs (-),0.0,0.9891
palmcore_decoys (-),0.0,0.993072


In [46]:
metrics_df.transpose()

Unnamed: 0,f1,accuracy,precision,recall,specificity
wolf2018 (+),0.982182,0.964988,1.0,0.964988,0.0
rdrp_decoy.10k (-),0.0,0.993348,0.0,0.0,0.993348
serratusL_other_negs (-),0.0,0.9891,0.0,0.0,0.9891
palmcore_decoys (-),0.0,0.993072,0.0,0.0,0.993072
