In [38]:
import torch
from typing import Optional, Union
import torch

import torch

def create_temporal_mask(
    seq_len: int,
    look_back: int = -1,   # -1 => unconstrained past
    look_ahead: int = 0,
    device=None,           # torch.device, str, or None
) -> torch.Tensor:
    """
    Build a boolean mask of shape [1, 1, seq_len, seq_len] where
    mask[..., t, k] == True iff k is within [t - look_back, t + look_ahead].

    Args:
        seq_len:     sequence length T.
        look_back:   how many past positions each timestep can see.
                     If < 0, it's treated as unlimited (T-1).
        look_ahead:  how many future positions each timestep can see.
        device:      torch device or string (e.g., 'cuda'), or None.

    Returns:
        Boolean tensor of shape [1, 1, T, T].
    """
    if look_back < 0:
        look_back = seq_len - 1  # effectively unlimited past

    i = torch.arange(seq_len, device=device).unsqueeze(1)  # [T, 1] (query idx)
    j = torch.arange(seq_len, device=device).unsqueeze(0)  # [1, T] (key idx)

    dist = j - i  # [T, T]
    mask = (dist >= -look_back) & (dist <= look_ahead)  # bool [T, T]

    return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]


In [39]:
mask

tensor([[[[ True, False, False, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]]],
       device='cuda:0')