In [1]:
import torch
import pathlib
import gc
import math
import random

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm



In [2]:
INPUT_PATH = pathlib.Path('/kaggle/input/stanford-ribonanza-rna-folding-converted')
MODEL_PATH = pathlib.Path('/kaggle/input/rna-folding-model/')
WORKING_PATH = pathlib.Path('/kaggle/working/')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
full_df = pd.read_parquet(INPUT_PATH/"train_data.parquet")

df_2A3 = full_df[full_df.experiment_type =='2A3_MaP'].reset_index(drop=True)
df_DMS = full_df[full_df.experiment_type =='DMS_MaP'].reset_index(drop=True)
train_2A3, val_2A3, train_DMS, val_DMS= train_test_split(df_2A3, df_DMS, test_size=0.1, random_state=42)

In [4]:
class RNA_Dataset(Dataset):
    def __init__(self, df_2A3, df_DMS):
        # filter noisy data for now
        predicate = (df_2A3.SN_filter.values > 0) & (df_DMS.SN_filter.values > 0)
        df_2A3 = df_2A3[predicate].reset_index(drop=True)
        df_DMS = df_DMS[predicate].reset_index(drop=True)
        
        self.seq_map = {'A':1, 'C':2, 'G':3, 'U':4}
        self.seqs = df_2A3.sequence.values
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        
    def __len__(self):
        return len(self.seqs)
        
    def __getitem__(self, idx):
        seq = self.seqs[idx]
        seq_idx = torch.tensor([self.seq_map[s] for s in seq], dtype=torch.long)
        labels = torch.tensor(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]], -1), dtype=torch.float32)
        return seq_idx, labels
    
# Useful for sampling batches of similar lengths to minimize padding
class GroupLengthBatchSampler(BatchSampler):
    def __iter__(self):
        dataset = self.sampler.data_source
        indices = [idx for idx in self.sampler]

        step = 100 * self.batch_size
        for i in range(0, len(dataset), step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: len(dataset[x][0]))
            for j in range(0, len(pool), self.batch_size):
                if j + self.batch_size > len(pool): # assume drop_last=True
                    break
                yield pool[j:j+self.batch_size]
        
def collate_fn(data):
    seq_idx, labels = zip(*data)
    padded_seqs = nn.utils.rnn.pad_sequence(seq_idx, batch_first=True)
    B, T = padded_seqs.shape
    labels = torch.stack(labels)[:, :T, :]
    return padded_seqs, labels

In [5]:
vocab_size = 5 # the 4 bases + padding
emb_dim = 256
n_layers = 12
n_heads =8
batch_size = 128
itos = {0: "<PAD>", 1: "A", 2: "C", 3: "G", 4: "U"}

def precompute_freqs_cis(dim, end=500, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):

    # reshape xq and xk to match the complex representation
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # reshape freqs_cos and freqs_sin for broadcasting
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # apply rotation using real numbers
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # flatten last two dimensions
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = dropout
        self.n_heads = n_heads
        self.emb_dim = emb_dim
        self.head_size = emb_dim // n_heads
        self.c_attn = nn.Linear(emb_dim, 3*emb_dim, bias=False)
        self.c_proj = nn.Linear(emb_dim, emb_dim, bias=False)
        self.proj_dropout = nn.Dropout(dropout)
        
    def forward(self, x, freqs_cos, freqs_sin):
        B, T, _ = x.shape
        xq, xk, xv = self.c_attn(x).split(self.emb_dim, dim=2)
        xq = xq.view(B, T, self.n_heads, self.head_size)
        xk = xk.view(B, T, self.n_heads, self.head_size)
        xv = xv.view(B, T, self.n_heads, self.head_size)
        
        # RoPE
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
        
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)
        
        out = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout)
        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.proj_dropout(self.c_proj(out))
    
