In [32]:
import torch
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat

In [33]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement

    def forward(self, x):
        print("torch_roll : " , torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2)).shape)
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))

In [34]:
## torch.roll : 주어진 차원을 따라 텐서를 굴림
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]).view(4,4)
x

tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12],
        [13, 14, 15, 16]])

In [35]:
torch.roll(x, 1, 0) # 0행을 1행으로 

tensor([[13, 14, 15, 16],
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])

In [36]:
torch.roll(x, 2, 0)

tensor([[ 9, 10, 11, 12],
        [13, 14, 15, 16],
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8]])

In [37]:
torch.roll(x, 2, 1) # 1열까지를 2열 옮기기 

tensor([[ 3,  4,  1,  2],
        [ 7,  8,  5,  6],
        [11, 12,  9, 10],
        [15, 16, 13, 14]])

In [38]:
torch.roll(x, shifts=(2, 1), dims=(0, 1))  # 0행을 2행씩 밀고 / 1열을 1 열씩 밀자  dims = 0 : 행 / 1 : 열
#torch.roll(x, shifts=1, dims=1 )

tensor([[12,  9, 10, 11],
        [16, 13, 14, 15],
        [ 4,  1,  2,  3],
        [ 8,  5,  6,  7]])

In [39]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [40]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [41]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)

In [42]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2)

    if upper_lower:
        mask[-displacement * window_size:, :-displacement * window_size] = float('-inf')
        mask[:-displacement * window_size, -displacement * window_size:] = float('-inf')

    if left_right:
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacement:] = float('-inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
        print("mask shape:", mask.shape)
    return mask

In [43]:
window_size =  7 
displacement = window_size // 2
displacement

3

In [91]:
mask = torch.zeros(window_size ** 2, window_size ** 2)
mask #torch.Size([49, 49])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [92]:
mask[-displacement * window_size:, :-displacement * window_size] # 21 / 49-21 = 28 

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [85]:
mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size) # torch.Size([7, 7, 7, 7]) # 해제 
mask[:, -displacement:, :, :-displacement] #torch.Size([7, 3, 7, 4])
mask[:, :-displacement, :, -displacement:] #torch.Size([7, 4, 7, 3])

tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
    

In [46]:
mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)') ## 다시 결합 
mask.shape # torch.Size([49, 49])

torch.Size([49, 49])

In [47]:
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    print("distances.shape:", distances.shape)
    return distances

In [48]:
indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
distances = indices[None, :, :] - indices[:, None, :]
distances + window_size - 1 # distances + window_size - 1

tensor([[[ 6,  6],
         [ 6,  7],
         [ 6,  8],
         ...,
         [12, 10],
         [12, 11],
         [12, 12]],

        [[ 6,  5],
         [ 6,  6],
         [ 6,  7],
         ...,
         [12,  9],
         [12, 10],
         [12, 11]],

        [[ 6,  4],
         [ 6,  5],
         [ 6,  6],
         ...,
         [12,  8],
         [12,  9],
         [12, 10]],

        ...,

        [[ 0,  2],
         [ 0,  3],
         [ 0,  4],
         ...,
         [ 6,  6],
         [ 6,  7],
         [ 6,  8]],

        [[ 0,  1],
         [ 0,  2],
         [ 0,  3],
         ...,
         [ 6,  5],
         [ 6,  6],
         [ 6,  7]],

        [[ 0,  0],
         [ 0,  1],
         [ 0,  2],
         ...,
         [ 6,  4],
         [ 6,  5],
         [ 6,  6]]], dtype=torch.int32)

In [94]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads 

        self.heads = heads
        self.scale = head_dim ** -0.5 # dim scaling 
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted

        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement) # back
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False), requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True), requires_grad=False)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # dim > inner_dim * 3 channel 
  

        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim) # inner_dim > dim channel

    def forward(self, x):
        if self.shifted:
            x = self.cyclic_shift(x) # CyclicShift
        print("cyclic_shift", x.shape)
        b, n_h, n_w, _, h = *x.shape, self.heads
        print("b, n_h, n_w, _, h ", b, n_h, n_w, _, h )
        
        qkv = self.to_qkv(x).chunk(3, dim=-1) # 분할 
        print("to qkv :" , self.to_qkv(x).shape)
        print("qkv chunk :", len(qkv))
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size

        q, k, v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                h=h, w_h=self.window_size, w_w=self.window_size), qkv) # shape 변경 
        print("q:", q.shape)
        print("k:", q.shape)
        print("v:", q.shape)
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale # query and key dots 
        print("dots : ", dots.shape)
        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indices[:, :, 0].type(torch.long), self.relative_indices[:, :, 1].type(torch.long)]
        else:
            dots += self.pos_embedding
        print("relative_pos_embedding dots : ", dots.shape)
        if self.shifted:
            dots[:, :, -nw_w:] += self.upper_lower_mask
            dots[:, :, nw_w - 1::nw_w] += self.left_right_mask
        print("shifted dots : ", dots.shape)
        attn = dots.softmax(dim=-1)
        print("attn : ", len(attn))
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out) 
        print("out attention fin : " ,out.shape)
        return out

