In [4]:
import numpy as np
import pandas as pd

from collections import Counter

from sklearn.model_selection import train_test_split

from datasets import Dataset, load_metric
from transformers import (
    AutoTokenizer, 
    EsmForSequenceClassification, 
    TrainingArguments, 
    Trainer
)

In [5]:
MODEL_PRETRAINED = "esm2_t6_8M_UR50D" # Path to the pre-trained ESM model
BATCH_SIZE = 10 # Training batch size

In [24]:
def read_fasta(path: str, min_length=200, max_length=1024):
    """
    Reads a fasta file and returns a list of sequences, and a list of sequence identifiers. 
    Sequences are filtered out if shorter than `min_length`, and truncated if longer than `max_length`.
    """
    sequences, ids = [], []
    sequence = ""
    with open(path, "r") as fasta:
        for line in fasta:
            if line.startswith(">"):
                if len(sequence) >= min_length:
                    ids.append(line[1:].strip())
                    sequences.append(sequence[:max_length])
                sequence = ""
                continue
            sequence += line.strip()
    if len(sequence) >= min_length:
        ids.append(line[1:].strip())
        sequences.append(sequence[:max_length])
    
    return sequences, ids

# Data Loading and preparation

We use 2 datasets for training: 

 - 10k Non-RDRP protein sequences in `datasets/rdrp_decoy.10k.fa`
 - 4627 RDRP sequences from wolf2018 in `datasets/wolf2018.fa`

These are read and combined to form the training set.

## Importing and formatting data

The data used as input for the ESM based classifier is the set of raw sequences, truncated to a max length of 1024.  
The labels are `1` for RDRP sequences and `0` for Non-RDRP. 

In [5]:
pos_sequences, pos_ids = read_fasta("datasets/wolf2018.fa")
pos_labels = [1 for _ in pos_sequences]

In [6]:
neg_sequences, neg_ids = read_fasta("datasets/rdrp_decoy.10k.fa")
neg_labels = [0 for _ in neg_sequences]

In [7]:
sequences = pos_sequences + neg_sequences
labels = pos_labels + neg_labels
ids = pos_ids + neg_ids

## Making datasets 
The data is then split into a training and a testing set, the sequences in each set are then tokenized using the included pretrained Tokenizer.

In [11]:
# Splitting dataset
train_sequences, test_sequences, train_labels, test_labels = (
    train_test_split(sequences, labels, test_size=0.25, shuffle=True)
)

In [12]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PRETRAINED)

In [13]:
# Tokenize inputs
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

In [14]:
# Build datasets
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)

# Model training

The sequence classifier is trained on the created dataset. Since the training data is imbalanced *(5k pos for 10k neg)* we use the F1 score to train model.  
For evaluation we compute several metrics, to get an overall view of how the model performs:  
 - F1
 - Accuracy
 - Precision
 - Recall

In [9]:
metric_paths = ["./metrics/f1", "./metrics/accuracy", "./metrics/precision", "./metrics/recall"]
metrics = [load_metric(path) for path in metric_paths]

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    scores = dict()
    for metric in metrics:
        scores.update(
            metric.compute(predictions=predictions, references=labels)
        )

    return scores

  metrics = [load_metric(path) for path in metric_paths]


In [16]:
num_labels = 2 # Binary classifier
model = EsmForSequenceClassification.from_pretrained(
    MODEL_PRETRAINED, num_labels=num_labels
)

Some weights of the model checkpoint at esm2_t6_8M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bi