class FeedForward(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(emb_dim, 4*emb_dim, bias=False)
        self.w2 = nn.Linear(4*emb_dim, emb_dim, bias=False)
        self.w3 = nn.Linear(emb_dim, 4*emb_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
    
class EncoderBlock(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.attention = Attention()
        self.feed_forward = FeedForward()
        self.attention_norm = nn.LayerNorm(emb_dim)
        self.ffn_norm = nn.LayerNorm(emb_dim)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out
    
class RNA_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, emb_dim)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(EncoderBlock())
        self.regression_head = nn.Linear(emb_dim, 2)
        freqs_cos, freqs_sin = precompute_freqs_cis(emb_dim//n_heads)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)
        
    def forward(self, x, targets=None):
        B, T = x.shape
        z = self.token_emb(x)
        freqs_cos, freqs_sin = self.freqs_cos[:T], self.freqs_sin[:T]
        
        for layer in self.layers:
            z = layer(z, freqs_cos, freqs_sin)
        preds = self.regression_head(z)
        
        if targets is None:
            loss = None
        else:
            preds = preds.view(B*T, 2)
            targets = targets.contiguous().view(B*T, 2).clamp(0, 1)
            loss = F.l1_loss(preds, targets, reduction='none')
            loss = loss[~loss.isnan()].mean()
        return preds, loss

In [6]:
train_dataset, val_dataset = RNA_Dataset(train_2A3, train_DMS), RNA_Dataset(val_2A3, val_DMS)
trainsampler = GroupLengthBatchSampler(RandomSampler(train_dataset), batch_size, drop_last=True)
valsampler = GroupLengthBatchSampler(RandomSampler(val_dataset), batch_size, drop_last=True)
trainloader = DataLoader(train_dataset, batch_sampler=trainsampler, collate_fn=collate_fn)
validloader = DataLoader(val_dataset, batch_sampler=valsampler, collate_fn=collate_fn)

In [7]:
model = RNA_Transformer() #torch.load(MODEL_PATH/"best_model.pth", map_location=device)
model.to(device);

In [8]:
epochs = 50
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
train_steps = epochs * len(trainloader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_steps)

In [9]:
@torch.no_grad()
def eval_loop():
    model.eval()
    losses = torch.zeros(len(validloader))
    for i, (x, y) in tqdm(enumerate(validloader), total=len(validloader)):
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
    model.train()
    val_loss = losses.mean().item()
    print(f"Val Loss: {val_loss}")
    return val_loss
            
eval_distance = 500
min_loss = 0.138
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {n_params:,} parameters...")
loss_dict = {"train_loss": [], "val_loss": []}
for epoch in range(epochs):
    losses = torch.zeros(len(trainloader))
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    pbar.set_description(f"Epoch {epoch}")
    for i, (x, y) in pbar:
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
        
        if i >= eval_distance and i % eval_distance == 0:
            train_loss = losses[i-eval_distance:i].mean().item()
            pbar.set_postfix({"Loss":  train_loss})
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()
        scheduler.step()
    val_loss = eval_loop()
    loss_dict["train_loss"].append(train_loss)
    loss_dict["val_loss"].append(val_loss)
    if min_loss > val_loss:
        print("Saving new best model...")
        min_loss = val_loss
        torch.save(model, WORKING_PATH/"best_model.pth")

Training model with 12,596,994 parameters...


Epoch 0: 100%|██████████| 1274/1274 [10:00<00:00,  2.12it/s, Loss=0.205]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.19259080290794373


Epoch 1: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.186]
100%|██████████| 141/141 [00:23<00:00,  6.11it/s]


Val Loss: 0.1819210648536682


Epoch 2: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.178]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.1742209792137146


Epoch 3: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.172]
100%|██████████| 141/141 [00:23<00:00,  6.11it/s]


Val Loss: 0.17077411711215973


Epoch 4: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.168]
100%|██████████| 141/141 [00:23<00:00,  6.06it/s]


Val Loss: 0.16777914762496948


Epoch 5: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.162]
100%|██████████| 141/141 [00:23<00:00,  6.11it/s]


Val Loss: 0.16105113923549652


Epoch 6: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.158]
100%|██████████| 141/141 [00:23<00:00,  6.11it/s]


Val Loss: 0.15783005952835083


Epoch 7: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.155]
100%|██████████| 141/141 [00:23<00:00,  6.10it/s]


Val Loss: 0.15639294683933258


Epoch 8: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.153]
100%|██████████| 141/141 [00:23<00:00,  5.95it/s]


Val Loss: 0.1525498777627945


Epoch 9: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.15]
100%|██████████| 141/141 [00:23<00:00,  6.09it/s]


Val Loss: 0.15135781466960907


Epoch 10: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.148]
100%|██████████| 141/141 [00:22<00:00,  6.13it/s]


Val Loss: 0.1505119651556015


Epoch 11: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.147]
100%|██████████| 141/141 [00:23<00:00,  6.09it/s]


Val Loss: 0.14868344366550446


Epoch 12: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.145]
100%|██████████| 141/141 [00:23<00:00,  6.02it/s]


Val Loss: 0.1477375626564026


Epoch 13: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.143]
100%|██████████| 141/141 [00:23<00:00,  6.02it/s]


Val Loss: 0.14668205380439758


Epoch 14: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.142]
100%|██████████| 141/141 [00:23<00:00,  6.09it/s]


Val Loss: 0.1458306759595871


Epoch 15: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.141]
100%|██████████| 141/141 [00:23<00:00,  6.10it/s]


Val Loss: 0.14558008313179016


Epoch 16: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.139]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.14430512487888336


Epoch 17: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.138]
100%|██████████| 141/141 [00:23<00:00,  6.13it/s]


Val Loss: 0.14443637430667877


Epoch 18: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.138]
100%|██████████| 141/141 [00:23<00:00,  6.10it/s]


Val Loss: 0.14434711635112762


Epoch 19: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.136]
100%|██████████| 141/141 [00:23<00:00,  6.04it/s]


