In [1]:
import sys
import os
pwd = os.getcwd()
sys.path.append(f'{pwd}')
os.environ['PYTHONPATH'] = f'{pwd}'

In [17]:
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
from Bio import SeqIO  # for file parsing

In [47]:
from multimolecule import RnaTokenizer
tokenizer = RnaTokenizer(nmers = 1)
print(tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHV.X*-I')["input_ids"])

[1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2]


In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    BertConfig,
    BertForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
import random
import numpy as np
from transformers.trainer_utils import get_last_checkpoint

# BERT config for RNA-seq encoder

In [40]:
from transformers import BertConfig, AutoModelForMaskedLM

config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=640,               # match RNA-FM
    num_hidden_layers=12,
    num_attention_heads=20,
    intermediate_size=5120,
    max_position_embeddings=1024,  # cap at 1024 tokens total (incl. specials)
    type_vocab_size=1,             # no token type (segment) embeddings needed
    pad_token_id=tokenizer.pad_token_id,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    layer_norm_eps=1e-12,
    initializer_range=0.02,
    position_embedding_type="absolute",
    # tie_word_embeddings (defaults to True in HF PreTrainedModel; keeps logits softmax tied to input embeddings)
)

model = AutoModelForMaskedLM.from_config(config)  # constructs a BertForMaskedLM

print(f"\nModel initialized with {model.num_parameters():,} parameters.")


Model initialized with 99,520,026 parameters.


In [41]:
data_dir = f"rnaseq_data"

In [42]:
# load fasta data
fasta_paths = glob.glob(f'{data_dir}/RF*.fasta')
fasta_paths.sort()
print(len(fasta_paths))

3


# Loading sequence with 80k nucleotide sequence for SSL training

In [43]:
rfam_list = []  # list of RNA families

seqs = []  # list of two-element tuples [(sequence ID, sequence),]
labels = []  # list of labels correspond to each entry in the seqs list, the labels are the RNA families

for i, fasta_path in enumerate(fasta_paths):

    if i > 0:
        break
    rfam = Path(fasta_path).stem
    rfam_list.append(rfam)
    print(rfam)

    records = list(SeqIO.parse(fasta_path, 'fasta'))
    fasta_seqs = [str(record.seq) for record in records]
    fasta_seq_names = [record.id for record in records]

    seqs += [(seq_name, seq) for seq_name, seq in zip(fasta_seq_names, fasta_seqs)]
    labels += [rfam] * len(fasta_seq_names)

    print(len(seqs), len(labels))



RF00001
712 712


In [44]:
from datasets import Dataset
texts = [s for _id, s in seqs]   # e.g., ['CUUGA...', 'UACGG...']


count = 0 
for text in texts:
    count += len(list(text))
print("total # nucleotides = ", count)

# --- 3) Build a Dataset
ds = Dataset.from_dict({"text": texts})

# --- 4) Tokenize (truncate to RNA-FM's max length: 1024 tokens)
def tok_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        max_length=500,
        # no need to pad here; the collator will handle dynamic batch padding
        return_special_tokens_mask=True,  # helps the collator avoid masking specials
    )

tokenized = ds.map(tok_fn, batched=True, remove_columns=["text"])

# --- 5) Data collator for MLM (applies 15% masking with BERT 80/10/10)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.20,
)

total # nucleotides =  83024


## Setting up training config and trainer

In [45]:
training_args = TrainingArguments(
    output_dir="./rnafm_mlm",
    per_device_train_batch_size=8,
    num_train_epochs=8,
    learning_rate=2e-5,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
    report_to="none",           # or "tensorboard"
    fp16=False,                 # True if your GPU supports it
)

In [46]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 2, 'bos_token_id': 1}.


Step,Training Loss
100,1.2205
200,1.0876
300,1.0458
400,1.0517
500,1.0479
600,1.0222
700,1.019




TrainOutput(global_step=712, training_loss=1.0698299407958984, metrics={'train_runtime': 576.0845, 'train_samples_per_second': 9.887, 'train_steps_per_second': 1.236, 'total_flos': 415097340431136.0, 'train_loss': 1.0698299407958984, 'epoch': 8.0})

In [39]:
ds

Dataset({
    features: ['text'],
    num_rows: 712
})