# Import

In [7]:
import os
import torch
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger

# Wstępna konfiguracja

In [8]:
torch.set_float32_matmul_precision('high')

# konfiguracje można zmienić w locie przy definiowaniu
config = RnaquanetConfig('/app/configs/config_ares_tiny.yml', override = {
    'network': {
        'num_workers': 4
    }
})


data = GRNDataModule(config)
data.prepare_data()

# Trening

In [9]:
# jeżeli chcemy wykonać jakiś eksperyment możemy ręcznie zmienić jakąś właściwość pliku konfiguracyjnego
config.network.max_epochs = 25

path = os.path.join(config.network.model_output_path, str(time_ns()))
os.makedirs(path, exist_ok=False)

model = GraphRegressionNetwork(config)
trainer = pl.Trainer(max_epochs=config.network.max_epochs, log_every_n_steps=1, callbacks=[
    EarlyStopping('val_loss'),
    ModelCheckpoint(dirpath=path, save_top_k=3, monitor='val_loss'),
], logger=TensorBoardLogger("tb_logs", name="RnaQALightning")) 
trainer.fit(model, data)
trainer.save_checkpoint(os.path.join(path, 'final.cpkt'))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | batch_norm | ModuleList | 0     
1 | conv       | ModuleList | 37.6 K
------------------------------------------
37.6 K    Trainable params
0         Non-trainable params
37.6 K    Total params
0.151     Total estimated model params size (MB)


Epoch 12: 100%|██████████| 1/1 [00:00<00:00,  4.15it/s, v_num=3, train_loss_step=15.40, val_loss=9.710, train_loss_epoch=15.40]