Val Loss: 0.14332620799541473


Epoch 20: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.135]
100%|██████████| 141/141 [00:22<00:00,  6.15it/s]


Val Loss: 0.1433216631412506


Epoch 21: 100%|██████████| 1274/1274 [09:55<00:00,  2.14it/s, Loss=0.135]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.1424853652715683


Epoch 22: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.133]
100%|██████████| 141/141 [00:23<00:00,  5.98it/s]


Val Loss: 0.14218170940876007


Epoch 23: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.133]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.14201928675174713


Epoch 24: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.132]
100%|██████████| 141/141 [00:23<00:00,  6.04it/s]


Val Loss: 0.1414187103509903


Epoch 25: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.131]
100%|██████████| 141/141 [00:23<00:00,  6.11it/s]


Val Loss: 0.14124636352062225


Epoch 26: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.131]
100%|██████████| 141/141 [00:22<00:00,  6.16it/s]


Val Loss: 0.1413615494966507


Epoch 27: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.13]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.14097827672958374


Epoch 28: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.129]
100%|██████████| 141/141 [00:23<00:00,  6.13it/s]


Val Loss: 0.14072704315185547


Epoch 29: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.128]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.14065854251384735


Epoch 30: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.128]
100%|██████████| 141/141 [00:22<00:00,  6.14it/s]


Val Loss: 0.13967639207839966


Epoch 31: 100%|██████████| 1274/1274 [09:56<00:00,  2.14it/s, Loss=0.127]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.1400841325521469


Epoch 32: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.126]
100%|██████████| 141/141 [00:23<00:00,  5.91it/s]


Val Loss: 0.14003320038318634


Epoch 33: 100%|██████████| 1274/1274 [09:57<00:00,  2.13it/s, Loss=0.126]
100%|██████████| 141/141 [00:22<00:00,  6.16it/s]


Val Loss: 0.13949140906333923


Epoch 34: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.125]
100%|██████████| 141/141 [00:23<00:00,  6.13it/s]


Val Loss: 0.13947300612926483


Epoch 35: 100%|██████████| 1274/1274 [09:56<00:00,  2.13it/s, Loss=0.126]
100%|██████████| 141/141 [00:23<00:00,  6.04it/s]


Val Loss: 0.13946256041526794


Epoch 36: 100%|██████████| 1274/1274 [10:00<00:00,  2.12it/s, Loss=0.125]
100%|██████████| 141/141 [00:23<00:00,  6.05it/s]


Val Loss: 0.13932353258132935


Epoch 37: 100%|██████████| 1274/1274 [10:00<00:00,  2.12it/s, Loss=0.124]
100%|██████████| 141/141 [00:23<00:00,  6.06it/s]


Val Loss: 0.13929443061351776


Epoch 38: 100%|██████████| 1274/1274 [09:59<00:00,  2.12it/s, Loss=0.124]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.13916917145252228


Epoch 39: 100%|██████████| 1274/1274 [09:59<00:00,  2.13it/s, Loss=0.123]
100%|██████████| 141/141 [00:23<00:00,  6.06it/s]


Val Loss: 0.13888515532016754


Epoch 40: 100%|██████████| 1274/1274 [09:59<00:00,  2.13it/s, Loss=0.123]
100%|██████████| 141/141 [00:23<00:00,  5.99it/s]


Val Loss: 0.138809934258461


Epoch 41: 100%|██████████| 1274/1274 [10:02<00:00,  2.12it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.06it/s]


Val Loss: 0.13877438008785248


Epoch 42: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.09it/s]


Val Loss: 0.1387282758951187


Epoch 43: 100%|██████████| 1274/1274 [09:59<00:00,  2.13it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.04it/s]


Val Loss: 0.13873112201690674


Epoch 44: 100%|██████████| 1274/1274 [09:59<00:00,  2.12it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.13879764080047607


Epoch 45: 100%|██████████| 1274/1274 [09:58<00:00,  2.13it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.10it/s]


Val Loss: 0.1386171281337738


Epoch 46: 100%|██████████| 1274/1274 [09:59<00:00,  2.12it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.00it/s]


Val Loss: 0.13850858807563782


Epoch 47: 100%|██████████| 1274/1274 [10:00<00:00,  2.12it/s, Loss=0.122]
100%|██████████| 141/141 [00:23<00:00,  6.12it/s]


Val Loss: 0.13863098621368408


Epoch 48: 100%|██████████| 1274/1274 [09:59<00:00,  2.12it/s, Loss=0.121]
100%|██████████| 141/141 [00:23<00:00,  6.01it/s]


Val Loss: 0.13861501216888428


Epoch 49: 100%|██████████| 1274/1274 [10:00<00:00,  2.12it/s, Loss=0.121]
100%|██████████| 141/141 [00:23<00:00,  5.97it/s]

Val Loss: 0.1385551393032074



