In [1]:
import torch
import pathlib
import gc
import math
import random

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm



In [2]:
INPUT_PATH = pathlib.Path('/kaggle/input/stanford-ribonanza-rna-folding-converted')
MODEL_PATH = pathlib.Path('/kaggle/input/rna-folding-model/')
WORKING_PATH = pathlib.Path('/kaggle/working/')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
full_df = pd.read_parquet(INPUT_PATH/"train_data.parquet")
test_df = pd.read_parquet(INPUT_PATH/"test_sequences.parquet")

df_2A3 = full_df[full_df.experiment_type =='2A3_MaP'].reset_index(drop=True)
df_DMS = full_df[full_df.experiment_type =='DMS_MaP'].reset_index(drop=True)
train_2A3, val_2A3, train_DMS, val_DMS= train_test_split(df_2A3, df_DMS, test_size=0.1, random_state=42)

In [4]:
class RNA_Dataset(Dataset):
    def __init__(self, df_2A3, df_DMS):
        # filter noisy data for now
        predicate = (df_2A3.SN_filter.values > 0) & (df_DMS.SN_filter.values > 0)
        df_2A3 = df_2A3[predicate].reset_index(drop=True)
        df_DMS = df_DMS[predicate].reset_index(drop=True)
        
        self.seq_map = {'A':1, 'C':2, 'G':3, 'U':4}
        self.seqs = df_2A3.sequence.values
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        
    def __len__(self):
        return len(self.seqs)
        
    def __getitem__(self, idx):
        seq = self.seqs[idx]
        seq_idx = torch.tensor([self.seq_map[s] for s in seq], dtype=torch.long)
        labels = torch.tensor(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]], -1), dtype=torch.float32)
        return seq_idx, labels
    
# Useful for sampling batches of similar lengths to minimize padding
class GroupLengthBatchSampler(BatchSampler):
    def __iter__(self):
        dataset = self.sampler.data_source
        indices = [idx for idx in self.sampler]

        step = 100 * self.batch_size
        for i in range(0, len(dataset), step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: len(dataset[x][0]))
            for j in range(0, len(pool), self.batch_size):
                if j + self.batch_size > len(pool): # assume drop_last=True
                    break
                yield pool[j:j+self.batch_size]
        
def collate_fn(data):
    seq_idx, labels = zip(*data)
    padded_seqs = nn.utils.rnn.pad_sequence(seq_idx, batch_first=True)
    B, T = padded_seqs.shape
    labels = torch.stack(labels)[:, :T, :]
    return padded_seqs, labels

In [5]:
vocab_size = 5 # the 4 bases + padding
emb_dim = 256
num_layers = 12
nhead=8
batch_size = 128
itos = {0: "<PAD>", 1: "A", 2: "C", 3: "G", 4: "U"}

# we have to use fixed Positions because training data is 
# shorter than test data
class PositionEncoding(nn.Module):
    def __init__(self, emb_dim, max_len=512):
        super().__init__()
        positions = torch.arange(max_len).unsqueeze(1)
        evens = torch.arange(0, emb_dim, 2)
        frequencies = torch.exp(evens * (-math.log(10_000)/emb_dim))
        pos_embs = torch.zeros(max_len, emb_dim)
        pos_embs[:, 0::2] = torch.sin(positions * frequencies)
        pos_embs[:, 1::2] = torch.cos(positions * frequencies)
        self.register_buffer('pos_emb', pos_embs)
        
    def forward(self, x):
        return x + self.pos_emb[:x.size(1)]
                
class RNA_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb = PositionEncoding(emb_dim)
        enc_layer = nn.TransformerEncoderLayer(emb_dim, nhead,
                                               dim_feedforward=4*emb_dim,
                                               batch_first=True, norm_first=True,
                                               activation="gelu")
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers)
        self.regression_head = nn.Linear(emb_dim, 2)
        
    def forward(self, x, targets=None):
        B, T = x.shape
        z = self.token_emb(x)
        z = self.pos_emb(z)
        z = self.encoder(z)
        preds = self.regression_head(z)
        
        if targets is None:
            loss = None
        else:
            preds = preds.view(B*T, 2)
            targets = targets.contiguous().view(B*T, 2).clamp(0, 1)
            loss = F.l1_loss(preds, targets, reduction='none')
            loss = loss[~loss.isnan()].mean()
        return preds, loss

In [6]:
train_dataset, val_dataset = RNA_Dataset(train_2A3, train_DMS), RNA_Dataset(val_2A3, val_DMS)
trainsampler = GroupLengthBatchSampler(RandomSampler(train_dataset), batch_size, drop_last=True)
valsampler = GroupLengthBatchSampler(RandomSampler(val_dataset), batch_size, drop_last=True)
trainloader = DataLoader(train_dataset, batch_sampler=trainsampler, collate_fn=collate_fn)
validloader = DataLoader(val_dataset, batch_sampler=valsampler, collate_fn=collate_fn)

In [7]:
model = torch.load(MODEL_PATH/"best_model.pth", map_location=device)
model.to(device);

In [8]:
epochs = 30
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
train_steps = epochs * len(trainloader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_steps)

In [9]:
@torch.no_grad()
def eval_loop():
    model.eval()
    losses = torch.zeros(len(validloader))
    for i, (x, y) in tqdm(enumerate(validloader), total=len(validloader)):
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
    model.train()
    val_loss = losses.mean().item()
    print(f"Val Loss: {val_loss}")
    return val_loss
            
