# RNA Reactivity Training
https://www.kaggle.com/competitions/stanford-ribonanza-rna-folding/overview

In [1]:
TRAIN_DATA_PATH = '../data/small_sets/train_data_QUICK_START.csv'
TRAIN_DATA_EXT_PATH = '../data/small_sets/train_extracted.csv'

In [2]:
import pandas as pd

train_data_pd = pd.read_csv(TRAIN_DATA_PATH)
train_data_pd.head()

Unnamed: 0,sequence_id,sequence,experiment_type,dataset_name,reactivity_0001,reactivity_0002,reactivity_0003,reactivity_0004,reactivity_0005,reactivity_0006,...,reactivity_error_0197,reactivity_error_0198,reactivity_error_0199,reactivity_error_0200,reactivity_error_0201,reactivity_error_0202,reactivity_error_0203,reactivity_error_0204,reactivity_error_0205,reactivity_error_0206
0,0000d87cab97,GGGAACGACUCGAGUAGAGUCGAAAAAGAUCGCCACGCACUUACGA...,2A3_MaP,DasLabBigLib_OneMil_RFAM_windows_100mers_2A3,,,,,,,...,,,,,,,,,,
1,0000d87cab97,GGGAACGACUCGAGUAGAGUCGAAAAAGAUCGCCACGCACUUACGA...,DMS_MaP,DasLabBigLib_OneMil_RFAM_windows_100mers_DMS,,,,,,,...,,,,,,,,,,
2,0001ca9d21b0,GGGAACGACUCGAGUAGAGUCGAAAAGGUGGCCGGCAGAAUCGCGA...,2A3_MaP,DasLabBigLib_OneMil_OpenKnot_Round_2_train_2A3,,,,,,,...,,,,,,,,,,
3,0001ca9d21b0,GGGAACGACUCGAGUAGAGUCGAAAAGGUGGCCGGCAGAAUCGCGA...,DMS_MaP,DasLabBigLib_OneMil_OpenKnot_Round_2_train_DMS,,,,,,,...,,,,,,,,,,
4,00021f968267,GGGAACGACUCGAGUAGAGUCGAAAACAUUGUUAAUGCCUAUAUUA...,2A3_MaP,DasLabBigLib_OneMil_Replicates_from_previous_l...,,,,,,,...,,,,,,,,,,


In [3]:
train_extracted_pd = pd.read_csv(TRAIN_DATA_EXT_PATH)
train_extracted_pd.head()

