In [1]:
import torch
import pathlib
import gc
import math

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm



In [2]:
INPUT_PATH = pathlib.Path('/kaggle/input/stanford-ribonanza-rna-folding-converted')
WORKING_PATH = pathlib.Path('/kaggle/working/')
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
full_df = pd.read_parquet(INPUT_PATH/"train_data.parquet")
test_df = pd.read_parquet(INPUT_PATH/"test_sequences.parquet")

df_2A3 = full_df[full_df.experiment_type =='2A3_MaP'].reset_index(drop=True)
df_DMS = full_df[full_df.experiment_type =='DMS_MaP'].reset_index(drop=True)
train_2A3, val_2A3, train_DMS, val_DMS= train_test_split(df_2A3, df_DMS, test_size=0.1, random_state=42)

In [4]:
class RNA_Dataset(Dataset):
    def __init__(self, df_2A3, df_DMS):
        # filter noisy data for now
        df_2A3 = df_2A3[(df_2A3.SN_filter > 0) & (df_DMS.SN_filter > 0)]
        
        self.seq_map = {'A':1, 'C':2, 'G':3, 'U':4}
        self.seqs = df_2A3.sequence.values
        self.react_2A3 = df_2A3[[c for c in df_2A3.columns if \
                                 'reactivity_0' in c]].values
        self.react_DMS = df_DMS[[c for c in df_DMS.columns if \
                                 'reactivity_0' in c]].values
        
    def __len__(self):
        return len(self.seqs)
        
    def __getitem__(self, idx):
        seq = self.seqs[idx]
        seq_idx = torch.tensor([self.seq_map[s] for s in seq], dtype=torch.long)
        labels = torch.tensor(np.stack([self.react_2A3[idx],
                                           self.react_DMS[idx]], -1), dtype=torch.float32)
        return seq_idx, labels
        
def collate_fn(data):
    seq_idx, labels = zip(*data)
    padded_seqs = nn.utils.rnn.pad_sequence(seq_idx, batch_first=True)
    B, T = padded_seqs.shape
    labels = torch.stack(labels)[:, :T]
    return padded_seqs, labels

In [23]:
vocab_size = 5 # the 4 bases + padding
emb_dim = 256
num_layers = 14
nhead=8
batch_size = 64

# we have to use fixed Positions because training data is 
# shorter than test data
class PositionEncoding(nn.Module):
    def __init__(self, emb_dim, max_len=512):
        super().__init__()
        positions = torch.arange(max_len).unsqueeze(1)
        evens = torch.arange(0, emb_dim, 2)
        frequencies = torch.exp(evens * (-math.log(10_000)/emb_dim))
        pos_embs = torch.zeros(max_len, emb_dim)
        pos_embs[:, 0::2] = torch.sin(positions * frequencies)
        pos_embs[:, 1::2] = torch.cos(positions * frequencies)
        self.register_buffer('pos_emb', pos_embs)
        
    def forward(self, x):
        return x + self.pos_emb[:x.size(1)]
                
class RNA_Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, emb_dim)
        self.pos_emb = PositionEncoding(emb_dim)
        enc_layer = nn.TransformerEncoderLayer(emb_dim, nhead,
                                               dim_feedforward=4*emb_dim,
                                               batch_first=True, norm_first=True,
                                               activation="gelu")
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers)
        self.regression_head = nn.Linear(emb_dim, 2)
        
    def forward(self, x, targets=None):
        B, T = x.shape
        x = self.token_emb(x)
        x = self.pos_emb(x)
        x = self.encoder(x)
        preds = self.regression_head(x)
        
        if targets is None:
            loss = None
        else:
            preds = preds.view(B*T, 2)
            targets = targets.contiguous().view(B*T, 2).clamp(0, 1)
            loss = F.l1_loss(preds, targets, reduction='none')
            loss = loss[~loss.isnan()].mean()
            
        return preds, loss

