In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
batch = 4
channel = 32
height = 128
width = 128

In [3]:
test_t = torch.rand(batch, channel, height, width)
test_t.shape

torch.Size([4, 32, 128, 128])

In [4]:
dim = 32
num_heads = 4
dim_head = 32
inner_dim = num_heads * dim_head

In [5]:
to_qkv = nn.Conv2d(in_channels=dim, out_channels=inner_dim * 3, kernel_size=1, bias=False)

In [6]:
qkv = to_qkv(test_t)
qkv.shape

torch.Size([4, 384, 128, 128])

In [7]:
qkv = qkv.chunk(3, dim=1)
len(qkv)

3

In [8]:
qkv[0].shape

torch.Size([4, 128, 128, 128])

In [9]:
q_reshaped = qkv[0].view(batch, inner_dim, -1)
q_reshaped.shape

torch.Size([4, 128, 16384])

In [10]:
q_reshaped = q_reshaped.view(batch, num_heads, dim_head, -1)
q_reshaped.shape

torch.Size([4, 4, 32, 16384])

In [11]:
q, k, v = map(lambda x: x.view(batch, num_heads, dim_head, -1), qkv)
q.shape

torch.Size([4, 4, 32, 16384])

In [12]:
torch.eq(q, q_reshaped)

tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...

In [13]:
s = q.matmul(k.swapdims(-1, -2))
s.shape

torch.Size([4, 4, 32, 32])

In [14]:
attn = F.softmax(s, dim=-1)
attn.shape

torch.Size([4, 4, 32, 32])

In [15]:
out = attn.matmul(v)
out.shape

torch.Size([4, 4, 32, 16384])

In [16]:
out2 = out.view(batch, inner_dim, height, width)
out2.shape

torch.Size([4, 128, 128, 128])

In [19]:
q2, k2, v2 = map(lambda x: x.view(batch, num_heads, dim_head, -1), qkv)
q2.shape

torch.Size([4, 4, 32, 16384])

In [20]:
q2.softmax(dim=-2)
k2.softmax(dim=-1)

tensor([[[[5.3062e-05, 7.8847e-05, 7.9855e-05,  ..., 6.2761e-05,
           8.2297e-05, 6.9736e-05],
          [4.9579e-05, 5.7943e-05, 5.4345e-05,  ..., 5.3442e-05,
           6.1293e-05, 5.5299e-05],
          [6.6173e-05, 6.9737e-05, 7.9594e-05,  ..., 7.1509e-05,
           8.0747e-05, 7.1145e-05],
          ...,
          [7.8785e-05, 5.2889e-05, 5.4667e-05,  ..., 7.3684e-05,
           4.9925e-05, 6.5018e-05],
          [8.3834e-05, 7.5841e-05, 6.5846e-05,  ..., 5.3267e-05,
           5.9850e-05, 7.8551e-05],
          [6.1350e-05, 5.8024e-05, 5.6659e-05,  ..., 7.2465e-05,
           5.5833e-05, 6.5400e-05]],

         [[4.2616e-05, 4.9286e-05, 5.0862e-05,  ..., 4.7694e-05,
           5.5041e-05, 4.5461e-05],
          [4.6266e-05, 5.6972e-05, 6.0124e-05,  ..., 6.2430e-05,
           6.8484e-05, 6.9369e-05],
          [4.8719e-05, 3.5395e-05, 4.9985e-05,  ..., 7.2277e-05,
           5.0612e-05, 4.5900e-05],
          ...,
          [7.3793e-05, 6.7570e-05, 6.1700e-05,  ..., 6.1805

In [21]:
q = q * dim_head ** 0.5
v = v / (height * width)

In [22]:
context1 = k.matmul(v.transpose(-1, -2))

In [23]:
context2 = torch.einsum('b h d n, b h e n -> b h d e', k, v)

In [24]:
torch.eq(context1, context2)

tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...

In [25]:
out1 = torch.einsum('b h d e, b h d n -> b h e n', context1, q)

In [27]:
out2 = context2.transpose(-1, -2).matmul(q)

In [28]:
torch.eq(out1, out2)

tensor([[[[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True]],

         [[True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          ...,
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ..., True, True, True],
          [True, True, True,  ...