In [2]:
import torch
import torch.nn as nn
import torch.functional as F
import pytorch_lightning as pl

from torchrbpnet.layers import Conv1DFirstLayer, Conv1DResBlock, IndexEmbeddingOutputHead
from torchrbpnet.losses import MultinomialNLLLossFromLogits
from torchrbpnet.data.utils import TFIterableDataset

In [21]:
class Network(nn.Module):
    def __init__(self, tasks, n_layers=9, n_body_filters=256):
        super(Network, self).__init__()

        self.tasks = tasks

        self.body = nn.Sequential(*[Conv1DFirstLayer(4, n_body_filters, 6)]+[(Conv1DResBlock(n_body_filters, n_body_filters, dilation=(2**i))) for i in range(n_layers)])
        self.head = IndexEmbeddingOutputHead(len(self.tasks), dims=n_body_filters)
    
    def forward(self, inputs, **kwargs):
        x = inputs

        for layer in self.body:
            x = layer(x)

        return self.head(x)

network = Network(tasks=list(range(223)))

In [22]:
class MultiRBPNet(pl.LightningModule):
    def __init__(self, network):
        super().__init__()
        self.network = network # Network(tasks=range(223))
        self.loss_fn = MultinomialNLLLossFromLogits()

    def training_step(self, batch, **kwargs):
        x, y = batch
        y_pred = self.network(x)
        loss = self.loss_fn(y, y_pred, dim=-2)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = MultiRBPNet(network)
model

MultiRBPNet(
  (network): Network(
    (body): Sequential(
      (0): Conv1DFirstLayer(
        (conv1d): Conv1d(4, 256, kernel_size=(6,), stride=(1,), padding=same)
        (act): ReLU()
      )
      (1): Conv1DResBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same)
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
        (dropout): Dropout(p=0.25, inplace=False)
      )
      (2): Conv1DResBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): ReLU()
        (dropout): Dropout(p=0.25, inplace=False)
      )
      (3): Conv1DResBlock(
        (conv1d): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
        (batch_norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_runn

In [11]:
dataset = TFIterableDataset('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord', shuffle=1_000_000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)

In [18]:
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model=model, train_dataloaders=dataloader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type                         | Params
---------------------------------------------------------
0 | network | Network                      | 1.8 M 
1 | loss_fn | MultinomialNLLLossFromLogits | 0     
---------------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.384     Total estimated model params size (MB)


Epoch 0: : 2it [00:28, 14.40s/it, loss=323, v_num=9]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [20]:
{f'chr{i}' for i in range(1, 24)}.difference({'chr1', 'chr2'})

{'chr10',
 'chr11',
 'chr12',
 'chr13',
 'chr14',
 'chr15',
 'chr16',
 'chr17',
 'chr18',
 'chr19',
 'chr20',
 'chr21',
 'chr22',
 'chr23',
 'chr3',
 'chr4',
 'chr5',
 'chr6',
 'chr7',
 'chr8',
 'chr9'}