In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

In [2]:
import numpy as np
import torch

from einops import rearrange, repeat
from torch import nn, einsum

In [3]:
from swin.model import PatchMerging, CyclicShift, WindowAttention

# Functional Blocks

There are some interesting functional approaches to structure resdiual and layer norm.

In [4]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

Now I can compose residual with any other inner block. This is pretty neat.

In [5]:
bottleneck = nn.Sequential(
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 128),
    nn.ReLU(),
)
res_block = Residual(bottleneck)

x = torch.rand(1, 4, 4, 128)
y = res_block(x)
y.shape

torch.Size([1, 4, 4, 128])

Extend the same functional concept to `LayerNorm`.

In [6]:
class PreNorm(nn.Module):
    def __init__(self, embed_dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [7]:
norm_block = PreNorm(128, nn.Sequential(
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 128),
    nn.ReLU(),
))

x = torch.rand(1, 4, 4, 128)
y = norm_block(x)
y.shape

torch.Size([1, 4, 4, 128])

# Patch Merging

If I have an image `(3, 16, 16)`, and I designate patch size to be `(3, 4, 4)`. Then I should have 16 patches and each patch contains 48 values from `3x4x4`.

The `Unfold` module extracts sliding local blocks from a batched input tensor.

In [8]:
patch_size = 4
out_dim = 32
chan_dim = 3

unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
linear = nn.Linear(chan_dim * patch_size * patch_size, out_dim)

x = torch.rand(1, chan_dim, 16, 16)
print('Image', x.shape)
x = unfold(x)
print('Unfolded', x.shape, 'contains 16 patches with each has 4x4x3 dimensions')
x = x.view(1, -1, patch_size, patch_size)
print('View as image patches', x.shape)
x = x.permute(0, 2, 3, 1)
print('Move patch values to last axis', x.shape)
y = linear(x)
print('Final output', y.shape)

Image torch.Size([1, 3, 16, 16])
Unfolded torch.Size([1, 48, 16]) contains 16 patches with each has 4x4x3 dimensions
View as image patches torch.Size([1, 48, 4, 4])
Move patch values to last axis torch.Size([1, 4, 4, 48])
Final output torch.Size([1, 4, 4, 32])


In [9]:
x = torch.rand(1, 3, 16, 16)
y = PatchMerging(chan_dim, out_dim, downscaling_factor=4)(x)
y.shape

torch.Size([1, 4, 4, 32])

# Window Attention

## Cyclic Shift

Suppose the input has 8 by 8 patches and each patch has embedding dimension 32, let's create 4 windows. When we apply cyclic shift, the displacement will shift the element by `(i, j)` amount. In the following example, the shift pushes every element by `(2, 2)`.

In [10]:
x = torch.rand(1, 8, 8, 32)

# Make it more readable, assign 1...64 to the first element of each embedding.
x[:, :, :, 0] = torch.arange(1, 65).view(8, 8)

window_size = 4
displacement = window_size // 2

print(x[0, :, :, 0])
shifted_x = torch.roll(x, shifts=(displacement, displacement), dims=(1, 2))
print(shifted_x[0, :, :, 0])
shifted_x.shape

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12., 13., 14., 15., 16.],
        [17., 18., 19., 20., 21., 22., 23., 24.],
        [25., 26., 27., 28., 29., 30., 31., 32.],
        [33., 34., 35., 36., 37., 38., 39., 40.],
        [41., 42., 43., 44., 45., 46., 47., 48.],
        [49., 50., 51., 52., 53., 54., 55., 56.],
        [57., 58., 59., 60., 61., 62., 63., 64.]])
tensor([[55., 56., 49., 50., 51., 52., 53., 54.],
        [63., 64., 57., 58., 59., 60., 61., 62.],
        [ 7.,  8.,  1.,  2.,  3.,  4.,  5.,  6.],
        [15., 16.,  9., 10., 11., 12., 13., 14.],
        [23., 24., 17., 18., 19., 20., 21., 22.],
        [31., 32., 25., 26., 27., 28., 29., 30.],
        [39., 40., 33., 34., 35., 36., 37., 38.],
        [47., 48., 41., 42., 43., 44., 45., 46.]])


torch.Size([1, 8, 8, 32])

In [11]:
x = torch.rand(1, 8, 8, 32)
x[:, :, :, 0] = torch.arange(1, 65).view(8, 8)

