### Trying to learn BERT for RNA sequences

Get the data

In [None]:
!wget https://www.dropbox.com/s/sosuzcpzngwiknq/pair_dataset_large.tsv?dl=1 -O pair_dataset_large.tsv

Imports and loads

In [None]:
import pandas as pd
from torch.nn import CrossEntropyLoss
from transformers import DistilBertConfig, DistilBertForMaskedLM
from transformers import BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import notebook
from torch.utils.data import DataLoader, Dataset
from torch import LongTensor
import torch
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
df = pd.read_csv('./pair_dataset_large.tsv',sep="\t")

In [None]:
train, test = train_test_split(df, stratify=df.sequence, test_size=0.33)

In [None]:
train['sequence'] = train['sequence'].apply(lambda x:x[:len(x)//2])
test['sequence'] = test['sequence'].apply(lambda x:x[-len(x)//2:])

Set params

In [None]:
EPOCHS = 128
BATCH_SIZE = 4
BATCHES_PER_STEP = 1
DEVICE = torch.device('cpu') # torch.device('cuda')

In [None]:
MAXLEN = 100 # df.sequence.apply(len).max() # make it smaller for better perfomance (worse score)

In [None]:
WARMUP_STEPS = 0
TOTAL_SCHEDULER_STEPS = train.shape[0] * EPOCHS

In [None]:
DROPOUT = 0.3

In [None]:
EARLY_STOPPING = True
EARLY_STOPPING_PATIENCE = 3
EARLY_STOPPING_TOLERANCE = 0.01

FITTED_THRESHOLD = 1e-8

LOSS_LOG_EACH = 10

Write vocabulary file

In [None]:
vocab = list(set.union(*df.sequence.apply(lambda x: set(x))))
vocab_file = '\n'.join(vocab)
with open("./vocabulary.txt","w") as f:
    f.write(vocab_file)

Define model and tokenizer

In [None]:
config = DistilBertConfig(vocab_size_or_config_json_file=len(vocab), dropout=DROPOUT, max_position_embeddings=MAXLEN)
model = DistilBertForMaskedLM(config)
tokenizer = BertTokenizer("./vocabulary.txt",do_basic_tokenize=True, do_lower_case=False)

Uncomment for cuda

In [None]:
# model.cuda(0)

Create batch generator

In [None]:
class rfam(Dataset):
    def __init__(self, sequences, question_masks, answer_masks, maxlen=512, vocabulary_file="./vocabulary.txt"):
        self.df = pd.DataFrame({'sequences': sequences, 'question_masks': question_masks, 'answer_masks': answer_masks})
        self.tokenizer = BertTokenizer(vocabulary_file, do_basic_tokenize=True, do_lower_case=False)
        self.maxlen = maxlen
        
    def __getitem__(self, i):
        r = self.df.iloc[i]
        encoded_seq = LongTensor(self.tokenizer.encode(' '.join(r.sequences[:self.maxlen]),add_special_tokens=False) + [0 for i in range(self.maxlen-len(r.sequences[:self.maxlen]))])
        attention_mask = LongTensor([1]*len(encoded_seq) + [0]*(self.maxlen-len(encoded_seq)))
        q, a = list(r.question_masks).index('1'), list(r.answer_masks).index('1')
        return encoded_seq, attention_mask, q, a
        
    def __len__(self):
        return self.df.shape[0]
    
train_data = rfam(train.sequence, train['mask'], train.ans, maxlen=MAXLEN)
val_data = rfam(test.sequence, test['mask'], test.ans, maxlen=MAXLEN)
batch_generator = DataLoader(train_data, batch_size=BATCH_SIZE)
validation_sampler = DataLoader(val_data, batch_size=BATCH_SIZE)

Training loop

In [None]:
criterion = CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = WARMUP_STEPS,
                                            num_training_steps = TOTAL_SCHEDULER_STEPS)

bad_epochs_es, bad_epochs_od, min_epoch_val_loss = 0, 0, 9000

epoch_train_losses, epoch_val_losses = [], []
epoch_train_loss_errs, epoch_val_loss_errs = [], []

for epoch in range(EPOCHS):
    losses, batch_idx, val_batch_idx = [], 0, 0
        
    for seq, attn, q, a in notebook.tqdm(batch_generator, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        prediction = model(seq, attention_mask=attn)[0].transpose(1,2).to(device=DEVICE)
        loss = criterion(prediction, seq).to(device=DEVICE)
        losses.append(loss.item())
        loss.backward()
        if batch_idx%BATCHES_PER_STEP==0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()        
        if batch_idx%LOSS_LOG_EACH==0:
            epoch_train_loss = np.mean(losses)
            epoch_train_losses.append(epoch_train_loss)
            losses = []
            epoch_train_loss_errs.append(np.std(losses)/np.sqrt(len(losses)))
            val_losses = []
            for seq, attn, _, _ in validation_sampler:
                val_prediction = model(seq, attention_mask=attn)[0].transpose(1,2).to(device=DEVICE)
                val_loss = criterion(val_prediction, seq).to(device=DEVICE)
                val_losses.append(val_loss.item())
            epoch_val_loss = np.mean(val_losses)
            epoch_val_losses.append(epoch_val_loss)
            val_losses = []
            epoch_val_loss_errs.append(np.std(val_losses)/np.sqrt(len(val_losses)))
        batch_idx+=1
    
    plt.clf()
    clear_output()
    x_axis = [(i+1)*LOSS_LOG_EACH for i in range(len(epoch_train_losses))]
    plt.errorbar(x=x_axis,y=epoch_train_losses, yerr=epoch_train_loss_errs, fmt='o-', capsize=10, label="Train")
    plt.errorbar(x=x_axis,y=epoch_val_losses, yerr=epoch_val_loss_errs, fmt='o-', capsize=10, label="Val")
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.legend()
    plt.grid()
    plt.show()
    if epoch_val_loss >= min_epoch_val_loss + EARLY_STOPPING_TOLERANCE:
        bad_epochs_es += 1
    if epoch_val_loss < min_epoch_val_loss:
        min_epoch_val_loss = epoch_val_loss
        model.save_pretrained("bert_for_rna_seqs")
    if bad_epochs_es > EARLY_STOPPING_PATIENCE:
        print("Break by early stopping")
        break
    if epoch_val_loss < FITTED_THRESHOLD:
        print("Break due to holdout loss being under threshold")
        break