In [2]:
import torch.nn as nn
import torch
import einops
from math import sqrt

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        energy_term = torch.einsum("bqhe, bkhe -> bqhk", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('bihv, bvhd -> bihd ', mh_out / divider, values)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)
    
x = torch.rand(4, 197, 768)
MultiHeadAttention(768, 8)(x).shape

In [None]:
import torch.nn.functional as F

class MultiHeadXCITAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store
        self.tau = torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        keys = F.normalize(keys, p=2, dim=1)
        queries = F.normalize(queries, p=2, dim=1)
        energy_term = torch.einsum("bnhe, bnhq -> behq", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('behq, bnhe -> bnhq ', mh_out / divider, values)
        print(out.shape)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)


MultiHeadXCITAttention(768, 8)(x).shape

In [3]:
from einops.layers.torch import Reduce, Rearrange

class PatchEmbeddingPixelwise(torch.nn.Sequential):
    
    def __init__(self, stride, embedding_size, channels=3) -> None:
        reduce = Reduce("b c (w i) (h k) -> b (c i k) w h", "mean", i=stride, k=stride)
        rearange = Rearrange("b e h w -> b (h w) e")
        linear = torch.nn.Linear(stride * stride * channels, embedding_size)
        super().__init__(*[
            reduce,
            rearange,
            linear
        ])
        
imgs = torch.rand(4, 3, 512, 512)

PatchEmbeddingPixelwise(4, 128)(imgs).shape

torch.Size([4, 16384, 128])

In [None]:
x = einops.reduce(imgs, "b c (w i) (h k) -> b (c i k) w h", "mean", i=4, k=4)

print(x[0, :, 0, 0])
imgs[0, :, :4, :4]

In [None]:
x = torch.rand(4, 8, 197, 96)
x.transpose(-2, -1).shape

In [None]:
class Conv3x3(nn.Sequential):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__(*[
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels)
        ])