eval_distance = 500
min_loss = 0.138
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {n_params:,} parameters...")
loss_dict = {"train_loss": [], "val_loss": []}
for epoch in range(epochs):
    losses = torch.zeros(len(trainloader))
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    pbar.set_description(f"Epoch {epoch}")
    for i, (x, y) in pbar:
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
        
        if i >= eval_distance and i % eval_distance == 0:
            train_loss = losses[i-eval_distance:i].mean().item()
            pbar.set_postfix({"Loss":  train_loss})
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()
        scheduler.step()
    val_loss = eval_loop()
    loss_dict["train_loss"].append(train_loss)
    loss_dict["val_loss"].append(val_loss)
    if min_loss > val_loss:
        print("Saving new best model...")
        min_loss = val_loss
        torch.save(model, WORKING_PATH/"best_model.pth")

Training model with 9,478,914 parameters...


Epoch 0: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.136]
100%|██████████| 141/141 [00:15<00:00,  9.17it/s]


Val Loss: 0.13866277039051056


Epoch 1: 100%|██████████| 1274/1274 [08:38<00:00,  2.46it/s, Loss=0.135]
100%|██████████| 141/141 [00:15<00:00,  8.99it/s]


Val Loss: 0.13956549763679504


Epoch 2: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.135]
100%|██████████| 141/141 [00:15<00:00,  9.11it/s]


Val Loss: 0.13846315443515778


Epoch 3: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.135]
100%|██████████| 141/141 [00:15<00:00,  9.14it/s]


Val Loss: 0.1387830227613449


Epoch 4: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.135]
100%|██████████| 141/141 [00:15<00:00,  9.18it/s]


Val Loss: 0.13940268754959106


Epoch 5: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.134]
100%|██████████| 141/141 [00:15<00:00,  9.19it/s]


Val Loss: 0.13798388838768005
Saving new best model...


Epoch 6: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.134]
100%|██████████| 141/141 [00:15<00:00,  9.02it/s]


Val Loss: 0.13823768496513367


Epoch 7: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.133]
100%|██████████| 141/141 [00:15<00:00,  9.12it/s]


Val Loss: 0.13788586854934692
Saving new best model...


Epoch 8: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.133]
100%|██████████| 141/141 [00:15<00:00,  9.05it/s]


Val Loss: 0.1382211297750473


Epoch 9: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.132]
100%|██████████| 141/141 [00:15<00:00,  9.11it/s]


Val Loss: 0.1378752589225769
Saving new best model...


Epoch 10: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.132]
100%|██████████| 141/141 [00:15<00:00,  9.19it/s]


Val Loss: 0.1378674954175949
Saving new best model...


Epoch 11: 100%|██████████| 1274/1274 [08:41<00:00,  2.44it/s, Loss=0.132]
100%|██████████| 141/141 [00:15<00:00,  9.16it/s]


Val Loss: 0.13749952614307404
Saving new best model...


Epoch 12: 100%|██████████| 1274/1274 [08:41<00:00,  2.44it/s, Loss=0.132]
100%|██████████| 141/141 [00:15<00:00,  9.14it/s]


Val Loss: 0.13738518953323364
Saving new best model...


Epoch 13: 100%|██████████| 1274/1274 [08:41<00:00,  2.44it/s, Loss=0.131]
100%|██████████| 141/141 [00:15<00:00,  8.82it/s]


Val Loss: 0.13736163079738617
Saving new best model...


Epoch 14: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.131]
100%|██████████| 141/141 [00:15<00:00,  9.03it/s]


Val Loss: 0.1375550925731659


Epoch 15: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.131]
100%|██████████| 141/141 [00:15<00:00,  9.07it/s]


Val Loss: 0.1374318152666092


Epoch 16: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.13]
100%|██████████| 141/141 [00:15<00:00,  9.11it/s]


Val Loss: 0.13695305585861206
Saving new best model...


Epoch 17: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.13]
100%|██████████| 141/141 [00:15<00:00,  8.87it/s]


Val Loss: 0.13749806582927704


Epoch 18: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.13]
100%|██████████| 141/141 [00:15<00:00,  9.00it/s]


Val Loss: 0.1371692270040512


Epoch 19: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.13]
100%|██████████| 141/141 [00:15<00:00,  9.17it/s]


Val Loss: 0.13695111870765686
Saving new best model...


Epoch 20: 100%|██████████| 1274/1274 [08:41<00:00,  2.44it/s, Loss=0.13]
100%|██████████| 141/141 [00:15<00:00,  9.08it/s]


Val Loss: 0.1370006650686264


Epoch 21: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  8.87it/s]


Val Loss: 0.13686436414718628
Saving new best model...


Epoch 22: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.09it/s]


Val Loss: 0.13669878244400024
Saving new best model...


Epoch 23: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.04it/s]


Val Loss: 0.13671106100082397


Epoch 24: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.08it/s]


Val Loss: 0.13667283952236176
Saving new best model...


Epoch 25: 100%|██████████| 1274/1274 [08:41<00:00,  2.44it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.11it/s]


Val Loss: 0.1368371546268463


Epoch 26: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  8.90it/s]


Val Loss: 0.13671113550662994


Epoch 27: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.15it/s]


Val Loss: 0.13677354156970978


Epoch 28: 100%|██████████| 1274/1274 [08:39<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  9.03it/s]


Val Loss: 0.1366683393716812
Saving new best model...


Epoch 29: 100%|██████████| 1274/1274 [08:40<00:00,  2.45it/s, Loss=0.129]
100%|██████████| 141/141 [00:15<00:00,  8.99it/s]

Val Loss: 0.13668011128902435



