In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

import parnet

In [None]:
network = parnet.networks.PanRBPNet(223)
network

In [None]:
# # %%
# import torch
# import torch.nn as nn
# import torchmetrics

# # %%
# def log_likelihood_from_logits(y, y_pred, dim):
#     return torch.sum(torch.mul(torch.log_softmax(y_pred, dim=dim), y), dim=dim) + log_combinations(y, dim)

# def log_combinations(input, dim):
#     total_permutations = torch.lgamma(torch.sum(input, dim=dim) + 1)
#     counts_factorial = torch.lgamma(input + 1)
#     redundant_permutations = torch.sum(counts_factorial, dim=dim)
#     return total_permutations - redundant_permutations

# def multinomial_neg_log_probs(y, y_pred, dim=-1):
#     return log_likelihood_from_logits(y, y_pred, dim) * -1

# # %%
# class MultinomialNLLLossFromLogits(torchmetrics.MeanMetric):
#     def __init__(self, dim=-2, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.dim = dim

#     def update(self, y: torch.Tensor, y_pred: torch.Tensor):
#         assert y_pred.shape == y.shape

#         loss = torch.mean(multinomial_neg_log_probs(y, y_pred, dim=self.dim))

#         # update running mean
#         super().update(loss)

# y = torch.rand(2, 1000, 223)
# y_pred = network({'sequence': torch.rand(2, 4, 1000)})

# m_1 = multinomial_neg_log_probs
# m_2 = MultinomialNLLLossFromLogits()

# print(m_1(y, y_pred, dim=-2).shape)
# print(torch.mean(m_1(y, y_pred, dim=-2)))
# print(m_2(y, y_pred))

In [None]:
from parnet.losses import MultinomialNLLLossFromLogits
from parnet.metrics import PearsonCorrCoeff, FilteredPearsonCorrCoeff

class Model(pl.LightningModule):
    def __init__(self, network, _example_input=None, metrics=None, optimizer=torch.optim.Adam):
        super().__init__()
        self.network = network

        # loss
        self.loss_fn = nn.ModuleDict({
            'TRAIN': MultinomialNLLLossFromLogits(dim=-2),
            'VAL': MultinomialNLLLossFromLogits(dim=-2),
        })
        
        # metrics
        if metrics is None:
            metrics = {}
        self.metrics = nn.ModuleDict({
            'TRAIN': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
            'VAL': nn.ModuleDict({name: metric() for name, metric in metrics.items()}),
        })
        
        # optimizer
        self.optimizer_cls = optimizer
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(self.parameters())
        return optimizer

    def training_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        # loss = self.loss_fn(y, y_pred)
        loss = self.compute_and_log_loss(y, y_pred, partition='TRAIN')
        self.compute_and_log_metics(y, y_pred, partition='TRAIN')
        return loss

    def validation_step(self, batch, batch_idx=None, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        self.compute_and_log_loss(y, y_pred, partition='VAL')
        self.compute_and_log_metics(y, y_pred, partition='VAL')
    
    def compute_and_log_loss(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        loss = self.loss_fn[partition](y, y_pred)
        self.log(f'{partition}/loss', loss, on_step=True, on_epoch=True, prog_bar=False)
        return loss

    def compute_and_log_metics(self, y, y_pred, partition=None):
        # on_step = False
        # if partition == 'TRAIN':
        #     on_step = True

        for name, metric in self.metrics[partition].items():
            metric(y, y_pred)
            self.log(f'{partition}/{name}', metric, on_step=True, on_epoch=True, prog_bar=False)

model = Model(network, metrics={'pcc': PearsonCorrCoeff, 'filtered_pcc': FilteredPearsonCorrCoeff})
model

In [None]:
dataloader = torch.utils.data.DataLoader(parnet.data.datasets.TFIterableDataset('../example/data.matrix/head.tfrecord', batch_size=4, shuffle=1_000_000), batch_size=None)
dataloader

In [None]:
next(iter(dataloader))[1]['total'].shape

In [None]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=10, logger=loggers, callbacks=[checkpoint_callback, LearningRateMonitor('step', log_momentum=True)], log_every_n_steps=1)
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model.network, 'test.pt')

In [None]:
y = next(iter(dataloader))[1]['total']
y_pred = model(next(iter(dataloader))[0])

In [None]:
y.shape

In [None]:
import torch
import torchmetrics

y = torch.tensor(
    [
        [
            [1, 2, 3, 2],
            [1, 0, 0, 7],
        ],
        [
            [1, 2, 3, 2],
            [1, 0, 0, 7],
        ],
    ], dtype=torch.float32
)
print(y.shape)

y_pred = torch.tensor(
    [
        [
            [1, 4, 3, 2],
            [1, 0, 2, 2],
        ],
        [
            [1, 4, 3, 2],
            [1, 0, 2, 2],
        ],
    ], dtype=torch.float32
)
print(y_pred.shape)

print(torchmetrics.functional.pearson_corrcoef(torch.tensor([1., 2., 3., 2.]), torch.tensor([1., 4., 3., 2.])))
print(torchmetrics.functional.pearson_corrcoef(torch.tensor([1., 0., 0., 7.]), torch.tensor([1., 0., 2., 2.])))

In [None]:
from parnet import metrics

print(metrics.PearsonCorrCoeff()(y, y_pred))

In [None]:
def pearson_corrcoef(x, y, dim=-1):
    x = x - torch.unsqueeze(torch.mean(x, dim), dim)
    y = y - torch.unsqueeze(torch.mean(y, dim), dim)
    return torch.sum(x * y, dim) / torch.sqrt(torch.sum(x ** 2, dim) * torch.sum(y ** 2, dim))

torch.mean(pearson_corrcoef(y, y_pred))

In [None]:
torchmetrics.functional.pearson_corrcoef(y[0], y_pred[0])

In [None]:
y_flat = torch.flatten(y, start_dim=0, end_dim=-2)
y_pred_flat = torch.flatten(y_pred, start_dim=0, end_dim=-2)

y_flat_t = torch.transpose(y_flat, -1, -2)
y_pred_flat_t = torch.transpose(y_pred_flat, -1, -2)
print(y_pred_flat_t)

In [None]:
torchmetrics.functional.pearson_corrcoef(y_flat_t, y_pred_flat_t).reshape(y.shape[:-1])

In [None]:
def batched_pearson_corrcoef(y_batch, y_pred_batch, reduction=None):
    pcc = torch.stack([torchmetrics.functional.pearson_corrcoef(y_batch[i], y_pred_batch[i]) for i in range(y_batch.shape[0])])
    if reduction is not None:
        pcc = reduction(pcc)
    return pcc
batched_pearson_corrcoef(y, y_pred)

In [None]:
from parnet import metrics

values = metrics.batched_pearson_corrcoef(y, y_pred, reduction=None)
values.shape