In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm


In [13]:
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H//window_size, window_size, W//window_size, window_size, C)
    x = x.permute(0, 1, 3, 2, 4, 5)
    x = x.reshape(-1, window_size, window_size, C)

    return x

In [14]:
def window_reverse(windows, window, H, W):
    B = int(windows.shape[0] / (H//window * W//window))
    x = windows.view(B, H//window, W//window, window, window, -1)
    x = x.permute(0, 1, 3, 2, 4, 5)
    x = x.reshape(B, H, W, -1)
    return x

In [15]:
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = self.dim // self.num_heads
        self.scale = self.head_dim ** -0.5
        self.window_size = window

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

        # Relative positional bias and mask value
        coords_h = torch.arange(self.window_size)
        coords_w = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size - 1  # shift to be positive
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1 # This ensures unique indices
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window - 1) * (2 * window - 1), num_heads))
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)


    def forward(self, x, mask=None):
        B_, N, C = x.shape

        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2] # q, k, v: (B_, num_heads, N, head_dim)

        q = q * self.scale
        attention = (q @ k.transpose(-2, -1)) # (B_, num_heads, N, N)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1) # N, N, num_heads
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # num_heads, N, N
        attention = attention + relative_position_bias.unsqueeze(0) # (B_, num_heads, N, N) + (1, num_heads, N, N)

        if mask is not None:
            num_window = mask.shape[0]
            attention = attention.view(B_ // num_window, num_window, self.num_heads, N, N)
            attention = attention + mask.unsqueeze(1).unsqueeze(0)
            attention = attention.view(-1, self.num_heads, N, N)

        attention = attention.softmax(dim = -1)
        output = (attention @ v).transpose(1, 2).reshape(B_, N, C)
        output = self.proj(output)
        return output

In [16]:
class SwinBlock(nn.Module):
    def __init__(self, dim, resolution, window, shift, heads):
        super().__init__()
        self.dim = dim
        self.resolution = resolution
        self.num_window = window
        self.shift = shift
        self.norm_1 = nn.LayerNorm(dim)
        self.attention = WindowAttention(dim, heads, window)
        self.norm_2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, 4*dim),
            nn.GELU(),
            nn.Linear(4*dim, dim),
        )
        H, W = resolution
        if shift > 0:
            self.mask = self.create_mask(H, W, window, shift)

        else:
            self.mask = None

    def create_mask(self, H, W, window, shift):
        img_mask = torch.zeros((1, H, W, 1))
        count = 0

        for h in (slice(0, -window), (slice(-window, -shift)), slice(-shift, None)):
            for w in (slice(0, -window), (slice(-window, -shift)), slice(-shift, None)):
                img_mask[:, h, w, :] = count
                count += 1

        mask = window_partition(img_mask, window)
        mask = mask.view(-1, window * window)
        mask = mask.unsqueeze(1) - mask.unsqueeze(2)
        mask = mask.masked_fill(mask!=0, -10000.0)
        return mask

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.resolution

        residual = x
        x = self.norm_1(x)
        x = x.view(B, H, W, C)

        if self.shift > 0:
            x = torch.roll(x, shifts=(-self.shift, -self.shift), dims=(1,2))

        window_x = window_partition(x, self.num_window).view(-1, self.num_window*self.num_window, C)

        attention_output = self.attention(window_x, self.mask.to(x.device) if self.mask is not None else None)

        x = window_reverse(attention_output, self.num_window, H, W)

        if self.shift > 0:
            x = torch.roll(x, shifts=(self.shift, -self.shift), dims=(1, 2))

        x = residual + x.view(B, L, C)

        residual_2 = x
        x = self.norm_2(x)
        x = self.mlp(x)
        x = x + residual_2

        return x

In [17]:
class PatchMerging(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

        self.reduction = nn.Linear(4*dim, 2*dim, bias=False)
        self.norm = nn.LayerNorm(4*dim)

    def forward(self, x, H, W):
        B, L, C = x.shape
        x = x.view(B, H, W, C)

        x_0 = x[:, 0::2, 0::2, :]
        x_1 = x[:, 1::2, 0::2, :]
        x_2 = x[:, 0::2, 1::2, :]
        x_3 = x[:, 1::2, 1::2, :]

        x_0 = x_0.reshape(B, -1, C)
        x_1 = x_1.reshape(B, -1, C)
        x_2 = x_2.reshape(B, -1, C)
        x_3 = x_3.reshape(B, -1, C)

        x = torch.cat([x_0, x_1, x_2, x_3], -1)

        x = self.norm(x)
        x = self.reduction(x)

        return x, H//2, W//2

In [18]:
class TwoStageSwinMNIST(nn.Module):
    def __init__(self, embed_dim= 48, heads=3, window=7, num_classes=10):
        super().__init__()
        self.embedding_dim = embed_dim
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=2, stride=2)
        initial_resolution = (14, 14)

        # Stage 1
        self.stage_1_blocks = nn.Sequential(
            SwinBlock(embed_dim, initial_resolution, heads=heads, window=window, shift=0),
            SwinBlock(embed_dim, initial_resolution,  heads=heads, window=window, shift=3)
        )

        self.patch_merge = PatchMerging(embed_dim)
        merged_dim = embed_dim * 2
        stage_2_resolution = (initial_resolution[0]//2, initial_resolution[1]//2)

        # Stage 2
        self.stage_2_blocks = nn.Sequential(
            SwinBlock(merged_dim, stage_2_resolution, heads=heads, window=window, shift=0),
            SwinBlock(merged_dim, stage_2_resolution, heads=heads, window=window, shift=window // 2)
        )

        self.layer_norm = nn.LayerNorm(merged_dim)
        self.fully_connected_layer = nn.Linear(merged_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        L = H * W

        x = self.stage_1_blocks(x)
        x, H, W = self.patch_merge(x, H, W)
        L = H * W
        C = C * 2

        x = self.stage_2_blocks(x)

        x = self.layer_norm(x)

        x = x.mean(dim = 1)
        x = self.fully_connected_layer(x)

        return x

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081,))
])

train_data = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST("./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=256, shuffle=False)

model = TwoStageSwinMNIST().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

In [20]:
def test(model):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for img, lable in test_loader:
            img = img.to(device)
            lable = lable.to(device)

            output = model(img)
            pred = output.argmax(1)
            correct += pred.eq(lable).sum().item()
            total += lable.size(0)

    return correct / total

In [21]:
for epoch in range(5):
    model.train()
    for img, lable in tqdm(train_loader):
        img = img.to(device)
        lable = lable.to(device)

        output = model(img)
        loss = loss_fn(output, lable)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    acc = test(model)
    print("Epoch", epoch + 1, "Accuracy = ", acc * 100)

100%|██████████| 469/469 [00:20<00:00, 22.83it/s]


Epoch 1 Accuracy =  89.23


100%|██████████| 469/469 [00:20<00:00, 22.87it/s]


Epoch 2 Accuracy =  94.01


100%|██████████| 469/469 [00:20<00:00, 23.09it/s]


Epoch 3 Accuracy =  95.19999999999999


100%|██████████| 469/469 [00:20<00:00, 23.19it/s]


Epoch 4 Accuracy =  94.81


100%|██████████| 469/469 [00:20<00:00, 23.30it/s]


Epoch 5 Accuracy =  95.87
