In [1]:
import torch
from typing import Union, Tuple

def sliding_window(data: torch.Tensor, window_size: Union[int, Tuple[int, ...]], stride: Union[int, Tuple[int, ...]] = 1, padding: str = 'valid') -> torch.Tensor:
    """
    Split a tensor into smaller tensors using a sliding window approach.

    Args:
        data (torch.Tensor): The input tensor of any dimension.
        window_size (int or tuple): The size of the sliding window. If an integer, the same size is used for all dimensions.
                                     If a tuple, it specifies the window size for each dimension.
        stride (int or tuple, optional): The stride of the sliding window. If an integer, the same stride is used for all dimensions.
                                          If a tuple, it specifies the stride for each dimension. Default is 1.
        padding (str, optional): The padding mode. Can be 'valid' (no padding) or 'same' (pad to maintain the original shape).
                                 Default is 'valid'.

    Returns:
        torch.Tensor: The output tensor containing the extracted subsequences.
    """
    # Convert window_size and stride to tuples if they are integers
    if isinstance(window_size, int):
        window_size = (window_size,) * data.ndim
    if isinstance(stride, int):
        stride = (stride,) * data.ndim

    # Compute the output shape based on the input shape, window size, stride, and padding
    output_shape = tuple((data.size(i) - window_size[i]) // stride[i] + 1 for i in range(data.ndim))
    output_shape += window_size

    # Create a view of the input tensor with the sliding window dimensions
    view_shape = tuple(output_shape[:data.ndim]) + tuple(window_size)
    strides = tuple(stride[i] * data.stride(i) for i in range(data.ndim)) + data.stride()
    strided_data = data.as_strided(size=view_shape, stride=strides)

    # Apply padding if specified
    if padding == 'same':
        pad_width = tuple((0, (output_shape[i] - 1) * stride[i] + window_size[i] - data.size(i)) for i in range(data.ndim))
        strided_data = torch.nn.functional.pad(strided_data, pad_width, mode='constant')

    return strided_data

# Example usage
data = torch.randint(0, 10, size=(5, 6, 7))
window_size = (2, 3, 4)
stride = (1, 2, 2)  
subsequences = sliding_window(data, window_size, stride, padding='valid')

print(f"Input tensor shape: {data.shape}")
print(f"Output tensor shape: {subsequences.shape}")


Input tensor shape: torch.Size([5, 6, 7])
Output tensor shape: torch.Size([4, 2, 2, 2, 3, 4])


: 