In [10]:
import torch.nn as nn
import torch
import einops
from math import sqrt


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        energy_term = torch.einsum("bqhe, bkhe -> bqhk", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('bihv, bvhd -> bihd ', mh_out / divider, values)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)
    
x = torch.rand(4, 197, 768)
MultiHeadAttention(768, 8)(x).shape

torch.Size([4, 197, 8, 197])


torch.Size([4, 197, 768])

In [16]:
import torch.nn.functional as F

class MultiHeadXCITAttention(torch.nn.Module):
    def __init__(self, embed_size, num_heads, attention_store=None):
        super().__init__()
        self.queries_projection = nn.Linear(embed_size, embed_size)
        self.values_projection = nn.Linear(embed_size, embed_size)
        self.keys_projection = nn.Linear(embed_size, embed_size)
        self.final_projection = nn.Linear(embed_size, embed_size)
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.attention_store = attention_store
        self.tau = torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        assert len(x.shape) == 3
        keys = self.keys_projection(x)
        values = self.values_projection(x)
        queries = self.queries_projection(x)
        keys = einops.rearrange(keys, "b n (h e) -> b n h e", h=self.num_heads)
        queries = einops.rearrange(queries, "b n (h e) -> b n h e", h=self.num_heads)
        values = einops.rearrange(values, "b n (h e) -> b n h e", h=self.num_heads)
        keys = F.normalize(keys, p=2, dim=1)
        queries = F.normalize(queries, p=2, dim=1)
        energy_term = torch.einsum("bnhe, bnhq -> behq", queries, keys)
        print(energy_term.shape)
        divider = sqrt(self.embed_size)
        mh_out = torch.softmax(energy_term, -1)
        if self.attention_store is not None:
            self.attention_store.append(mh_out.detach().cpu())
        out = torch.einsum('behq, bnhe -> bnhq ', mh_out / divider, values)
        print(out.shape)
        out = einops.rearrange(out, "b n h e -> b n (h e)")
        return self.final_projection(out)


MultiHeadXCITAttention(768, 8)(x).shape

torch.Size([4, 96, 8, 96])
torch.Size([4, 197, 8, 96])


torch.Size([4, 197, 768])

In [5]:
from einops.layers.torch import Reduce, Rearrange

class PatchEmbeddingPixelwise(torch.nn.Sequential):
    
    def __init__(self, stride, embedding_size, channels=3) -> None:
        reduce = Reduce("b c (w i) (h k) -> b (c i k) w h", "mean", i=stride, k=stride)
        rearange = Rearrange("b e h w -> b (h w) e")
        linear = torch.nn.Linear(stride * stride * channels, embedding_size)
        super().__init__(*[
            reduce,
            rearange,
            linear
        ])
        
imgs = torch.rand(4, 3, 512, 512)

PatchEmbeddingPixelwise(4, 128)(imgs).shape

torch.Size([4, 16384, 128])

In [6]:
x = einops.reduce(imgs, "b c (w i) (h k) -> b (c i k) w h", "mean", i=4, k=4)

print(x[0, :, 0, 0])
imgs[0, :, :4, :4]

tensor([0.5891, 0.4272, 0.8274, 0.4417, 0.0085, 0.5833, 0.0616, 0.0424, 0.4950,
        0.2065, 0.5875, 0.8506, 0.8575, 0.5328, 0.0374, 0.1130, 0.5295, 0.4659,
        0.5719, 0.1116, 0.0251, 0.0025, 0.9558, 0.0884, 0.9224, 0.9056, 0.3549,
        0.4874, 0.3215, 0.8398, 0.1643, 0.6919, 0.9742, 0.6123, 0.4258, 0.9993,
        0.1846, 0.2410, 0.5798, 0.6895, 0.7292, 0.6438, 0.4387, 0.5016, 0.5945,
        0.5443, 0.7958, 0.2468])


tensor([[[0.5891, 0.4272, 0.8274, 0.4417],
         [0.0085, 0.5833, 0.0616, 0.0424],
         [0.4950, 0.2065, 0.5875, 0.8506],
         [0.8575, 0.5328, 0.0374, 0.1130]],

        [[0.5295, 0.4659, 0.5719, 0.1116],
         [0.0251, 0.0025, 0.9558, 0.0884],
         [0.9224, 0.9056, 0.3549, 0.4874],
         [0.3215, 0.8398, 0.1643, 0.6919]],

        [[0.9742, 0.6123, 0.4258, 0.9993],
         [0.1846, 0.2410, 0.5798, 0.6895],
         [0.7292, 0.6438, 0.4387, 0.5016],
         [0.5945, 0.5443, 0.7958, 0.2468]]])

In [18]:
x = torch.rand(4, 8, 197, 96)
x.transpose(-2, -1).shape

torch.Size([4, 8, 96, 197])