In [1]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchrbpnet.data import tfrecord_to_dataloader, dummy_dataloader

dataloader = dummy_dataloader(10)

In [16]:
# %%
import torch
import torch.nn as nn
import pytorch_lightning as pl

import torchmetrics
from torchrbpnet.losses import MultinomialNLLLossFromLogits
from torchrbpnet.metrics import BatchedPCC, MultinomialNLLFromLogits
from torchrbpnet.networks import MultiRBPNet

# %%
class BatchIdx(torchmetrics.MeanMetric):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def update(self, batch_idx: torch.Tensor, *args, **kwargs):
        # update (i.e. take mean)
        super().update(batch_idx)

# %%
class Model(pl.LightningModule):
    def __init__(self, network):
        super().__init__()
        self.network = network
        self.loss_fn = MultinomialNLLLossFromLogits()
        self.metrics = nn.ModuleDict({'loss': MultinomialNLLFromLogits(), 'pcc': BatchedPCC()}) 
        self.example_input_array = torch.rand(2, 4, 101)
        self.test_metric = torchmetrics.MeanMetric() #BatchedPCC()

        print(self.hparams)
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch, batch_idx, **kwargs):
        x, y = batch
        y_pred = self.forward(x)
        loss = self.loss_fn(y, y_pred, dim=-2)
        self.compute_and_log_metics(y_pred, y, partition='train')
        return loss
    
    def training_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.forward(x)
        self.compute_and_log_metics(y_pred, y, partition='val')
    
    def compute_and_log_metics(self, y_pred, y, partition=None):
        on_step = False
        if partition == 'train':
            on_step = True

        for name, metric in self.metrics.items():
            metric(y_pred, y)
            self.log(f'{partition}/{name}', metric.compute(), on_step=on_step, on_epoch=True, prog_bar=False)
    
    def _reset_metrics(self):
        for metric in self.metrics.values():
            metric.reset()

model = Model(network=MultiRBPNet(n_tasks=7))




In [20]:
metrics_module_dict = nn.ModuleDict({'loss': MultinomialNLLFromLogits(), 'pcc': BatchedPCC()})

In [23]:
metrics_module_dict['pcc']

BatchedPCC()

In [4]:
# next(iter(model.modules()))

In [19]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=10, logger=loggers, callbacks=[checkpoint_callback, early_stop_callback])
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model, 'model.pt')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type                         | Params | In sizes    | Out sizes  
-----------------------------------------------------------------------------------------
0 | network     | MultiRBPNet                  | 1.8 M  | [2, 4, 101] | [2, 101, 7]
1 | loss_fn     | MultinomialNLLLossFromLogits | 0      | ?           | ?          
2 | test_metric | MeanMetric                   | 0      | ?           | ?          
-----------------------------------------------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.138     Total estimated model params size (MB)


Epoch 9: : 20it [00:03,  6.32it/s, loss=728, v_num=]                       

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: : 20it [00:03,  6.05it/s, loss=728, v_num=]


In [6]:
m = torchmetrics.MeanMetric()

In [7]:
m(1)
m(2)
m(3)

tensor(3.)

In [11]:
m.reset()
m(1)
m(2)

tensor(2.)

In [13]:
m([2,4])

tensor(3.)

In [10]:
sum(range(20))/20

9.5