print(x[0, :, :, 0])
shifted_x = CyclicShift(displacement)(x)
print(shifted_x[0, :, :, 0])

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12., 13., 14., 15., 16.],
        [17., 18., 19., 20., 21., 22., 23., 24.],
        [25., 26., 27., 28., 29., 30., 31., 32.],
        [33., 34., 35., 36., 37., 38., 39., 40.],
        [41., 42., 43., 44., 45., 46., 47., 48.],
        [49., 50., 51., 52., 53., 54., 55., 56.],
        [57., 58., 59., 60., 61., 62., 63., 64.]])
tensor([[55., 56., 49., 50., 51., 52., 53., 54.],
        [63., 64., 57., 58., 59., 60., 61., 62.],
        [ 7.,  8.,  1.,  2.,  3.,  4.,  5.,  6.],
        [15., 16.,  9., 10., 11., 12., 13., 14.],
        [23., 24., 17., 18., 19., 20., 21., 22.],
        [31., 32., 25., 26., 27., 28., 29., 30.],
        [39., 40., 33., 34., 35., 36., 37., 38.],
        [47., 48., 41., 42., 43., 44., 45., 46.]])


## Relative Position Embedding

Each window contains `(M, M)` patches. The `M` is the window size. Now each patch needs to learn a relative position embedding. `M**2` is the number of patches in each window.

If we don't use relative position, then the position embedding is a matrix of `(M**2, M**2)`. It's every position to every position.

If we use relative position, then relative position along each axis lies in the range of `[-M + 1, M - 1]`, i.e. if `M = 4`, then we have `[-3, -2, -1, 0, 1, 2, 3]` for each axis.

In [12]:
window_size = 4
indices = np.array([[x, y] for x in range(window_size) for y in range(window_size)])
indices = torch.tensor(indices)
print(indices.shape)
print(indices)
distances = indices[None, :, :] - indices[:, None, :]
print(distances.shape)

torch.Size([16, 2])
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3]])
torch.Size([16, 16, 2])


We have 16 to 16 positions, the distances cache the offset between `positions[i]` to `positions[j]`.

In [13]:
distances[1][0]

tensor([ 0, -1])

In [14]:
distances[0][1]

tensor([0, 1])

This will return all the `i` offsets for all 16 positions.

