In [1]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchrbpnet.data import tfrecord_to_dataloader, dummy_dataloader

dataloader = dummy_dataloader(5)

In [6]:
# %%
import torch
import torch.nn as nn
import pytorch_lightning as pl

from torchrbpnet.losses import MultinomialNLLLossFromLogits
from torchrbpnet.metrics import batched_pearson_corrcoef
from torchrbpnet.networks import MultiRBPNet

# %%
class Model(pl.LightningModule):
    def __init__(self, network):
        super().__init__()
        self.network = network
        self.loss_fn = MultinomialNLLLossFromLogits()
        self.metrics = [batched_pearson_corrcoef]
        self.example_input_array = torch.rand(2, 4, 101)

        print(self.hparams)
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def training_step(self, batch, *args, **kwargs):
        x, y = batch
        y_pred = self.network(x)
        loss = self.loss_fn(y, y_pred, dim=-2)
        self.log_dict(self._compute_metrics(y, y_pred, partition='train'), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, *args, **kwargs):
        x, y = batch
        y_pred = self.network(x)
        loss = self.loss_fn(y, y_pred, dim=-2)
        self.log_dict(self._compute_metrics(y, y_pred, partition='val'), on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def _compute_loss(self, y, y_pred):
        return self.loss_fn(y, y_pred)

    def _compute_metrics(self, y, y_pred, partition=''):
        results = dict()
        for metric_fn in self.metrics:
            results[f'{partition}/{metric_fn.__name__}'] = metric_fn(y, y_pred)
        results[f'{partition}/loss'] = self._compute_loss(y, y_pred)
        return results

model = Model(network=MultiRBPNet(n_tasks=7))




In [7]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=3, logger=loggers, callbacks=[checkpoint_callback, early_stop_callback])
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model, 'model.pt')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  return F.conv1d(input, weight, bias, self.stride,

  | Name    | Type                         | Params | In sizes    | Out sizes  
-------------------------------------------------------------------------------------
0 | network | MultiRBPNet                  | 1.8 M  | [2, 4, 101] | [2, 101, 7]
1 | loss_fn | MultinomialNLLLossFromLogits | 0      | ?           | ?          
-------------------------------------------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.138     Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.69it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 2: : 10it [00:01,  5.17it/s, loss=2.93e+03, v_num=, train/batched_pearson_corrcoef_step=0.126, train/loss_step=193.0, val/batched_pearson_corrcoef_step=0.106, val/loss_step=95.20, val/batched_pearson_corrcoef_epoch=0.0209, val/loss_epoch=95.80, train/batched_pearson_corrcoef_epoch=0.0295, train/loss_epoch=193.0]  

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: : 10it [00:02,  4.91it/s, loss=2.93e+03, v_num=, train/batched_pearson_corrcoef_step=0.126, train/loss_step=193.0, val/batched_pearson_corrcoef_step=0.106, val/loss_step=95.20, val/batched_pearson_corrcoef_epoch=0.0209, val/loss_epoch=95.80, train/batched_pearson_corrcoef_epoch=0.0295, train/loss_epoch=193.0]
