In [1]:
import torch.nn as nn
import torch
import einops
from math import sqrt

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        energy_term = torch.einsum("bqhe, bkhe -> bqhk", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('bihv, bvhd -> bihd ', mh_out / divider, values)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)
    
x = torch.rand(4, 197, 768)
MultiHeadAttention(768, 8)(x).shape

In [2]:
import torch.nn.functional as F

class MultiHeadXCITAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store
        self.tau = torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        keys = F.normalize(keys, p=2, dim=1)
        queries = F.normalize(queries, p=2, dim=1)
        energy_term = torch.einsum("bnhe, bnhq -> behq", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('behq, bnhe -> bnhq ', mh_out / divider, values)
        print(out.shape)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)


MultiHeadXCITAttention(768, 8)(x).shape

torch.Size([4, 96, 8, 96])
torch.Size([4, 197, 8, 96])


torch.Size([4, 197, 768])

In [5]:
from einops.layers.torch import Reduce, Rearrange

class PatchEmbeddingPixelwise(torch.nn.Sequential):
    
    def __init__(self, stride, embedding_size, channels=3) -> None:
        reduce = Reduce("b c (w i) (h k) -> b (c i k) w h", "mean", i=stride, k=stride)
        rearange = Rearrange("b e h w -> b (h w) e")
        linear = torch.nn.Linear(stride * stride * channels, embedding_size)
        super().__init__(*[
            reduce,
            rearange,
            linear
        ])
        
imgs = torch.rand(4, 3, 512, 512)

PatchEmbeddingPixelwise(4, 128)(imgs).shape

torch.Size([4, 16384, 128])

In [6]:
x = einops.reduce(imgs, "b c (w i) (h k) -> b (c i k) w h", "mean", i=4, k=4)

print(x[0, :, 0, 0])
imgs[0, :, :4, :4]

tensor([0.5891, 0.4272, 0.8274, 0.4417, 0.0085, 0.5833, 0.0616, 0.0424, 0.4950,
        0.2065, 0.5875, 0.8506, 0.8575, 0.5328, 0.0374, 0.1130, 0.5295, 0.4659,
        0.5719, 0.1116, 0.0251, 0.0025, 0.9558, 0.0884, 0.9224, 0.9056, 0.3549,
        0.4874, 0.3215, 0.8398, 0.1643, 0.6919, 0.9742, 0.6123, 0.4258, 0.9993,
        0.1846, 0.2410, 0.5798, 0.6895, 0.7292, 0.6438, 0.4387, 0.5016, 0.5945,
        0.5443, 0.7958, 0.2468])


tensor([[[0.5891, 0.4272, 0.8274, 0.4417],
         [0.0085, 0.5833, 0.0616, 0.0424],
         [0.4950, 0.2065, 0.5875, 0.8506],
         [0.8575, 0.5328, 0.0374, 0.1130]],

        [[0.5295, 0.4659, 0.5719, 0.1116],
         [0.0251, 0.0025, 0.9558, 0.0884],
         [0.9224, 0.9056, 0.3549, 0.4874],
         [0.3215, 0.8398, 0.1643, 0.6919]],

        [[0.9742, 0.6123, 0.4258, 0.9993],
         [0.1846, 0.2410, 0.5798, 0.6895],
         [0.7292, 0.6438, 0.4387, 0.5016],
         [0.5945, 0.5443, 0.7958, 0.2468]]])

In [18]:
x = torch.rand(4, 8, 197, 96)
x.transpose(-2, -1).shape

torch.Size([4, 8, 96, 197])

In [12]:
class Conv3x3(nn.Sequential):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__(*[
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels)
        ])