class ConvPatchEmbedding(nn.Module):
    
    def __init__(self, stride=16, embed_dim=768):
        super().__init__()
        num_conv_layers = int(torch.log2(torch.tensor(stride)))
        self.patch_embedding = self._get_patch_embedding(num_conv_layers, stride, embed_dim=embed_dim)
        print(self)
        
    def _get_patch_embedding(self, num_conv_layers, stride, embed_dim):
        embedding = [Conv3x3(in_channels=3, out_channels=embed_dim // (stride // 2), stride=2), nn.GELU()]
        for idx in range(num_conv_layers - 1, 1, -1):
            embedding += [Conv3x3(in_channels=embed_dim // (2 ** idx), out_channels=embed_dim // (2 ** (idx - 1)), stride=2),  nn.GELU()]
        embedding += [Conv3x3(in_channels=embed_dim // 2, out_channels=embed_dim, stride=2)]
        return nn.Sequential(*embedding)
        
    def forward(self, image):
        embed = self.patch_embedding(image)
        _, _, w, h = embed.shape
        return einops.rearrange(embed, "b c w h -> b (w h) c"), (w, h)

ConvPatchEmbedding()(torch.rand(2, 3, 128, 128))[0].shape

In [None]:
class LPI(nn.Module):
    """
    Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows
    to augment the implicit communcation performed by the block diagonal scatter attention.
    Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
                 drop=0., kernel_size=3):
        super().__init__()
        out_features = out_features or in_features

        padding = kernel_size // 2

        self.conv1 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size,
                                     padding=padding, groups=out_features)
        self.act = act_layer()
        self.bn = nn.BatchNorm2d(in_features)
        self.conv2 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size,
                                     padding=padding, groups=out_features)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        x = self.conv1(x)
        x = self.act(x)
        x = self.bn(x)
        x = self.conv2(x)
        x = x.reshape(B, C, N).permute(0, 2, 1)

        return x
    

LPI(768).cuda()(torch.rand(2, 196, 768).cuda(), 14, 14)

In [None]:
class MLP(torch.nn.Sequential):
    def __init__(self, embed_size=768, expansion=4):
        super().__init__(*[
            nn.Linear(embed_size, embed_size * expansion),
            nn.GELU(),
            nn.Linear(embed_size * expansion, embed_size)
        ])
        

MLP(768)(torch.rand(2, 197, 768))

In [None]:
class ClassAttention(nn.Module):

    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.divider = sqrt(self.embed_size // self.num_heads)
        self.projection = nn.Linear(embed_size, embed_size * 3)
        self.output_projection = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        qkv = einops.rearrange(self.projection(x), "b n (a c) -> a b n c", a=3)
        q, k, v = qkv[0, ...], qkv[1, ...], qkv[2, ...]
        keys = einops.rearrange(k, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(q, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(v, "b n (h e) -> b h n e", h=self.num_heads)
        queries = queries[:, 0:1, :, :]
        attention = (queries * keys).sum(-1) / self.divider
        attention = einops.rearrange(attention.softmax(1), "b h n -> b n h")
        attention = einops.rearrange(attention.unsqueeze(2) @ values, "b h t e -> b t (h e)")
        token = self.output_projection(attention)
        return torch.cat([token, x[:, 1:, :]], dim=1)
    

ClassAttention(768, 8)(torch.rand(2, 197, 768))

In [None]:
class ClassAttentionLayer(nn.Module):

    def __init__(self, embed_size, num_heads, use_token_norm):
        super(ClassAttentionLayer, self).__init__()
        self.attention = ClassAttention(embed_size=embed_size, num_heads=num_heads)
        self.use_token_norm = use_token_norm
        self.mlp = MLP(embed_size=embed_size)
        self.norm_attention = nn.LayerNorm(embed_size)
        self.norm_mlp = nn.LayerNorm(embed_size)

    def forward(self, x):
        x = x + self.attention(self.norm_attention(x))

        if self.use_token_norm:
            x = self.norm_mlp(x)

        else:
            x[:, 0:1, :] = self.norm_mlp(x[:, 0:1, :])

        cls_token = x[:, 0:1, :]
        cls_token = cls_token + self.mlp(cls_token)
        out_x = torch.cat([cls_token, x[:, 1:, :]], dim=1)
        return x + out_x
    

ClassAttentionLayer(768, 8, True)(torch.rand(2, 197, 768))

In [None]:
x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
# x = torch.nn.functional.pad(x, (0, 1, 0, 1))
torch.roll(x, dims=(1, 0), shifts=(-2, -2))


In [18]:
class MLP(torch.nn.Sequential):
    def __init__(self, embed_size=768, expansion=4):
        super().__init__(*[
            nn.Linear(embed_size, embed_size * expansion),
            nn.GELU(),
            nn.Linear(embed_size * expansion, embed_size)
        ])


class ResidualAdd(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block

    def forward(self, x):
        return x + self.block(x)
    

class SwinMSA(nn.Module):

    def __init__(self, embed_dim, num_heads, window_size, attention_mask=None,
                 attention_dropout=0.0, projection_dropout=0.0):
        super(SwinMSA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.projection = nn.Linear(embed_dim, embed_dim)
        self.scale = (self.embed_dim // self.num_heads) ** (-0.5)

        self.attention_dropout = nn.Dropout(attention_dropout)
        self.projection_dropout = nn.Dropout(projection_dropout)
        self.relative_position_bias_table = torch.nn.Parameter(torch.zeros(
            (2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        self.register_buffer("attention_mask", attention_mask)
        self.register_buffer("relative_bias_index", self._get_relative_bias_index(window_size=window_size))

    @staticmethod
    def _get_relative_bias_index(window_size):
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords_new = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords_new.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_bias_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        return relative_bias_index

    def forward(self, x):
        """
        Multihead self-attention for swin module.
        Args:
            x: tensor of shape: b n c

        Returns:

        """
        print(x.shape)
        qkv = self.qkv(x)
        qkv = einops.rearrange(qkv, "b n (h d c) -> d b n h c", h=self.num_heads, d=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        energy_term = torch.einsum("bqhe, bkhe -> bqhk", queries, keys)
        energy_term *= self.scale
        relative_position_bias = self.relative_position_bias_table[self.relative_bias_index.view(-1)].view(
            self.window_size * self.window_size, self.window_size * self.window_size, -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(0, 2, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        print(relative_position_bias.shape)
        print(energy_term.shape)
        energy_term = energy_term + relative_position_bias
        if self.attention_mask is not None:
            number_of_windows = self.attention_mask.shape[0]
            energy_term = einops.rearrange(energy_term, "(b nw) h width height -> b nw h width height", nw=number_of_windows)
            print(self.attention_mask.shape)
            energy_term = energy_term + self.attention_mask.unsqueeze(2).unsqueeze(0)
            energy_term = einops.rearrange(energy_term, "b nw h width height -> (b nw) h width height")
        energy_term = energy_term.softmax(dim=-1)
        out = torch.einsum('bihv, bvhd -> bihd ', energy_term, values)
        print(out.shape)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        out = self.projection(out)
        return self.projection_dropout(out)


class SwinBlock(nn.Module):

    def __init__(self, window_size, embed_dim, num_heads, image_resolution, shift_size=0,
                 attention_dropout=0.0, projection_dropout=0.0):
        super(SwinBlock, self).__init__()
        self.window_size = window_size
        self.resolution = image_resolution
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.shift_size = shift_size
        self.attention_norm = nn.LayerNorm(embed_dim)
        attention_mask = self._get_attention_mask(shift_size, window_size, image_resolution[1], image_resolution[0])
        self.attention = SwinMSA(
            embed_dim,
            num_heads,
            window_size=window_size,
            attention_mask=attention_mask,
            attention_dropout=attention_dropout,
            projection_dropout=projection_dropout
        )
        self.mlp = ResidualAdd(nn.Sequential(*[nn.LayerNorm(embed_dim), MLP(embed_size=embed_dim)]))

    def _get_attention_mask(self, shift, window_size, h, w):
        if self.shift_size > 0:
            image_mask = torch.zeros((1, h, w, 1))
            h_slices = (slice(0, -window_size),
                        slice(-window_size, -shift),
                        slice(-shift, None))
            w_slices = (slice(0, -window_size),
                        slice(-window_size, -shift),
                        slice(-shift, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    image_mask[:, h, w, :] = cnt
                    cnt += 1
            attention_mask = self.window_partition(image_mask)
            attention_mask = einops.rearrange(attention_mask, "num_windows (window_size_y window_size_x) channels -> "
                                                              "(num_windows channels) (window_size_y window_size_x)", 
                                                              window_size_y=window_size, 
                                                              window_size_x=window_size)
            attention_mask = attention_mask.unsqueeze(1) - attention_mask.unsqueeze(2)
            attention_mask = attention_mask.masked_fill(attention_mask != 0, -100.).masked_fill(attention_mask == 0, 0.)
            return attention_mask
        else:
            return None

    def cyclic_shift(self, tensor):
        return torch.roll(tensor, (-self.shift_size, -self.shift_size), dims=(1, 2))

    def reverse_cyclic_shift(self, tensor):
        return torch.roll(tensor, (self.shift_size, self.shift_size), dims=(1, 2))

    def window_reverse(self, tensor, h, w):
        image = einops.rearrange(tensor, "(b h w) (shifty shiftx) c -> b shifty h shiftx w c",
                                 shiftx=self.window_size, shifty=self.window_size,
                                 h=h // self.window_size, w=w // self.window_size)
        return einops.rearrange(image, "b sy h sx w c -> b (sy h) (sx w) c")

    def window_partition(self, tensor):
        windows = einops.rearrange(tensor, "b (wy h) (wx w) c -> b wy h wx w c",
                                   wx=self.window_size, wy=self.window_size)
        return einops.rearrange(windows, "b wy h wx w c -> (b h w) (wy wx) c")

    def forward(self, x):
        """
        Single swin block execution
        Args:
            x: (B (H W) C) tensor.

        Returns:

        """
        b, n, c = x.shape
        assert n == self.resolution[0] * self.resolution[1]
        img = einops.rearrange(x, "b (h w) c ->  b h w c", h=self.resolution[0], w=self.resolution[1])
        if self.shift_size:
            img = self.cyclic_shift(img)
        print(img.shape)
        norm1 = self.attention_norm(img)
        print(norm1.shape)
        partitions = self.window_partition(norm1)
        print(partitions.shape)
        attention = self.attention(partitions)
        reverse_shifted = self.window_reverse(attention, *self.resolution)
        if self.shift_size:
            reverse_shifted = self.reverse_cyclic_shift(reverse_shifted)
        msa_out = einops.rearrange(reverse_shifted, "b h w c -> b (h w) c")
        msa_out = x + msa_out
        print(msa_out.shape)
        return msa_out + self.mlp(msa_out)

block = SwinBlock(window_size=7, embed_dim=768, num_heads=8, image_resolution=(28, 28), shift_size=3)
out = block(torch.rand(1, (28 * 28), 768))
out.shape

torch.Size([1, 28, 28, 768])
torch.Size([1, 28, 28, 768])
torch.Size([16, 49, 768])
torch.Size([16, 49, 768])
torch.Size([49, 8, 49])
torch.Size([16, 49, 8, 49])
torch.Size([16, 49, 49])
torch.Size([16, 49, 8, 96])
torch.Size([1, 784, 768])


torch.Size([1, 784, 768])

In [None]:
num_heads = 8
window_size = (7, 7)
relative_position_bias_table = torch.rand(
            (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)  # 2*Wh-1 * 2*Ww-1, nH

# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords_new = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords_new.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias_table[relative_position_index.view(-1)].view(
            window_size[0] * window_size[1], window_size[0] * window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

In [31]:

class PatchMergingLayer(nn.Module):

    def __init__(self, in_channels, out_channels, image_resolution, dropout_rate=0.1):
        super().__init__()
        self.merging_dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(in_channels * 4, out_channels)
        self.norm = nn.LayerNorm(out_channels)
        self.image_resolution = image_resolution


    def forward(self, x):
        h, w = self.image_resolution
        x = einops.rearrange(x, "b (h w) c -> b h w c", h=h, w=w)
        top_left_corner = x[:, 0::2, 0::2, :]
        top_right_corner = x[:, 1::2, 0::2, :]
        bottom_right_corner = x[:, 1::2, 1::2, :]
        bottom_left_corner = x[:, 0::2, 1::2, :]
        features = torch.cat([
            top_left_corner,
            top_right_corner,
            bottom_left_corner,
            bottom_right_corner
            ], dim=-1)
        return einops.rearrange(
            self.merging_dropout(self.norm(self.linear(features))),
            "b h w c -> b (h w) c"
            )
        


PatchMergingLayer(128, 256, image_resolution=(28, 28), dropout_rate=0.5)(torch.rand(2, (28 * 28), 128))

tensor([[[ 0.0000,  0.3176,  1.3113,  ..., -0.0000, -2.0796, -1.0306],
         [ 0.0000,  0.0000,  0.7587,  ...,  1.7577, -1.5120,  0.0000],
         [ 0.0000,  0.0000, -0.1454,  ...,  0.6666, -0.0000, -0.0896],
         ...,
         [ 0.0000,  0.0000,  0.5860,  ...,  1.2830, -0.0000, -0.4082],
         [ 0.0000,  0.0000,  1.6942,  ...,  0.0000, -1.6115, -0.0000],
         [ 0.0000,  0.1895, -0.4519,  ..., -0.7924, -0.0000, -0.0000]],

        [[-0.0000,  1.1288,  0.7217,  ...,  0.0000, -0.0000, -0.0000],
         [ 2.2393,  0.0000,  0.0000,  ...,  0.8778, -0.0000, -0.0000],
         [ 0.0000,  0.0000, -0.9392,  ..., -0.0000, -2.6729, -1.0166],
         ...,
         [ 0.0000,  2.1562,  0.0000,  ...,  0.9410, -0.0000, -1.7900],
         [ 0.0000,  0.5322, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.9249,  ...,  1.9310, -0.0000,  0.1573]]],
       grad_fn=<ViewBackward>)