In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
torch.manual_seed(0)
torch.set_printoptions(precision=7)

batch, sentence_length, embedding_dim = 1, 6, 3
embedding = torch.randn(batch, sentence_length, embedding_dim)
print(embedding)

tensor([[[-1.1258398, -1.1523602,  0.5666506],
         [ 0.7935084,  0.5988395, -1.5550951],
         [-0.3413604,  1.8530061,  0.4680964],
         [-0.1577124, -0.1733968,  0.1834779],
         [ 1.3893661,  1.5863342,  0.9462984],
         [-0.8436767,  0.9318266,  1.2590092]]])


In [None]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
val, att = scaled_dot_product(embedding, embedding, embedding)
mask_simple = torch.tril(torch.ones(sentence_length, sentence_length))
val_mas, att_mas = scaled_dot_product(embedding, embedding, embedding, mask_simple)

In [None]:
print(att_mas)
print(mask_simple)

tensor([[[1.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000],
         [0.0326173, 0.9673827, 0.0000000, 0.0000000, 0.0000000, 0.0000000],
         [0.0411625, 0.1034752, 0.8553624, 0.0000000, 0.0000000, 0.0000000],
         [0.3287411, 0.1850140, 0.2241982, 0.2620467, 0.0000000, 0.0000000],
         [0.0064874, 0.0471875, 0.1808757, 0.0280419, 0.7374074, 0.0000000],
         [0.0882735, 0.0190180, 0.2824616, 0.0705991, 0.1491040, 0.3905437]]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [None]:
print(val_mas)

tensor([[[-1.1258398, -1.1523602,  0.5666506],
         [ 0.7309045,  0.5417201, -1.4858896],
         [-0.2562208,  1.5995227,  0.2628030],
         [-0.3411601,  0.1019680,  0.0515932],
         [ 0.9885026,  1.5208580,  0.7179147],
         [-0.3141791,  1.0212752,  0.7984132]]])


In [None]:
def MultiheadAttention(qkv, num_heads, mask=None):
        batch_size, seq_length, qkv_dim = qkv.size()
        embed_dim = qkv_dim // 3
        head_dim = embed_dim // (num_heads)

        # Separate Q, K, V from linear output
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.reshape(batch_size, seq_length, num_heads, head_dim)
        q = q.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]
        print(q)
        k = k.reshape(batch_size, seq_length, num_heads, head_dim)
        k = k.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]

        v = v.reshape(batch_size, seq_length, num_heads, head_dim)
        v = v.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]

        # qkv = qkv.reshape(batch_size, seq_length, 3, num_heads * head_dim)
        # print(qkv)
        # qkv = qkv.permute(0, 2, 1, 3)  # [Batch, Head, SeqLen, Dims]
        # q, k, v = qkv.chunk(3, dim=-1)


        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3)  # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, embed_dim)

        return values, attention


In [None]:
torch.manual_seed(0)
num_heads = 3
embedding_dim = 9
qkv_dim = embedding_dim * 3
qkv = torch.rand(batch, sentence_length, qkv_dim)
print(qkv)
print(qkv.size())

tensor([[[0.4962566, 0.7682218, 0.0884774, 0.1320305, 0.3074228, 0.6340787,
          0.4900934, 0.8964447, 0.4556280, 0.6323063, 0.3488935, 0.4017173,
          0.0223258, 0.1688589, 0.2938884, 0.5185218, 0.6976676, 0.8000114,
          0.1610295, 0.2822686, 0.6816086, 0.9151940, 0.3970999, 0.8741559,
          0.4194083, 0.5529070, 0.9527381],
         [0.0361648, 0.1852310, 0.3734174, 0.3051000, 0.9320004, 0.1759102,
          0.2698336, 0.1506798, 0.0317195, 0.2081298, 0.9297990, 0.7231092,
          0.7423363, 0.5262958, 0.2436582, 0.5845923, 0.0331526, 0.1387169,
          0.2422350, 0.8154690, 0.7931606, 0.2782525, 0.4819588, 0.8197803,
          0.9970666, 0.6984411, 0.5675464],
         [0.8352432, 0.2055988, 0.5931720, 0.1123472, 0.1534569, 0.2417082,
          0.7262365, 0.7010802, 0.2038237, 0.6510535, 0.7744860, 0.4368913,
          0.5190908, 0.6158524, 0.8101883, 0.9800971, 0.1146882, 0.3167651,
          0.6965050, 0.9142747, 0.9351037, 0.9411784, 0.5995073, 0.0652087,


In [None]:
MultiheadAttention(qkv, num_heads, mask_simple)

tensor([[[[0.4962566, 0.7682218, 0.0884774],
          [0.0361648, 0.1852310, 0.3734174],
          [0.8352432, 0.2055988, 0.5931720],
          [0.9442462, 0.8801799, 0.0012360],
          [0.4724502, 0.5750725, 0.2952349],
          [0.4485999, 0.5138961, 0.4568655]],

         [[0.1320305, 0.3074228, 0.6340787],
          [0.3051000, 0.9320004, 0.1759102],
          [0.1123472, 0.1534569, 0.2417082],
          [0.5935860, 0.4157700, 0.4177194],
          [0.7966888, 0.1957304, 0.9536850],
          [0.6011907, 0.8179197, 0.9736231]],

         [[0.4900934, 0.8964447, 0.4556280],
          [0.2698336, 0.1506798, 0.0317195],
          [0.7262365, 0.7010802, 0.2038237],
          [0.2711216, 0.6922781, 0.2038482],
          [0.8426499, 0.0783585, 0.3755578],
          [0.8175279, 0.9747068, 0.4638392]]]])


(tensor([[[0.1610295, 0.2822686, 0.6816086, 0.9151940, 0.3970999, 0.8741559,
           0.4194083, 0.5529070, 0.9527381],
          [0.2041172, 0.5651852, 0.7407982, 0.5471206, 0.4461379, 0.8427336,
           0.6996289, 0.6235052, 0.7658825],
          [0.3756499, 0.6754045, 0.8065838, 0.7126206, 0.4977171, 0.5667509,
           0.6224525, 0.4672919, 0.5303556],
          [0.3405205, 0.7688560, 0.8197187, 0.6017370, 0.4748967, 0.4218882,
           0.5884267, 0.3828350, 0.4334926],
          [0.3403816, 0.6683302, 0.7224301, 0.5051575, 0.4620013, 0.4434961,
           0.5172774, 0.3966154, 0.4936182],
          [0.4002910, 0.6833256, 0.6219555, 0.5567545, 0.4183558, 0.4400183,
           0.4843370, 0.3930716, 0.4928433]]]),
 tensor([[[[1.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000],
           [0.4693991, 0.5306010, 0.0000000, 0.0000000, 0.0000000, 0.0000000],
           [0.3279736, 0.3197069, 0.3523195, 0.0000000, 0.0000000, 0.0000000],
           [0.2181782, 0.232