class ConvPatchEmbedding(nn.Module):
    
    def __init__(self, stride=16, embed_dim=768):
        super().__init__()
        num_conv_layers = int(torch.log2(torch.tensor(stride)))
        self.patch_embedding = self._get_patch_embedding(num_conv_layers, stride, embed_dim=embed_dim)
        print(self)
        
    def _get_patch_embedding(self, num_conv_layers, stride, embed_dim):
        embedding = [Conv3x3(in_channels=3, out_channels=embed_dim // (stride // 2), stride=2), nn.GELU()]
        for idx in range(num_conv_layers - 1, 1, -1):
            embedding += [Conv3x3(in_channels=embed_dim // (2 ** idx), out_channels=embed_dim // (2 ** (idx - 1)), stride=2),  nn.GELU()]
        embedding += [Conv3x3(in_channels=embed_dim // 2, out_channels=embed_dim, stride=2)]
        return nn.Sequential(*embedding)
        
    def forward(self, image):
        embed = self.patch_embedding(image)
        _, _, w, h = embed.shape
        return einops.rearrange(embed, "b c w h -> b (w h) c"), (w, h)

ConvPatchEmbedding()(torch.rand(2, 3, 128, 128))[0].shape

ConvPatchEmbedding(
  (patch_embedding): Sequential(
    (0): Conv3x3(
      (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): GELU()
    (2): Conv3x3(
      (0): Conv2d(96, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (3): GELU()
    (4): Conv3x3(
      (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (5): GELU()
    (6): Conv3x3(
      (0): Conv2d(384, 768, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)


torch.Size([2, 64, 768])

In [12]:
class LPI(nn.Module):
    """
    Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows
    to augment the implicit communcation performed by the block diagonal scatter attention.
    Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 drop=0., kernel_size=3):
        super().__init__()
        out_features = out_features or in_features

        padding = kernel_size // 2

        self.conv1 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size,
                                     padding=padding, groups=out_features)
        self.act = act_layer()
        self.bn = nn.BatchNorm2d(in_features)
        self.conv2 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size,
                                     padding=padding, groups=out_features)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        x = self.conv1(x)
        x = self.act(x)
        x = self.bn(x)
        x = self.conv2(x)
        x = x.reshape(B, C, N).permute(0, 2, 1)

        return x
    

LPI(768).cuda()(torch.rand(2, 196, 768).cuda(), 14, 14)

tensor([[[-3.9201e-02,  7.3577e-02, -3.3857e-01,  ...,  7.1404e-02,
          -5.9091e-01, -3.7378e-01],
         [-5.1637e-02, -1.1477e-01,  1.2482e-01,  ..., -7.0015e-01,
          -6.8895e-01,  2.8070e-01],
         [-8.1479e-02, -3.8260e-02,  8.0394e-01,  ..., -6.1427e-01,
           2.7029e-01, -3.6985e-01],
         ...,
         [-4.1867e-01,  1.3993e+00,  6.3790e-02,  ..., -6.3944e-02,
          -2.9100e-01, -1.2993e-01],
         [-4.8697e-01, -9.8769e-01, -2.9202e-01,  ...,  3.1321e-01,
          -5.1861e-01,  6.3304e-01],
         [-2.0910e-01,  3.9316e-01,  2.8611e-01,  ..., -2.3512e-01,
          -3.9437e-02,  4.2751e-01]],

        [[-6.8820e-01,  8.9558e-01,  4.4689e-01,  ..., -4.7872e-01,
          -3.6535e-01, -1.9128e-01],
         [-2.6235e-01, -5.6604e-01,  9.8445e-02,  ..., -1.9249e-01,
          -7.7222e-01,  3.0269e-01],
         [ 6.8488e-01,  4.2476e-01,  4.2119e-01,  ..., -5.8732e-01,
          -1.3842e-01,  5.9836e-01],
         ...,
         [-9.3575e-02,  5

In [14]:
class MLP(torch.nn.Sequential):
    def __init__(self, embed_size=768, expansion=4):
        super().__init__(*[
            nn.Linear(embed_size, embed_size * expansion),
            nn.GELU(),
            nn.Linear(embed_size * expansion, embed_size)
        ])
        

MLP(768)(torch.rand(2, 197, 768))

tensor([[[-0.0119,  0.0649,  0.0795,  ...,  0.0373, -0.0676,  0.0408],
         [ 0.0905, -0.0137,  0.1502,  ..., -0.0574, -0.0048, -0.0458],
         [ 0.1419, -0.0640,  0.1212,  ..., -0.0503,  0.0452,  0.0219],
         ...,
         [-0.0061,  0.0251,  0.0915,  ..., -0.0029, -0.0027, -0.0728],
         [ 0.0985,  0.0010,  0.0708,  ...,  0.0574, -0.0171, -0.1115],
         [-0.0151, -0.0368,  0.1491,  ..., -0.0172, -0.0019, -0.0692]],

        [[ 0.0244,  0.0497,  0.1023,  ..., -0.0365, -0.0204, -0.1134],
         [-0.0109,  0.0144,  0.1729,  ...,  0.0245, -0.0728, -0.0241],
         [-0.0014,  0.0451,  0.0597,  ..., -0.0188, -0.0286, -0.0613],
         ...,
         [ 0.0993, -0.0834,  0.1267,  ...,  0.1155, -0.0147, -0.0741],
         [-0.0208, -0.0269,  0.0578,  ...,  0.0185,  0.0106,  0.0194],
         [ 0.0376,  0.1106,  0.0475,  ...,  0.0610, -0.0200, -0.0344]]],
       grad_fn=<AddBackward0>)

In [13]:
class ClassAttention(nn.Module):

    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.divider = sqrt(self.embed_size // self.num_heads)
        self.projection = nn.Linear(embed_size, embed_size * 3)
        self.output_projection = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        qkv = einops.rearrange(self.projection(x), "b n (a c) -> a b n c", a=3)
        q, k, v = qkv[0, ...], qkv[1, ...], qkv[2, ...]
        keys = einops.rearrange(k, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(q, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(v, "b n (h e) -> b h n e", h=self.num_heads)
        queries = queries[:, 0:1, :, :]
        attention = (queries * keys).sum(-1) / self.divider
        attention = einops.rearrange(attention.softmax(1), "b h n -> b n h")
        attention = einops.rearrange(attention.unsqueeze(2) @ values, "b h t e -> b t (h e)")
        token = self.output_projection(attention)
        return torch.cat([token, x[:, 1:, :]], dim=1)
    

ClassAttention(768, 8)(torch.rand(2, 197, 768))

tensor([[[ 0.2034,  0.0966, -0.2158,  ...,  0.2347,  0.0722, -0.0249],
         [ 0.7054,  0.0563,  0.4442,  ...,  0.3818,  0.0695,  0.5631],
         [ 0.8876,  0.6540,  0.5337,  ...,  0.1968,  0.0086,  0.4894],
         ...,
         [ 0.9510,  0.3464,  0.7857,  ...,  0.2829,  0.1989,  0.7497],
         [ 0.7678,  0.3197,  0.3608,  ...,  0.0874,  0.1468,  0.6992],
         [ 0.3224,  0.8372,  0.1662,  ...,  0.1762,  0.3569,  0.1786]],

        [[ 0.2092,  0.0929, -0.2311,  ...,  0.2383,  0.0860, -0.0233],
         [ 0.3607,  0.3273,  0.5632,  ...,  0.3444,  0.9594,  0.7298],
         [ 0.0330,  0.7662,  0.6852,  ...,  0.9630,  0.2035,  0.9838],
         ...,
         [ 0.7344,  0.7691,  0.9999,  ...,  0.6485,  0.0166,  0.0254],
         [ 0.0321,  0.9927,  0.1681,  ...,  0.4650,  0.9918,  0.3013],
         [ 0.1197,  0.0235,  0.6346,  ...,  0.0768,  0.5337,  0.8871]]],
       grad_fn=<CatBackward>)

In [15]:
class ClassAttentionLayer(nn.Module):

    def __init__(self, embed_size, num_heads, use_token_norm):
        super(ClassAttentionLayer, self).__init__()
        self.attention = ClassAttention(embed_size=embed_size, num_heads=num_heads)
        self.use_token_norm = use_token_norm
        self.mlp = MLP(embed_size=embed_size)
        self.norm_attention = nn.LayerNorm(embed_size)
        self.norm_mlp = nn.LayerNorm(embed_size)

    def forward(self, x):
        x = x + self.attention(self.norm_attention(x))

        if self.use_token_norm:
            x = self.norm_mlp(x)

        else:
            x[:, 0:1, :] = self.norm_mlp(x[:, 0:1, :])

        cls_token = x[:, 0:1, :]
        cls_token = cls_token + self.mlp(cls_token)
        out_x = torch.cat([cls_token, x[:, 1:, :]], dim=1)
        return x + out_x
    

ClassAttentionLayer(768, 8, True)(torch.rand(2, 197, 768))

tensor([[[ 3.7040,  2.9413, -3.4869,  ...,  2.8102, -2.9792, -2.0274],
         [ 1.9408, -2.4944, -3.2384,  ...,  2.7298, -1.0727, -3.4231],
         [ 0.8058, -2.1132,  3.2575,  ...,  2.2010, -0.9828,  2.5714],
         ...,
         [-0.8156, -0.9888, -1.0129,  ...,  1.6347,  2.2150,  3.0470],
         [ 0.8213, -0.8219, -1.4050,  ..., -2.3250, -1.0813, -0.6954],
         [-2.2497,  2.6547, -2.1102,  ...,  2.0017,  3.2088, -1.3781]],

        [[-2.9991,  1.3460, -2.2611,  ...,  0.2030, -0.2257,  1.5942],
         [-2.1537, -0.5351,  0.5013,  ..., -1.5655, -0.6691,  2.5270],
         [ 0.2471, -3.4727, -1.9128,  ...,  2.9075, -1.7280,  2.2539],
         ...,
         [-2.7479,  2.4425, -0.2620,  ...,  2.4858,  0.2755, -1.7910],
         [-1.9909, -0.0783,  2.4433,  ...,  3.3677, -1.7459,  2.4705],
         [ 2.1301,  2.6681,  3.4611,  ..., -1.9404, -3.2601, -0.8471]]],
       grad_fn=<AddBackward0>)

In [4]:
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
# x = torch.nn.functional.pad(x, (0, 1, 0, 1))
torch.roll(x, dims=(1, 0), shifts=(-2, -2))


tensor([[11, 12,  9, 10],
        [15, 16, 13, 14],
        [ 3,  4,  1,  2],
        [ 7,  8,  5,  6]])

In [26]:
class MLP(torch.nn.Sequential):
    def __init__(self, embed_size=768, expansion=4):
        super().__init__(*[
            nn.Linear(embed_size, embed_size * expansion),
            nn.GELU(),
            nn.Linear(embed_size * expansion, embed_size)
        ])


class ResidualAdd(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block

    def forward(self, x):
        return x + self.block(x)
    

class SwinMSA(nn.Module):
    
    def __init__(self, embed_dim, num_heads, attention_mask=None,
                 attention_dropout=0.0, projection_dropout=0.0):
        super(SwinMSA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.scale = (self.embed_dim // self.num_heads) ** (-0.5)

        self.attention_dropout = nn.Dropout(attention_dropout)
        self.projection_dropout = nn.Dropout(projection_dropout)

        self.register_buffer("attention_mask", attention_mask)

    def forward(self, x):
        """
        Multihead self-attention for swin module.
        Args:
            x: tensor of shape: b n c

        Returns:

        """
        print(x.shape)
        qkv = self.qkv(x)
        qkv = einops.rearrange(qkv, "b n (h d c) -> d b n h c", h=self.num_heads, d=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy_term = torch.einsum("bqhe, bkhe -> bqhk", queries, keys)
        energy_term *= self.scale

        if self.attention_mask is not None:
            energy_term = energy_term + self.attention_mask

        energy_term = energy_term.softmax(dim=-1)
        out = torch.einsum('bihv, bvhd -> bihd ', energy_term, values)
        print(out.shape)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        out = self.projection(out)
        return self.projection_dropout(out)


class SwinBlock(nn.Module):
    
    def __init__(self, shift_size, embed_dim, num_heads, image_resolution,
                 shifted_block=False, attention_dropout=0.0, projection_dropout=0.0):
        super(SwinBlock, self).__init__()
        self.resolution = image_resolution
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.shifted = shifted_block
        self.shift_size = shift_size
        self.attention_norm = nn.LayerNorm(embed_dim)
        attention_mask = self._get_attention_mask()

        self.attention = SwinMSA(
            embed_dim,
            num_heads,
            attention_mask=attention_mask,
            attention_dropout=attention_dropout,
            projection_dropout=projection_dropout
        )
        self.mlp = ResidualAdd(nn.Sequential(*[nn.LayerNorm(embed_dim), MLP(embed_size=embed_dim)]))
        
    def _get_attention_mask(self):
        if self.shifted:
            # TODO: fill
            return torch.ones(1)
        else:
            return None

    def cyclic_shift(self, tensor):
        return torch.roll(tensor, (-self.shift_size, -self.shift_size), dims=(1, 2))

    def reverse_cyclic_shift(self, tensor):
        return torch.roll(tensor, (self.shift_size, self.shift_size), dims=(1, 2))

    def window_reverse(self, tensor, h , w):
        image = einops.rearrange(tensor, "(b h w) (shifty shiftx) c -> b shifty h shiftx w c",
                                 shiftx=self.shift_size, shifty=self.shift_size, h=h//self.shift_size, w=w//self.shift_size)
        return einops.rearrange(image, "b sy h sx w c -> b (sy h) (sx w) c")

    def window_partition(self, tensor):
        windows = einops.rearrange(tensor, "b (sy h) (sx w) c -> b sy h sx w c", sx=self.shift_size, sy=self.shift_size)
        return einops.rearrange(windows, "b sy h sx w c -> (b h w) (sy sx) c")

    def forward(self, x):
        """
        Single swin block execution
        Args:
            img: (B (H W) C) tensor.

        Returns:

        """
        b, n, c = x.shape
        assert n == self.resolution[0] * self.resolution[1]
        img = einops.rearrange(x, "b (h w) c ->  b h w c", h=self.resolution[0], w=self.resolution[1])
        if self.shifted:
            img = self.cyclic_shift(img)
        print(img.shape)
        norm1 = self.attention_norm(img)
        print(norm1.shape)
        partitions = self.window_partition(norm1)
        print(partitions.shape)
        attention = self.attention(partitions)
        reverse_shifted = self.window_reverse(attention, *self.resolution)
        if self.shifted:
            reverse_shifted = self.reverse_cyclic_shift(reverse_shifted)
        msa_out = einops.rearrange(reverse_shifted, "b h w c -> b (h w) c")
        msa_out = x + msa_out
        print(msa_out.shape)
        return msa_out + self.mlp(msa_out)
    
block = SwinBlock(7, 256, 8, (224, 224), False)(torch.rand(2, 224 * 224, 256))

torch.Size([2, 224, 224, 256])
torch.Size([2, 224, 224, 256])
torch.Size([2048, 49, 256])
torch.Size([2048, 49, 256])


In [7]:
def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows



H, W = 15, 15
window_size = 3
shift_size = 1
img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        print(h)
        print(w)
        img_mask[:, h, w, :] = cnt
        cnt += 1

torch.set_printoptions(threshold=10_000)

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask2 = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
x = mask_windows[24, :]
print(img_mask.squeeze())
print(mask_windows.squeeze())
print(attn_mask2.shape)
attn_mask = attn_mask2.masked_fill(attn_mask2 != 0, float(-100.0)).masked_fill(attn_mask2 == 0, float(0.0))

slice(0, -3, None)
slice(0, -3, None)
slice(0, -3, None)
slice(-3, -1, None)
slice(0, -3, None)
slice(-1, None, None)
slice(-3, -1, None)
slice(0, -3, None)
slice(-3, -1, None)
slice(-3, -1, None)
slice(-3, -1, None)
slice(-1, None, None)
slice(-1, None, None)
slice(0, -3, None)
slice(-1, None, None)
slice(-3, -1, None)
slice(-1, None, None)
slice(-1, None, None)
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 2.],
    

In [2]:
num_heads = 8
window_size = (7, 7)
relative_position_bias_table = torch.rand(
            (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords_new = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords_new.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(
            window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

In [3]:
relative_position_bias_table[x, :].shape

NameError: name 'x' is not defined