In [None]:
from torch import nn
## --------------------------------------------------------------------------------
## Simple implementation of multi-head self-attention
## Modified from timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
## --------------------------------------------------------------------------------
class Attention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__(hidden_size, num_heads)
        self.head_dim = hidden_size // num_heads
        self.qkv = nn.Linear(hidden_size, hidden_size * 3)
        self.proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        # 1st computation complexity: N * C * 3C = 3N(C^2)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)  # shape: (3, B, heads, N, head_dim)
        q, k, v = qkv.unbind(0)  # shape: (B, heads, N, head_dim)

        q = q * (self.head_dim ** -0.5)
        # 2nd computation complexity: N * C * N = C(N^2)
        attn = q @ k.transpose(-2, -1)  # shape: (B, heads, N, N)
        if mask is not None:
            attn += (mask * -1e9)
        attn = attn.softmax(dim=-1)
        # 3rd computation complexity: N * N * C = C(N^2)
        x = attn @ v  # shape: (B, heads, N, head_dim)

        x = x.transpose(1, 2).reshape(B, N, C)  # shape: (B, N, dim)
        
        # 4th computation complexity: N * C * C = N(C^2)
        x = self.proj(x)

        # total computation complexity: 3N(C^2) + C(N^2) + C(N^2) + N(C^2) = 4N(C^2) + 2C(N^2)
        return x