# Import

In [2]:
import os
import torch
import numpy as np
import pytorch_lightning as pl
from time import time_ns
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from rnaquanet.network.graph_regression_network import GraphRegressionNetwork
from rnaquanet.network.grn_data_module import GRNDataModule
from rnaquanet.utils.rnaquanet_config import RnaquanetConfig
from pytorch_lightning.loggers import TensorBoardLogger

# Wstępna konfiguracja

In [3]:
torch.set_float32_matmul_precision('high')

# konfiguracje można zmienić w locie przy definiowaniu
config = RnaquanetConfig('/app/configs/config_rnaquanetdataset.yml', override = {
    'network': {
        'num_workers': 4
    }
})


data = GRNDataModule(config)
data.prepare_data()

In [4]:
from pytorch_lightning import Callback
import copy


class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self,writer,what="epochs"):
        super().__init__()
        self.what = what
        self.metrics = []
        self.writer=writer
        self.epochs=1

    def on_train_epoch_end(self, trainer, pl_module):
        if self.what == "epochs":
            self.epochs += 1
            try:
                self.writer.add_scalar('Loss/train', trainer.callback_metrics['train_loss_epoch'].item(), self.epochs)
            except:
                pass
    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(copy.deepcopy(trainer.callback_metrics))
        try:
            self.writer.add_scalar('Loss/val', trainer.callback_metrics['val_loss'].item(), self.epochs)
        except:
                pass

# Trening

In [8]:
# jeżeli chcemy wykonać jakiś eksperyment możemy ręcznie zmienić jakąś właściwość pliku konfiguracyjnego
config.network.max_epochs = 25

path = os.path.join(config.network.model_output_path, str(time_ns()))
os.makedirs(path, exist_ok=False)

# writer = SummaryWriter()
model = GraphRegressionNetwork(config)
logger = TensorBoardLogger("tb_logs", name="RnaQALightning")

trainer = pl.Trainer(max_epochs=config.network.max_epochs, log_every_n_steps=1, callbacks=[
    EarlyStopping('val_loss'),
    ModelCheckpoint(dirpath=path, save_top_k=3, monitor='val_loss'),
    MetricsCallback(logger.experiment)
],
 logger=logger
) 
trainer.fit(model, data)
trainer.save_checkpoint(os.path.join(path, 'final.cpkt'))
for item in data.train_dataloader():
    logger.experiment.add_graph(model,[item.x,item.edge_index,item.edge_attr, item.batch])
logger.experiment.add_hparams({
                "HP_hidden_dim": config.network.hidden_dim,
                "HP_num_of_layers": config.network.num_of_layers,
                "HP_num_of_node_features": config.network.num_of_node_features,
                "HP_batch_norm": config.network.batch_norm,
                "HP_gat_dropout": config.network.gat_dropout,
                "HP_lr": config.network.lr,
                "HP_weight_decay": config.network.weight_decay,
                "HP_scheduler_step_size": config.network.scheduler_step_size,
                "HP_scheduler_gamma": config.network.scheduler_gamma
            },
            {
                'Loss_train':trainer.logged_metrics['train_loss_epoch'].item(),
                'Loss_val':trainer.logged_metrics['val_loss'].item()
            },run_name=''
        )
model.custom_histogram_adder()


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | batch_norm | ModuleList | 0     
1 | conv       | ModuleList | 37.6 K
------------------------------------------
37.6 K    Trainable params
0         Non-trainable params
37.6 K    Total params
0.151     Total estimated model params size (MB)


Epoch 6: 100%|██████████| 11/11 [00:01<00:00,  9.82it/s, v_num=3, train_loss_step=100.0, val_loss=161.0, train_loss_epoch=100.0]
