<a href="https://colab.research.google.com/github/dimoynwa/Computer-vision-tasks/blob/main/SWIN_transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SWIN transformers from scracth

Implement SWIN transformers from scratch using Pytorch.

Following: https://github.com/berniwal/swin-transformer-pytorch


![](https://drive.google.com/uc?export=view&id=16xWyp2Q5oio-m8GtuaAdCZBRT4wTwJeF)


In [None]:
import torch
from torch import nn, einsum
import numpy as np

!pip install einops # Clear and reliable tensor transformations
from einops import rearrange
from einops import einsum

import torch.nn.functional as F

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m702.2 kB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


## Create classes

Every **Stage** at the image is represented in the `StageModule` class.
What it does? First is going to do **Patch Merging**. The idea of Patch Merging is just to change the size and create a hierarchy. The output of PatchMerging at Stage 2 will be [b_s, 28, 28, 192]. The Patch Merging is implemented in `PatchMerging_Conv` class. And also the **output** of the SwinTransformerBlock at Stage 2 will be the same [b_s, 28, 28, 192]. **Input** and **Output** of **SwinTransformerBlock** are the same shapes.

![](https://drive.google.com/uc?export=view&id=1UkEOxgGyd8MK9u1fyQ9oRwluykv-1aea)


In [None]:
class PatchMerging_Conv(nn.Module):
  def __init__(self, in_channels, out_channels, downscalling_factor):
    super().__init__()
    self.patch_merging = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                   kernel_size=downscalling_factor,
                                   stride=downscalling_factor,
                                   padding=0)

  def forward(self, x):
    # print(f'x.size = {x.size()}').  # (1, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))
    # self.patch_merging(x).   # (1, (96, 192, 384, 768), (56, 28, 14, 7), (56, 28, 14, 7))
    x = self.patch_merging(x).permute(0, 2, 3, 1) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
    return x

### PatchMerging different approach

To remove **Conv2d** layer we can use different approach for downscalling the image and increasing the number of channels using `torch.nn.Unfold` and `torch.nn.Linear`.  

In [None]:
class PatchMerging(nn.Module):
  def __init__(self, in_channels, out_channels, downscalling_factor):
    super().__init__()
    self.downscalling_factor = downscalling_factor

    self.patch_merge = nn.Unfold(kernel_size=downscalling_factor,
                                 stride=downscalling_factor)
    self.linear = nn.Linear(in_channels * downscalling_factor ** 2, out_channels)

  def forward(self, x):
    b, c, h, w = x.shape
    # print(f'x.size(): {x.size()}')      # (1, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))

    new_h, new_w = h // self.downscalling_factor, w // self.downscalling_factor
    # print(f'new_h: {new_h}, new_w: {new_w}')    # (56, 28, 14, 7)

    tmp_patch_merge = self.patch_merge(x) # (1, (48, 384, 768, 1536), (3136, 768, 196, 49))
    # print(f'tmp_patch_merge.size(): {tmp_patch_merge.size()}')

    tmp_view = self.patch_merge(x).view(b, -1, new_h, new_w) # (1, (48, 384, 768, 1536), (56, 28, 14, 7), (56, 28, 14, 7))
    # print(f'tmp_view.size(): {tmp_view.size()}')

    x = self.patch_merge(x).view(b, -1, new_h, new_w) # (1, (48, 384, 768, 1536), (56, 28, 14, 7), (56, 28, 14, 7))
    # print(f'x.size(): {x.size()}')
    x = x.permute(0, 2, 3, 1) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (48, 384, 768, 1536))
    # print(f'x.size(): {x.size()}')

    x = self.linear(x) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
    # print(f'x.size(): {x.size()}')

    return x




#### Test PatchMerging_Conv

In [None]:
dummy_in = torch.randn(2, 3, 224, 224)
pm1 = PatchMerging_Conv(3, 96, 4)
pm2 = PatchMerging_Conv(96, 192, 2)
pm3 = PatchMerging_Conv(192, 384, 2)

out = pm1(dummy_in)
print(f'Out shape after PatchMerging_1: {out.shape}')

Out shape after PatchMerging_1: torch.Size([2, 56, 56, 96])


