In [14]:
import torch
import torch.nn as nn

In [38]:
class MultinomialNLLLossFromLogits(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(MultinomialNLLLossFromLogits, self).__init__()
        self.reduction = reduction
    
    def __call__(self, y, y_pred):
        return self.log_likelihood_from_logits(y, y_pred)

    def log_likelihood_from_logits(self, y, y_pred):
        log_prob = -torch.sum(torch.mul(torch.log_softmax(y_pred, dim=-1), y), dim=-1) * self.log_combinations(y)
        if self.reduction is not None:
            return self.reduction(log_prob)
        return log_prob

    def log_combinations(self, input):
        total_permutations = torch.lgamma(torch.sum(input, dim=-1) + 1)
        counts_factorial = torch.lgamma(input + 1)
        redundant_permutations = torch.sum(counts_factorial, dim=-1)
        return total_permutations - redundant_permutations

In [15]:
class Conv1DFirstLayer(nn.Module):
    def __init__(self, in_chan, filters=128, kernel_size=12):
        super(Conv1DFirstLayer, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, filters, kernel_size=kernel_size, padding='same')
        self.act = nn.ReLU()
    
    def forward(self, inputs, **kwargs):
        x = self.conv1d(inputs)
        x = self.act(x)
        return x

In [16]:
class Conv1DResBlock(nn.Module):
    def __init__(self, in_chan, filters=128, kernel_size=3, dropout=0.25, dilation=1, residual=True):
        super(Conv1DResBlock, self).__init__()

        self.conv1d = nn.Conv1d(in_chan, filters, kernel_size=kernel_size, dilation=dilation, padding='same')
        self.batch_norm = nn.BatchNorm1d(filters)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.residual = residual
    
    def forward(self, inputs, **kwargs):
        x = self.conv1d(inputs)
        x = self.batch_norm(x)
        x = self.act(x)
        x = self.dropout(x)
        if self.residual:
            x = inputs + x
        return x

# %%
class IndexEmbeddingOutputHead(nn.Module):
    def __init__(self, n_tasks, dims):
        super(IndexEmbeddingOutputHead, self).__init__()

        # protein/experiment embedding of shape (p, d)
        self.embedding = torch.nn.Embedding(n_tasks, dims)
    
    def forward(self, bottleneck, **kwargs):
        # bottleneck of shape (batch, d, n) --> (batch, n, d)
        bottleneck = torch.transpose(bottleneck, -1, -2)
        
        # embedding of (batch, p, d) --> (batch, d, p)
        embedding = torch.transpose(self.embedding.weight, 0, 1)

        logits = torch.matmul(bottleneck, embedding) # torch.transpose(self.embedding.weight, 0, 1)  
        return logits

In [31]:
class IndexEmbeddingOutputHead(nn.Module):
    def __init__(self, n_tasks, dims):
        super(IndexEmbeddingOutputHead, self).__init__()

        # protein/experiment embedding of shape (p, d)
        self.embedding = torch.nn.Embedding(n_tasks, dims)
    
    def forward(self, bottleneck, **kwargs):
        # bottleneck of shape (batch, d, n) --> (batch, n, d)
        bottleneck = torch.transpose(bottleneck, -1, -2)
        
        # embedding of (batch, p, d) --> (batch, d, p)
        embedding = torch.transpose(self.embedding.weight, 0, 1)

        logits = torch.matmul(bottleneck, embedding) # torch.transpose(self.embedding.weight, 0, 1)  
        return logits

In [32]:
class Network(nn.Module):
    def __init__(self, tasks, nlayers=9):
        super(Network, self).__init__()

        self.tasks = tasks

        self.body = nn.Sequential(*[Conv1DFirstLayer(4, 128)]+[(Conv1DResBlock(128, dilation=(2**i))) for i in range(nlayers)])
        self.head = IndexEmbeddingOutputHead(len(self.tasks), dims=128)
    
    def forward(self, inputs, **kwargs):
        x = inputs

        for layer in self.body:
            x = layer(x)

        return self.head(x)

In [33]:
net = Network(tasks=list(range(223)))
net

Network(
  (body): Sequential(
    (0): Conv1DFirstLayer(
      (conv1d): Conv1d(4, 128, kernel_size=(12,), stride=(1,), padding=same)
      (act): ReLU()
    )
    (1): Conv1DResBlock(
      (conv1d): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=same)
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
      (dropout): Dropout(p=0.25, inplace=False)
    )
    (2): Conv1DResBlock(
      (conv1d): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=same, dilation=(2,))
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
      (dropout): Dropout(p=0.25, inplace=False)
    )
    (3): Conv1DResBlock(
      (conv1d): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=same, dilation=(4,))
      (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
      (dropout): Dropout(p=0.

In [20]:
y_pred = net(torch.rand(2, 4, 201))
y_pred.shape

torch.Size([2, 201, 128])
torch.Size([128, 223])


  return F.conv1d(input, weight, bias, self.stride,


torch.Size([2, 201, 223])

In [156]:
from bioflow import io
import tensorflow as tf
import torch

def load_tf_dataset_to_torch(filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
    dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)

    # cache
    if cache:
        dataset = dataset.cache()

    if shuffle:
        dataset = dataset.shuffle(shuffle)

    # deserialize
    if features_filepath is None:
        features_filepath = filepath + '.features.json'
    features = io.dataset_ops.features_from_json_file(features_filepath)
    dataset = io.dataset_ops.deserialize_dataset(dataset, features)

    # batch
    dataset = dataset.batch(batch_size)

    # format dataset
    dataset = dataset.map(lambda e: (tf.transpose(e['inputs']['input'], perm=[0, 2, 1]), e['outputs']))

    for example in dataset.as_numpy_iterator():
        # yield example
        yield tf.nest.map_structure(lambda x: torch.tensor(x).to(torch.float32), example)

torch_dataset = load_tf_dataset_to_torch('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord', shuffle=1_000_000)

In [159]:
class TFIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, filepath, features_filepath=None, batch_size=64, cache=True, shuffle=None):
        super(TFIterableDataset).__init__()

        self.dataset = io.dataset_ops.load_tfrecord(filepath, deserialize=False)

        # cache
        if cache:
            self.dataset = self.dataset.cache()

        if shuffle:
            self.dataset = self.dataset.shuffle(shuffle)

        # deserialize
        if features_filepath is None:
            features_filepath = filepath + '.features.json'
        self.features = io.dataset_ops.features_from_json_file(features_filepath)
        self.dataset = io.dataset_ops.deserialize_dataset(self.dataset, self.features)

        # batch
        self.dataset = self.dataset.batch(batch_size)

        # format dataset
        self.dataset = self.dataset.map(lambda e: (tf.transpose(e['inputs']['input'], perm=[0, 2, 1]), e['outputs']))
        
    def __iter__(self):
        for example in self.dataset.as_numpy_iterator():
            yield tf.nest.map_structure(lambda x: torch.tensor(x).to(torch.float32), example)

dataset = TFIterableDataset('example-data-matrix/windows.chr13.4.data.matrix.filtered.tfrecord', shuffle=1_000_000)

In [161]:
next(iter(dataset))

(tensor([[[0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [1., 1., 1.,  ..., 0., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 1.],
          [0., 0., 0.,  ..., 0., 1., 0.]],
 
         [[1., 1., 1.,  ..., 1., 0., 1.],
          [0., 0., 0.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[1., 0., 1.,  ..., 0., 1., 1.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 1.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 1.],
          [0., 1., 0.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[1., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 1.],
   

In [167]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)

In [168]:
for s in dataloader:
    print(len(s))
    print(s[0].shape)
    print(s)
    break

2
torch.Size([64, 4, 1000])
[tensor([[[1., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 1.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [1., 1., 0.,  ..., 1., 0., 0.]],

        [[1., 1., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        ...,

        [[0., 0., 0.,  ..., 1., 0., 1.],
         [0., 0., 1.,  ..., 0., 1., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 1.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 0.,  ..., 1., 0., 0.]],

        [[0., 1., 0.,  ..., 1., 1., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
    

In [157]:
batch = next(torch_dataset)

In [158]:
print(batch[0].shape)
print(batch[1]['signal']['total'].shape)

torch.Size([64, 4, 1000])
torch.Size([64, 223, 1000])


In [72]:
def example_dataset_generator(n=1000):
    for _ in range(n):
        yield (torch.rand(8, 4, 101, dtype=torch.float32), {'signal': {'total': torch.randint(10, (8, 101, 7)).to(torch.float32)}})

next(iter(example_dataset_generator()))

(tensor([[[0.8432, 0.0831, 0.4262,  ..., 0.0482, 0.0173, 0.9287],
          [0.1853, 0.7028, 0.1899,  ..., 0.9724, 0.4613, 0.1646],
          [0.2861, 0.9356, 0.3750,  ..., 0.1189, 0.9248, 0.2968],
          [0.5126, 0.7694, 0.4358,  ..., 0.6782, 0.0509, 0.2679]],
 
         [[0.8185, 0.8322, 0.4912,  ..., 0.7084, 0.9014, 0.2759],
          [0.8016, 0.5364, 0.0559,  ..., 0.9678, 0.4356, 0.8085],
          [0.7543, 0.9849, 0.1166,  ..., 0.6654, 0.8289, 0.0895],
          [0.7539, 0.3302, 0.0449,  ..., 0.7764, 0.1359, 0.3103]],
 
         [[0.4504, 0.8426, 0.1631,  ..., 0.5507, 0.8572, 0.5975],
          [0.5861, 0.5429, 0.3307,  ..., 0.4199, 0.4708, 0.8016],
          [0.2089, 0.6339, 0.0161,  ..., 0.0156, 0.8499, 0.1226],
          [0.7547, 0.6272, 0.6041,  ..., 0.4308, 0.2640, 0.5104]],
 
         ...,
 
         [[0.6375, 0.2446, 0.0256,  ..., 0.1221, 0.9895, 0.5450],
          [0.8393, 0.1714, 0.8652,  ..., 0.7372, 0.8369, 0.2513],
          [0.5807, 0.2437, 0.5308,  ..., 0.8825, 0.

In [37]:
from tqdm import tqdm

for epoch in range(5):
    print(f'Epoch: {epoch}/5')
    for sample in tqdm(example_dataset_generator(100), total=100):
        _ = net(sample[0])

Epoch: 0/5


100%|██████████| 100/100 [00:01<00:00, 58.97it/s]


Epoch: 1/5


100%|██████████| 100/100 [00:01<00:00, 77.32it/s]


Epoch: 2/5


100%|██████████| 100/100 [00:01<00:00, 76.85it/s]


Epoch: 3/5


100%|██████████| 100/100 [00:01<00:00, 77.65it/s]


Epoch: 4/5


100%|██████████| 100/100 [00:01<00:00, 81.45it/s]


In [79]:
dataset = lambda: example_dataset_generator(100)

In [81]:
for i in dataset():
    print(i)

(tensor([[[0.4751, 0.9533, 0.0266,  ..., 0.6050, 0.2872, 0.3567],
         [0.5962, 0.5101, 0.5655,  ..., 0.8972, 0.8150, 0.7629],
         [0.5315, 0.6408, 0.8128,  ..., 0.5156, 0.2693, 0.4139],
         [0.1764, 0.9361, 0.8698,  ..., 0.6896, 0.8076, 0.6014]],

        [[0.9005, 0.3587, 0.7187,  ..., 0.0264, 0.4518, 0.3712],
         [0.9135, 0.3575, 0.5092,  ..., 0.3477, 0.9583, 0.7306],
         [0.8526, 0.4434, 0.8368,  ..., 0.7812, 0.9579, 0.3542],
         [0.8791, 0.0073, 0.8408,  ..., 0.9809, 0.1059, 0.0236]],

        [[0.5058, 0.4194, 0.6513,  ..., 0.2173, 0.6942, 0.1623],
         [0.6693, 0.2856, 0.2852,  ..., 0.5943, 0.7653, 0.9163],
         [0.4995, 0.5578, 0.8971,  ..., 0.3253, 0.0456, 0.3702],
         [0.1815, 0.7807, 0.5370,  ..., 0.7297, 0.4280, 0.7986]],

        ...,

        [[0.5855, 0.3257, 0.8845,  ..., 0.9049, 0.6775, 0.3236],
         [0.9162, 0.0987, 0.6347,  ..., 0.3225, 0.9180, 0.7494],
         [0.2932, 0.8945, 0.1478,  ..., 0.1152, 0.5237, 0.8453],
    

In [82]:
import tqdm

test_net = Network(tasks=list(range(7)))
test_net

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = MultinomialNLLLossFromLogits()

def train(net, dataset, epochs=2):
    for epoch in tqdm.trange(epochs):
        epoch_running_loss = 0.0
        print(f'Epoch {epoch}')
        for sample in dataset():
            x, y = sample

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            y_pred = net(x)
            loss = criterion(y['signal']['total'], y_pred)
            loss.backward()
            optimizer.step()
            
            # add to running loss
            epoch_running_loss += loss.item()
        print(f'Loss {epoch_running_loss}')

train(test_net, lambda: example_dataset_generator(100), epochs=3)


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 0


 33%|███▎      | 1/3 [00:03<00:07,  3.73s/it]

Loss 11511124.8671875
Epoch 1


 67%|██████▋   | 2/3 [00:07<00:03,  3.55s/it]

Loss 11502714.40625
Epoch 2


100%|██████████| 3/3 [00:10<00:00,  3.55s/it]

Loss 11502822.1171875





In [5]:
print(next(iter(torch_dataset))[0].shape)
print(next(iter(torch_dataset))[1].keys())

InvalidArgumentError: {{function_node __wrapped__IteratorGetNext_output_types_5_device_/job:localhost/replica:0/task:0/device:CPU:0}} Cannot add tensor to the batch: number of elements does not match. Shapes are: [tensor]: [223,641], [batch]: [223,1000] [Op:IteratorGetNext]

In [48]:
total_lengths = []
for e in io.load_tfrecord('example-data-matrix/windows.chr15.4.data.matrix.filtered.tfrecord'):
    # total_lengths.append(int(e['outputs']['signal']['total'].shape[1]))
    total_lengths.append(int(e['inputs']['input'].shape[0]))

In [49]:
set(total_lengths)

{1000}

In [93]:
res = torch.softmax(y_pred, dim=-2)
print(res[0][:,0].shape)
print(torch.sum(res[0][:,3]))

torch.Size([201])
tensor(1.0000, grad_fn=<SumBackward0>)


In [47]:
ex_pred = torch.rand(2, 201, 128)
print(ex_pred.shape)

embed = torch.rand(128, 223)
print(embed.shape)

print(torch.unsqueeze(embed, dim=0).shape)

torch.Size([2, 201, 128])
torch.Size([128, 223])
torch.Size([1, 128, 223])


In [48]:
torch.matmul(ex_pred, embed).shape

torch.Size([2, 201, 223])

In [32]:
torch.mul(torch.rand(2, 3, 4), torch.rand(3, 4))

tensor([[[0.1262, 0.1029, 0.6163, 0.6493],
         [0.1689, 0.0365, 0.2958, 0.1062],
         [0.3376, 0.0367, 0.6692, 0.0428]],

        [[0.0301, 0.8762, 0.1694, 0.0185],
         [0.1460, 0.4245, 0.0510, 0.2159],
         [0.3140, 0.1952, 0.1188, 0.3553]]])

In [54]:
import torch
import torchmetrics

In [66]:
corr = torchmetrics.PearsonCorrCoef(num_outputs=4)
corr(torch.rand(101, 4), torch.rand(101, 4))

tensor([-0.0724,  0.2136, -0.1253,  0.0681])

In [67]:
def transparent_corr(x, y):
    print(x.shape, y.shape)
    return corr(x, y)

In [68]:
import functorch

vmap_corr = functorch.vmap(transparent_corr, in_dims=0, out_dims=0)
vmap_corr(torch.rand(2, 101, 4), torch.rand(2, 101, 4))

torch.Size([101, 4]) torch.Size([101, 4])


RuntimeError: output with shape [4] doesn't match the broadcast shape [2, 4]

In [90]:
multinomial = torch.distributions.Multinomial(total_count=42, logits=torch.tensor([2, 3.2, 5, 1.9]))
nll = -multinomial.log_prob(torch.tensor([7, 8, 20, 7]))
nll

tensor(38.4663)

In [97]:
from torch.distributions import Multinomial

In [94]:
y, y_pred = torch.randint(0, 10, size=(4, 42, 7)), torch.rand(4, 42, 7)

In [143]:
manual_nll = []
for i in range(y.shape[0]):
    for j in range(y.shape[2]):
        single_y, single_y_pred = y[i, :, j], y_pred[i, :, j]
        # print(Multinomial(total_count=torch.sum(single_y), logits=single_y_pred))
        manual_nll.append(-Multinomial(int(torch.sum(single_y)), logits=single_y_pred).log_prob(single_y))
true_nll = torch.mean(torch.tensor(manual_nll))
true_nll

tensor(114.9822)

In [170]:
class MultinomialNLLLossFromLogits(nn.Module):
    def __init__(self, reduction=torch.mean):
        super(MultinomialNLLLossFromLogits, self).__init__()
        self.reduction = reduction
    
    def __call__(self, y, y_pred, dim=-1):
        neg_log_probs = self.log_likelihood_from_logits(y, y_pred, dim) * -1
        if self.reduction is not None:
            return self.reduction(neg_log_probs)
        return neg_log_probs

    def log_likelihood_from_logits(self, y, y_pred, dim):
        return torch.sum(torch.mul(torch.log_softmax(y_pred, dim=dim), y), dim=dim) + self.log_combinations(y, dim)

    def log_combinations(self, input, dim):
        total_permutations = torch.lgamma(torch.sum(input, dim=dim) + 1)
        counts_factorial = torch.lgamma(input + 1)
        redundant_permutations = torch.sum(counts_factorial, dim=dim)
        return total_permutations - redundant_permutations

print(y.shape)
print(y_pred.shape)

nll_loss = MultinomialNLLLossFromLogits(reduction=torch.mean)
nll = nll_loss(y, y_pred, dim=-2)
nll

tensor(114.9822)

In [147]:
assert bool(true_nll == nll)

In [130]:
Multinomial(total_count=int(torch.sum(single_y)), logits=single_y_pred).log_prob(single_y)

tensor(-125.4803)

In [137]:
nll_loss(single_y, single_y_pred)

tensor(125.4803)

In [107]:
torch.sum(single_y)

tensor(190)