In [1]:
import torch

In [5]:
def window_partition(x, window_size):
    """
    Args:
        x: (b, h, w, c)
        window_size (int): window size

    Returns:
        windows: (num_windows*b, window_size, window_size, c)
    """
    b, h, w, c = x.shape
    x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
    windows = (
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
    )
    return windows

def window_reverse(windows, window_size, h, w):
    """
    Args:
        windows: (num_windows*b, window_size, window_size, c)
        window_size (int): Window size
        h (int): Height of image
        w (int): Width of image

    Returns:
        x: (b, h, w, c)
    """
    b = int(windows.shape[0] / (h * w / window_size / window_size))
    x = windows.view(
        b, h // window_size, w // window_size, window_size, window_size, -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
    return x

In [12]:
class test_mask():
    def __init__(self) -> None:
        self.window_size = 4
        self.shift_size = self.window_size // 2
        pass
    
    def calculate_mask(self, x_size):
        # calculate attention mask for SW-MSA
        h, w = x_size
        img_mask = torch.zeros((1, h, w, 1))  # 1 h w 1
        h_slices = (
            slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None),
        )
        print(f"h_slices : {h_slices}")
        w_slices = (
            slice(0, -self.window_size),
            slice(-self.window_size, -self.shift_size),
            slice(-self.shift_size, None),
        )
        print(f"w_slices : {w_slices}")
        
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                print(f"{h}-{w}")
                img_mask[:, h, w, :] = cnt
                cnt += 1
        print(f'image_mask : {img_mask}')
        mask_windows = window_partition(
            img_mask, self.window_size
        )  # nw, window_size, window_size, 1
        print(f'image_mask : {img_mask}')
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
            attn_mask == 0, float(0.0)
        )

In [4]:
batch_size = 1
num_channel = 3
input_resolution = (8, 8)
image = torch.rand(batch_size, num_channel, input_resolution[0], input_resolution[1])
image.shape

torch.Size([1, 3, 8, 8])

In [13]:
mask = test_mask()
mask.calculate_mask(input_resolution)

h_slices : (slice(0, -4, None), slice(-4, -2, None), slice(-2, None, None))
w_slices : (slice(0, -4, None), slice(-4, -2, None), slice(-2, None, None))
slice(0, -4, None)-slice(0, -4, None)
slice(0, -4, None)-slice(-4, -2, None)
slice(0, -4, None)-slice(-2, None, None)
slice(-4, -2, None)-slice(0, -4, None)
slice(-4, -2, None)-slice(-4, -2, None)
slice(-4, -2, None)-slice(-2, None, None)
slice(-2, None, None)-slice(0, -4, None)
slice(-2, None, None)-slice(-4, -2, None)
slice(-2, None, None)-slice(-2, None, None)
image_mask : tensor([[[[0.],
          [0.],
          [0.],
          [0.],
          [1.],
          [1.],
          [2.],
          [2.]],

         [[0.],
          [0.],
          [0.],
          [0.],
          [1.],
          [1.],
          [2.],
          [2.]],

         [[0.],
          [0.],
          [0.],
          [0.],
          [1.],
          [1.],
          [2.],
          [2.]],

         [[0.],
          [0.],
          [0.],
          [0.],
          [1.],