In [None]:
import math

import torch
import torch.nn as nn

In [None]:
from diffusers.models.attention import Attention

In [None]:
from miniai.activations import set_seed

In [None]:
set_seed(1103)

https://arxiv.org/abs/1706.03762

In [None]:
x = torch.randn(32, 8, 4, 4)

In [None]:
inp = x.reshape(x.shape[0], x.shape[1], -1)
inp.shape

torch.Size([32, 8, 16])

In [None]:
inp = inp.transpose(1, 2)
inp.shape

torch.Size([32, 16, 8])

In [None]:
n_dim = inp.shape[-1]

In [None]:
lin_q = nn.Linear(n_dim, n_dim)
lin_k = nn.Linear(n_dim, n_dim)
lin_v = nn.Linear(n_dim, n_dim)

In [None]:
q, k, v = lin_q(inp), lin_k(inp), lin_v(inp)
q.shape

torch.Size([32, 16, 8])

In [None]:
scores = (q @ k.transpose(1,2) * (1 / math.sqrt(n_dim)))
scores.shape

torch.Size([32, 16, 16])

In [None]:
out = scores.softmax(dim=-1) @ v
out.shape

torch.Size([32, 16, 8])

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, n_dim):
        super().__init__()
        self.scale = 1 / math.sqrt(n_dim)
        self.q = nn.Linear(n_dim, n_dim)
        self.k = nn.Linear(n_dim, n_dim)
        self.v = nn.Linear(n_dim, n_dim)
        self.norm = nn.GroupNorm(1, n_dim)
        self.lin = nn.Linear(n_dim, n_dim)

    def forward(self, x):
        init_x = x
        bs, c, h, w = x.shape
        x = x.reshape(bs, c, -1).transpose(1, 2)
        q, k, v = self.q(x), self.k(x), self.v(x)
        x = (q @ k.transpose(1, 2) * self.scale).softmax(dim=-1) @ v
        x = self.lin(x).transpose(1, 2).reshape(bs, c, h, w)
        return self.norm(init_x + x)

In [None]:
sa = SelfAttention(n_dim)
o1 = sa(x)
o1.shape

torch.Size([32, 8, 4, 4])

In [None]:
def cp_params(a, b):
    a.weight, a.bias = b.weight, b.bias

In [None]:
sa2 = SelfAttention(n_dim)
for fr, to in (
    (sa.q, sa2.q), (sa.k, sa2.k), (sa.v, sa2.v), (sa.norm, sa2.norm), (sa.lin, sa2.lin),
):
    cp_params(to, fr)
o2 = sa2(x)
(o1 == o2).all()

tensor(True)

In [None]:
lin_qkv = nn.Linear(n_dim, 3 * n_dim)
qkv = lin_qkv(inp)
qkv.shape

torch.Size([32, 16, 24])

In [None]:
q, k, v = qkv[..., :n_dim], qkv[..., n_dim: 2 * n_dim], qkv[..., 2 * n_dim: 3 * n_dim]
q.shape, k.shape, v.shape

(torch.Size([32, 16, 8]), torch.Size([32, 16, 8]), torch.Size([32, 16, 8]))

In [None]:
def heads_to_batch(x, heads):
    bs, c, d = x.shape
    x = x.reshape(bs, c, heads, -1)  # (bs, c, heads, dh)
    x = x.transpose(1, 2)  # (bs, heads, c, dh)
    return x.reshape(bs * heads, c, -1)

In [None]:
x = torch.empty((32, 16, 8))
print(heads_to_batch(x, 2).shape)
print(heads_to_batch(x, 4).shape)
print(heads_to_batch(x, 8).shape)

torch.Size([64, 16, 4])
torch.Size([128, 16, 2])
torch.Size([256, 16, 1])


In [None]:
def batch_to_heads(x, heads):
    _, c, dh = x.shape
    x = x.reshape(-1, heads, c, dh)  # (bs, heads, c, dh)
    x = x.transpose(1, 2)  # (bs, c, heads, dh)
    return x.reshape(-1, c, heads * dh)

In [None]:
print(batch_to_heads(heads_to_batch(x, 2), 2).shape)
print(batch_to_heads(heads_to_batch(x, 4), 4).shape)
print(batch_to_heads(heads_to_batch(x, 8), 8).shape)

torch.Size([32, 16, 8])
torch.Size([32, 16, 8])
torch.Size([32, 16, 8])


In [None]:
class SelfAttentionMultihead(nn.Module):
    def __init__(self, n_dim, nheads):
        super().__init__()
        
        self.nheads = nheads
        self.scale = 1 / math.sqrt(n_dim / nheads)
        self.qkv = nn.Linear(n_dim, 3 * n_dim)
        self.norm = nn.GroupNorm(1, n_dim)
        self.lin = nn.Linear(n_dim, n_dim)
        
    def forward(self, x):
        init_x = x
        bs, c, h, w = x.shape
        x = x.reshape(bs, c, -1).transpose(1, 2)
        qkv = self.qkv(x)
        qkv = heads_to_batch(qkv, self.nheads)
        n_dim = qkv.shape[-1] // 3
        q, k, v = qkv[..., :n_dim], qkv[..., n_dim: 2 * n_dim], qkv[..., 2 * n_dim: 3 * n_dim]
        x = (q @ k.transpose(1, 2) * self.scale).softmax(dim=-1) @ v
        x = batch_to_heads(x, self.nheads)
        x = self.lin(x).transpose(1, 2).reshape(bs, c, h, w)
        return self.norm(init_x + x)

In [None]:
x = torch.randn(32, 8, 4, 4)
print(SelfAttentionMultihead(8, 4)(x).shape)
print(SelfAttentionMultihead(8, 2)(x).shape)
print(SelfAttentionMultihead(8, 8)(x).shape)

torch.Size([32, 8, 4, 4])
torch.Size([32, 8, 4, 4])
torch.Size([32, 8, 4, 4])
