# RNA Masked Training
https://www.kaggle.com/competitions/stanford-ribonanza-rna-folding/overview

In [1]:
TRAIN_DATA_EXT_MASKED_PATH = '../data/small_sets/train_extracted.csv'
VOCAB_PATH = '../data/vocab.csv'

In [2]:
import pandas as pd

train_extracted_pd = pd.read_csv(TRAIN_DATA_EXT_MASKED_PATH)
train_extracted_pd.head()

Unnamed: 0,sequence,sequence_ext
0,GGGAACGACUCGAGUAGAGUCGAAAAAGAUCGCCACGCACUUACGA...,.....((((((.....)))))).....((((((((((((((....)...
1,GGGAACGACUCGAGUAGAGUCGAAAAGGUGGCCGGCAGAAUCGCGA...,.....((((((.....))))))........(((((..(.....).....
2,GGGAACGACUCGAGUAGAGUCGAAAACAUUGUUAAUGCCUAUAUUA...,.....((((((.....))))))........(((((.((((.........
3,GGGAACGACUCGAGUAGAGUCGAAAAGGAGAUCGAAGACGACUUAC...,.....((((((.....))))))....((((((((.....(.........
4,GGGAACGACUCGAGUAGAGUCGAAAAGAUAUGGACUGACGAAGUCG...,.....((((((.....))))))....(((..(((((((((..((((...


In [3]:
len(train_extracted_pd)

20000

In [4]:
import sys
sys.path.append('..')

from python_scripts.transformers.dataset import MaskedDataset

masked_dataset = MaskedDataset(
    data_ext=train_extracted_pd[:1000],
    vocab=pd.read_csv(VOCAB_PATH),
    max_len=512
)

In [5]:
len(masked_dataset)

1000

In [6]:
from torchinfo import summary

import sys
sys.path.append('..')

from python_scripts.transformers.model import BERTCustomMasked, BERTCustom
from torch.utils.data import DataLoader

bertmodel = BERTCustom(
    vocab_size=len(masked_dataset.vocab),
    hidden=512,
    dim_k=64,
    num_layer=12,
    num_attn_head=8,
)
masked_model = BERTCustomMasked(bertmodel)

summary(masked_model)

Layer (type:depth-idx)                             Param #
BERTCustomMasked                                   --
├─BERTCustom: 1-1                                  --
│    └─CombEmbedding: 2-1                          --
│    │    └─TokenEmbedding: 3-1                    11,776
│    │    └─PositionEmbedding: 3-2                 --
│    │    └─Dropout: 3-3                           --
│    └─ModuleList: 2-2                             --
│    │    └─EncoderBlock: 3-4                      3,152,384
│    │    └─EncoderBlock: 3-5                      3,152,384
│    │    └─EncoderBlock: 3-6                      3,152,384
│    │    └─EncoderBlock: 3-7                      3,152,384
│    │    └─EncoderBlock: 3-8                      3,152,384
│    │    └─EncoderBlock: 3-9                      3,152,384
│    │    └─EncoderBlock: 3-10                     3,152,384
│    │    └─EncoderBlock: 3-11                     3,152,384
│    │    └─EncoderBlock: 3-12                     3,152,384
│    │    

In [7]:
masked_model(next(iter(DataLoader(masked_dataset, 3)))).shape

torch.Size([3, 512, 23])

In [9]:
import torch
from torchmetrics import Accuracy
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import sys
sys.path.append('..')

from python_scripts.transformers.dataset import MaskedDataModule
from python_scripts.transformers.task import MaskingTask

masked_datamodule = MaskedDataModule(whole_dataset=masked_dataset, batch_size=8)

masked_optimizer = torch.optim.Adam(masked_model.parameters(), 1e-3)
# masked_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer=masked_optimizer,
#     T_max=5,
#     eta_min=1e-4,
#     verbose=True,
# )
# rna_scheduler = torch.optim.lr_scheduler.MultiStepLR(
#     optimizer=masked_optimizer,
#     [4, 7, 10, 13, 16, 19],
#     verbose=True,
#     gamma=0.3
# )
# rna_scheduler = torch.optim.lr_scheduler.CyclicLR(
#     optimizer=masked_optimizer,
#     base_lr=1e-6,
#     max_lr=1e-3,
#     step_size_up=3000,
#     step_size_down=7000,
#     verbose=True
# )
masked_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=masked_optimizer,
    max_lr=1e-4,
    steps_per_epoch=100,
    epochs=10,
    div_factor=1e2,
    pct_start=0.1,
    verbose=False
)

maskingtask = MaskingTask(
    model=masked_model,
    loss_fn=torch.nn.CrossEntropyLoss(),
    optimizer=masked_optimizer,
    scheduler=masked_scheduler,
    acc_fn=Accuracy(task='multiclass', num_classes=len(masked_dataset.vocab), ignore_index=-100)
)

callbacks = []
callbacks.append(ModelCheckpoint(
    monitor='val_avg_accuracy',
    save_top_k=3,
    mode='max'
))
# callbacks.append(EarlyStopping(
#     monitor='val_avg_accuracy',
#     min_delta=0.1,
#     patience=3,
#     verbose=True,
#     mode='max'
# ))

trainer = pl.Trainer(
    max_epochs=10,
    callbacks=callbacks
)

# maskingtask = MaskingTask.load_from_checkpoint(
#     './lightning_logs/version_0/checkpoints/epoch=0-step=33562.ckpt',
#     model=masked_model,
#     loss_fn=torch.nn.CrossEntropyLoss(),
#     optimizer=masked_optimizer,
#     scheduler=masked_scheduler,
#     acc_fn=Accuracy(task='multiclass', num_classes=len(masked_dataset.vocab), ignore_index=-100)
# )

# trainer.fit(maskingtask, ckpt_path="some/path/to/my_checkpoint.ckpt")# trainer = pl.Trainer(resume_from_checkpoint='../notebooks/lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt')

trainer.fit(maskingtask, datamodule=masked_datamodule)
trainer.test(maskingtask, datamodule=masked_datamodule)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
import torch

torch.save(bertmodel.state_dict(), './lightning_logs/masked_5epoch/checkpoints/pretrained_bert.pt')

In [10]:
import torch

bertmodel.load_state_dict(torch.load('./lightning_logs/masked_5epoch/checkpoints/pretrained_bert.pt'))

<All keys matched successfully>