In [1]:
import torch
#from srm import SpatialRearrangementUnit 
#from srm_kevin import SpatialRearrangementUnit
from srm import SpatialRearrangementUnit
from srm import WindowPartitioningUnit
from srm import SpatialProjectionUnit
from srm import WindowMergingUnit
from srm import SpatialRearrangementRestorationUnit


# Set print formatting for readability
torch.set_printoptions(linewidth=200)

# Test case 1: 4x4 input with window_size=2
dummy_input = torch.arange(1, 17).reshape(1, 1, 4, 4).float()
print("Test Case 1 - 4x4 input, window_size=2")
print("Input:")
print(dummy_input)

srm = SpatialRearrangementUnit(window_size=2)
output = srm(dummy_input)
print("\nOutput:")
print(output)

wp = WindowPartitioningUnit(window_size=2)
output_wp = wp(output)
print("\nOutput Window Partition:")
print(output_wp)
print("\nShape of Window Partition Output:", output_wp.shape)

sp = SpatialProjectionUnit(window_size=2,in_channels=1)
output_sp = sp(output_wp)
print("\nOutput Spatial Projection:")
print(output_sp)
print("\nShape of Spatial Projection Output:", output_wp.shape)

wm = WindowMergingUnit(window_size=2,original_height=4, original_width=4)
output_wm = wm(output_sp)
print("\nOutput Window merging:")
print(output_wm)

restorer = SpatialRearrangementRestorationUnit(window_size=2)  # Only pass window_size here
output_srr = restorer(output_wm) 
print("\nOutput spatial rearrangment restoration:")
print(output_srr)

# Test case 2: 16x16 input with window_size=4
dummy_input_large = torch.arange(1, 257).reshape(1, 1, 16, 16).float()
print("\nTest Case 2 - 16x16 input, window_size=4")
print("Input shape:", dummy_input_large.shape)

srm_large = SpatialRearrangementUnit(window_size=4)
output_large = srm_large(dummy_input_large)
print("Output shape:", output_large.shape)

wp = WindowPartitioningUnit(window_size=4)
output_wp = wp(output_large)
print("\nOutput Window Partition:")
print(output_wp)
print("\nShape of Window Partition Output:", output_wp.shape)

sp = SpatialProjectionUnit(window_size=4,in_channels=1)
output_sp = sp(output_wp)
print("\nOutput Spatial Projection:")
print(output_sp)
print("\nShape of Spatial Projection Output:", output_wp.shape)

wm = WindowMergingUnit(window_size=4,original_height=16, original_width=16)
output_wm = wm(output_sp)
print("\nOutput Window merging:")
print(output_wm)

restorer = SpatialRearrangementRestorationUnit(window_size=2)  # Only pass window_size here
output_srr = restorer(output_wm) 
print("\nOutput spatial rearrangment restoration:")
print(output_srr)


# Print middle 8x8 squares
print("\nMiddle 8x8 of input:")
print(dummy_input_large[0, 0, 4:12, 4:12])
print("\nMiddle 8x8 of output:")
print(output_large[0, 0, 4:12, 4:12])

print("all input")
print(dummy_input_large)
print("all output")
print(output_large)

Test Case 1 - 4x4 input, window_size=2
Input:
tensor([[[[ 1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.],
          [ 9., 10., 11., 12.],
          [13., 14., 15., 16.]]]])

Output:
tensor([[[[ 1.,  3.,  2.,  4.],
          [ 9., 11., 10., 12.],
          [ 5.,  7.,  6.,  8.],
          [13., 15., 14., 16.]]]])

Output Window Partition:
tensor([[[[ 1.,  3.],
          [ 9., 11.]]],


        [[[ 2.,  4.],
          [10., 12.]]],


        [[[ 5.,  7.],
          [13., 15.]]],


        [[[ 6.,  8.],
          [14., 16.]]]])

Shape of Window Partition Output: torch.Size([4, 1, 2, 2])

Output Spatial Projection:
tensor([[[[-1.0264,  1.0168],
          [ 4.6176,  7.4693]]],


        [[[-1.4199,  1.2270],
          [ 4.7987,  7.9604]]],


        [[[-2.6006,  1.8575],
          [ 5.3419,  9.4336]]],


        [[[-2.9942,  2.0677],
          [ 5.5230,  9.9247]]]], grad_fn=<ViewBackward0>)

Shape of Spatial Projection Output: torch.Size([4, 1, 2, 2])

Output Window merging:
tensor([[[[-

In [None]:
import os
import matplotlib.pyplot as plt
import torchvision.transforms as T

# Load Input Blurry Image
load_dir = './Datasets/train/GoPro/input_crops'
name = '000001-1'
img_path = os.path.join(load_dir, name + '.png')

# Load full RGB image (H, W, 3)
input_img = plt.imread(img_path)
print("Loaded shape:", input_img.shape)  # Should be (256, 256, 3)

# Convert to tensor: (1, 3, H, W)
transform = T.Compose([
    T.ToTensor(),  # Converts to [0, 1], shape (C, H, W)
])

input_tensor = transform(input_img).unsqueeze(0)  # → shape: (1, 3, 256, 256)

print("Tensor shape for SRM:", input_tensor.shape)

In [None]:
from srm import SpatialRearrangementUnit

srm = SpatialRearrangementUnit(window_size=4)
output = srm(input_tensor)

print("SRM Output Shape:", output.shape)