In [None]:
class Residual(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    return self.fn(x, **kwargs) + x

In [None]:
class PreNorm(nn.Module):
  def __init__(self, dim, fn, version=2):
    super().__init__()
    self.fn = fn
    self.norm = nn.LayerNorm(dim)
    self.version = version

  def forward(self, x, **kwargs):
    '''
    In SWIN transformers V1, they use Norm, which do the Normalization
    before the block,

    In SWIN transformers V2, they use Norm, which do the Normalization
    after the block.
    '''
    if self.version == 2:
      return self.norm(self.fn(x), **kwargs)
    return self.fn(self.norm(x), **kwargs)

In [None]:
class FeedForward(nn.Module):
  # mlp_dim = hidden_dim * 4 where dim=hidden_dim=(96, 192, 384, 768)
  def __init__(self, dim, hidden_dim):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, dim)
    )

  def forward(self, x):
    return self.net(x)

## Swin Block

![](https://drive.google.com/uc?export=view&id=1fqzDh6Hj-C3l5CD4LwuwrrAByNo6_8Xd)

The first part in red is the WindowAttention and the second Part is MlpBlock, which is implemented in `FeedForward` class. LN is just a LayerNorm implemented in `PreNorm` class. And each of them has a **Residual connnection**.

## Window Attention

The most important class of the whole SWIN Transformer implementation


![](https://drive.google.com/uc?export=view&id=1YgesmeyPdAB1vsQb9GCODI3b5_P5Y2lr)


#### Shifting

When `self.shifted` is `True` we are going to shift all of the windows to right and down at the same time by `window_size \\ 2`. But whe we shift the regions on the right and down will be empty and we need to pad them.

In the paper they suggest 2 types of padding: `naive padding` and `cyclic padding`. **Naive padding** just add `zeros` at those positions. **Cyclic Padding**:

![](https://drive.google.com/uc?export=view&id=1LrSzZ6ZC65mOofeOsEKO6FKMnvMksZWE)

</br>

Padding speed:
![](https://drive.google.com/uc?export=view&id=1IzWGcsnDJLTGw-4hxOCavAAW40P2fMtz)

**!NOTE:** So, shifts should be negative for `right & down` direction and positive for `left & up` direction.

With this type of padding (**Cyclic padding**) we have a problem with the last row and last column. After the shifting and padding, the first row from the original image will be neighbour with the last row from the original image. The same applies for the first and last columns. So we need to do something to handle this problem. This is called **Masking**.

![](https://drive.google.com/uc?export=view&id=1PkmGJX6A9SiDbGlRPK6w3H3GlR2Vj9Mp)


![](https://drive.google.com/uc?export=view&id=1vNhXyd3RHx2hr1DDeDheTMdyrHUzrvk_)

In [None]:
class CyclicShift(nn.Module):
  def __init__(self, displacement):
    super().__init__()
    self.displacement = displacement

  def forward(self, x):
    return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [None]:
x = torch.linspace(1, 16, 16).view(4, 4)
print(x)

y = torch.roll(x, shifts=(-1, -1), dims=(0, 1))
print(y)

x = torch.roll(x, shifts=(1, 1), dims=(0, 1))
print(x)

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[ 6.,  7.,  8.,  5.],
        [10., 11., 12.,  9.],
        [14., 15., 16., 13.],
        [ 2.,  3.,  4.,  1.]])
tensor([[16., 13., 14., 15.],
        [ 4.,  1.,  2.,  3.],
        [ 8.,  5.,  6.,  7.],
        [12.,  9., 10., 11.]])


In [None]:
def create_mask(window_size, displacement, upper_lower, left_right):
  mask = torch.zeros(window_size ** 2, window_size ** 2) # 49x49

  if upper_lower:
    mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
    mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

  if left_right:
    mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
    mask[:, -displacement:, :, :-displacement] = float('-inf')
    mask[:, :-displacement, :, -displacement:] = float('-inf')
    mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

  return mask

In [None]:
mask_upper_lower = create_mask(3, 1, False, True)
print(mask_upper_lower)

tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.]])


## Relative positional embeddings

With relative positional embedding we reduce the number of parameters from **(n * n)** to **(2*window_size - 1) * (2*window_size - 1)**.

We can achieve that by:

`self.pos_embeddings = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))`

![](https://drive.google.com/uc?export=view&id=1uo2JhjAISwybYW-8FPoQd5FxTrVh5WLQ)

### Performance comparation:

![](https://drive.google.com/uc?export=view&id=19M_OHqdWn3nVr_1f8XhWsQfl0GUK6ZPM)

In [None]:
import torch
import numpy as np

def get_relative_distances(window_size: int):
  indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
  distances = indices[None, :, :] - indices[:, None, :]
  return distances

In [None]:
rel_dist = get_relative_distances(3)
print(f'Shape: {rel_dist.shape}')
print(f'First: {rel_dist[2,1,:]}')

Shape: torch.Size([9, 9, 2])
First: tensor([ 0, -1])


In [None]:
import torch

# Exampme of relative positional embedding (2, 2)
# pos embedding
p = torch.tensor([[1, 2],
                  [3, 4]])

print(f'p.size: {p.size()}')

# relative indices (2 * 2 - 1, 2 * 2 - 1, 2)
rel_pos_embed = torch.tensor([[[0, 0], [0, 0], [0, 0]],
                              [[1, 1], [1, 1], [1, 1]],
                              [[0, 1], [0, 1], [0, 1]]])
print(f'rel_pos_embed.size: {rel_pos_embed.size()}')

print(p[rel_pos_embed[:, :, 0], rel_pos_embed[:, :, 1]])



p.size: torch.Size([2, 2])
rel_pos_embed.size: torch.Size([3, 3, 2])
tensor([[1, 1, 1],
        [4, 4, 4],
        [2, 2, 2]])


### Version 2 of SWIN transformers changes

In Version 2 of SWIN transformers the authors introduces **cosine similarity** instead of just **dot product** between Q and K matrices. Then the result is divided by TAU which is **traineable parameter** which is defined with value > 0.01.



In [None]:
class WindowAttention(nn.Module):
  def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embeddings):
    super().__init__()
    # dim = hidden_dim = (96, 192, 384, 768)
    # heads = num_heads = (3, 6, 12, 24)
    # head_dim = 32
    inner_dim = head_dim * heads  # 32 * 3 = 96, 6 * 32 = 192, 12 * 32 = 384, 24 * 32 = 768
    self.heads = heads
    self.scale = head_dim ** -.5  # scalling dot product inside the softmax
    self.window_size = window_size # by default it is 7
    self.relative_pos_embeddings = relative_pos_embeddings
    self.shifted = shifted

    # The TAU parameter for Version 2 cosine similarity
    self.tau = nn.Parameter(torch.tensor(.01), requires_grad=True)

    '''
    If shifted is True, we are going to shift all of the windows to the right and down and pad them.
    '''
    if self.shifted:
      displacement = window_size // 2   # 7 // 2 = 3
      self.cyclic_shift = CyclicShift(-displacement)
      self.cyclic_back_shift = CyclicShift(displacement)

      # (49, 49) masks are NOT learneable paramenters; set requires_grad to False
      self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size,
                                                       displacement=displacement,
                                                       upper_lower=True,
                                                       left_right=False),
                                           requires_grad=False)
      self.left_right_mask = nn.Parameter(create_mask(window_size=window_size,
                                                      displacement=displacement,
                                                      upper_lower=False,
                                                      left_right=True),
                                          requires_grad=False)
    # Queries, Keys, Values
    self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
    # dim = (96, 192, 384, 768) and inner_dim = head_dim * heads; we can C * 3 and gives us same thing

    if relative_pos_embeddings:
      self.relative_indices = get_relative_distances(window_size) + window_size - 1
      self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
      # (13, 13) because if I am one cell I have 6 possible relationship behind and after
    else:
      # Absolute positional embedding
      self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2)) # (49, 49)

    # inner_dim = head_dim * heads = C, dim = hidden_dim = (96, 192, 384, 768)
    self.to_out = nn.Linear(inner_dim, dim)

  def forward(self, x):
    if self.shifted:
      # print(f'x.size(): {x.size()}')    # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
      x = self.cyclic_shift(x)
      # print(f'x.size(): {x.size()}')    # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

    # x.shape is batch_size, height, width, channels
    batch_size, height, width, _ = x.shape
    heads = self.heads

    # Window sizes do NOT change, but number of channels is increased 3 times, for Queries, Keys and Values
    # print(f'self.to_qkv(x): {self.to_qkv(x).size()}')   # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96*3, 192*3, 384*3, 768*3))

    # Chunk function divide it to 3 distinct tensors (Q, K and V) based on last dim, which is number of channels
    qkv = self.to_qkv(x).chunk(3, dim=-1)

    # print(f'qkv[0]: {qkv[0].size()}')   # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)) for qkv[0] as qkv is a tuple of 3 elements

    # Number of windows in height
    n_wh = height // self.window_size     # (56 // 7 = 8, 28 // 7 = 4, 14 // 7 = 2, 7 // 7 = 1)
    # Number of windows in width
    n_ww = width // self.window_size      # (56 // 7 = 8, 28 // 7 = 4, 14 // 7 = 2, 7 // 7 = 1)

    q, k, v = map(lambda t: rearrange(t, 'b (n_wh w_h) (n_ww w_w) (h d) -> b h (n_wh n_ww) (w_h w_w) d',
                                      h = heads, w_h = self.window_size, w_w = self.window_size),
                  qkv)

    # print(f'q.size(): {q.size()}')
    # (b=1, h=(3, 6, 12, 24), (n_wh * n_ww)=(64, 16, 4, 1), (w_h*w_w)=49, d=32) where d=head_dim, h=number of heads

    # print(f'k.size(): {k.size()}')    # same as q
    # print(f'v.size(): {v.size()}')    # same as v

    # Find dot product of Q and K
    # Dot product similarity for version 1 of SWIN transformers

    #dots = einsum(q, k, 'b h w i d, b h w j d -> b h w i j') * self.scale

    # b - batch size, h - heads (3, 6, 12, 24), w - windows (64, 16, 4, 1) i=j=49

    # Cosine similarity
    # First normalize Q and K matrices with respect to each row
    q = F.normalize(q, p=2.0, dim=-1)
    k = F.normalize(k, p=2.0, dim=-1)

    # cosine similarity divided by self.tau
    dots = einsum(q, k, 'b h w i d, b h w j d -> b h w i j') / self.tau
    # b-batch size, h-heads, w-windows, i = j = 49

    # Now we should add the Positional Embedding
    if self.relative_pos_embeddings:
      # print(f'self.pos_embedding.size(): {self.pos_embedding.size()}')    # (13, 13)
      tmp1 = self.relative_indices[:, :, 0] # (49, 49) as relative_indices is (49, 49, 2)

      # tmp2 = self.pos_embedding[self.relative_indices[:, :, 0]], self.relative_indices[:, :, 1]     # (49, 49)
      dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]      # b-batch_size, h-#heads, w-windows, i=49, j=49
    else:
      # Absolute embeddings
      dots += self.pos_embedding   # (b-batch_size, h-heads, w-windows, i=49, j=49)

    # Add masking
    if self.shifted:
      # tmp1 = self.upper_lower_mask    (49, 49)
      # tmp2 = self.left_right_mask     (49, 49)

      # Add mask to last row
      dots[:, :, -n_ww:] += self.upper_lower_mask
      # Add mask to last column
      dots[:, :, n_ww-1::n_ww] += self.left_right_mask

    attn = dots.softmax(dim=-1) # (batch_size, heads=(3, 6, 12, 24), windows=(64, 16, 4, 1), i=49, j=49)

    # Add the Value to the Attention
    out = einsum(attn, v, 'b h w i j, b h w j d -> b h w i d')
    # shape: (batch_size, heads=(3, 6, 12, 24), windows=(64, 16, 4, 1), i=49, d=head_dim=32)

    # Rearrange output
    out = rearrange(out, 'b h (n_wh n_ww) (w_h w_w) d -> b (n_wh w_h) (n_ww w_w) (h d)',
                    h=heads, w_h=self.window_size, w_w=self.window_size, n_wh=n_wh, n_ww=n_ww)
    # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

    out = self.to_out(out) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

    if self.shifted:
      # We need to shift back to get to the original shape to send it to the next block
      out = self.cyclic_back_shift(out)    # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

    return out

