In [37]:
import torch
import torch.nn as nn
from torch import Tensor

class Canon(nn.Module):
    """
    "Canon", ie modulated delay, layer as proposed in phys lm 4.1 @darknoon
    Implemented as a causal conv that only attends to past tokens.
    We implement it as a conv1d. this doesn't account for document boundaries yet :/
    """
    def __init__(self, dim: int, kernel_size: int = 4):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size-1)

    def init_weights(self):
        nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.conv.bias)

    def forward(self, x: Tensor):
        # Transpose for conv1d which expects [B, C, T]
        x_conv = x.transpose(1, 2)
        # Apply causal convolution (only keep the valid part to ensure causality)
        conv_out = self.conv(x_conv)
        conv_out = conv_out[:, :, :x_conv.size(2)]
        # Transpose back to [B, T, C]
        conv_out = conv_out.transpose(1, 2)
        return conv_out


In [42]:
eos = 999
c = 3
tokens = torch.tensor([1,2,3,4,5,eos,7,8,9,10,eos,11,12,13])
doc_ids = (tokens == eos).cumsum(0)
print(doc_ids)
t = len(tokens)
b = 1
layer = Canon(c, kernel_size=3)

with torch.no_grad():
    seq = torch.randn(b, t, c)
    result = layer(seq)
    print(result.shape)
    print(result)

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2])
torch.Size([1, 16, 3])
tensor([[[-0.1507, -0.6324, -0.2129],
         [ 0.2614,  0.7770,  0.8635],
         [-0.5542, -0.3305, -0.3751],
         [-0.4898,  0.1173,  0.3476],
         [ 0.1862, -1.2137, -0.0988],
         [ 0.0308, -0.9513,  0.2161],
         [ 0.2579,  1.6207,  1.1352],
         [ 0.4274, -1.3634, -0.8486],
         [ 0.4619, -0.3607,  0.2520],
         [-0.5408,  0.1816,  0.5895],
         [ 0.5442, -0.3975,  0.1204],
         [ 0.3375, -0.2826,  0.1290],
         [ 0.9161, -0.0607,  0.0474],
         [ 0.8632, -0.4486, -0.4880],
         [ 0.3471, -0.8270, -0.3924],
         [ 0.2075, -0.5210,  0.1839]]])


In [60]:
kv_index = torch.arange(t).unsqueeze(0).expand(b, t)
q_index = torch.arange(t).unsqueeze(0).expand(b, t).t()

causal_mask = (q_index >= kv_index) * 1
print(causal_mask)

doc_mask = (doc_ids[q_index] == doc_ids[kv_index]) * 1
document_causal = causal_mask & doc_mask

print(document_causal)

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0,