In [95]:
As = torch.randn(3,2,5)
Bs = torch.randn(3,5,4)
torch.einsum('bij,bjk->bik', As, Bs).shape # 3 2 5 * 3 5 4 >  3 2 4 

torch.Size([3, 2, 4])

In [96]:
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        print("attention_block shape:", x.shape)
        x = self.mlp_block(x)
        print("mlp_block shape:", x.shape)
        return x

In [97]:
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        self.downscaling_factor = downscaling_factor
        self.patch_merge = nn.Unfold(kernel_size=downscaling_factor, stride=downscaling_factor, padding=0)
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        print("patch partition shape :", x.shape)
        x = self.linear(x)
        print("patch linear shape :", x.shape)
        return x

In [98]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim, window_size,
                 relative_pos_embedding):
        super().__init__()
        assert layers % 2 == 0, 'Stage layers need to be divisible by 2 for regular and shifted block.'

        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        x = self.patch_partition(x)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        print("regular / shifted :", x.permute(0, 3, 1, 2).shape)
        return x.permute(0, 3, 1, 2)

In [99]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        super().__init__()

        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )

    def forward(self, img):
        x = self.stage1(img)
        print("stage1 shape:", x.shape)
        x = self.stage2(x)
        print("stage2 shape:", x.shape)
        x = self.stage3(x)
        print("stage3 shape:", x.shape)
        x = self.stage4(x)
        print("stage4 shape:", x.shape)
        x = x.mean(dim=[2, 3])
        print("mean shape:", x.shape)
        print("mlp_head shape:", self.mlp_head(x).shape)
        return self.mlp_head(x)

In [100]:
  
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_s(hidden_dim=96, layers=(2, 2, 18, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_b(hidden_dim=128, layers=(2, 2, 18, 2), heads=(4, 8, 16, 32), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)


def swin_l(hidden_dim=192, layers=(2, 2, 18, 2), heads=(6, 12, 24, 48), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

# Swin-T model
10장

In [101]:
net = SwinTransformer(
    hidden_dim=96,
    layers=(2, 2, 6, 2), #architecture swin block 
    heads=(3, 6, 12, 24),
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4, 2, 2, 2),
    relative_pos_embedding=True
)

distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])


In [102]:
net = swin_t()

distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])


In [103]:
dummy_x = torch.randn(10, 3, 224, 224)
dummy_x.shape

torch.Size([10, 3, 224, 224])

In [104]:
logits = net(dummy_x)  # (1,3)
logits

patch partition shape : torch.Size([10, 56, 56, 48])
patch linear shape : torch.Size([10, 56, 56, 96])
cyclic_shift torch.Size([10, 56, 56, 96])
b, n_h, n_w, _, h  10 56 56 96 3
to qkv : torch.Size([10, 56, 56, 288])
qkv chunk : 3
q: torch.Size([10, 3, 64, 49, 32])
k: torch.Size([10, 3, 64, 49, 32])
v: torch.Size([10, 3, 64, 49, 32])
dots :  torch.Size([10, 3, 64, 49, 49])
relative_pos_embedding dots :  torch.Size([10, 3, 64, 49, 49])
shifted dots :  torch.Size([10, 3, 64, 49, 49])
attn :  10
out attention fin :  torch.Size([10, 56, 56, 96])
attention_block shape: torch.Size([10, 56, 56, 96])
mlp_block shape: torch.Size([10, 56, 56, 96])
torch_roll :  torch.Size([10, 56, 56, 96])
cyclic_shift torch.Size([10, 56, 56, 96])
b, n_h, n_w, _, h  10 56 56 96 3
to qkv : torch.Size([10, 56, 56, 288])
qkv chunk : 3
q: torch.Size([10, 3, 64, 49, 32])
k: torch.Size([10, 3, 64, 49, 32])
v: torch.Size([10, 3, 64, 49, 32])
dots :  torch.Size([10, 3, 64, 49, 49])
relative_pos_embedding dots :  torch.S

