In [7]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

import parnet

In [8]:
network = parnet.networks.PanRBPNet(223)
network

PanRBPNet(
  (stem): StemConv1D(
    (conv1d): Conv1d(4, 128, kernel_size=(12,), stride=(1,), padding=same)
    (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): ReLU()
  )
  (body): Conv1DTower(
    (tower): Sequential(
      (0): Sequential(
        (0): ResConv1DBlock(
          (conv1d): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
          (linear_upsample): Conv1d(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        )
        (1): ResConv1DBlock(
          (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same)
          (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU()
          (dropout): Dropout1d(p=0.25, inplace=False)
        )
   

In [9]:
# %%
import torch
import torch.nn as nn
import torchmetrics

# %%
def log_likelihood_from_logits(y, y_pred, dim):
    return torch.sum(torch.mul(torch.log_softmax(y_pred, dim=dim), y), dim=dim) + log_combinations(y, dim)

def log_combinations(input, dim):
    total_permutations = torch.lgamma(torch.sum(input, dim=dim) + 1)
    counts_factorial = torch.lgamma(input + 1)
    redundant_permutations = torch.sum(counts_factorial, dim=dim)
    return total_permutations - redundant_permutations

def multinomial_neg_log_probs(y, y_pred, dim=-1):
    return log_likelihood_from_logits(y, y_pred, dim) * -1

# %%
class MultinomialNLLLossFromLogits(torchmetrics.MeanMetric):
    def __init__(self, dim=-2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dim = dim

    def update(self, y: torch.Tensor, y_pred: torch.Tensor):
        assert y_pred.shape == y.shape

        loss = torch.mean(multinomial_neg_log_probs(y, y_pred, dim=self.dim))

        # update running mean
        super().update(loss)

y = torch.rand(2, 1000, 223)
y_pred = network({'sequence': torch.rand(2, 4, 1000)})

m_1 = multinomial_neg_log_probs
m_2 = MultinomialNLLLossFromLogits()

print(m_1(y, y_pred, dim=-2).shape)
print(torch.mean(m_1(y, y_pred, dim=-2)))
print(m_2(y, y_pred))

torch.Size([2, 223])
tensor(1062.9937, grad_fn=<MeanBackward0>)
tensor(1062.9937, grad_fn=<SqueezeBackward0>)


In [10]:
from parnet.losses import MultinomialNLLLossFromLogits
from parnet.metrics import BatchedPearsonCorrCoef

# class Model(pl.LightningModule):
#     def __init__(self, network, _example_input=None, metrics=None, optimizer=torch.optim.Adam):
#         super().__init__()
#         self.network = network

#         # loss
#         self.loss_fn = nn.ModuleDict({
#             'TRAIN': MultinomialNLLLossFromLogits(dim=-2),
#             'VAL': MultinomialNLLLossFromLogits(dim=-2),
#         })
        
#         # metrics
#         if metrics is None:
#             metrics = {}
#         self.metrics = nn.ModuleDict({
#             'TRAIN': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
#             'VAL': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
#         })
        
#         # optimizer
#         self.optimizer_cls = optimizer
    
#     def forward(self, *args, **kwargs):
#         return self.network(*args, **kwargs)

#     def configure_optimizers(self):
#         optimizer = self.optimizer_cls(self.parameters())
#         return optimizer

#     def training_step(self, batch, batch_idx=None, **kwargs):
#         inputs, y = batch
#         y = y['total']
#         print('y:', torch.sum(y))
#         y_pred = self.forward(inputs)
#         # loss = self.loss_fn(y, y_pred)
#         loss = self.compute_and_log_loss(y, y_pred, partition='TRAIN')
#         self.compute_and_log_metics(y, y_pred, partition='TRAIN')
#         return loss

#     def validation_step(self, batch, batch_idx=None, **kwargs):
#         inputs, y = batch
#         y = y['total']
#         y_pred = self.forward(inputs)
#         self.compute_and_log_loss(y, y_pred, partition='VAL')
#         self.compute_and_log_metics(y, y_pred, partition='VAL')
    
#     def compute_and_log_loss(self, y, y_pred, partition=None):
#         on_step = False
#         if partition == 'TRAIN':
#             on_step = True

#         loss = self.loss_fn[partition](y, y_pred)
#         self.log(f'{partition}/loss', loss, on_step=on_step, on_epoch=True, prog_bar=False)
#         return loss

#     def compute_and_log_metics(self, y, y_pred, partition=None):
#         on_step = False
#         if partition == 'TRAIN':
#             on_step = True

#         for name, metric in self.metrics[partition].items():
#             metric(y, y_pred)
#             self.log(f'{partition}/{name}', metric, on_step=on_step, on_epoch=True, prog_bar=False)

class Model(pl.LightningModule):
    def __init__(self, network, _example_input=None, metrics=None, optimizer=torch.optim.Adam):
        super().__init__()
        self.network = network

        # loss
        self.loss_fn = nn.ModuleDict({
            'TRAIN': MultinomialNLLLossFromLogits(dim=-2),
            'VAL': MultinomialNLLLossFromLogits(dim=-2),
        })
        
        # metrics
        if metrics is None:
            metrics = {}
        self.metrics = nn.ModuleDict({
            'TRAIN': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
            'VAL': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
        })
        
        # optimizer
        self.optimizer_cls = optimizer
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(self.parameters())
        return optimizer

    def training_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        # loss = self.loss_fn(y, y_pred)
        loss = self.compute_and_log_loss(y, y_pred, partition='TRAIN')
        self.compute_and_log_metics(y, y_pred, partition='TRAIN')
        return loss

    def validation_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        self.compute_and_log_loss(y, y_pred, partition='VAL')
        self.compute_and_log_metics(y, y_pred, partition='VAL')
    
    def compute_and_log_loss(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        loss = self.loss_fn[partition](y, y_pred)
        self.log(f'{partition}/loss', loss, on_step=True, on_epoch=True, prog_bar=False)
        return loss

    def compute_and_log_metics(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        for name, metric in self.metrics[partition].items():
            metric(y, y_pred)
            self.log(f'{partition}/{name}', metric, on_step=True, on_epoch=True, prog_bar=False)

model = Model(network, metrics={'pcc': BatchedPearsonCorrCoef})
model

Model(
  (network): PanRBPNet(
    (stem): StemConv1D(
      (conv1d): Conv1d(4, 128, kernel_size=(12,), stride=(1,), padding=same)
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (body): Conv1DTower(
      (tower): Sequential(
        (0): Sequential(
          (0): ResConv1DBlock(
            (conv1d): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=same)
            (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): ReLU()
            (dropout): Dropout1d(p=0.25, inplace=False)
            (linear_upsample): Conv1d(128, 256, kernel_size=(1,), stride=(1,), bias=False)
          )
          (1): ResConv1DBlock(
            (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same)
            (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act): ReLU()
         

In [11]:
dataloader = torch.utils.data.DataLoader(parnet.data.datasets.TFIterableDataset('../example/data.matrix/head.tfrecord', batch_size=4, shuffle=1_000_000), batch_size=None)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7ff3b03638e0>

In [12]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=10, logger=loggers, callbacks=[checkpoint_callback, LearningRateMonitor('step', log_momentum=True)], log_every_n_steps=1)
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model.network, 'test.pt')

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | network | PanRBPNet  | 1.6 M 
1 | loss_fn | ModuleDict | 0     
2 | metrics | ModuleDict | 0     
---------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.308     Total estimated model params size (MB)


Epoch 9: : 4it [00:00,  4.36it/s, loss=1.33, v_num=]                       

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: : 4it [00:01,  3.92it/s, loss=1.33, v_num=]