Unnamed: 0,sequence,sequence_ext
0,GGGAACGACUCGAGUAGAGUCGAAAAAGAUCGCCACGCACUUACGA...,.....((((((.....)))))).....((((((((((((((....)...
1,GGGAACGACUCGAGUAGAGUCGAAAAGGUGGCCGGCAGAAUCGCGA...,.....((((((.....))))))........(((((..(.....).....
2,GGGAACGACUCGAGUAGAGUCGAAAACAUUGUUAAUGCCUAUAUUA...,.....((((((.....))))))........(((((.((((.........
3,GGGAACGACUCGAGUAGAGUCGAAAAGGAGAUCGAAGACGACUUAC...,.....((((((.....))))))....((((((((.....(.........
4,GGGAACGACUCGAGUAGAGUCGAAAAGAUAUGGACUGACGAAGUCG...,.....((((((.....))))))....(((..(((((((((..((((...


In [4]:
len(train_data_pd), len(train_extracted_pd)

(40000, 20000)

In [5]:
import sys
sys.path.append('..')

from python_scripts.transformers.dataset import RNADataset_train


rna_dataset = RNADataset_train(
    data=train_data_pd,
    data_ext = train_extracted_pd,
    vocab=pd.read_csv('../data/vocab.csv'),
    max_len=512
)

In [6]:
len(rna_dataset)

20000

In [7]:
from torchinfo import summary

import sys
sys.path.append('..')

from python_scripts.transformers.model import BERTCustomRNAReactivity, BERTCustom
from torch.utils.data import DataLoader

bertmodel = BERTCustom(
    vocab_size=len(rna_dataset.vocab),
    hidden=128,
    dim_k=16,
    num_layer=12,
    num_attn_head=12
)
RNA_model = BERTCustomRNAReactivity(bertmodel)

summary(RNA_model)

Layer (type:depth-idx)                             Param #
BERTCustomRNAReactivity                            --
├─BERTCustom: 1-1                                  --
│    └─CombEmbedding: 2-1                          --
│    │    └─TokenEmbedding: 3-1                    2,944
│    │    └─PositionEmbedding: 3-2                 --
│    │    └─Dropout: 3-3                           --
│    └─ModuleList: 2-2                             --
│    │    └─EncoderBlock: 3-4                      231,232
│    │    └─EncoderBlock: 3-5                      231,232
│    │    └─EncoderBlock: 3-6                      231,232
│    │    └─EncoderBlock: 3-7                      231,232
│    │    └─EncoderBlock: 3-8                      231,232
│    │    └─EncoderBlock: 3-9                      231,232
│    │    └─EncoderBlock: 3-10                     231,232
│    │    └─EncoderBlock: 3-11                     231,232
│    │    └─EncoderBlock: 3-12                     231,232
│    │    └─EncoderBlock: 3-1

In [8]:
RNA_model(next(iter(DataLoader(rna_dataset, 3)))[0]).shape

torch.Size([3, 2, 512])

In [9]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import sys
sys.path.append('..')

from python_scripts.transformers.dataset import RNADataModule
from python_scripts.transformers.task import RNATask

rna_datamodule = RNADataModule(whole_train_dataset=rna_dataset, batch_size=8)

def rna_rmse_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.sqrt(torch.square(x[not_ignore] - y[not_ignore]).mean())

def rna_mse_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.square(x[not_ignore] - y[not_ignore]).mean()

def rna_mae_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.abs(x[not_ignore] - y[not_ignore]).mean()

# rna_optimizer = torch.optim.Adam(RNA_model.parameters(), 1e-3)
rna_optimizer = torch.optim.SGD(RNA_model.parameters(), 1e-3, 0.9)
# rna_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     rna_optimizer,
#     T_max=5,
#     eta_min=1e-4,
#     verbose=True,
# )
# rna_scheduler = torch.optim.lr_scheduler.MultiStepLR(
#     rna_optimizer,
#     [4, 7, 10, 13, 16, 19],
#     verbose=True,
#     gamma=0.3
# )
# rna_scheduler = torch.optim.lr_scheduler.CyclicLR(
#     optimizer=rna_optimizer,
#     base_lr=1e-6,
#     max_lr=1e-3,
#     step_size_up=3000,
#     step_size_down=7000,
#     verbose=True
# )
rna_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=rna_optimizer,
    max_lr=1e-3,
    steps_per_epoch=2000,
    epochs=1,
    div_factor=1e2,
    verbose=False
)

rna_task = RNATask(
    model=RNA_model,
    loss_fn=rna_mae_loss,
    optimizer=rna_optimizer,
    scheduler=rna_scheduler,
)

callbacks = []
callbacks.append(ModelCheckpoint(
    monitor='val_avg_loss',
    save_top_k=3,
    mode='min'
))
# callbacks.append(EarlyStopping(
#     monitor='val_avg_loss',
#     min_delta=0.001,
#     patience=3,
#     verbose=True,
#     mode='min'
# ))

trainer = pl.Trainer(
    max_epochs=1,
    callbacks=callbacks,
)

# rna_task = RNATask.load_from_checkpoint(
#     checkpoint_path='./lightning_log/~~'
#     model=RNA_model,
#     loss_fn=rna_mae_loss,
#     optimizer=rna_optimizer,
#     scheduler=rna_scheduler,
# )

# trainer.fit(rna_task, ckpt_path="some/path/to/my_checkpoint.ckpt")# trainer = pl.Trainer(resume_from_checkpoint='../notebooks/lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt')

trainer.fit(rna_task, datamodule=rna_datamodule)
trainer.test(rna_task, datamodule=rna_datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | BERTCustomRNAReactivity | 2.8 M 
--------------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.112    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 1/2000 [00:01<37:43,  1.13s/it, v_num=1, train_loss=0.609]

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0: 100%|██████████| 2000/2000 [09:49<00:00,  3.39it/s, v_num=1, train_loss=0.222]
Epoch 0, Avg. Training Loss: 0.2720 Avg. Validation Loss: 0.2145
2.258876325242882e-09
Epoch 0: 100%|██████████| 2000/2000 [10:13<00:00,  3.26it/s, v_num=1, train_loss=0.222, val_loss=0.211]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 2000/2000 [10:13<00:00,  3.26it/s, v_num=1, train_loss=0.222, val_loss=0.211]


  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 250/250 [00:22<00:00, 10.99it/s]


[{'test_loss': 0.21089835464954376}]