In [2]:
import torch
from einops import rearrange
from torch import nn

In [3]:
class LayerNorm2d(nn.LayerNorm):
    def forward(self, x):
        x = rearrange(x, "b c h w -> b h w c")
        x = super().forward(x)
        x = rearrange(x, "b h w c -> b c h w")
        return x


class OverlapPatchMerging(nn.Sequential):
    # input image with tensor b, c, h, w

    def __init__(
        self, in_channels:int, out_channels:int, patch_size:int, overlap_size:int, 
    ):
        super().__init__(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=patch_size,
                stride=overlap_size,
                padding=patch_size//2,
                bias=False
            ),
            # Layer Norm
            LayerNorm2d(out_channels)
        )

### test

In [4]:
N, C, H, W = 5, 3, 512, 512
testinput = torch.randn(N, C, H, W)
merge_encode_768 = OverlapPatchMerging(3,768,7,4)
out_test = merge_encode_768(testinput)
out_test.shape

torch.Size([5, 768, 128, 128])

In [5]:
from torch import einsum
class EfficientSelfAttention(nn.Module):
    def __init__(self, *, dim, heads, reduction_ratio):
        super(EfficientSelfAttention, self).__init__()
        self.scale = (dim // heads) ** -0.5
        self.heads = heads

        self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
        self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
        self.to_out = nn.Conv2d(dim, dim, 1, bias=False)

    def forward(self, x):
        h, w = x.shape[-2:]
        heads = self.heads

        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))

        # q @ k.transpose(-2, -1) * self.scale
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = sim.softmax(dim=-1)

        # attn @ V
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
        out = self.to_out(out)
        return out



### Test

In [7]:

heads = 2      # Number of attention heads
reduction_ratio = 4  # Reduction ratio for keys and values

efficient_attn = EfficientSelfAttention(dim=768, heads=heads, reduction_ratio=reduction_ratio)
output = efficient_attn(out_test)
print(output.shape)  # Should print the output shape

torch.Size([5, 768, 128, 128])


In [8]:
class MixMLP(nn.Sequential):
    def __init__(self, channels: int, expansion: int = 4):
        super().__init__(
            # dense layer
            nn.Conv2d(channels, channels, kernel_size=1),
            # depth wise conv
            nn.Conv2d(
                channels,
                channels * expansion,
                kernel_size=3,
                groups=channels,
                padding=1,
            ),
            nn.GELU(),
            # dense layer
            nn.Conv2d(channels * expansion, channels, kernel_size=1),
        )
