In [1]:
import torch

In [3]:
data = torch.load('data/test-sample.pt')

In [4]:
for k, v in data.items():
    print(k, v.dtype, v.size())

sequence torch.float32 torch.Size([131072, 4])
target torch.float32 torch.Size([896, 5313])


In [14]:
2**17 / 2**7

1024.0

In [11]:
tf_gammas = torch.load('enformer_pytorch/precomputed/tf_gammas.pt')

In [12]:
tf_gammas.shape

torch.Size([3071, 32])

In [21]:
x = torch.ones((10, 100, 20))

In [22]:
from torch import nn
from einops.layers.torch import Rearrange


class AttentionPool(nn.Module):
    """
    This class implements the attention pooling mechanism for the Enformer model.
    """
    def __init__(self, dim, pool_size=2):
        super().__init__()
        self.pool_size = pool_size
        self.pool_fn = Rearrange("b d (n p) -> b d n p", p=pool_size)

        self.to_attn_logits = nn.Conv2d(dim, dim, 1, bias=False)

        nn.init.dirac_(self.to_attn_logits.weight)

        with torch.no_grad():
            self.to_attn_logits.weight.mul_(2)

    def forward(self, x):
        """
        x: torch.Tensor, shape (batch, dim, seq_len)

        Returns:
        - torch.Tensor, shape (batch, dim, seq_len)
        """
        b, _, n = x.shape
        remainder = n % self.pool_size
        needs_padding = remainder > 0

        if needs_padding:
            x = F.pad(x, (0, remainder), value=0)
            mask = torch.zeros((b, 1, n), dtype=torch.bool, device=x.device)
            mask = F.pad(mask, (0, remainder), value=True)

        x = self.pool_fn(x)
        logits = self.to_attn_logits(x)

        if needs_padding:
            mask_value = -torch.finfo(logits.dtype).max
            logits = logits.masked_fill(self.pool_fn(mask), mask_value)

        attn = logits.softmax(dim=-1)

        return (x * attn).sum(dim=-1)


In [23]:
pool = AttentionPool(dim=100)

In [26]:
x.shape, pool(x).shape

(torch.Size([10, 100, 20]), torch.Size([10, 100, 10]))

In [28]:
import math


def exponential_linspace_int(start, end, num, divisible_by=1):
    def _round(x):
        return int(round(x / divisible_by) * divisible_by)

    base = math.exp(math.log(end / start) / (num - 1))
    return [_round(start * base**i) for i in range(num)]


[768, 896, 1024, 1152, 1280, 1536]