In [25]:
import numpy as np
import torch

from einops import rearrange, repeat
from torch import nn, einsum

# Swin Transformer

## Patch Merging

If I have an image `(3, 16, 16)`, and I designate patch size to be `(3, 4, 4)`. Then I should have 16 patches and each patch contains `3*4*4=48` values.

In [2]:
patch_size = 4
unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
linear = nn.Linear(3 * patch_size ** 2, 32)

x = torch.rand(1, 3, 16, 16)
print('Image', x.shape)
x = unfold(x)
print('Unfolded', x.shape)
x = x.view(1, -1, patch_size, patch_size)
print('View as image patches', x.shape)
x = x.permute(0, 2, 3, 1)
print('Move patch values to last axis', x.shape)
y = linear(x)
print('Final output', y.shape)

Image torch.Size([1, 3, 16, 16])
Unfolded torch.Size([1, 48, 16])
View as image patches torch.Size([1, 48, 4, 4])
Move patch values to last axis torch.Size([1, 4, 4, 48])
Final output torch.Size([1, 4, 4, 32])


## Functional Blocks

There are some interesting functional approaches to structure resdiual and layer norm.

In [3]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

Now I can compose residual with any other inner block. This is pretty neat.

In [4]:
bottleneck = nn.Sequential(
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 128),
    nn.ReLU(),
)
res_block = Residual(bottleneck)

x = torch.rand(1, 4, 4, 128)
y = res_block(x)
y.shape

torch.Size([1, 4, 4, 128])

Extend the same functional concept to `LayerNorm`.

In [5]:
class PreNorm(nn.Module):
    def __init__(self, embed_dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [6]:
bottleneck = nn.Sequential(
    nn.Linear(128, 32),
    nn.ReLU(),
    nn.Linear(32, 32),
    nn.ReLU(),
    nn.Linear(32, 128),
    nn.ReLU(),
)
norm_block = PreNorm(128, bottleneck)

x = torch.rand(1, 4, 4, 128)
y = norm_block(x)
y.shape

torch.Size([1, 4, 4, 128])

## Relative Position Embedding

Each window contains `(M, M)` patches. The `M` is the window size. Now each patch needs to learn a relative position embedding. `M**2` is the number of patches in each window.

If we don't use relative position, then the position embedding is a matrix of `(M**2, M**2)`. It's every position to every position.

If we use relative position, then relative position along each axis lies in the range of `[-M + 1, M - 1]`, i.e. if `M = 4`, then we have `[-3, -2, -1, 0, 1, 2, 3]` for each axis.

In [7]:
window_size = 4
indices = np.array([[x, y] for x in range(window_size) for y in range(window_size)])
indices = torch.tensor(indices)
print(indices.shape)
print(indices)
distances = indices[None, :, :] - indices[:, None, :]
print(distances.shape)

torch.Size([16, 2])
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 3],
        [1, 0],
        [1, 1],
        [1, 2],
        [1, 3],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 3],
        [3, 0],
        [3, 1],
        [3, 2],
        [3, 3]])
torch.Size([16, 16, 2])


We have 16 to 16 positions, the distances cache the offset between `positions[i]` to `positions[j]`.

In [8]:
distances[1][0]

tensor([ 0, -1])

In [9]:
distances[0][1]

tensor([0, 1])

This will return all the `i` offsets for all 16 positions.