#### Test WindowAttention

In [None]:
w_attn = WindowAttention(dim=96, heads=3, head_dim=32, shifted=True, window_size=7, relative_pos_embeddings=True)

dummy_in = torch.randn(2, 56, 56, 96)

out = w_attn(dummy_in)
print(f'Out shape: {out.shape}')

Out shape: torch.Size([2, 56, 56, 96])


In [None]:
class SwinBlock(nn.Module):
  def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embeddings):
    # dim=hidden_dim=(96, 192, 384, 768), heads=num_heads=(3, 6, 12, 24), mlp_dim=dim*4
    super().__init__()
    self.attention_block = Residual(PreNorm(dim, WindowAttention(
        dim=dim,
        heads=heads,
        head_dim=head_dim,
        shifted=shifted,
        window_size=window_size,
        relative_pos_embeddings=relative_pos_embeddings
    )))

    self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

  def forward(self, x):
    x = self.attention_block(x)  # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
    x = self.mlp_block(x) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
    return x


#### Test SwinBlock

In [None]:
sb = SwinBlock(dim=96, heads=3, head_dim=32, mlp_dim=96*4, shifted=False,
               window_size=7, relative_pos_embeddings=True)

dummy_in = torch.randn(2, 56, 56, 96)

out = sb(dummy_in)
print(f'Out shape: {out.shape}')

