In [19]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchinfo import summary
import timm
from pytorch_models_imp.swin_transformer import PatchEmbeddings, PatchMerging, partition_window, reverse_partition, WindowSelfAttention, SwinTransformer, SwinBlock, TransformerBlock

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=5)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=5)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
images, labels = next(iter(trainloader))

In [4]:
N, C, H, W = images.shape

### Images to patches

In [5]:
IMAGE_SIZE = 224
PATCH_SIZE = 4
EMBED_SIZE = 96
patch_embedder = PatchEmbeddings(IMAGE_SIZE, PATCH_SIZE, EMBED_SIZE)

In [6]:
patch_resolution = patch_embedder.patch_resolution

In [7]:
print(f"NUM OF PATCHES: {patch_embedder.num_of_patches}. Patch resolution: {patch_resolution}")
patches_out = patch_embedder(images)
assert patches_out.shape == (N, patch_embedder.num_of_patches, EMBED_SIZE)  # batch, number_of_patches, embed_size

NUM OF PATCHES: 3136. Patch resolution: 56


### Patch merging

In [8]:
patch_merging = PatchMerging((patch_resolution, patch_resolution), EMBED_SIZE)
patch_merging_out = patch_merging(patches_out)

In [9]:
assert patch_merging_out.shape == (N, (patch_resolution // 2)**2, EMBED_SIZE * 2)  # batch, number_of_patches, embed_size * 2

### Partition & reverse window

In [21]:
window_size = 7
patches_with_resolution = patches_out.reshape(N, patch_resolution, patch_resolution, -1)
B, H, W, C = patches_with_resolution.shape
partitioned_out = partition_window(patches_with_resolution, window_size)
partitioned_out.shape # B * nwindows, window_size**2, C

torch.Size([8192, 7, 7, 96])

In [22]:
reversed_out = reverse_partition(partitioned_out, window_size, H, W)
print(reversed_out.shape)
assert (patches_with_resolution == reversed_out).all()

torch.Size([128, 56, 56, 96])


### Window Attention with relative bias

In [23]:
HEADS = 3
DROPOUT = 0.0
self_attention = WindowSelfAttention(HEADS, (window_size, window_size), EMBED_SIZE,  DROPOUT)
partitioned_out = partitioned_out.view(-1, window_size * window_size, C)

In [24]:
self_attention_out = self_attention(partitioned_out, partitioned_out, partitioned_out)

In [25]:
assert self_attention_out.shape == (partitioned_out.shape[0], window_size * window_size, C)

### Transformer block

In [49]:
transformer_block = TransformerBlock(heads=HEADS, shift_size=window_size // 2, embed_size=EMBED_SIZE, window_size=window_size,
                              patches_resolution=(patches_resolution, patches_resolution), forward_expansion=4, dropout=0.1)

In [50]:
transformer_out = transformer_block(patches_out)
assert transformer_out.shape == (N, patches_resolution**2, C)

### Swin block

In [51]:
swin_block = SwinBlock((patches_resolution, patches_resolution), heads=HEADS, depth=2, embed_size=EMBED_SIZE, expansion=4, window_size=window_size, dropout=0.0, downsample=PatchMerging)

In [52]:
swin_block_out = swin_block(patches_out)

In [53]:
assert swin_block_out.shape == (N, patches_out.shape[1] // 4, patches_out.shape[2] * 2)

### Swin transformer

In [89]:
swin_transformer = SwinTransformer(IMAGE_SIZE, 1000)

In [85]:
swin_transformer_out = swin_transformer(images)

### Compare with Timm

In [86]:
swin_timm = timm.create_model(
    "swin_tiny_patch4_window7_224", pretrained=True, num_classes=1000
)

In [95]:
summary(swin_timm, input_size=(1, 3, 224, 224), device='cpu')

Layer (type:depth-idx)                             Output Shape              Param #
SwinTransformer                                    --                        --
├─Sequential: 1                                    --                        --
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-1                        --                        224,694
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-2                        --                        891,756
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-3                        --                        10,658,952
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-4                        --                        14,183,856
├─PatchEmbed: 1-1                                  [1, 3136, 96]             --
│    └─Co

In [96]:
summary(swin_transformer, input_size=(1, 3, 224, 224), device='cpu')

Layer (type:depth-idx)                                  Output Shape              Param #
SwinTransformer                                         --                        --
├─ModuleList: 1-1                                       --                        --
│    └─SwinBlock: 2                                     --                        --
│    │    └─ModuleList: 3-1                             --                        224,694
│    └─SwinBlock: 2                                     --                        --
│    │    └─ModuleList: 3-2                             --                        891,756
│    └─SwinBlock: 2                                     --                        --
│    │    └─ModuleList: 3-3                             --                        10,658,952
│    └─SwinBlock: 2                                     --                        --
│    │    └─ModuleList: 3-4                             --                        14,183,856
├─PatchEmbeddings: 1-2            