In [1]:
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat
import torch.nn.functional as f
from torchinfo import summary

In [2]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x
    
class PreNorm(nn.Module):
    ## in swin v2, we use post norm , so we can change the forward
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        #return self.fn(self.norm(x), **kwargs) ## pernorm for version 1
        return self.norm(self.fn(x, **kwargs))  ## postnorm for version 2

class FeedForward(nn.Module):
    ## hidden_dim is the 4 * input_channels
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)

In [3]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [4]:
#### test CyclicShift  ########

x = torch.linspace(1, 81, 81).view(9,9)
print(x)

print('-----------')
y = torch.roll(input = x, shifts = (-1, -1), dims = (0,1))
print(y)

#### end test CyclicShift #####

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.],
        [37., 38., 39., 40., 41., 42., 43., 44., 45.],
        [46., 47., 48., 49., 50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59., 60., 61., 62., 63.],
        [64., 65., 66., 67., 68., 69., 70., 71., 72.],
        [73., 74., 75., 76., 77., 78., 79., 80., 81.]])
-----------
tensor([[11., 12., 13., 14., 15., 16., 17., 18., 10.],
        [20., 21., 22., 23., 24., 25., 26., 27., 19.],
        [29., 30., 31., 32., 33., 34., 35., 36., 28.],
        [38., 39., 40., 41., 42., 43., 44., 45., 37.],
        [47., 48., 49., 50., 51., 52., 53., 54., 46.],
        [56., 57., 58., 59., 60., 61., 62., 63., 55.],
        [65., 66., 67., 68., 69., 70., 71., 72., 64.],
        [74., 75., 76., 77., 78., 79., 80., 81., 73.],
        [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,  1.

In [5]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')

    return mask

In [6]:
#### test create_mask  ############
create_mask(window_size = 3 , displacement = 1, upper_lower= False, left_right = True)
#### end test create_mask  ########

tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
        [-inf, -inf, 0., -inf, -inf, 0., -inf, -inf, 0.]])

In [7]:
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

In [8]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):

        ## dim = hidden_dim = (96, 192, 384, 768)
        ## head = (3, 6, 12, 24)
        ## head_dim  = 32

        super().__init__()
        inner_dim = head_dim * heads  ## hidden_dim [96, 192, 384, 768]
        self.heads = heads
        self.scale = head_dim ** -0.5 ## like in transformer, we add 1/sqrt(d)
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted ## SW-MSA else W-MSA

        if self.shifted:
            displacement = window_size // 2 ## define how many pixel to shift: usually half of the window
            self.cyclic_shift = CyclicShift(-displacement)     ## cycleshift
            self.cyclic_back_shift = CyclicShift(displacement) ## shift back
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            ## absolute positional embedding
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        ## x shape:  [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        if self.shifted:
            x = self.cyclic_shift(x)   ##  [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]

        b, n_h, n_w, _, h = *x.shape, self.heads

        ## after self.to_qkv(x) shape is:  [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (288, 576, 1152, 2304)]
        qkv = self.to_qkv(x).chunk(3, dim=-1) ## divided three distict tensor based on the last dim
        ### qkv[0] : [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        ### qkv[1] : [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        ### qkv[2] : [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]

        nw_h = n_h // self.window_size ## how many windows we have 
        nw_w = n_w // self.window_size ## [56 // 7, 28 // 7, 14 // 7, 7 // 7] => [8, 4, 2, 1]

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv)
        
        ### q: [batchsize, (3, 6, 12, 24), (64, 16, 4, 1), 49, 32]
        ### k: [batchsize, (3, 6, 12, 24), (64, 16, 4, 1), 49, 32]
        ### v: [batchsize, (3, 6, 12, 24), (64, 16, 4, 1), 49, 32]

        #### new version on swin v2 ####
        self.tau = nn.Parameter(torch.tensor(0.01), requires_grad=True)

        q = f.normalize(q, p=2 , dim = -1)
        k = f.normalize(k, p=2 , dim = -1)

        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) / self.tau
        #### ###################### ###

        # uncomment for swin v1
        # dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale

        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]]
        else:
            dots += self.pos_embedding
            ## b: batchsize
            ## h: # of heads: (3, 6, 12, 24)
            ## w: (64, 16, 4, 1)
            ## i: 49 
            ## j: 49

        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask

        attn = dots.softmax(dim=-1)
        ## batchsize, h = (3, 6, 12, 24), w= (64, 16, 4, 1), i = 49, j = 49

        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        ## batchsize, h = (3, 6, 12, 24), (nw_h * nw_w) = (64, 16, 4, 1), (w_h * w_w) =49, 32)

        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        ## [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]

        if self.shifted:
            out = self.cyclic_back_shift(out)
        return out

In [9]:
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        ## dim = hidden_dim = (96, 192, 384, 768)
        ## head = (3, 6, 12, 24)
        ## mlp_dim = hidden_dim * 4
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)  # [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        x = self.mlp_block(x)        # [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        return x    

In [10]:
class PatchMerging_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()

        self.patch_merge = nn.Conv2d(
            in_channels  = in_channels,
            out_channels = out_channels,
            kernel_size  = downscaling_factor,
            stride = downscaling_factor,
            padding = 0
        )

    def forward(self, x): 
        # x shape: (batchsize, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))
        x = self.patch_merge(x)      # [batchsize, (96, 192, 384, 768), (56, 28, 14, 7), (56, 28, 14, 7) ]
        return x.permute(0, 2, 3, 1) # [batchsize, (56, 28, 14, 7),     (56, 28, 14, 7), (96, 192, 384, 768)]


In [11]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        ## x shape: (batchsize, (3, 96, 192, 384), (224, 56, 28, 14), (224, 56, 28, 14))
        x = self.patch_partition(x)     # [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)        # [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
            x = shifted_block(x)        # [batchsize, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768)]
        return x.permute(0, 3, 1, 2)    # [batchsize, 768 , 7, 7]

In [12]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        ### image shape (batchsize, 3, 224, 224)
        x = self.stage1(img)       ## [batchsize, 96,  56, 56]
        x = self.stage2(x)         ## [batchsize, 192, 28 ,28]
        x = self.stage3(x)         ## [batchsize, 384, 14, 14]
        x = self.stage4(x)         ## [batchsize, 768, 7,  7 ]
        x = x.mean(dim=[2, 3])     ## [batchsize, 768]
        return self.mlp_head(x)    ## [batchsize, 1000]


def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

In [13]:
net = swin_t(channels = 3, 
             num_classes = 3, 
             head_dim = 32, 
             window_size = 7, 
             downscaling_factors = (4,2,2,2), 
             relative_pos_embedding = False)

img = torch.randn((1,3,224,224))
net(img).shape

torch.Size([1, 3])

In [14]:
summary(model=net, 
        input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                                     Input Shape          Output Shape         Param #              Trainable
SwinTransformer (SwinTransformer)                                           [1, 3, 224, 224]     [1, 3]               --                   Partial
├─StageModule (stage1)                                                      [1, 3, 224, 224]     [1, 96, 56, 56]      --                   Partial
│    └─PatchMerging_Conv (patch_partition)                                  [1, 3, 224, 224]     [1, 56, 56, 96]      --                   True
│    │    └─Conv2d (patch_merge)                                            [1, 3, 224, 224]     [1, 96, 56, 56]      4,704                True
│    └─ModuleList (layers)                                                  --                   --                   --                   Partial
│    │    └─ModuleList (0)                                                  --                   --                   232,