# 🤯 Multi-Head Attention 📢
---

*COSCI 223 - Machine Learning 3*

*Prepared by Sebastian C. Ibañez*

<a href="https://colab.research.google.com/github/aim-msds/msds2023-ml3/blob/main/notebooks/transformer/01-multiheadattn.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" style="float: left;"></a><br>

PyTorch docs [here](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [2]:
# create input sentence
input_sentence = ['the', 'train', 'left', 'the', 'station', 'on', 'time']
input_sentence

['the', 'train', 'left', 'the', 'station', 'on', 'time']

In [3]:
# create a vocab
vocab = sorted(list(set(input_sentence))) # our vocab only has 7 words
vocab

['left', 'on', 'station', 'the', 'time', 'train']

In [4]:
# convert tokens to integers, then torch tensor with appropriate shape
x = [vocab.index(token) for token in input_sentence]
x = torch.tensor(x).view(-1, 1) # shape is (sequence length, batch size)

print(x.shape)
print(x)

torch.Size([7, 1])
tensor([[3],
        [5],
        [0],
        [3],
        [2],
        [1],
        [4]])


In [5]:
torch.manual_seed(123)

# embedding layer
embed_dim = 2
embedding = nn.Embedding(len(vocab), embed_dim)

with torch.no_grad():
    x_embed = embedding(x)

print(x_embed.shape)
print(x_embed)

torch.Size([7, 1, 2])
tensor([[[-0.9724, -0.7550]],

        [[ 0.2103, -0.3908]],

        [[-0.1115,  0.1204]],

        [[-0.9724, -0.7550]],

        [[-1.1969,  0.2093]],

        [[-0.3696, -0.2404]],

        [[ 0.3239, -0.1085]]])


In [6]:
torch.manual_seed(1)

# create embedding layers for Q, K, V
head_dim = 4 # doesn't necessarily have to be the same as embed_dim
             # also called d_k or d_model in Vaswani et al. (2017)
             # sometimes called block_size

query_embed = nn.Linear(embed_dim, head_dim, bias=False) # bias is off as per Vaswani et al. (2017)
key_embed   = nn.Linear(embed_dim, head_dim, bias=False)
value_embed = nn.Linear(embed_dim, head_dim, bias=False)

# for every x in our sequence, get a query, key, value
with torch.no_grad():
    query = query_embed(x_embed)
    key   = key_embed(x_embed)
    value = value_embed(x_embed)

print('Query:')
print(query.shape)
print(query)

print('\nKey:')
print(key.shape)
print(key)

print('\nValue:')
print(value.shape)
print(value)

Query:
torch.Size([7, 1, 4])
tensor([[[-0.1186, -0.1173,  0.3271, -0.1302]],

        [[ 0.1986, -0.1586, -0.3058, -0.1712]],

        [[-0.0782,  0.0552,  0.1252,  0.0595]],

        [[-0.1186, -0.1173,  0.3271, -0.1302]],

        [[-0.5014,  0.2335,  0.8855,  0.2494]],

        [[-0.0596, -0.0291,  0.1441, -0.0327]],

        [[ 0.1519, -0.0804, -0.2616, -0.0862]]])

Key:
torch.Size([7, 1, 4])
tensor([[[-0.0302, -0.2170, -0.0431,  0.0982]],

        [[ 0.0545,  0.0276,  0.1620,  0.0140]],

        [[-0.0214, -0.0177, -0.0620, -0.0019]],

        [[-0.0302, -0.2170, -0.0431,  0.0982]],

        [[-0.1358, -0.2274, -0.3668,  0.0484]],

        [[-0.0155, -0.0809, -0.0292,  0.0344]],

        [[ 0.0412,  0.0597,  0.1136, -0.0098]]])

Value:
torch.Size([7, 1, 4])
tensor([[[-0.0975, -0.7672,  0.5785,  0.3455]],

        [[ 0.0227,  0.0440,  0.1115,  0.0943]],

        [[-0.0118, -0.0424, -0.0221, -0.0235]],

        [[-0.0975, -0.7672,  0.5785,  0.3455]],

        [[-0.1233, -0.6938,  0.

In [7]:
torch.manual_seed(1)

# create MHA layer
num_heads = 2 # Note that head_dim will be split across num_heads (i.e. each head will have dimension head_dim//num_heads)

multihead_attn = nn.MultiheadAttention(head_dim, num_heads, dropout=0.0)

with torch.no_grad():
    attn_output, attn_output_weights = multihead_attn(query, key, value)

print('Output:')
print(attn_output.shape)
print(attn_output)

print('\nAttn Weights:')
print(attn_output_weights.shape)
print(attn_output_weights)

Output:
torch.Size([7, 1, 4])
tensor([[[ 0.0168, -0.0299,  0.0095, -0.0084]],

        [[ 0.0170, -0.0291,  0.0093, -0.0086]],

        [[ 0.0168, -0.0298,  0.0095, -0.0085]],

        [[ 0.0168, -0.0299,  0.0095, -0.0084]],

        [[ 0.0165, -0.0309,  0.0097, -0.0082]],

        [[ 0.0168, -0.0298,  0.0094, -0.0085]],

        [[ 0.0169, -0.0292,  0.0093, -0.0086]]])

Attn Weights:
torch.Size([1, 7, 7])
tensor([[[0.1429, 0.1446, 0.1425, 0.1429, 0.1400, 0.1429, 0.1441],
         [0.1428, 0.1404, 0.1432, 0.1428, 0.1470, 0.1427, 0.1410],
         [0.1429, 0.1438, 0.1427, 0.1429, 0.1413, 0.1429, 0.1436],
         [0.1429, 0.1446, 0.1425, 0.1429, 0.1400, 0.1429, 0.1441],
         [0.1428, 0.1493, 0.1418, 0.1428, 0.1326, 0.1430, 0.1476],
         [0.1429, 0.1437, 0.1427, 0.1429, 0.1415, 0.1429, 0.1435],
         [0.1428, 0.1409, 0.1431, 0.1428, 0.1461, 0.1428, 0.1414]]])


In [8]:
# Create a causal mask
seq_len = len(input_sentence)
causal_mask = torch.tril(torch.ones((seq_len, seq_len))) # shape is (batch_size*num_heads, output seq_len, input seq_len)
causal_mask[causal_mask == 0] = float('-inf') # replace 0 with -inf

print(causal_mask.shape)
print(causal_mask)

torch.Size([7, 7])
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1.]])


In [9]:
# MHA with causal mask
with torch.no_grad():
    attn_output, attn_output_weights = multihead_attn(query, key, value, attn_mask=causal_mask)

print('Output:')
print(attn_output.shape)
print(attn_output)

print('\nAttn Weights:')
print(attn_output_weights.shape)
print(attn_output_weights)

Output:
torch.Size([7, 1, 4])
tensor([[[ 0.0422, -0.0849,  0.0276, -0.0172]],

        [[ 0.0262, -0.0692,  0.0236, -0.0038]],

        [[ 0.0165, -0.0406,  0.0136, -0.0039]],

        [[ 0.0230, -0.0517,  0.0171, -0.0072]],

        [[ 0.0208, -0.0342,  0.0105, -0.0121]],

        [[ 0.0199, -0.0317,  0.0098, -0.0115]],

        [[ 0.0169, -0.0292,  0.0093, -0.0086]]])

Attn Weights:
torch.Size([1, 7, 7])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3327, 0.3349, 0.3324, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2495, 0.2523, 0.2488, 0.2495, 0.0000, 0.0000, 0.0000],
         [0.2014, 0.2104, 0.1999, 0.2014, 0.1869, 0.0000, 0.0000],
         [0.1668, 0.1677, 0.1666, 0.1668, 0.1652, 0.1668, 0.0000],
         [0.1428, 0.1409, 0.1431, 0.1428, 0.1461, 0.1428, 0.1414]]])
