### swin transformer paper (https://arxiv.org/pdf/2103.14030.pdf)

In [12]:
import torch
import timm
model = timm.create_model("swin_base_patch4_window7_224")
print('the number of model parameters: {:,}'.format(sum([p.data.nelement() for p in model.parameters()])))
print(model.absolute_pos_embed)

the number of model parameters: 87,768,224
None


In [4]:
model = timm.create_model("swin_base_patch4_window7_224")

inputs = torch.randn((2,3,224,224))
# print(model(inputs).shape)

inputs = model.patch_embed(inputs)
print(inputs.shape) # (B, 3136, 128)

inputs = model.layers[0].blocks[0](inputs)
print(inputs.shape) # (B, 3136, 128)

torch.Size([2, 3136, 128])
torch.Size([2, 3136, 128])


In [5]:
# (56, 56)  <==  actual fed image size {patch size = 4  -->  inputs image H, W = 224  -->  224/4 = 56}
print(model.patch_embed.grid_size)
print()

# The input image resolution is gradually decreasing!!!!
for i in range(len(model.layers)):
    print(model.layers[i].input_resolution)

(56, 56)

(56, 56)
(28, 28)
(14, 14)
(7, 7)


In [6]:
for i, layer in enumerate(model.layers):
    if i+1 < len(model.layers):
        print("downsample",layer.downsample.input_resolution)
    print(len(layer.blocks))
    print("-"*10)

downsample (56, 56)
2
----------
downsample (28, 28)
2
----------
downsample (14, 14)
18
----------
2
----------


In [37]:
vit = timm.create_model("vit_base_patch16_224")
print('the number of model parameters: {:,}'.format(sum([p.data.nelement() for p in vit.parameters()])))

the number of model parameters: 86,567,656


In [13]:
x = torch.Tensor([
    [[1,2,3,4],
     [6,7,8,9],
     [11,12,13,14],
     [16,17,18,19]]
])

print(x)
print("-"*10)
x1 = x[:, 0::2, 0::2]
print(x1)
print("-"*10)
x2 = x[:, 1::2, 0::2]
print(x2)
print("-"*10)
x3 = x[:, 0::2, 1::2]
print(x3)
print("-"*10)
x4 = x[:, 1::2, 1::2]
print(x4)
print("-"*10)



tensor([[[ 1.,  2.,  3.,  4.],
         [ 6.,  7.,  8.,  9.],
         [11., 12., 13., 14.],
         [16., 17., 18., 19.]]])
----------
tensor([[[ 1.,  3.],
         [11., 13.]]])
----------
tensor([[[ 6.,  8.],
         [16., 18.]]])
----------
tensor([[[ 2.,  4.],
         [12., 14.]]])
----------
tensor([[[ 7.,  9.],
         [17., 19.]]])
----------


In [30]:
import torch
import swin_transformer
import warnings
warnings.filterwarnings("ignore")

model_kwargs = dict(
        patch_size=4, 
        window_size=7,
        embed_dim=128,
        depths=(2, 2, 18, 2),
        num_heads=(4, 8, 16, 32)
    )
model = swin_transformer.SwinTransformer(**model_kwargs)
print('the number of model parameters: {:,}'.format(sum([p.data.nelement() for p in model.parameters()])))
x = torch.randn((2,3,224,224))
print(model(x).shape)

the number of model parameters: 87,768,224
torch.Size([2, 1000])


