In [1]:
import torch
from einops import rearrange

In [48]:
def spatial_self_attention(Q, K, V):
    """
    Q: (B, D, H, W)
    K: (B, D, H, W)
    V: (B, D, H, W)
    """
    B, D, H, W = Q.shape
    Q = rearrange(Q, "b d h w -> b (h w) d")  # (B, HW, D)
    K = rearrange(K, "b d h w -> b (h w) d")  # (B, HW, D)
    V = rearrange(V, "b d h w -> b (h w) d")  # (B, HW, D)

    QK = Q @ K.transpose(-2, -1)  # (B, HW, HW)
    QK = torch.nn.functional.softmax(QK, dim=-1)  # (B, HW, HW)

    QKV = QK @ V  # (B, HW, D)
    QKV = rearrange(QKV, "b (h w) d -> b d h w", h=H)  # (B, D, H, W)
    return QKV

In [49]:
def spatial_linear_self_attention(Q, K, V):
    """
    Q: (B, D, H, W)
    K: (B, D, H, W)
    V: (B, D, H, W)
    """
    B, D, H, W = Q.shape
    Q = rearrange(Q, "b d h w -> b (h w) d")  # (B, HW, D)
    K = rearrange(K, "b d h w -> b (h w) d")  # (B, HW, D)
    V = rearrange(V, "b d h w -> b (h w) d")  # (B, HW, D)

    Q = torch.nn.functional.softmax(Q, dim=-2)  # (B, HW, D)
    K = torch.nn.functional.softmax(K, dim=-1)  # (B, HW, D)

    KV = K.transpose(-2, -1) @ V  # (B, HW, D)
    QKV = Q @ KV  # (B, HW, D)
    QKV = rearrange(QKV, "b (h w) d -> b d h w", h=H)  # (B, D, H, W)
    return QKV

In [50]:
torch.nn.functional.softmax(torch.randn(B, H * W, D), dim=-2).shape

torch.Size([16, 32, 256])

In [52]:
B, D, H, W = 16, 256, 4, 8
Q = torch.rand(B, D, H, W)
K = torch.rand(B, D, H, W)
V = torch.rand(B, D, H, W)

quadratic = spatial_self_attention(Q, K, V)
linear = spatial_linear_self_attention(Q, K, V)
print(quadratic.shape, linear.shape)

norm = torch.norm(quadratic - linear)
print(norm)

torch.Size([16, 256, 4, 8]) torch.Size([16, 256, 4, 8])
tensor(59.5511)
