In [1]:
import torch
import pathlib
import gc
import math
import random

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torchmetrics.regression import MeanAbsoluteError



In [2]:
INPUT_PATH = pathlib.Path('/kaggle/input/stanford-ribonanza-rna-folding-converted')
MODEL_PATH = pathlib.Path('/kaggle/input/rna-folding-model/')
WORKING_PATH = pathlib.Path('/kaggle/working/')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
full_df = pd.read_parquet(INPUT_PATH/"train_data.parquet")

df_2A3 = full_df[full_df.experiment_type =='2A3_MaP'].reset_index(drop=True)
df_DMS = full_df[full_df.experiment_type =='DMS_MaP'].reset_index(drop=True)
train_2A3, val_2A3, train_DMS, val_DMS = train_test_split(df_2A3, df_DMS, test_size=0.05, random_state=42)

In [4]:
class RNA_Dataset(Dataset):
    def __init__(self, df_2A3, df_DMS): 
        self.seq_map = {'A':1, 'C':2, 'G':3, 'U':4}
        self.seqs = df_2A3.sequence.values
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        
        react_error_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_error_0' in c]].values
        react_error_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_error_0' in c]].values
        
        self.react_2A3 = np.where((react_error_2A3 < 0.5), self.react_2A3, float("nan"))
        self.react_DMS = np.where((react_error_DMS < 0.5), self.react_DMS, float("nan"))
           
    def __len__(self):
        return len(self.seqs)
        
    def __getitem__(self, idx):
        seq = self.seqs[idx]
        seq_idx = torch.tensor([self.seq_map[s] for s in seq], dtype=torch.long)
        labels = torch.tensor(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]], -1), dtype=torch.float32)
        return seq_idx, labels
    
# Useful for sampling batches of similar lengths to minimize padding
class GroupLengthBatchSampler(BatchSampler):
    def __iter__(self):
        dataset = self.sampler.data_source
        indices = [idx for idx in self.sampler]

        step = 100 * self.batch_size
        for i in range(0, len(dataset), step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: len(dataset[x][0]))
            for j in range(0, len(pool), self.batch_size):
                if j + self.batch_size > len(pool): # assume drop_last=True
                    break
                yield pool[j:j+self.batch_size]
        
def collate_fn(data):
    seq_idx, labels = zip(*data)
    padded_seqs = nn.utils.rnn.pad_sequence(seq_idx, batch_first=True)
    B, T = padded_seqs.shape
    labels = torch.stack(labels)[:, :T, :]
    return padded_seqs, labels

In [5]:
vocab_size = 5 # the 4 bases + padding
emb_dim = 256
n_layers = 10
n_heads =8
batch_size = 128
itos = {0: "<PAD>", 1: "A", 2: "C", 3: "G", 4: "U"}

def precompute_freqs_cis(dim, end=500, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):

    # reshape xq and xk to match the complex representation
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # reshape freqs_cos and freqs_sin for broadcasting
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # apply rotation using real numbers
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # flatten last two dimensions
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = dropout
        self.n_heads = n_heads
        self.emb_dim = emb_dim
        self.head_size = emb_dim // n_heads
        self.c_attn = nn.Linear(emb_dim, 3*emb_dim, bias=False)
        self.c_proj = nn.Linear(emb_dim, emb_dim, bias=False)
        self.proj_dropout = nn.Dropout(dropout)
        
    def forward(self, x, freqs_cos, freqs_sin):
        B, T, _ = x.shape
        xq, xk, xv = self.c_attn(x).split(self.emb_dim, dim=2)
        xq = xq.view(B, T, self.n_heads, self.head_size)
        xk = xk.view(B, T, self.n_heads, self.head_size)
        xv = xv.view(B, T, self.n_heads, self.head_size)
        
        # RoPE
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
        
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)
        
        out = F.scaled_dot_product_attention(xq, xk, xv, dropout_p=self.dropout)
        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.proj_dropout(self.c_proj(out))
    
