In [1]:
from pathlib import Path
from rnamodif.data_utils.data_paths import name_to_files

train_pos_files = name_to_files['nia_2022_pos']['train']
train_neg_files = name_to_files['nia_2022_neg']['train']

valid_exp_to_files_pos = {
    '5eu_2020_pos':name_to_files['nia_2020_pos']['test'], 
    'Nanoid_pos':name_to_files['nano_pos_1']['test'], 
    '5eu_2022_chr1_pos':name_to_files['nia_2022_pos']['test']
}

valid_exp_to_files_neg = {
    'UNM_2020':name_to_files['nia_2020_neg']['test'], 
    'Nanoid_neg':name_to_files['nano_neg_1']['test'], 
    '5eu_2022_chr1_neg':name_to_files['nia_2022_neg']['test']
}

In [1]:
from rnamodif.model import RodanPretrained
from rnamodif.data_utils.dataloading_5eu import TrainingDatamodule
import pytorch_lightning as pl
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.callbacks import ModelCheckpoint


model = RodanPretrained(lr=1e-4, warmup_steps=3000, frozen_layers=0, gru_layers=1, gru_dropout=0.5, gru_hidden=32)

dm = TrainingDatamodule(
    train_pos_files=train_pos_files,
    train_neg_files=train_neg_files,
    valid_exp_to_files_pos=valid_exp_to_files_pos,
    valid_exp_to_files_neg=valid_exp_to_files_neg,
    batch_size=64, 
    window=4096,
    per_dset_read_limit=250, 
    shuffle_valid=True,
    workers=8,
)

experiment_name = 'TEMPLATE'
checkpoint_callback = ModelCheckpoint(
    dirpath=f"/home/jovyan/RNAModif/rnamodif/checkpoints_pl/{experiment_name}", 
    save_top_k=2, 
    monitor="valid_loss", 
    save_last=True, 
    save_weights_only=False
)

logger = CometLogger(api_key="TEVQbgxxvilM1WdTyqZLJ57ac", project_name='RNAModif', experiment_name=experiment_name) 
trainer= pl.Trainer(
    max_steps = 1000000, logger=logger, accelerator='gpu',
    auto_lr_find=False, val_check_interval=1000,  
    log_every_n_steps=1000, benchmark=True, precision=16,
    callbacks=[checkpoint_callback],
    # resume_from_checkpoint=f'/home/jovyan/RNAModif/rnamodif/checkpoints_pl/{experiment_name}/lastX.ckpt'
)


trainer.fit(model, dm)