# Attention shapes

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from attention import (
    attend,
    self_attend,
    SelfAttention,
    SinusoidalEncoding
)

In [None]:
_ = torch.manual_seed(1223334444) # set random seed manually

## Attention

In [None]:
# set parameters
m = 8
n = 9

d_q = 20
d_k = 15
d_v = 10

# sample Q, K and V
Q = torch.randn(m, d_k)
K = torch.randn(n, d_k)
V = torch.randn(n, d_v)

# compute attention
attn = attend(Q, K, V)

print(f'Attention shape: {attn.shape}')

In [None]:
# check against reference implementation
ref_attn = nn.functional.scaled_dot_product_attention(Q, K, V)

is_close = torch.allclose(attn, ref_attn, atol=1e-05)

print(f'Close to reference implementation: {is_close}')

## Self-attention

In [None]:
# set parameters
l = 7

d_x = 100
d_k = 15
d_v = 10

# sample X, W_q, W_k and W_v
X = torch.randn(l, d_x)

W_q = torch.randn(d_x, d_k)
W_k = torch.randn(d_x, d_k)
W_v = torch.randn(d_x, d_v)

# compute self-attention
self_attn = self_attend(X, W_q, W_k, W_v)

print(f'Self-attention shape: {self_attn.shape}')

In [None]:
# check against reference implementation
ref_self_attn = nn.functional.scaled_dot_product_attention(X @ W_q, X @ W_k, X @ W_v)

is_close = torch.allclose(self_attn, ref_self_attn, atol=1e-05)

print(f'Close to reference implementation: {is_close}')

## Self-attention layer

In [None]:
# set parameters
d_x = 10
d_k = 11
d_v = 12

seq_length = 25
batch_size = 64

# initialize layer
self_attention = SelfAttention(
    d_x=d_x,
    d_k=d_k,
    d_v=d_v,
    scale=True
)

# sample input sequence
X = torch.randn(batch_size, seq_length, d_x)

# compute attention
self_attn = self_attention(X)

print(f'Self-attention shape: {self_attn.shape}')

## Sinusoidal encoding

In [None]:
# set parameters
max_length = 96
embed_dim = 128

# initialize encoding
sinusoidal_encoding = SinusoidalEncoding(embed_dim=embed_dim)

# compute encodings
t = torch.arange(max_length).view(-1, 1)
enc = sinusoidal_encoding(t)

print(f'Encoding shape: {enc.shape}')

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
img = ax.imshow(enc.numpy(), cmap='PRGn', aspect='auto', vmin=-1, vmax=1)
ax.set(xlabel='embedding dim.', ylabel='position')
fig.colorbar(img)
fig.tight_layout()