class FeedForward(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.w1 = nn.Linear(emb_dim, 4*emb_dim, bias=False)
        self.w2 = nn.Linear(4*emb_dim, emb_dim, bias=False)
        self.w3 = nn.Linear(emb_dim, 4*emb_dim, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
    
class EncoderBlock(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.attention = Attention()
        self.feed_forward = FeedForward()
        self.attention_norm = nn.LayerNorm(emb_dim)
        self.ffn_norm = nn.LayerNorm(emb_dim)

    def forward(self, x, freqs_cos, freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out
    
class RNA_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, emb_dim)
        self.layers = nn.ModuleList()
        for _ in range(n_layers):
            self.layers.append(EncoderBlock())
        self.regression_head = nn.Linear(emb_dim, 2)
        freqs_cos, freqs_sin = precompute_freqs_cis(emb_dim//n_heads)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)
        
    def forward(self, x, targets=None):
        B, T = x.shape
        z = self.token_emb(x)
        freqs_cos, freqs_sin = self.freqs_cos[:T], self.freqs_sin[:T]
        
        for layer in self.layers:
            z = layer(z, freqs_cos, freqs_sin)
        preds = self.regression_head(z)
        
        if targets is None:
            loss = None
        else:
            preds = preds.view(B*T, 2)
            targets = targets.contiguous().view(B*T, 2).clamp(0, 1)
            loss = F.l1_loss(preds, targets, reduction='none')
            loss = loss[~loss.isnan()].mean()
        return preds, loss

In [6]:
train_dataset, val_dataset = RNA_Dataset(train_2A3, train_DMS), RNA_Dataset(val_2A3, val_DMS)
trainsampler = GroupLengthBatchSampler(RandomSampler(train_dataset), batch_size, drop_last=True)
valsampler = GroupLengthBatchSampler(RandomSampler(val_dataset), batch_size, drop_last=True)
trainloader = DataLoader(train_dataset, batch_sampler=trainsampler, collate_fn=collate_fn)
validloader = DataLoader(val_dataset, batch_sampler=valsampler, collate_fn=collate_fn)

In [7]:
model = RNA_Transformer() #torch.load(MODEL_PATH/"best_model.pth", map_location=device)
model.to(device);

In [8]:
epochs = 30
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
train_steps = epochs * len(trainloader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_steps)

In [9]:
@torch.no_grad()
def eval_loop():
    model.eval()
    losses = torch.zeros(len(validloader))
    for i, (x, labels) in tqdm(enumerate(validloader), total=len(validloader)):
        _, loss = model(x.to(device), labels.to(device))
        losses[i] = loss.item()
    model.train()
    val_loss = losses.mean().item()
    print(f"Val Loss: {val_loss}")
    return val_loss
            
eval_distance = 500
min_loss = 0.2
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {n_params:,} parameters...")
loss_dict = {"train_loss": [], "val_loss": []}
for epoch in range(epochs):
    losses = torch.zeros(len(trainloader))
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    pbar.set_description(f"Epoch {epoch}")
    for i, (x, y) in pbar:
        _, loss= model(x.to(device), y.to(device))
        losses[i] = loss.item()
        
        if i >= eval_distance and i % eval_distance == 0:
            train_loss = losses[i-eval_distance:i].mean().item()
            pbar.set_postfix({"Loss":  train_loss})
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()
        scheduler.step()
    eval_loop()
    if min_loss > val_loss:
        print("Saving new best model...")
        min_loss = val_loss
        torch.save(model, WORKING_PATH/"best_model.pth")

Training model with 10,497,794 parameters...


Epoch 0:   1%|          | 64/6099 [00:28<45:24,  2.22it/s] 


KeyboardInterrupt: 

## TODOS
* Mask Padding in Attention
* Filter Sequences with only bad measurments from og data -> efficiency
* Deal with Duplicate Sequences in og data
* Improve Initialization of Tranformer