# RNA Masked Training
https://www.kaggle.com/competitions/stanford-ribonanza-rna-folding/overview

In [None]:
TRAIN_DATA_EXT_PATH = '../data/small_sets/train_extracted.csv'

In [None]:
import pandas as pd

train_extracted_pd = pd.read_csv(TRAIN_DATA_EXT_PATH)
train_extracted_pd.head()

In [None]:
len(train_extracted_pd)

In [None]:
import sys
sys.path.append('..')

from python_scripts.transformers.dataset import MaskedDataset

masked_dataset = MaskedDataset(
    data=train_extracted_pd[:1000],
    vocab=pd.read_csv('../data/vocab.csv'),
    max_len=512
)

In [None]:
len(masked_dataset)

In [None]:
from torchinfo import summary

import sys
sys.path.append('..')

from python_scripts.transformers.model import BERTCustomMasked, BERTCustom
from torch.utils.data import DataLoader

bertmodel = BERTCustom(
    vocab_size=len(masked_dataset.vocab),
    hidden=32,
    dim_k=4,
)
masked_model = BERTCustomMasked(bertmodel)

summary(masked_model)

In [None]:
masked_model(next(iter(DataLoader(masked_dataset, 3)))).shape

In [None]:
import torch
from torchmetrics import Accuracy
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

import sys
sys.path.append('..')

from python_scripts.transformers.dataset import MaskedDataModule
from python_scripts.transformers.task import MaskingTask

masked_datamodule = MaskedDataModule(whole_dataset=masked_dataset, batch_size=8)

masked_optimizer = torch.optim.Adam(masked_model.parameters(), 1e-3)
masked_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer=masked_optimizer,
    T_max=5,
    eta_min=1e-4,
    verbose=True,
)
# rna_scheduler = torch.optim.lr_scheduler.MultiStepLR(
#     optimizer=masked_optimizer,
#     [4, 7, 10, 13, 16, 19],
#     verbose=True,
#     gamma=0.3
# )
# rna_scheduler = torch.optim.lr_scheduler.CyclicLR(
#     optimizer=masked_optimizer,
#     base_lr=1e-6,
#     max_lr=1e-3,
#     step_size_up=3000,
#     step_size_down=7000,
#     verbose=True
# )
# rna_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#     optimizer=masked_optimizer,
#     max_lr=1e-3,
#     steps_per_epoch=100,
#     epochs=5,
#     div_factor=1e2,
#     verbose=False
# )

maskingtask = MaskingTask(
    model=masked_model,
    loss_fn=torch.nn.CrossEntropyLoss(),
    optimizer=masked_optimizer,
    scheduler=masked_scheduler,
    acc_fn=Accuracy(task='multiclass', num_classes=len(masked_dataset.vocab), ignore_index=-100)
)

callbacks = []
callbacks.append(ModelCheckpoint(
    monitor='val_avg_accuracy',
    save_top_k=3,
    mode='max'
))
# callbacks.append(EarlyStopping(
#     monitor='val_avg_accuracy',
#     min_delta=0.1,
#     patience=3,
#     verbose=True,
#     mode='max'
# ))

trainer = pl.Trainer(
    max_epochs=5,
    callbacks=callbacks
)

# maskingtask = MaskingTask.load_from_checkpoint(
#     './lightning_logs/version_0/checkpoints/epoch=0-step=33562.ckpt',
#     model=masked_model,
#     loss_fn=torch.nn.CrossEntropyLoss(),
#     optimizer=masked_optimizer,
#     scheduler=masked_scheduler,
#     acc_fn=Accuracy(task='multiclass', num_classes=len(masked_dataset.vocab), ignore_index=-100)
# )

# trainer.fit(maskingtask, ckpt_path="some/path/to/my_checkpoint.ckpt")# trainer = pl.Trainer(resume_from_checkpoint='../notebooks/lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt')

trainer.fit(maskingtask, datamodule=masked_datamodule)
trainer.test(maskingtask, datamodule=masked_datamodule)