# MultiHead Attention Unit

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.shape[-1]
    scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scaled = scaled + mask
    attention = F.softmax(scaled, dim=-1)
    output = torch.matmul(attention, v)
    return output, attention

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, d_model, n_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f'x.size(): {x.size()}')
        qkv = self.qkv_layer(x)
        print(f'qkv.size(): {qkv.size()}')
        qkv = qkv.reshape(batch_size, sequence_length, self.n_heads, 3 * self.head_dim)
        print(f'qkv.size(): {qkv.size()}')
        qkv = qkv.permute(0, 2, 1, 3)
        print(f'qkv.size(): {qkv.size()}')
        q, k, v = qkv.chunk(3, dim=-1)
        print(f'q: {q.size()}\nk: {k.size()}\nv: {v.size()}\n')
        values, attention = scaled_dot_product_attention(q, k, v, mask=mask)
        print(f'values.size(): {values.size()}\nattention.size(): {attention.size()}')
        values = values.reshape(batch_size, sequence_length, self.n_heads * self.head_dim)
        print(f'values.size(): {values.size()}')
        out = self.linear_layer(values)
        print(f'out: {out.size()}')
        return out

In [7]:
batch_size = 30

input_dim = 1024
n_heads = 8
d_model = 512

sequence_length = 5
x = torch.randn((batch_size, sequence_length, input_dim))

model = MultiHeadAttention(input_dim, d_model, n_heads)
out = model.forward(x)
out

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q: torch.Size([30, 8, 5, 64])
k: torch.Size([30, 8, 5, 64])
v: torch.Size([30, 8, 5, 64])

values.size(): torch.Size([30, 8, 5, 64])
attention.size(): torch.Size([30, 8, 5, 5])
values.size(): torch.Size([30, 5, 512])
out: torch.Size([30, 5, 512])


tensor([[[-0.1623, -0.1653,  0.1396,  ...,  0.3152,  0.0166,  0.0153],
         [ 0.0114,  0.2720, -0.0178,  ...,  0.2065,  0.0054,  0.0868],
         [ 0.1590,  0.1196,  0.0500,  ..., -0.0605, -0.2720, -0.1128],
         [ 0.4091,  0.1578, -0.1297,  ...,  0.1036, -0.0105, -0.0956],
         [-0.1116,  0.0521,  0.1413,  ..., -0.1493,  0.0577,  0.3494]],

        [[-0.1405,  0.1221,  0.0095,  ...,  0.0075, -0.2480,  0.0052],
         [-0.1524,  0.3020, -0.1569,  ...,  0.0943, -0.0375, -0.0781],
         [ 0.0838,  0.0777, -0.0320,  ...,  0.1853,  0.1358,  0.1878],
         [-0.0298, -0.0100,  0.0615,  ..., -0.0851, -0.1713, -0.1433],
         [ 0.0863,  0.0296,  0.1967,  ..., -0.1727,  0.1416,  0.1266]],

        [[-0.0207,  0.1130, -0.1000,  ..., -0.1030, -0.0508,  0.0267],
         [-0.1799, -0.0367, -0.0710,  ...,  0.1104, -0.0764, -0.0496],
         [-0.0626, -0.0912, -0.0757,  ...,  0.1202, -0.1761, -0.2932],
         [-0.0082, -0.2578,  0.2435,  ..., -0.1285,  0.0730, -0.2077],
  