In [6]:
from torchvision import transforms
from PIL import Image

In [7]:
path_to_image = "../results/motion/if_stage_II.png"
image = Image.open(path_to_image)
tensor_image = transforms.ToTensor()(image)

In [8]:
import numpy as np
import torch
import torch.nn.functional as F

In [9]:
def motion_blur_factorization(x: torch.Tensor, blur_length: int = 7):
    """
    Factorizes image(s) into motion blurred and residual components using a diagonal motion blur kernel.

    Args:
        x (torch.Tensor): Input image tensor of shape (C, H, W) or (B, C, H, W).
        blur_length (int): Size of the square diagonal blur kernel. Default is 7.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: (motion_blurred, residual), both same shape as input.
    """

    # Ensure batch dimension
    if x.dim() == 3:
        x = x.unsqueeze(0)  # Shape: (1, C, H, W)

    B, C, H, W = x.shape

    # Create a diagonal blur kernel
    kernel = torch.zeros(1, 1, blur_length, blur_length, device=x.device)
    for i in range(blur_length):
        kernel[0, 0, i, i] = 1.0 / blur_length  # normalized diagonal

    # Expand for depthwise convolution: (C, 1, kH, kW)
    kernel = kernel.expand(C, 1, blur_length, blur_length)

    # Apply depthwise convolution
    motion_blurred = F.conv2d(x, kernel, padding="same", groups=C)

    residual = x - motion_blurred

    return motion_blurred, residual

In [10]:
blur_ii = motion_blur_factorization(tensor_image)[0]
arr_m = (blur_ii[0].cpu().permute(1, 2, 0).numpy() * 255).round().astype(np.uint8)
Image.fromarray(arr_m).save("if_stage_II_motion.png")