Out shape: torch.Size([2, 56, 56, 96])


In [None]:
class StageModule(nn.Module):
  def __init__(self, in_channels, hidden_dim, layers, downscalling_factor,
               num_heads, head_dim, window_size, relative_pos_embeddings):
    super().__init__()
    assert layers % 2 == 0, 'Stage Layers must be divisible by 2 for regular and shifted blocks'

    self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dim,
                                             downscalling_factor=downscalling_factor)

    self.layers = nn.ModuleList([])
    for _ in range(layers // 2):
      self.layers.append(nn.ModuleList([
          SwinBlock(dim=hidden_dim, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dim*4,
                    shifted=False, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings),
          SwinBlock(dim=hidden_dim, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dim*4,
                    shifted=True, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings)
      ]))

  def forward(self, x):
    '''
    What does this shapes at the end mean? It is a tuple of 4 elements.
    First element is the batch size - 1.

    The next element (3, 96, 192, 384) is the number of channels is every SWIN transformer block.
      In block 0 it is 3, in block 1 96, in block 2 192 and the last block 384.

    The next element (224, 56, 28, 14) is the height of the image in every SWIN transformer block.
      In block 0 it is 224, block 1 - 56, block 2 - 28 and last block 14.

    The next element (224, 56, 28, 14) is the width of the image in every SWIN transformer block.
      In block 0 it is 224, block 1 - 56, block 2 - 28 and last block 14.

    So the input of StageModule forward method on first block will be (batch_size, 3, 224, 224),
      the second block (batch_size, 96, 56, 56), the third block (batch_size, 192, 28, 28)
      and tha last (batch_size, 384, 14, 14)

    '''

    # print(f'Before patch partition: {x.size()}')    # (1, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))
    x = self.patch_partition(x)
    # print(f'After patch partition: {x.size()}')     # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
    for regular_block, shifted_block in self.layers:
      x = regular_block(x) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
      x = shifted_block(x) # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))

    x = x.permute(0, 3, 1, 2) # (1, 768, 7, 7)
    return x

In [None]:
class SWINTransformer(nn.Module):
  def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000,
               head_dim=32, window_size=7, downscaling_factors=(4, 2, 2, 2),
               relative_pos_embeddings=True):
    super().__init__()
    self.stage1 = StageModule(in_channels=channels, hidden_dim=hidden_dim, layers=layers[0],
                              downscalling_factor=downscaling_factors[0], num_heads=heads[0],
                              head_dim=head_dim, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings)

    self.stage2 = StageModule(in_channels=hidden_dim, hidden_dim=hidden_dim * 2, layers=layers[1],
                              downscalling_factor=downscaling_factors[1], num_heads=heads[1],
                              head_dim=head_dim, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings)

    self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dim=hidden_dim * 4, layers=layers[2],
                              downscalling_factor=downscaling_factors[2], num_heads=heads[2],
                              head_dim=head_dim, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings)

    self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dim=hidden_dim * 8, layers=layers[3],
                              downscalling_factor=downscaling_factors[3], num_heads=heads[3],
                              head_dim=head_dim, window_size=window_size, relative_pos_embeddings=relative_pos_embeddings)

    self.mlp_head = nn.Sequential(
        nn.LayerNorm(hidden_dim*8),
        nn.Linear(hidden_dim * 8, num_classes)
    )

  def forward(self, img):
    x = self.stage1(img)
    x = self.stage2(x)
    x = self.stage3(x)
    x = self.stage4(x)

    x = x.mean(dim=[2, 3])

    x = self.mlp_head(x)
    return x

In [None]:
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
  return SWINTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

In [None]:
net = swin_t(
    hidden_dim=96,
    layers=(2, 2, 6, 2),
    heads=(3, 6, 12, 24),
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4, 2, 2, 2),
    relative_pos_embeddings=True
)

dummy_h = torch.randn(1, 3, 224, 224)

loggits = net(dummy_h)
print(loggits)

tensor([[-1.2585, -0.5348, -1.1451]], grad_fn=<AddmmBackward0>)
