In [1]:
import einops
import torch
from torch import Tensor
from torch.nested import as_nested_tensor, nested_tensor, to_padded_tensor
from torch.nn import Identity, Linear, Module, Parameter, Sequential
from torch.nn.functional import scaled_dot_product_attention as sdpa


class ColumnarAttention(Module):
    def __init__(
        self, hidden_dim: int, num_heads: int, dropout: float = 0.0, columns: int = 128
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.columns = columns
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.head_dim = hidden_dim // num_heads

        self.q = Parameter(
            torch.randn(
                columns,
                hidden_dim,
            )
            * 0.02
        )
        self.k_proj = Linear(hidden_dim, hidden_dim)
        self.v_proj = Linear(hidden_dim, hidden_dim)
        self.out_proj = Linear(hidden_dim, hidden_dim)

        self.apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def prepare_kv(self, x: Tensor, seq_lens: Tensor, is_k: bool) -> Tensor:
        x = self.k_proj(x) if is_k else self.v_proj(x)
        x = einops.rearrange(
            x,
            "1 (batch seq) (heads headdim) -> (batch seq) heads headdim",
            heads=self.num_heads,
            batch=len(seq_lens),
        )
        splits = x.split(seq_lens.tolist(), dim=0)
        nest = as_nested_tensor(list(splits))
        nest = nest.transpose(1, 2).contiguous()  # batch x heads x seq x headdim
        return nest

    def forward(self, x, doc_ids: Tensor) -> Tensor:
        batch, seq, dim = x.shape
        assert batch == 1, "Batch size must be 1 for packed sequences"
        assert doc_ids.shape == (batch, seq)

        seq_lens = torch.bincount(doc_ids.flatten())

        k = self.prepare_kv(x, seq_lens, is_k=True)
        v = self.prepare_kv(x, seq_lens, is_k=False)

        queries = einops.rearrange(
            einops.repeat(
                self.q,
                "seq dim -> batch seq dim",
                batch=len(seq_lens),
            ),
            "batch seq (heads headdim) -> batch seq heads headdim",
            heads=self.num_heads,
            batch=len(seq_lens),
        )
        nested_queries = as_nested_tensor(queries)
        nested_queries = nested_queries.transpose(
            1, 2
        ).contiguous()  # batch x heads x seq x headdim

        out = sdpa(query=nested_queries, key=k, value=v, dropout_p=self.dropout)

        # note that the padded tensor here is just to make it so it is easier to work downstream,
        # because all queries have the same length, there won't be any actual padding
        out = to_padded_tensor(
            out.contiguous(), padding=0.0
        )  # batch x heads x seq x headdim
        out = einops.rearrange(
            out,
            "batch heads seq headdim -> batch seq (heads headdim)",
        )
        out = self.out_proj(out)
        return out


class AntiColumnarAttention(Module):
    def __init__(self, hidden_dim: int, num_heads: int, dropout: float = 0.0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.dropout = dropout
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
        self.head_dim = hidden_dim // num_heads

        self.q_proj = Linear(hidden_dim, hidden_dim)
        self.k_proj = Linear(hidden_dim, hidden_dim)
        self.v_proj = Linear(hidden_dim, hidden_dim)
        self.out_proj = Linear(hidden_dim, hidden_dim)

        self.apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def prepare_q(self, x: Tensor, seq_lens: Tensor) -> Tensor:
        x = self.q_proj(x)
        x = einops.rearrange(
            x,
            "1 (batch seq) (heads headdim) -> (batch seq) heads headdim",
            heads=self.num_heads,
            batch=len(seq_lens),
        )
        splits = x.split(seq_lens.tolist(), dim=0)
        nest = as_nested_tensor(list(splits))
        nest = nest.transpose(1, 2).contiguous()  # batch x heads x seq x headdim
        return nest

    def prepare_kv(self, x: Tensor, is_k: bool) -> Tensor:
        x = self.k_proj(x) if is_k else self.v_proj(x)
        # x is (batch, fixed_seq, dim)
        x = einops.rearrange(
            x,
            "batch seq (heads headdim) -> batch seq heads headdim",
            heads=self.num_heads,
        )
        # Convert to nested tensor to match q structure
        nest = as_nested_tensor(list(x.unbind(0)))
        nest = nest.transpose(1, 2).contiguous()  # batch x heads x seq x headdim
        return nest

    def forward(self, x: Tensor, context: Tensor, doc_ids: Tensor) -> Tensor:
        batch, seq, dim = x.shape
        assert batch == 1, "Batch size must be 1 for packed sequences"
        assert doc_ids.shape == (batch, seq)
        seq_lens = torch.bincount(doc_ids.flatten())

        q = self.prepare_q(x, seq_lens)
        k = self.prepare_kv(context, is_k=True)
        v = self.prepare_kv(context, is_k=False)

        out = sdpa(query=q, key=k, value=v, dropout_p=self.dropout)

        # out is batch x heads x seq x headdim (nested)
        out = out.transpose(1, 2)  # batch x seq x heads x headdim
        out = torch.cat(out.unbind(0), dim=0)  # total_seq x heads x headdim
        out = einops.rearrange(
            out, "total_seq heads headdim -> 1 total_seq (heads headdim)"
        )

        out = self.out_proj(out)
        return out


class PillarMan(Module):
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        columns: int = 128,
        middle: Module | None = None,
    ):
        super().__init__()
        self.columnar_attn = ColumnarAttention(hidden_dim, num_heads, dropout, columns)
        self.anti_columnar_attn = AntiColumnarAttention(hidden_dim, num_heads, dropout)
        self.middle = middle if middle is not None else Identity()

        self.residual = Sequential(
            Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            Linear(hidden_dim, hidden_dim),
        )

        self.apply(self._init_weight)

    def _init_weight(self, module):
        if isinstance(module, Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, original_x: Tensor, seq_idx: Tensor) -> Tensor:
        res = self.residual(original_x)
        x = self.columnar_attn(original_x, seq_idx)
        x = self.middle(x)
        x = self.anti_columnar_attn(res, x, seq_idx)
        return x


In [2]:
device = "cuda"

In [3]:
hidden_size = 16
num_heads = 2
cols = 6
pillarman = PillarMan(
    hidden_dim=hidden_size,
    num_heads=num_heads,
    dropout=0.1,
    columns=cols,
)
pillarman.to(device)

PillarMan(
  (columnar_attn): ColumnarAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (anti_columnar_attn): AntiColumnarAttention(
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (middle): Identity()
  (residual): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
  )
)

In [4]:
sequences = [1,2,3]
xs = [torch.randn(1, seq, hidden_size) for seq in sequences]
xs = torch.cat(xs, dim=1)
xs.requires_grad_(True)
index = 0
doc_ids = []
for seq in sequences:
    doc_ids.extend([index] * seq)
    index += 1
doc_ids = torch.tensor([doc_ids])
doc_ids = doc_ids.to(device)
xs = xs.to(device)

In [5]:
res = pillarman(xs, doc_ids)

  return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)
  out = sdpa(query=nested_queries, key=k, value=v, dropout_p=self.dropout)
  out = sdpa(query=nested_queries, key=k, value=v, dropout_p=self.dropout)


In [6]:
grad = torch.autograd.grad(res.sum(), xs, retain_graph=True)[0]

In [7]:
new_xs = xs - 0.1 * grad

In [8]:
some_target = torch.randn_like(xs)

In [9]:
l = (some_target - xs).pow(2).mean()

In [10]:
l.backward()