In [26]:
import torch
from einops import rearrange

# Example tensor with shape (b=1, c=1, h=2, w=2)
input_tensor = torch.arange(1, 73).reshape(1, 1, 12, 6)
patch_height = 3
patch_width = 2

output_tensor = rearrange(input_tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)


In [27]:
input_tensor

tensor([[[[ 1,  2,  3,  4,  5,  6],
          [ 7,  8,  9, 10, 11, 12],
          [13, 14, 15, 16, 17, 18],
          [19, 20, 21, 22, 23, 24],
          [25, 26, 27, 28, 29, 30],
          [31, 32, 33, 34, 35, 36],
          [37, 38, 39, 40, 41, 42],
          [43, 44, 45, 46, 47, 48],
          [49, 50, 51, 52, 53, 54],
          [55, 56, 57, 58, 59, 60],
          [61, 62, 63, 64, 65, 66],
          [67, 68, 69, 70, 71, 72]]]])

In [28]:
output_tensor

tensor([[[ 1,  2,  7,  8, 13, 14],
         [ 3,  4,  9, 10, 15, 16],
         [ 5,  6, 11, 12, 17, 18],
         [19, 20, 25, 26, 31, 32],
         [21, 22, 27, 28, 33, 34],
         [23, 24, 29, 30, 35, 36],
         [37, 38, 43, 44, 49, 50],
         [39, 40, 45, 46, 51, 52],
         [41, 42, 47, 48, 53, 54],
         [55, 56, 61, 62, 67, 68],
         [57, 58, 63, 64, 69, 70],
         [59, 60, 65, 66, 71, 72]]])

What einops rearange does is it changes the input so to be of shape num patches.    
Each patch is of shape patch width * patch height. Patches are arranged so that
patches nearby in time are next to each other.      
So every block of T/patch_height patches
should be able to attend to each other. 

In [29]:
def make_blockwise_mask(num_patches, patch_width, num_features, device='cpu'):
    """
    Creates a (num_patches, num_patches) boolean mask where each non-overlapping group of N patches
    can attend only to each other.

    Args:
        num_patches (int): total number of tokens (patches)
        N (int): number of patches per block (must divide num_patches evenly)
        device (str or torch.device): where to create the mask

    Returns:
        mask (torch.Tensor): boolean tensor of shape (num_patches, num_patches)
                             with True where attention is allowed
    """
    assert num_patches % N == 0, "num_patches must be divisible by N"
    
    block_id = torch.arange(num_patches, device=device) // N  # Shape: (num_patches,)
    mask = block_id.unsqueeze(0) == block_id.unsqueeze(1)     # Shape: (num_patches, num_patches)
    
    return mask  # dtype: bool

mask = make_blockwise_mask(num_patches=9, N=3)

print(mask.int())

tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1]], dtype=torch.int32)