In [7]:
trainer_args = TrainingArguments(
    f"{MODEL_PRETRAINED}-finetuned-rdrp",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

In [18]:
trainer = Trainer(
    model,
    trainer_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

***** Running training *****
  Num examples = 8994
  Num Epochs = 3
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 2700
  Number of trainable parameters = 7840642


Epoch,Training Loss,Validation Loss,F1,Accuracy,Precision,Recall
1,0.0867,0.005274,0.999142,0.999333,0.999142,0.999142
2,0.0022,0.003123,0.999571,0.999667,0.999142,1.0
3,0.0006,0.002903,0.999571,0.999667,0.999142,1.0


***** Running Evaluation *****
  Num examples = 2999
  Batch size = 10
Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-900
Configuration saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-900/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-900/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-900/tokenizer_config.json
Special tokens file saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-900/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 2999
  Batch size = 10
Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-1800
Configuration saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-1800/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-1800/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-rdrp/checkpoint-1800/tokenizer_config.json
Special tokens file saved in esm2_t6_8M_UR50D-finetuned-rdrp/check

TrainOutput(global_step=2700, training_loss=0.017897054202578686, metrics={'train_runtime': 386.5795, 'train_samples_per_second': 69.797, 'train_steps_per_second': 6.984, 'total_flos': 889892406284424.0, 'train_loss': 0.017897054202578686, 'epoch': 3.0})

In [20]:
trainer.save_model()
trainer.save_state()

Saving model checkpoint to esm2_t6_8M_UR50D-finetuned-rdrp
Configuration saved in esm2_t6_8M_UR50D-finetuned-rdrp/config.json
Model weights saved in esm2_t6_8M_UR50D-finetuned-rdrp/pytorch_model.bin
tokenizer config file saved in esm2_t6_8M_UR50D-finetuned-rdrp/tokenizer_config.json
Special tokens file saved in esm2_t6_8M_UR50D-finetuned-rdrp/special_tokens_map.json


# Model Evaluation

the model is then evaluated on a number of datasets:  
 - "Original" `rchikhi/palmesm` SerratusL training set (`datasets/eval/serratusL-negdepleted.csv`)
 - Other SerratusL negative sequences (`datasets/eval/other-negs.csv`)
 - ~~CFDL sequences (`datasets/eval/CFDL-sample.fa`)~~
 - Palmcore decoys (`datasets/eval/palmcores.fa`)
 

In [10]:
eval_tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_PRETRAINED}-finetuned-rdrp")
eval_model = EsmForSequenceClassification.from_pretrained(f"{MODEL_PRETRAINED}-finetuned-rdrp", num_labels=2)
trainer = Trainer(
    eval_model,
    trainer_args,
    tokenizer=eval_tokenizer,
    compute_metrics=compute_metrics,
)

In [19]:
def compute_specificity(predictions, labels):
    counts = Counter(zip(labels, predictions))
    
    TP, TN = counts[(1,1)], counts[(0,0)]
    FP, FN = counts[(0,1)], counts[(1,0)]
    
    return TN / (TN + FP)

### SerratusL 2 classes

In [11]:
df = pd.read_csv("datasets/eval/serratusL-negdepleted.csv")
labels = [1 if label == "rdrp" else 0 for label in df["type"].values]
tokens = eval_tokenizer(df["seq"].values.tolist())

serratusL_dataset = Dataset.from_dict(tokens)
pred_probas = trainer.predict(serratusL_dataset)
serratusL_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

serratusL_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

serratusL_metrics

***** Running Prediction *****
  Num examples = 43968
  Batch size = 10


{'f1': 0.7154921729439832,
 'accuracy': 0.6283888282387191,
 'precision': 0.5578787302794146,
 'recall': 0.997233278322493}

### SerratusL other negs

In [22]:
df = pd.read_csv("datasets/eval/other-negs.csv", header=None)
df.columns = ["seq", "type"]
labels = [1 if label == "rdrp" else 0 for label in df["type"].values]
tokens = eval_tokenizer(df["seq"].values.tolist())

serratusOther_dataset = Dataset.from_dict(tokens)

pred_probas = trainer.predict(serratusOther_dataset)
serratusOther_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

serratusOther_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

serratusOther_metrics

***** Running Prediction *****
  Num examples = 10000
  Batch size = 10


  _warn_prf(average, modifier, msg_start, len(result))


{'f1': 0.0,
 'accuracy': 0.686,
 'precision': 0.0,
 'recall': 0.0,
 'specificity': 0.686}

### Palmcore sequences

In [25]:
seqs, ids = read_fasta("datasets/eval/palmcores.fa")
labels = [0 if "decoy" in id else 1 for id in ids]
tokens = eval_tokenizer(seqs)

palmcore_dataset = Dataset.from_dict(tokens)

pred_probas = trainer.predict(palmcore_dataset)
palmcore_metrics = trainer.compute_metrics((pred_probas.predictions, labels))

palmcore_metrics["specificity"] = compute_specificity(
    np.argmax(pred_probas.predictions, axis=-1), labels
)

palmcore_metrics

***** Running Prediction *****
  Num examples = 5046
  Batch size = 10


{'f1': 0.9557816068092174,
 'accuracy': 0.915576694411415,
 'precision': 0.9169488149770962,
 'recall': 0.9980489919791893,
 'specificity': 0.03695150115473441}

In [38]:
predictions = np.argmax(pred_probas.predictions, axis=1)
decoys = [(p,t) for p,t in zip(predictions, labels) if t == 0]
tn_idx = [i for i, (p,t) in enumerate(decoys) if p==t]
fp_idx = [i for i, (p,t) in enumerate(decoys) if p!=t]

In [41]:
[ids[x] for x in fp_idx]

['decoy.1D8Y_A',
 'decoy.1XHZ_B',
 'decoy.2ATQ_A',
 'decoy.2DY4_B',
 'decoy.2OYQ_B',
 'decoy.2QV6_A',
 'decoy.3GV5_B',
 'decoy.5VU8_A',
 'decoy.5XOX_C',
 'decoy.6M7O_A',
 'decoy.7D9U_B',
 'decoy.A0A091DUR7_FUKDA',
 'decoy.A0A0K0F8E6_STRVS',
 'decoy.A0A151RAL1_CAJCA',
 'decoy.A0A151RPT4_CAJCA',
 'decoy.A0A158P8V2_ANGCA',
 'decoy.A0A165CCP3_9APHY',
 'decoy.A0A251UDF3_HELAN',
 'decoy.A0A2K3MUJ9_TRIPR',
 'decoy.A0A2P7YV37_9ASCO',
 'decoy.A0A2U9C161_SCOMX',
 'decoy.A0A319DXT4_ASPSB',
 'decoy.A0A4E0RLC3_FASHE',
 'decoy.A0A4W6CRS7_LATCA',
 'decoy.A0A5F9C6V0_RABIT',
 'decoy.A0A5F9CQF1_RABIT',
 'decoy.A0A5F9D223_RABIT',
 'decoy.A0A5F9DLH3_RABIT',
 'decoy.A0A5N6P6T5_9ASTR',
 'decoy.A0A656KFX9_BLUGR',
 'decoy.A0A6A3BX58_HIBSY',
 'decoy.A0A6D2JJ90_9BRAS',
 'decoy.A0A6D2KBZ0_9BRAS',
 'decoy.A0A7M7HK94_STRPU',
 'decoy.U6EAV9_9EURY',
 'decoy.giii.WP_000108316',
 'decoy.giii.WP_000169371',
 'decoy.giii.WP_002615862',
 'decoy.giii.WP_003021071',
 'decoy.giii.WP_004152765',
 'decoy.giii.WP_005935480',
 

### Joined metrics

In [26]:
metrics_df = pd.DataFrame({
    "serratusL": serratusL_metrics,
    "serratusL_other": serratusOther_metrics,
    "palmcore": palmcore_metrics,
})

metrics_df.transpose()

Unnamed: 0,f1,accuracy,precision,recall,specificity
serratusL,0.715492,0.628389,0.557879,0.997233,0.303176
serratusL_other,0.0,0.686,0.0,0.0,0.686
palmcore,0.955782,0.915577,0.916949,0.998049,0.036952
