In [35]:
import torch
from torch import einsum
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from vit_pytorch.max_vit import MaxViT, MBConv
from einops import rearrange

In [52]:
# pool = nn.MaxPool2d(kernel_size=2, stride=2)
pool = nn.AvgPool2d(kernel_size=2, stride=2)

In [53]:
x = torch.randn(1, 1, 24, 24)
pool(x).shape

torch.Size([1, 1, 12, 12])

In [51]:
x = torch.randn(2, 3)
torch.stack([x]).shape

torch.Size([1, 2, 3])

In [20]:
window_size = 7
pos = torch.arange(window_size)
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) # (2, 7, 7)
grid = rearrange(grid, 'c i j -> (i j) c') # (49, 2)
rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
rel_pos += window_size - 1
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

In [21]:
rel_pos_indices.shape

torch.Size([49, 49])

In [22]:
rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, 4)

In [23]:
bias = rel_pos_bias(rel_pos_indices)
bias.shape

torch.Size([49, 49, 4])

In [24]:
rel_pos_indices.shape

torch.Size([49, 49])

In [25]:
rearrange(grid, 'i ... -> i 1 ...').shape

torch.Size([49, 1, 2])

In [26]:
rearrange(grid, 'j ... -> 1 j ...').shape

torch.Size([1, 49, 2])

In [27]:
rel_pos.shape

torch.Size([49, 49, 2])

In [28]:
(rel_pos * torch.tensor([2 * window_size - 1, 1])).shape

torch.Size([49, 49, 2])

In [29]:
torch.meshgrid(pos, pos, indexing = 'ij')

(tensor([[0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2],
         [3, 3, 3, 3, 3, 3, 3],
         [4, 4, 4, 4, 4, 4, 4],
         [5, 5, 5, 5, 5, 5, 5],
         [6, 6, 6, 6, 6, 6, 6]]),
 tensor([[0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6],
         [0, 1, 2, 3, 4, 5, 6]]))

In [30]:
rel_pos_indices.shape

torch.Size([49, 49])

In [31]:
torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')).shape

torch.Size([2, 7, 7])

In [32]:
grid.shape

torch.Size([49, 2])

In [49]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 32,
        dropout = 0.,
        window_size = 7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

        self.attend = nn.Sequential(
            nn.Softmax(dim = -1),
            nn.Dropout(dropout)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias = False),
            nn.Dropout(dropout)
        )

        # relative positional bias

        self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)

        pos = torch.arange(window_size)
        grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
        grid = rearrange(grid, 'c i j -> (i j) c')
        rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
        rel_pos += window_size - 1
        rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)

        self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)

    def forward(self, x):
        batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

        # flatten

        x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')

        # project for queries, keys, values

        q, k, v = self.to_qkv(x).chunk(3, dim = -1)

        # split heads

        q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v))
        print(q.shape)
        print(k.shape)
        print(v.shape)
        # scale

        q = q * self.scale

        # sim

        sim = einsum('b h i d, b h j d -> b h i j', q, k)
        print(sim.shape)
        # add positional bias

        bias = self.rel_pos_bias(self.rel_pos_indices)
        print(bias.shape)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # attention

        attn = self.attend(sim)

        # aggregate

        out = einsum('b h i j, b h j d -> b h i d', attn, v)

        # merge heads
        print('out')
        print(out.shape)

        out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)

        # combine heads out
        print(out.shape)

        out = self.to_out(out)
        print(out.shape)
        return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)

In [50]:
attn = Attention(768, 32, 0, 28)
x = torch.randn((2, 2, 2, 28, 28, 768))
attn(x).shape

torch.Size([8, 24, 784, 32])
torch.Size([8, 24, 784, 32])
torch.Size([8, 24, 784, 32])
torch.Size([8, 24, 784, 784])
torch.Size([784, 784, 24])
out
torch.Size([8, 24, 784, 32])
torch.Size([8, 28, 28, 768])
torch.Size([8, 28, 28, 768])


torch.Size([2, 2, 2, 28, 28, 768])