In [8]:
import torch
int(torch.round(torch.tensor(0.7)))

1

In [15]:
d = 1.0
for i in range(10):
    print(int(torch.round(torch.tensor(d**i))))

1
1
1
1
1
1
1
1
1
1


In [None]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

In [None]:
import os
import torch
print(os.getcwd())

from torchrbpnet.data import tfrecord_to_dataloader, dummy_dataloader
from torchrbpnet.data import datasets

# dataloader = torch.utils.data.DataLoader(datasets.TFIterableDataset('../example/head.tfrecord', batch_size=2, shuffle=100), batch_size=None)

In [None]:
import torchmetrics

class BatchedPearsonCorrCoef(torchmetrics.MeanMetric):
    def __init__(self, min_height=2, min_count=None):
        super(BatchedPearsonCorrCoef, self).__init__()

        self.min_height = min_height
        self.min_count = min_count

    def update(self, y_pred: torch.Tensor, y: torch.Tensor):
        if y_pred.shape != y.shape:
            raise ValueError('shapes y_pred {y_pred.shape} and y {y.shape} are not the same. ')

        cc_values = self.compute_cc(y_pred, y)
        cc_mean = self.compute_mean(cc_values, y)

        # update
        super().update(cc_mean)
    
    def compute_cc(self, y_pred: torch.Tensor, y: torch.Tensor):
        values = []
        for i in range(y.shape[0]):
            values.append(torchmetrics.functional.pearson_corrcoef(y[i], y_pred[i]))
        # stack to (batch_size x ...) - at this point the shape should be (batch_size x experiments
        return torch.stack(values)

    def compute_mean(self, values: torch.Tensor, y: torch.Tensor):
        # create boolean tensor of entries that are *not* NaNs
        values_is_not_nan_mask = torch.logical_not(torch.isnan(values))
        # convert nan's to 0
        values = torch.nan_to_num(values, 0.0)

        # check if required height is reached per experiment
        if self.min_height is not None:
            # should be shape (batch_size, experiments)
            y_min_height_mask = (torch.max(y, dim=-2).values >= self.min_height)
        else:
            y_min_height_mask = torch.ones(*values.shape)
        
        # check if required count is reached per experiment
        if self.min_count is not None:
            # should be shape (batch_size, experiments)
            y_min_count_mask = (torch.sum(y, dim=-2) >= self.min_count)
        else:
            y_min_count_mask = torch.ones(*values.shape)
        
        # boolean mask indicating which experiment (in each batch) passed nan, heigh and count (and is thus used for the final mean PCC)
        passed_boolean_mask = torch.sum(torch.stack([values_is_not_nan_mask, y_min_height_mask, y_min_count_mask]), dim=0) > 0

        # mask out (i.e. zero) all PCC values that did not pass
        values_masked = torch.mul(values, passed_boolean_mask.to(torch.float32))

        # compute mean by only dividing by #-elements that passed
        values_mean = torch.sum(values_masked)/torch.sum(passed_boolean_mask)

        # if ignore_nan:
        #     # only divide by #-elements not NaN
        #     values_mean = torch.sum(values)/torch.sum(values_is_not_nan)
        # else:
        #     values_mean = torch.mean(values)
        
        return values_mean

In [None]:
esm_masked_dataset = datasets.MeanESMEmbeddingMaskedTFIterableDataset(embedding_matrix_filepath='../example/esm2_t33_650M_UR50D.ENCODE.idx2mean.pt', filepath='../example/head.tfrecord', batch_size=2, shuffle=100)
dataloader = torch.utils.data.DataLoader(esm_masked_dataset, batch_size=None)
print(next(iter(dataloader))[0].keys())

In [None]:
# %%
import torch
import torch.nn as nn
import pytorch_lightning as pl

import torchmetrics
from torchrbpnet.losses import MultinomialNLLLossFromLogits
from torchrbpnet.metrics import MultinomialNLLFromLogits #BatchedPCC
from torchrbpnet.networks import MultiRBPNet, ProteinEmbeddingMultiRBPNet

class Model(pl.LightningModule):
    def __init__(self, network, metrics=None, optimizer=torch.optim.Adam):
        super().__init__()
        self.network = network
        self.loss_fn = MultinomialNLLLossFromLogits()
        
        # metrics
        if metrics is None:
            self.metrics = nn.ModuleDict({})
        else:
            self.metrics = nn.ModuleDict(metrics)
        
        # optimizer
        self.optimizer_cls = optimizer
    
    def forward(self, *args, **kwargs):
        return self.network(*args, **kwargs)

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(self.parameters())
        return optimizer

    def training_step(self, batch, batch_idx, **kwargs):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        loss = self.loss_fn(y, y_pred, dim=-2)
        self.compute_and_log_metics(y_pred, y, partition='train')
        return loss
    
    def training_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_epoch_end(self, *args, **kwargs):
        self._reset_metrics()

    def validation_step(self, batch, batch_idx):
        inputs, y = batch
        y = y['total']
        y_pred = self.forward(inputs)
        self.compute_and_log_metics(y_pred, y, partition='val')
    
    def compute_and_log_metics(self, y_pred, y, partition=None):
        on_step = False
        if partition == 'train':
            on_step = True

        for name, metric in self.metrics.items():
            metric(y_pred, y)
            self.log(f'{partition}/{name}', metric.compute(), on_step=on_step, on_epoch=True, prog_bar=False)
    
    def _reset_metrics(self):
        for metric in self.metrics.values():
            metric.reset()

model = Model(network=MultiRBPNet(n_tasks=223)) # _example_input_shape=next(iter(dataloader))[0].shape
# model = Model(network=ProteinEmbeddingMultiRBPNet())

In [None]:
model.configure_optimizers()

In [None]:
example = next(iter(dataloader))

In [None]:
model(example[0]).shape

In [None]:
# optimizer = torch.optim.Adam()

In [None]:
import datetime
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar, LearningRateMonitor

root_log_dir = f'logs/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
loggers = [
    pl_loggers.TensorBoardLogger(root_log_dir+'/tensorboard', name='', version='', log_graph=True),
    # pl_loggers.CSVLogger(root_log_dir+'/tensorboard', name='', version=''),
]

checkpoint_callback = ModelCheckpoint(dirpath=f'{root_log_dir}/checkpoints', every_n_epochs=1, save_last=True)

early_stop_callback = EarlyStopping(monitor="val/loss", min_delta=0.00, patience=3, verbose=False, mode="min")

bar = RichProgressBar()

trainer = pl.Trainer(default_root_dir=root_log_dir, max_epochs=2, logger=loggers, callbacks=[checkpoint_callback, early_stop_callback, LearningRateMonitor('step', log_momentum=True)])
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)
torch.save(model.network, 'test.pt')

In [None]:
model.optimizers()