In [26]:
print(model)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate=none)
            (

In [31]:
patch_embed = model.patch_embed
x = patch_embed(x)
print(patch_embed)
print(x.shape) # 56x56 = 3136

PatchEmbed(
  (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
torch.Size([2, 3136, 128])


In [1]:
# x = model.layers[0](x)
# print(x.shape)
# x = model.layers[0].blocks
import torch
import swin_transformer
import warnings
warnings.filterwarnings("ignore")

model_kwargs = dict(
        patch_size=4, 
        window_size=7,
        embed_dim=128,
        depths=(2, 2, 18, 2),
        num_heads=(4, 8, 16, 32)
    )
model = swin_transformer.SwinTransformer(**model_kwargs)
x = torch.randn((2,3,224,224))
x = model.patch_embed(x)
print(x.shape)
print("shift_size:",model.layers[0].blocks[0].shift_size) # SwinTransformerBlock 1
print("attn_mask:",type(model.layers[0].blocks[0].attn_mask)) # SwinTransformerBlock 1
print(model.layers[0].blocks[0](x).shape) # SwinTransformerBlock 1
print("-----------")
print("shift_size:",model.layers[0].blocks[1].shift_size) # SwinTransformerBlock 2
print("attn_mask:",type(model.layers[0].blocks[1].attn_mask)) # SwinTransformerBlock 2
print("attn_mask:",model.layers[0].blocks[1].attn_mask.shape) # SwinTransformerBlock 2
print(model.layers[0].blocks[1](x).shape) # SwinTransformerBlock 2
print("-----------")
x = model.layers[0].downsample(x)
print(x.shape)
print("###############################")
print("shift_size:",model.layers[1].blocks[0].shift_size) # SwinTransformerBlock 2
print("attn_mask:",type(model.layers[1].blocks[0].attn_mask)) # SwinTransformerBlock 2
print(model.layers[1].blocks[0](x).shape) # SwinTransformerBlock 2

--------------------------------------------------
H, W: 56, 56
torch.Size([64, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 28, 28
torch.Size([16, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 14, 14
torch.Size([4, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 14, 14
torch.Size([4, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 14, 14
torch.Size([4, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 14, 14
torch.Size([4, 49, 49])
--------------------------------------------------
--------------------------------------------------
H, W: 14, 14
torch.Size([4, 49, 49])
--------------------------------------------------
-------------------------

In [13]:
for i in range(model.layers[0].blocks[1].attn_mask.shape[0]):
    if model.layers[0].blocks[1].attn_mask[i].sum() < 0.:
       break
print(i)
print(model.layers[0].blocks[1].attn_mask[6])
print(model.layers[0].blocks[1].attn_mask[7])
print(model.layers[0].blocks[1].attn_mask[8])
print(model.layers[0].blocks[1].attn_mask[9])
print(model.layers[0].blocks[1].attn_mask[62])
print(model.layers[0].blocks[1].attn_mask[63])

7
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[   0.,    0.,    0.,  ..., -100., -100., -100.],
        [   0.,    0.,    0.,  ..., -100., -100., -100.],
        [   0.,    0.,    0.,  ..., -100., -100., -100.],
        ...,
        [-100., -100., -100.,  ...,    0.,    0.,    0.],
        [-100., -100., -100.,  ...,    0.,    0.,    0.],
        [-100., -100., -100.,  ...,    0.,    0.,    0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        

In [4]:
print(model.layers[0].blocks[0].shift_size)
print(model.layers[0].blocks[1].shift_size)
print()
print(model.layers[2].blocks[0].shift_size)
print(model.layers[2].blocks[1].shift_size)
print(model.layers[2].blocks[2].shift_size)
print(model.layers[2].blocks[3].shift_size)
print(model.layers[2].blocks[4].shift_size)

0
3

0
3
0
3
0


In [17]:
import torch

x = torch.randn((3,32,32))
q,k,v = x.unbind(0)
print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([32, 32])
torch.Size([32, 32])
torch.Size([32, 32])


In [12]:
x = (slice(0,1), slice(3,2), slice(1,5,None))
print(x)

(slice(0, 1, None), slice(3, 2, None), slice(1, 5, None))


In [13]:
for h in (slice(0,1), slice(3,2), slice(1,5,2)):

    print(h)

slice(0, 1, None)
slice(3, 2, None)
slice(1, 5, 2)


In [57]:
from swin_transformer import window_partition
window_size = 7
shift_size = 3
x = torch.zeros((1,14,14))
cnt = 0
for h in (
        slice(0, -window_size),
        slice(-window_size, -shift_size),
        slice(-shift_size, None)):
    for w in (
            slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None)):
        # print(h, w)
        x[:,h,w] = cnt
        cnt += 1
        # print(x)
        # print("----")

x = x.unsqueeze(-1)
print(x.shape)
mask_windows1 = window_partition(x, window_size)
print(mask_windows1.shape)
# for win in mask_windows:
#     print(win.squeeze(-1))
mask_windows = mask_windows1.view(-1, window_size * window_size)
# for win in mask_windows:
#     print(win.squeeze(-1))
print(mask_windows.shape)
print(mask_windows.unsqueeze(1).shape)
print(mask_windows.unsqueeze(2).shape)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
print(attn_mask.shape)
# print(attn_mask[0])
# print(attn_mask[1])
# print(attn_mask[2])
# for win in mask_windows:
#     print(win.squeeze(-1))

torch.Size([1, 14, 14, 1])
torch.Size([4, 7, 7, 1])
torch.Size([4, 49])
torch.Size([4, 1, 49])
torch.Size([4, 49, 1])
torch.Size([4, 49, 49])


In [58]:
print(x.squeeze(-1))

tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5.],
         [3., 3., 3., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5.],
         [3., 3., 3., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5.],
         [3., 3., 3., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5.],
         [6., 6., 6., 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8.],
         [6., 6., 6., 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8.],
         [6., 6., 6., 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8.]]])


In [60]:
for win in mask_windows1:
    print(win.squeeze(-1))

tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.],
        [1., 1., 1., 1., 2., 2., 2.]])
tensor([[3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3., 3., 3.],
        [6., 6., 6., 6., 6., 6., 6.],
        [6., 6., 6., 6., 6., 6., 6.],
        [6., 6., 6., 6., 6., 6., 6.]])
tensor([[4., 4., 4., 4., 5., 5., 5.],
        [4., 4., 4., 4., 5., 5., 5.],
        [4., 4., 4., 4., 5., 5., 5.],
        [4., 4., 4., 4., 5., 5., 5.],
        [7., 7., 7., 7., 8., 8., 8.],
        [

In [61]:
for win in mask_windows:
    print(win)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.])
tensor([1., 1., 1., 1., 2., 2., 2., 1., 1., 1., 1., 2., 2., 2., 1., 1., 1., 1.,
        2., 2., 2., 1., 1., 1., 1., 2., 2., 2., 1., 1., 1., 1., 2., 2., 2., 1.,
        1., 1., 1., 2., 2., 2., 1., 1., 1., 1., 2., 2., 2.])
tensor([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
        3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 6., 6., 6., 6., 6., 6., 6., 6.,
        6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6., 6.])
tensor([4., 4., 4., 4., 5., 5., 5., 4., 4., 4., 4., 5., 5., 5., 4., 4., 4., 4.,
        5., 5., 5., 4., 4., 4., 4., 5., 5., 5., 7., 7., 7., 7., 8., 8., 8., 7.,
        7., 7., 7., 8., 8., 8., 7., 7., 7., 7., 8., 8., 8.])


In [67]:
for win in attn_mask:
    for h in range(len(win)):
        for w in range(len(win[0])):
            print(f"{win[h,w].item():4}", end='')
        print()
    print('-'*10)

 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
 0.0 0.0 0.0 0.

In [69]:
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for win in attn_mask:
    for h in range(len(win)):
        for w in range(len(win[0])):
            print(f"{win[h,w].item():7}", end='')
        print()
    print('-'*10)

    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0
    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0
    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    0.0    

In [53]:
x = torch.Tensor([
    [[1],[2],[3]],
    [[2],[3],[4]]
])
print(x.shape)
print(x)
print("----------")
y = torch.Tensor([
    [[1,2,3]],
    [[1,2,3]]
])
print(y.shape)
print(y)
print("----------")
print(x-y)
print("----------")
print(y-x)
print("----------")
attn = y-x
attn = attn.masked_fill(attn!=0, float(-100.0)).masked_fill(attn==0, float(0.0))
print(attn)

torch.Size([2, 3, 1])
tensor([[[1.],
         [2.],
         [3.]],

        [[2.],
         [3.],
         [4.]]])
----------
torch.Size([2, 1, 3])
tensor([[[1., 2., 3.]],

        [[1., 2., 3.]]])
----------
tensor([[[ 0., -1., -2.],
         [ 1.,  0., -1.],
         [ 2.,  1.,  0.]],

        [[ 1.,  0., -1.],
         [ 2.,  1.,  0.],
         [ 3.,  2.,  1.]]])
----------
tensor([[[ 0.,  1.,  2.],
         [-1.,  0.,  1.],
         [-2., -1.,  0.]],

        [[-1.,  0.,  1.],
         [-2., -1.,  0.],
         [-3., -2., -1.]]])
----------
tensor([[[   0., -100., -100.],
         [-100.,    0., -100.],
         [-100., -100.,    0.]],

        [[-100.,    0., -100.],
         [-100., -100.,    0.],
         [-100., -100., -100.]]])


In [6]:
import torch
y = torch.randn((1,3,32,32))
torch.nn.PixelUnshuffle(2)(y).shape

torch.Size([1, 12, 16, 16])