In [38]:
import torch
from einops import rearrange,einsum
from jaxtyping import Int

In [39]:
positions = rearrange(torch.arange(25),"(batch seq) -> batch seq",batch=5)
positions

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])

In [40]:
cols = positions.repeat(1,5)
cols

tensor([[ 0,  1,  2,  3,  4,  0,  1,  2,  3,  4,  0,  1,  2,  3,  4,  0,  1,  2,
          3,  4,  0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9,  5,  6,  7,  8,  9,  5,  6,  7,  8,  9,  5,  6,  7,
          8,  9,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 10, 11, 12, 13, 14, 10, 11, 12, 13, 14, 10, 11, 12,
         13, 14, 10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19, 15, 16, 17, 18, 19, 15, 16, 17, 18, 19, 15, 16, 17,
         18, 19, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 20, 21, 22, 23, 24, 20, 21, 22, 23, 24, 20, 21, 22,
         23, 24, 20, 21, 22, 23, 24]])

In [41]:
rows = rearrange(positions.unsqueeze(-1).repeat(1,1,5),"... seq1 seq2 -> ... (seq1 seq2)",seq2=5)
rows

tensor([[ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,
          3,  3,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  8,  8,  8,
          8,  8,  9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 14, 14, 14],
        [15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 18, 18, 18,
         18, 18, 19, 19, 19, 19, 19],
        [20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 23, 23, 23,
         23, 23, 24, 24, 24, 24, 24]])

In [42]:
combined =  torch.stack([rows.unsqueeze(-1), cols.unsqueeze(-1)],dim=-1).squeeze(-2)
combined.shape

torch.Size([5, 25, 2])

In [43]:
mask = rearrange(combined,"... (seq1 seq2) coords -> ... seq1 seq2 coords",seq2=5)
mask = mask[...,0] >= mask[...,1]
mask

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, Fa

In [None]:
def batch_cartesian_prod(
    a: Int[torch.Tensor, "... seq"], b: Int[torch.Tensor, "... seq"]
):
    seq = a.shape[-1]
    cols = b.repeat(*[1 for _ in range(len(b.shape) - 1)], seq)
    rows = a.unsqueeze(-1)
    rows = rearrange(
        rows.repeat(*[1 for _ in range(len(rows.shape) - 1)], seq),
        "... seq1 seq2 -> ... (seq1 seq2)",
        seq2=seq,
    )
    combined = torch.stack([rows.unsqueeze(-1), cols.unsqueeze(-1)], dim=-1).squeeze(-2)
    mask = rearrange(combined, "... (seq1 seq2) coords -> ... seq1 seq2 coords", seq2=seq)
    mask = mask[..., 0] >= mask[..., 1]
    return mask

In [45]:
batch_cartesian_prod(positions,positions)

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, Fa