In [10]:
distances[:, :, 0]

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0,  1,  1,  1,  1],
        [-3, -3, -3, -3, -2, -2, -2, -2, -1, -1, -1, -1,  0,  0,  0,  0],
        [-3, -3, -3, -3, -2, -2, -2, -

This will return all the `j` offsets for all 16 positions.

In [11]:
distances[:, :, 1]

tensor([[ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2, -1,  0,  1,  2],
        [-2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1, -2, -1,  0,  1],
        [-3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0, -3, -2, -1,  0],
        [ 0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3,  0,  1,  2,  3],
        [-1,  0,  1,  2, -1,  0,  1,  

In [12]:
pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
pos_embedding.shape

torch.Size([7, 7])

In [13]:
pos_embedding[distances[:, :, 0], distances[:, :, 1]].shape

torch.Size([16, 16])

In [14]:
pos_embedding[distances[:, :, 0], distances[:, :, 1]]

tensor([[ 0.8387,  1.0895,  0.3368,  0.2537, -1.8590,  1.5533, -1.0229,  0.9228,
         -1.3652, -3.2351,  0.1753,  0.3024, -0.6617, -0.2192,  0.3134,  1.2471],
        [ 0.0322,  0.8387,  1.0895,  0.3368,  0.3960, -1.8590,  1.5533, -1.0229,
          0.0525, -1.3652, -3.2351,  0.1753,  1.7376, -0.6617, -0.2192,  0.3134],
        [-0.1150,  0.0322,  0.8387,  1.0895,  1.0195,  0.3960, -1.8590,  1.5533,
          0.8233,  0.0525, -1.3652, -3.2351,  0.6204,  1.7376, -0.6617, -0.2192],
        [-0.0605, -0.1150,  0.0322,  0.8387, -0.8598,  1.0195,  0.3960, -1.8590,
          1.9468,  0.8233,  0.0525, -1.3652, -0.3867,  0.6204,  1.7376, -0.6617],
        [ 1.1581, -0.2017,  1.1784,  1.4307,  0.8387,  1.0895,  0.3368,  0.2537,
         -1.8590,  1.5533, -1.0229,  0.9228, -1.3652, -3.2351,  0.1753,  0.3024],
        [-1.0957,  1.1581, -0.2017,  1.1784,  0.0322,  0.8387,  1.0895,  0.3368,
          0.3960, -1.8590,  1.5533, -1.0229,  0.0525, -1.3652, -3.2351,  0.1753],
        [ 1.0194, -1.0

Since we defined position embedding to be a parameter, these values will be learned and updated.

## Cyclic Shift

Suppose the input has 8 by 8 patches and each patch has embedding dimension 32, let's create 4 windows. When we apply cyclic shift, the displacement will shift the element by `(i, j)` amount. In the following example, the shift pushes every element by `(2, 2)`.

In [31]:
x = torch.rand(1, 8, 8, 32)

# Make it more readable, assign 1...64 to the first element of each embedding.
x[:, :, :, 0] = torch.arange(1, 65).view(8, 8)

window_size = 4
displacement = window_size // 2
print(x[0, :, :, 0])
rolled_x = torch.roll(x, shifts=(displacement, displacement), dims=(1, 2))
print(rolled_x[0, :, :, 0])

rolled_x.shape

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12., 13., 14., 15., 16.],
        [17., 18., 19., 20., 21., 22., 23., 24.],
        [25., 26., 27., 28., 29., 30., 31., 32.],
        [33., 34., 35., 36., 37., 38., 39., 40.],
        [41., 42., 43., 44., 45., 46., 47., 48.],
        [49., 50., 51., 52., 53., 54., 55., 56.],
        [57., 58., 59., 60., 61., 62., 63., 64.]])
tensor([[55., 56., 49., 50., 51., 52., 53., 54.],
        [63., 64., 57., 58., 59., 60., 61., 62.],
        [ 7.,  8.,  1.,  2.,  3.,  4.,  5.,  6.],
        [15., 16.,  9., 10., 11., 12., 13., 14.],
        [23., 24., 17., 18., 19., 20., 21., 22.],
        [31., 32., 25., 26., 27., 28., 29., 30.],
        [39., 40., 33., 34., 35., 36., 37., 38.],
        [47., 48., 41., 42., 43., 44., 45., 46.]])


torch.Size([1, 8, 8, 32])

## Masking for Attention

The input will be re-arrange into windows. Within each window, we have `(4, 4)` patches with window size 4. The mask is applied after position embedding.

In [23]:
window_size = 4
displacement = window_size // 2

upper_lower_mask = torch.zeros(window_size**2, window_size**2)
print(upper_lower_mask.shape) # Same shape as relative position embedding.

upper_lower_mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
upper_lower_mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

upper_lower_mask

torch.Size([16, 16])


tensor([[0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., 0., 0., 0., 0.],
        [-

In [32]:
window_size = 4
displacement = window_size // 2

left_right_mask = torch.zeros(window_size**2, window_size**2)
print(left_right_mask.shape) # Same shape as relative position embedding.

left_right_mask = rearrange(left_right_mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
print(left_right_mask.shape)
left_right_mask[:, -displacement:, :, :-displacement] = float('-inf')
left_right_mask[:, :-displacement, :, -displacement:] = float('-inf')
left_right_mask = rearrange(left_right_mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
print(left_right_mask.shape)

left_right_mask

torch.Size([16, 16])
torch.Size([4, 4, 4, 4])
torch.Size([16, 16])


tensor([[0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf],
        [-inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0., -inf, -inf, 0., 0.],
        [-

## Window Attention

After the window partition, my input is a batch tensor with patches and embedding dimensions.

In [61]:
B = 4
H = 16 # unit of patches
W = 16 # unit of patches
embed_dim = 64
head_dim = 128
num_heads = 8
window_size = 4 # unit of patches

x = torch.rand(B, H, W, embed_dim)
to_qkv = nn.Linear(embed_dim, 3 * num_heads * head_dim)
to_out = nn.Linear(num_heads * head_dim, embed_dim)

qkv = to_qkv(x).chunk(3, dim=-1)
num_win_h = H // window_size # Number of windows along height axis
num_win_w = W // window_size # Number of windows along width axis

Each query, key, value tensor is of shape `(batch_size, num_heads, num_windows, num_patches, head_dim)`

In [57]:
q, k, v = map(
    lambda t: rearrange(t, 'b (num_win_h win_h) (num_win_w win_w) (h d) -> b h (num_win_h num_win_w) (win_h win_w) d',
                        h=num_heads,
                        win_h=window_size,
                        win_w=window_size), qkv)

print('Q', q.shape)
print('K', k.shape)
print('V', v.shape)

Q torch.Size([4, 8, 16, 16, 128])
K torch.Size([4, 8, 16, 16, 128])
V torch.Size([4, 8, 16, 16, 128])


The attention score will be computed with dot product of `Q` and `K`.

In [58]:
dots = einsum('b h w i d, b h w j d -> b h w i j', q, k)
dots.shape

torch.Size([4, 8, 16, 16, 16])

Positional embedding and masking will be added to the dot product and then perform softmax. I will skip it here.

In [59]:
attn = dots.softmax(dim=-1)
print('Softmax', attn.shape)

out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
print('Another matrix product', out.shape)

out = rearrange(out, 'b h (num_win_h num_win_w) (win_h win_w) d -> b (num_win_h win_h) (num_win_w win_w) (h d)',
                h=num_heads,
                win_h=window_size,
                win_w=window_size,
                num_win_h=num_win_h,
                num_win_w=num_win_w)
print('Rearranged back to patch format', out.shape)

Softmax torch.Size([4, 8, 16, 16, 16])
Another matrix product torch.Size([4, 8, 16, 16, 128])
Rearranged back to patch format torch.Size([4, 16, 16, 1024])


In [63]:
to_out(out).shape

torch.Size([4, 16, 16, 64])

Since the input has 16 patches, the attention score is 16 to 16 self-attention.