In [1]:
import datetime
import gzip
import sys

import pandas as pd

from datasets import Dataset, load_metric
from transformers import (
    AutoTokenizer, 
    EsmForSequenceClassification, 
    TrainingArguments, 
    Trainer
)

# Pretrained model loading

In [2]:
MODEL_NAME = "esm2_t6_8M_UR50D-finetuned-serratusL"
BATCH_SIZE = 100

In [3]:
metric_paths = ["./metrics/f1", "./metrics/accuracy", "./metrics/precision", "./metrics/recall"]
metrics = [load_metric(path) for path in metric_paths]

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    scores = dict()
    for metric in metrics:
        scores.update(
            metric.compute(predictions=predictions, references=labels)
        )

    return scores

  metrics = [load_metric(path) for path in metric_paths]


In [4]:
trainer_args = TrainingArguments(
    f"{MODEL_NAME}",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

In [5]:
eval_tokenizer = AutoTokenizer.from_pretrained(f"{MODEL_NAME}")
eval_model = EsmForSequenceClassification.from_pretrained(f"{MODEL_NAME}", num_labels=2)
trainer = Trainer(
    eval_model,
    trainer_args,
    tokenizer=eval_tokenizer,
    compute_metrics=compute_metrics,
)

# Study memory and time requirements
Here I just load 10,000 sequences to estimate how long the inference will run for on the whole dataset of 34M sequences. 

In [6]:
TOTAL_SEQ_NUM = 34_217_821 # number of sequences in tar.gz file

In [8]:
sequences, ids = [], []
with open("datasets/serratus_S.fasta", "r") as fasta:
    for i, line in enumerate(fasta):
        if len(sequences) >= 10000:
            break
        elif line.startswith(">"):
            ids.append(line[1:].split(";")[0])
        else:
            sequences.append(line.strip()[:1024])

In [9]:
size_per_seq = (sum(sys.getsizeof(x) for x in sequences) + sys.getsizeof(sequences)) / len(sequences)
(TOTAL_SEQ_NUM * size_per_seq) / (1024 * 1024 * 1024) # approx size of sequences in GB

6.390019190481212

In [10]:
tokens = eval_tokenizer(sequences)

In [11]:
size_per_token_seq = (sum(sys.getsizeof(x) for x in tokens) + sys.getsizeof(tokens)) / len(tokens)
(TOTAL_SEQ_NUM * size_per_token_seq) / (1024 * 1024 * 1024) # approx size of tokenized sequences in GB

2.692831563297659

In [12]:
dataset = Dataset.from_dict(tokens)

In [13]:
pred_probas = trainer.predict(dataset)

***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


In [14]:
seconds_per_seq = 1 / pred_probas.metrics["test_samples_per_second"]
print(datetime.timedelta(seconds = TOTAL_SEQ_NUM * seconds_per_seq))

17:10:26.399855


# Full Test run on "small dataset"

In [15]:
def infer_batch(tokens, ids, output_path, trainer):
    dataset = Dataset.from_list(tokens)
    pred_probas = trainer.predict(dataset)
    preds = pred_probas.predictions.argmax(axis=1)
    with open(output_path, "a") as outfile:
        for id_, pred in zip(ids, preds):
            outfile.write(f"{id_}\t{pred}\n")

In [16]:
%%time
tokens, ids = [], []
counter = 0
res_path = f"serratus_S.{MODEL_NAME}.inference.test.tsv"
with open("datasets/serratus_S.fasta", "r") as fasta:
    for i, line in enumerate(fasta):
        if len(tokens) >= 10_000: # infer and predict in batches
            infer_batch(tokens, ids, res_path, trainer)
            tokens, ids = [], []
        if counter >= 100_000: # max inferences
            break
        elif line.startswith(">"):
            ids.append(line[1:].split(";")[0])
        else:
            tokens.append(eval_tokenizer(line.strip()[:1024]))
            counter += 1

if len(tokens) > 0:
    infer_batch(tokens, ids, res_path, trainer)

***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


***** Running Prediction *****
  Num examples = 10000
  Batch size = 100


CPU times: user 3min 38s, sys: 721 ms, total: 3min 38s
Wall time: 3min 42s


# Full inference
We infer an RDRP status for all $36\cdot10^{6}$ sequences of `datasets/serratus_S.fasta` using our `esm2_t6_UR50D` model fine-tuned on `serratusL`.  
The inference is done on batches of 1 million sequences which are written to the output file incrementally. 

In [None]:
%%time 

tokens, ids = [], []
res_path = f"serratus_S.{MODEL_NAME}.inference.tsv"

with open("datasets/serratus_S.fasta", "r") as fasta:
    for i, line in enumerate(fasta):
        
        # Infer if batch is ready
        if len(tokens) >= 1_000_000: 
            infer_batch(tokens, ids, res_path, trainer)
            tokens, ids = [], []

        # Fill up batch
        if line.startswith(">"):
            ids.append(line[1:].split(";")[0])
        else:
            tokens.append(eval_tokenizer(line.strip()[:1024]))

# Flush incomplete batch
if len(tokens) > 0:
    infer_batch(tokens, ids, res_path, trainer)

***** Running Prediction *****
  Num examples = 1000000
  Batch size = 100


***** Running Prediction *****
  Num examples = 1000000
  Batch size = 100


***** Running Prediction *****
  Num examples = 1000000
  Batch size = 100


***** Running Prediction *****
  Num examples = 1000000
  Batch size = 100


***** Running Prediction *****
  Num examples = 1000000
  Batch size = 100


(the cell above did run but the output is messed up)
## Results

In [24]:
results = pd.read_csv("serratus_S.esm2_t6_8M_UR50D-finetuned-serratusL.inference.tsv", sep="\t", header=None)
results.columns = ["id", "rdrp"]

In [25]:
results.head()

Unnamed: 0,id,rdrp
0,SRR6181341_3_1339_1890_1_ID=58989_3,0
1,SRR7619050_2_274_1434_1_ID=41851_2,0
2,ERR1879302_1_3_1253_1_ID=59412_1,0
3,SRR6917556_1_2_787_1_ID=14462_1,0
4,SRR11673796_1_24_431_-1_ID=520071_1,0


In [36]:
num_hits = (results['rdrp'] == 1).sum()
print(f"{num_hits:,} sequences classified as rdrp ({100 * num_hits / results.shape[0]:.2f}%)")

135,861 sequences classified as rdrp (0.40%)