In [15]:
distances[:, :, 0]

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-3, -3, -3, -3, -2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0],
        [-3, -3, -3, -3, -2, -2, -2, -

This will return all the `j` offsets for all 16 positions.

In [16]:
distances[:, :, 1]

tensor([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  

In [17]:
pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
pos_embedding.shape

torch.Size([7, 7])

In [18]:
pos_embedding[distances[:, :, 0], distances[:, :, 1]].shape

torch.Size([16, 16])

In [19]:
pos_embedding[distances[:, :, 0], distances[:, :, 1]]

tensor([[-0.7837, -0.5602, -1.3962, -0.1333,  1.1046, -0.1810,  0.4121,  0.3423,
         -0.7281,  0.1334, -1.5607, -0.5599, -1.7256,  1.0132,  0.3025,  0.6772],
        [ 1.0940, -0.7837, -0.5602, -1.3962,  0.3335,  1.1046, -0.1810,  0.4121,
         -0.6759, -0.7281,  0.1334, -1.5607, -2.1741, -1.7256,  1.0132,  0.3025],
        [-0.1669,  1.0940, -0.7837, -0.5602, -0.1543,  0.3335,  1.1046, -0.1810,
         -0.0026, -0.6759, -0.7281,  0.1334, -0.1207, -2.1741, -1.7256,  1.0132],
        [ 1.2826, -0.1669,  1.0940, -0.7837,  0.0882, -0.1543,  0.3335,  1.1046,
         -0.2839, -0.0026, -0.6759, -0.7281,  0.0742, -0.1207, -2.1741, -1.7256],
        [-0.1886,  0.0454,  0.5713, -0.0612, -0.7837, -0.5602, -1.3962, -0.1333,
          1.1046, -0.1810,  0.4121,  0.3423, -0.7281,  0.1334, -1.5607, -0.5599],
        [ 1.0942, -0.1886,  0.0454,  0.5713,  1.0940, -0.7837, -0.5602, -1.3962,
          0.3335,  1.1046, -0.1810,  0.4121, -0.6759, -0.7281,  0.1334, -1.5607],
        [-0.3393,  1.0

Since we defined position embedding to be a parameter, these values will be learned and updated.

## Local Window Masking for Attention

The input will be re-arrange into windows. Within each window, we have `(4, 4)` patches with window size 4. The mask is applied after position embedding.

In [20]:
window_size = 4
displacement = window_size // 2

upper_lower_mask = torch.zeros(window_size**2, window_size**2)
print(upper_lower_mask.shape) # Same shape as relative position embedding.

upper_lower_mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
upper_lower_mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

upper_lower_mask

torch.Size([16, 16])


tensor([[0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-

In [21]:
window_size = 4
displacement = window_size // 2

left_right_mask = torch.zeros(window_size**2, window_size**2)
print(left_right_mask.shape) # Same shape as relative position embedding.

left_right_mask = rearrange(left_right_mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
print(left_right_mask.shape)
left_right_mask[:, -displacement:, :, :-displacement] = float('-inf')
left_right_mask[:, :-displacement, :, -displacement:] = float('-inf')
left_right_mask = rearrange(left_right_mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
print(left_right_mask.shape)

left_right_mask

torch.Size([16, 16])
torch.Size([4, 4, 4, 4])
torch.Size([16, 16])


tensor([[0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-

## Window Attention

After the window partition, my input is a batch tensor with patches and embedding dimensions.

In [22]:
B = 4
H = 16 # unit of patches
W = 16 # unit of patches
embed_dim = 64
head_dim = 128
num_heads = 8
window_size = 4 # unit of patches

x = torch.rand(B, H, W, embed_dim)
to_qkv = nn.Linear(embed_dim, 3 * num_heads * head_dim)
to_out = nn.Linear(num_heads * head_dim, embed_dim)

In [23]:
qkv = to_qkv(x).chunk(3, dim=-1)
num_win_h = H // window_size # Number of windows along height axis
num_win_w = W // window_size # Number of windows along width axis

Each query, key, value tensor is of shape `(batch_size, num_heads, num_windows, num_patches, head_dim)`. In the example below, there are

- Each sample has 8 heads
- Each head has 16 windows, because (16,16) patches can be divided into 16 (4, 4) windows.
- Each window has 16 patches
- Each patch has 128 head dimension for computing attention score

In [24]:
q, k, v = map(
    lambda t: rearrange(t, 'b (num_win_h win_h) (num_win_w win_w) (h d) -> b h (num_win_h num_win_w) (win_h win_w) d',
                        h=num_heads,
                        win_h=window_size,
                        win_w=window_size), qkv)

print('Q', q.shape)
print('K', k.shape)
print('V', v.shape)

Q torch.Size([4, 8, 16, 16, 128])
K torch.Size([4, 8, 16, 16, 128])
V torch.Size([4, 8, 16, 16, 128])


The attention score will be computed with dot product of `Q` and `K`.

In [25]:
dots = einsum('b h w i d, b h w j d -> b h w i j', q, k)
dots.shape

torch.Size([4, 8, 16, 16, 16])

Positional embedding and masking will be added to the dot product and then perform softmax. I will skip it here.

In [26]:
attn = dots.softmax(dim=-1)
print('Softmax', attn.shape)

out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
print('Another matrix product', out.shape)

out = rearrange(out, 'b h (num_win_h num_win_w) (win_h win_w) d -> b (num_win_h win_h) (num_win_w win_w) (h d)',
                h=num_heads,
                win_h=window_size,
                win_w=window_size,
                num_win_h=num_win_h,
                num_win_w=num_win_w)
print('Rearranged back to patch format', out.shape)

Softmax torch.Size([4, 8, 16, 16, 16])
Another matrix product torch.Size([4, 8, 16, 16, 128])
Rearranged back to patch format torch.Size([4, 16, 16, 1024])


In [27]:
to_out(out).shape

torch.Size([4, 16, 16, 64])

Since the input has 16 patches, the attention score is 16 to 16 self-attention.

## Shifted Window Attention

Same as above but I will apply a cyclic shift before computing attention score.

In [28]:
B = 4
H = 16 # unit of patches
W = 16 # unit of patches
embed_dim = 64
head_dim = 128
num_heads = 8
window_size = 4 # unit of patches

In [29]:
x = torch.rand(B, H, W, embed_dim)
shift_forward = CyclicShift(-window_size // 2)
shift_backward = CyclicShift(window_size // 2)

to_qkv = nn.Linear(embed_dim, 3 * num_heads * head_dim)
to_out = nn.Linear(num_heads * head_dim, embed_dim)

In [30]:
shifted_x = shift_forward(x)
qkv = to_qkv(shifted_x).chunk(3, dim=-1)
num_win_h = H // window_size # Number of windows along height axis
num_win_w = W // window_size # Number of windows along width axis
q, k, v = map(
    lambda t: rearrange(t, 'b (num_win_h win_h) (num_win_w win_w) (h d) -> b h (num_win_h num_win_w) (win_h win_w) d',
                        h=num_heads,
                        win_h=window_size,
                        win_w=window_size), qkv)
print('Q', q.shape)
print('K', k.shape)
print('V', v.shape)

Q torch.Size([4, 8, 16, 16, 128])
K torch.Size([4, 8, 16, 16, 128])
V torch.Size([4, 8, 16, 16, 128])


The tensor is structured as
- Batch size: 4
- Number of heads: 8
- Number of windows: 16,
- Number of patches per window: 16
- Head dimension: 128

In [31]:
q_dot_k = einsum('b h w i d, b h w j d -> b h w i j', q, k)
print('Dot Product', q_dot_k.shape)
print(torch.round(q_dot_k[0, 0, 0, :, :], decimals=2))

print("Apply mask", torch.round(upper_lower_mask))
q_dot_k[:, :, -num_win_w:] += upper_lower_mask
print("Results", torch.round(q_dot_k[0, 0, -1, :, :], decimals=2))

print("Apply mask", torch.round(left_right_mask))
q_dot_k[:, :, num_win_w - 1::num_win_w] += left_right_mask
print("Results", torch.round(q_dot_k[0, 0, -1, :, :], decimals=2))

softmax_qk = q_dot_k.softmax(dim=-1)
attn = einsum("b h w i j, b h w j d -> b h w i d", softmax_qk, v)
attn = rearrange(
    attn,
    "b h (num_win_h num_win_w) (win_h win_w) d -> b (num_win_h win_h) (num_win_w win_w) (h d)",
    h=num_heads,
    win_h=window_size,
    win_w=window_size,
    num_win_h=num_win_h,
    num_win_w=num_win_w,
)
attn = shift_backward(attn)
print('Final attention', attn.shape)

Dot Product torch.Size([4, 8, 16, 16, 16])
tensor([[ 0.7100,  0.8800,  1.0900,  0.5600,  1.8900,  1.7700,  1.0700,  0.6800,
          1.3700, -0.2300,  0.6400,  1.7000,  1.4900,  0.4600,  0.9400,  0.6600],
        [ 0.6400,  0.5700,  0.4400,  0.2000,  1.3600,  1.4600,  1.1700,  0.7900,
          1.1700,  0.0700, -0.2100,  1.1500,  1.5100,  0.9400,  1.3700,  1.3700],
        [ 1.9200,  0.9300,  1.7700,  0.7800,  2.0600,  2.1400,  2.3500,  1.4700,
          1.5600,  0.9100,  0.6800,  2.2600,  2.0600,  1.6800,  2.0800,  2.1200],
        [ 0.1100,  0.2800, -0.1200, -0.4800,  0.9900,  0.8100,  0.5700, -0.2700,
          0.9100, -0.8200, -0.4000,  0.7100,  1.1400,  0.3500,  0.6800,  0.5900],
        [ 0.4600,  0.7800,  0.2900,  0.1200,  1.6800,  1.4300,  0.4400,  0.6100,
          0.8400, -0.4200,  0.1200,  0.9300,  1.2500, -0.0900,  0.8400,  0.6900],
        [ 1.3000,  1.2900,  1.1500,  0.7600,  1.8600,  1.7900,  1.9400,  1.2700,
          1.4100,  0.4400,  0.3200,  1.9200,  2.0300,  1.5700

The final attention will organize the data back into patches. Each patch has an attention score that is comprised of `(num_of_heads, head_dim)` which is `8*128=1024`. Then it goes through a final linear mapping to reduce dimension.

In [32]:
out = to_out(attn)
out.shape

torch.Size([4, 16, 16, 64])

In [33]:
x = torch.rand(B, H, W, embed_dim)
attn = WindowAttention(embed_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, shifted=True, window_size=window_size, relative_pos_embedding=True)(x)
print(attn.shape)

torch.Size([4, 16, 16, 64])
