In [19]:
import torch
import torch.nn as nn

# Video: https://www.youtube.com/watch?v=zVDDITt4XEA&t=181s

# Create a sample input tensor with shape (batch_size, channels, height, width)
# For demonstration, we'll use a 4x4 image with 1 channel and 1 batch.
x = torch.arange(1 * 1 * 4 * 4, dtype=torch.float32).view(1, 1, 4, 4)
print("Input tensor:")
print(x)

# Define the Unfold operation.
# Here, kernel_size determines the patch size (2x2 patches in this example),
# stride sets the step size between patches, and padding/dilation are as usual.
unfold = nn.Unfold(kernel_size=2, stride=2, padding=0)

# Apply Unfold to extract patches.
# The output will be of shape (batch_size, C * kernel_height * kernel_width, L)
# where L is the number of sliding windows.
patches: torch.Tensor = unfold(x)
patches = patches.permute(0, 2, 1)
print("\nUnfolded tensor shape:", patches.shape)
print("Unfolded tensor:")
print(patches)
print(patches[0][0], patches[0][1])

Input tensor:
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])

Unfolded tensor shape: torch.Size([1, 4, 4])
Unfolded tensor:
tensor([[[ 0.,  1.,  4.,  5.],
         [ 2.,  3.,  6.,  7.],
         [ 8.,  9., 12., 13.],
         [10., 11., 14., 15.]]])
tensor([0., 1., 4., 5.]) tensor([2., 3., 6., 7.])