In [24]:
train_dataset, val_dataset = RNA_Dataset(train_2A3, train_DMS), RNA_Dataset(val_2A3, val_DMS)
trainloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
validloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)

In [25]:
model = RNA_Transformer() #torch.load(WORKING_PATH/"best_model.pth")
model.to(device);

In [29]:
epochs = 30
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
train_steps = epochs * len(trainloader)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-4, total_steps=train_steps, pct_start=0.02)

In [None]:
@torch.no_grad()
def eval_loop():
    model.eval()
    losses = torch.zeros(len(validloader))
    for i, (x, y) in tqdm(enumerate(validloader), total=len(validloader)):
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
    model.train()
    val_loss = losses.mean().item()
    print(f"Val Loss: {val_loss}")
    return val_loss
            
eval_distance = 500
min_loss = 0.19
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training model with {n_params:,} parameters...")
loss_dict = {"train_loss": [], "val_loss": []}
for epoch in range(epochs):
    losses = torch.zeros(len(trainloader))
    pbar = tqdm(enumerate(trainloader), total=len(trainloader))
    pbar.set_description(f"Epoch {epoch}")
    for i, (x, y) in pbar:
        _, loss = model(x.to(device), y.to(device))
        losses[i] = loss.item()
        
        if i >= eval_distance and i % eval_distance == 0:
            train_loss = losses[i-eval_distance:i].mean().item()
            pbar.set_postfix({"Loss":  train_loss})
        
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 3.0)
        optimizer.step()
        scheduler.step()
    val_loss = eval_loop()
    loss_dict["train_loss"].append(train_loss)
    loss_dict["val_loss"].append(val_loss)
    if min_loss > val_loss:
        min_loss = val_loss
        torch.save(model, "best_model.pth")

Training model with 11,058,434 parameters...


  nn.utils.clip_grad_norm(model.parameters(), 3.0)
Epoch 0: 100%|██████████| 2550/2550 [11:50<00:00,  3.59it/s, Loss=0.242]
100%|██████████| 284/284 [00:20<00:00, 14.15it/s]


Val Loss: 0.2453475147485733


Epoch 1: 100%|██████████| 2550/2550 [11:49<00:00,  3.59it/s, Loss=0.227]
100%|██████████| 284/284 [00:19<00:00, 14.36it/s]


Val Loss: 0.22963878512382507


Epoch 2: 100%|██████████| 2550/2550 [11:49<00:00,  3.60it/s, Loss=0.22] 
100%|██████████| 284/284 [00:19<00:00, 14.29it/s]


Val Loss: 0.22220367193222046


Epoch 3: 100%|██████████| 2550/2550 [11:49<00:00,  3.59it/s, Loss=0.216]
100%|██████████| 284/284 [00:20<00:00, 14.20it/s]


Val Loss: 0.22220933437347412


Epoch 4: 100%|██████████| 2550/2550 [11:49<00:00,  3.60it/s, Loss=0.212]
100%|██████████| 284/284 [00:19<00:00, 14.37it/s]


Val Loss: 0.21231424808502197


Epoch 5: 100%|██████████| 2550/2550 [11:46<00:00,  3.61it/s, Loss=0.208]
100%|██████████| 284/284 [00:19<00:00, 14.25it/s]


Val Loss: 0.2078811079263687


Epoch 6: 100%|██████████| 2550/2550 [11:46<00:00,  3.61it/s, Loss=0.205]
100%|██████████| 284/284 [00:19<00:00, 14.28it/s]


Val Loss: 0.20532453060150146


Epoch 7: 100%|██████████| 2550/2550 [11:47<00:00,  3.61it/s, Loss=0.201]
100%|██████████| 284/284 [00:19<00:00, 14.39it/s]


Val Loss: 0.2011735737323761


Epoch 8:  84%|████████▍ | 2141/2550 [09:55<01:55,  3.53it/s, Loss=0.199]

In [None]:
pd.DataFrame.from_dict(loss_dict)