# RNA Reactivity Training
https://www.kaggle.com/competitions/stanford-ribonanza-rna-folding/overview

In [1]:
TRAIN_DATA_SLICED = '../data/small_sets/train_data_sliced.pt'
VOCAB_PATH = '../data/vocab.csv'

In [2]:
import torch

train_data_sliced = torch.load(TRAIN_DATA_SLICED)

In [3]:
import pandas as pd
import sys
sys.path.append('..')

from python_scripts.transformers.dataset import RNAdataset_sliced_train


rna_dataset = RNAdataset_sliced_train(
    data=train_data_sliced,
    vocab=pd.read_csv(VOCAB_PATH)
)

In [4]:
len(rna_dataset)

1000

In [5]:
rna_dataset[0][0].shape, rna_dataset[0][1].shape

(torch.Size([100]), torch.Size([2, 100, 4]))

In [6]:
from torchinfo import summary

import sys
sys.path.append('..')

from python_scripts.transformers.model import BERTCustomRNAReactivity, BERTCustom
from torch.utils.data import DataLoader

bertmodel = BERTCustom(
    vocab_size=len(rna_dataset.vocab),
    hidden=512,
    dim_k=64,
    num_layer=12,
    num_attn_head=8
)
RNA_model = BERTCustomRNAReactivity(bertmodel)

summary(RNA_model)

Layer (type:depth-idx)                             Param #
BERTCustomRNAReactivity                            --
├─BERTCustom: 1-1                                  --
│    └─CombEmbedding: 2-1                          --
│    │    └─TokenEmbedding: 3-1                    11,776
│    │    └─PositionEmbedding: 3-2                 --
│    │    └─Dropout: 3-3                           --
│    └─ModuleList: 2-2                             --
│    │    └─EncoderBlock: 3-4                      3,152,384
│    │    └─EncoderBlock: 3-5                      3,152,384
│    │    └─EncoderBlock: 3-6                      3,152,384
│    │    └─EncoderBlock: 3-7                      3,152,384
│    │    └─EncoderBlock: 3-8                      3,152,384
│    │    └─EncoderBlock: 3-9                      3,152,384
│    │    └─EncoderBlock: 3-10                     3,152,384
│    │    └─EncoderBlock: 3-11                     3,152,384
│    │    └─EncoderBlock: 3-12                     3,152,384
│    │    

In [7]:
RNA_model(next(iter(DataLoader(rna_dataset, 3)))[0]).shape

torch.Size([3, 2, 100, 4])

In [8]:
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import sys
sys.path.append('..')

from python_scripts.transformers.dataset import RNADataModule
from python_scripts.transformers.task import RNATask

rna_datamodule = RNADataModule(whole_train_dataset=rna_dataset, batch_size=8)

def rna_rmse_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.sqrt(torch.square(x[not_ignore] - y[not_ignore]).mean())

def rna_mse_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.square(x[not_ignore] - y[not_ignore]).mean()

def rna_mae_loss(x: torch.tensor, y: torch.tensor, ignore_index=-100):
    not_ignore = y != ignore_index
    return torch.abs(x[not_ignore] - y[not_ignore]).mean()

rna_optimizer = torch.optim.Adam(RNA_model.parameters(), 1e-3)
# rna_optimizer = torch.optim.SGD(RNA_model.parameters(), 1e-3, 0.9)
# rna_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     rna_optimizer,
#     T_max=5,
#     eta_min=1e-4,
#     verbose=True,
# )
# rna_scheduler = torch.optim.lr_scheduler.MultiStepLR(
#     rna_optimizer,
#     [4, 7, 10, 13, 16, 19],
#     verbose=True,
#     gamma=0.3
# )
# rna_scheduler = torch.optim.lr_scheduler.CyclicLR(
#     optimizer=rna_optimizer,
#     base_lr=1e-6,
#     max_lr=1e-3,
#     step_size_up=3000,
#     step_size_down=7000,
#     verbose=True
# )
rna_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=rna_optimizer,
    max_lr=1e-4,
    steps_per_epoch=100,
    epochs=5,
    div_factor=1e2,
    pct_start=0.1,
    verbose=False
)

rna_task = RNATask(
    model=RNA_model,
    loss_fn=rna_mae_loss,
    optimizer=rna_optimizer,
    scheduler=rna_scheduler,
)

callbacks = []
callbacks.append(ModelCheckpoint(
    monitor='val_avg_loss',
    save_top_k=3,
    mode='min'
))
# callbacks.append(EarlyStopping(
#     monitor='val_avg_loss',
#     min_delta=0.001,
#     patience=3,
#     verbose=True,
#     mode='min'
# ))

trainer = pl.Trainer(
    max_epochs=5,
    callbacks=callbacks,
)

# rna_task = RNATask.load_from_checkpoint(
#     checkpoint_path='./lightning_log/~~'
#     model=RNA_model,
#     loss_fn=rna_mae_loss,
#     optimizer=rna_optimizer,
#     scheduler=rna_scheduler,
# )

# trainer.fit(rna_task, ckpt_path="some/path/to/my_checkpoint.ckpt")# trainer = pl.Trainer(resume_from_checkpoint='../notebooks/lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt')

trainer.fit(rna_task, datamodule=rna_datamodule)
trainer.test(rna_task, datamodule=rna_datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                    | Params
--------------------------------------------------
0 | model | BERTCustomRNAReactivity | 37.8 M
--------------------------------------------------
37.8 M    Trainable params
0         Non-trainable params
37.8 M    Total params
151.378   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 0/100 [00:00<?, ?it/s] 

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0: 100%|██████████| 100/100 [00:26<00:00,  3.76it/s, v_num=8, train_loss=0.283]
Epoch 0, Avg. Training Loss: 0.3601 Avg. Validation Loss: 0.2654
The learing_rate is set to:  9.68641026104951e-05
Epoch 1: 100%|██████████| 100/100 [00:23<00:00,  4.33it/s, v_num=8, train_loss=0.265, val_loss=0.213]
Epoch 1, Avg. Training Loss: 0.2587 Avg. Validation Loss: 0.2154
The learing_rate is set to:  7.469711863211823e-05
Epoch 2: 100%|██████████| 100/100 [00:22<00:00,  4.53it/s, v_num=8, train_loss=0.242, val_loss=0.215]
Epoch 2, Avg. Training Loss: 0.2392 Avg. Validation Loss: 0.2054
The learing_rate is set to:  4.097410176342927e-05
Epoch 3: 100%|██████████| 100/100 [00:22<00:00,  4.51it/s, v_num=8, train_loss=0.226, val_loss=0.205]
Epoch 3, Avg. Training Loss: 0.2282 Avg. Validation Loss: 0.2033
The learing_rate is set to:  1.1474426386872682e-05
Epoch 4: 100%|██████████| 100/100 [00:21<00:00,  4.55it/s, v_num=8, train_loss=0.224, val_loss=0.203]
Epoch 4, Avg. Training Loss: 0.2248 Avg. V

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 100/100 [00:23<00:00,  4.32it/s, v_num=8, train_loss=0.224, val_loss=0.202]


  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 13/13 [00:00<00:00, 22.67it/s]


[{'test_loss': 0.2014801800251007}]