tensor([[-0.0329,  0.4933,  1.0318,  ..., -0.3727,  0.2248,  0.1930],
        [-0.2262,  0.3669,  1.2623,  ..., -0.2189,  0.2762,  0.1446],
        [-0.0309,  0.5681,  1.1785,  ..., -0.1589,  0.1900, -0.0177],
        ...,
        [-0.1591,  0.4525,  1.2227,  ..., -0.3526,  0.3906,  0.2651],
        [-0.1657,  0.3989,  1.3800,  ..., -0.1764,  0.3076,  0.2358],
        [-0.0176,  0.4378,  1.1322,  ..., -0.1509,  0.3652,  0.1394]],
       grad_fn=<AddmmBackward>)

In [63]:
print(net)

SwinTransformer(
  (stage1): StageModule(
    (patch_partition): PatchMerging(
      (patch_merge): Unfold(kernel_size=4, dilation=1, padding=0, stride=4)
      (linear): Linear(in_features=48, out_features=96, bias=True)
    )
    (layers): ModuleList(
      (0): ModuleList(
        (0): SwinBlock(
          (attention_block): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): WindowAttention(
                (to_qkv): Linear(in_features=96, out_features=288, bias=False)
                (to_out): Linear(in_features=96, out_features=96, bias=True)
              )
            )
          )
          (mlp_block): Residual(
            (fn): PreNorm(
              (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (fn): FeedForward(
                (net): Sequential(
                  (0): Linear(in_features=96, out_features=384, bias=True)
                  (1): GELU()
               

In [64]:
print(logits.shape) # 10장의 1000 개 class 예측 확률

torch.Size([10, 1000])


# Swin-B model
5장

In [80]:
net_B = SwinTransformer(
    hidden_dim=96,
    layers=(2, 2, 6, 2), #architecture swin block 
    heads=(3, 6, 12, 24),
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4, 2, 2, 2),
    relative_pos_embedding=True
)

distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])


In [81]:
net_B = swin_b()

distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distances.shape: torch.Size([49, 49, 2])
distances.shape: torch.Size([49, 49, 2])
mask shape: torch.Size([49, 49])
distan

In [82]:
dummy_y = torch.randn(5, 3, 224, 224)
dummy_y.shape


torch.Size([5, 3, 224, 224])

In [83]:
logits_y = net_B(dummy_y)  
logits_y

patch partition shape : torch.Size([5, 56, 56, 48])
patch linear shape : torch.Size([5, 56, 56, 128])
cyclic_shift torch.Size([5, 56, 56, 128])
b, n_h, n_w, _, h  5 56 56 128 4
qkv: 3
q: torch.Size([5, 4, 64, 49, 32])
k: torch.Size([5, 4, 64, 49, 32])
v: torch.Size([5, 4, 64, 49, 32])
dots :  torch.Size([5, 4, 64, 49, 49])
relative_pos_embedding dots :  torch.Size([5, 4, 64, 49, 49])
shifted dots :  torch.Size([5, 4, 64, 49, 49])
attn :  5
out attention fin :  torch.Size([5, 56, 56, 128])
attention_block shape: torch.Size([5, 56, 56, 128])
mlp_block shape: torch.Size([5, 56, 56, 128])
torch_roll :  torch.Size([5, 56, 56, 128])
cyclic_shift torch.Size([5, 56, 56, 128])
b, n_h, n_w, _, h  5 56 56 128 4
qkv: 3
q: torch.Size([5, 4, 64, 49, 32])
k: torch.Size([5, 4, 64, 49, 32])
v: torch.Size([5, 4, 64, 49, 32])
dots :  torch.Size([5, 4, 64, 49, 49])
relative_pos_embedding dots :  torch.Size([5, 4, 64, 49, 49])
shifted dots :  torch.Size([5, 4, 64, 49, 49])
attn :  5
torch_roll :  torch.Siz

tensor([[-1.0724, -0.5715, -0.1288,  ..., -1.0682, -0.7317, -0.0189],
        [-1.1644, -0.4917, -0.2390,  ..., -0.9966, -0.8087, -0.0611],
        [-1.0819, -0.5840, -0.0381,  ..., -0.9123, -0.6584, -0.2453],
        [-1.1570, -0.3997, -0.1508,  ..., -0.9247, -0.6607, -0.1020],
        [-1.1046, -0.4589, -0.2040,  ..., -0.9864, -0.7731, -0.1307]],
       grad_fn=<AddmmBackward>)

In [84]:
logits_y.shape

torch.Size([5, 1000])