In [1]:
import torch
from einops import rearrange

# Example tensor with shape (b=1, c=1, h=2, w=2)
input_tensor = torch.arange(1, (12*8)+1).reshape(1, 1, 12, 8)
patch_height = 3
patch_width = 4

output_tensor = rearrange(input_tensor, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def sliding_chunks(x, chunk_size=32, stride=4):
    """
    x: Tensor of shape (B, T, C)
    Returns: Tensor of shape (B, M, chunk_size, C)
    """
    B, T, C = x.shape

    # Unfold the time dimension (dim=1) using torch.nn.functional.unfold logic
    x = x.unfold(dimension=1, size=chunk_size, step=stride)  # (B, M, chunk_size, C)
    return x

In [5]:

# B=1, T=12, C=1 -> we’ll fill time dimension with ascending numbers
T = 12
x = torch.arange(T).view(1, T, 1).float()

# Apply sliding
chunks = sliding_chunks(x, chunk_size=4, stride=2)

print("Input:")
print(x.squeeze(-1))  # just to see time steps clearly

print("\nSliding chunks (B x M x chunk_size x C):")
print(chunks.squeeze(0).squeeze(-1))  # Remove batch and channel dims for display

Input:
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]])

Sliding chunks (B x M x chunk_size x C):
tensor([[[ 0.,  1.,  2.,  3.]],

        [[ 2.,  3.,  4.,  5.]],

        [[ 4.,  5.,  6.,  7.]],

        [[ 6.,  7.,  8.,  9.]],

        [[ 8.,  9., 10., 11.]]])


In [4]:
100/4

25.0

What einops rearange does is it changes the input so to be of shape num patches.    
Each patch is of shape patch width * patch height. Patches are arranged so that
patches nearby in time are next to each other.      
So every block of T/patch_height patches
should be able to attend to each other. 

In [31]:
def make_blockwise_mask(num_patches, patch_width, num_features, device='cpu'):
    """
    Creates a (num_patches, num_patches) boolean mask where each non-overlapping group of N patches
    can attend only to each other.

    Args:
        num_patches (int): total number of tokens (patches)
        N (int): number of patches per block (must divide num_patches evenly)
        device (str or torch.device): where to create the mask

    Returns:
        mask (torch.Tensor): boolean tensor of shape (num_patches, num_patches)
                             with True where attention is allowed
    """
    
    N = num_features / patch_width
    assert num_patches % N == 0, "num_patches must be divisible by N"
    
    block_id = torch.arange(num_patches, device=device) // N  # Shape: (num_patches,)
    mask = block_id.unsqueeze(0) == block_id.unsqueeze(1)     # Shape: (num_patches, num_patches)
    
    return mask  # dtype: bool

mask = make_blockwise_mask(num_patches=9, num_features=6, patch_width=2)

print(mask.int())

tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1]], dtype=torch.int32)
