In [1]:
import torch
import torchtext

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x10c0bf9d0>

In [3]:
SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
q = torch.randn(25, 256, 3)
k = v = torch.randn(21, 256, 3)
expected_attn_output, expected_attn_weights = SDP(q, k, v)
print(expected_attn_output.shape, expected_attn_weights.shape)

key tranpsoed shape:  torch.Size([256, 3, 21])
query shape:  torch.Size([256, 25, 3])
attn_output_weights shape  torch.Size([256, 25, 21])
attn_output_weights shape after softmax torch.Size([256, 25, 21])
torch.Size([25, 256, 3]) torch.Size([256, 25, 21])


In [4]:
def assertEqual(t1, t2):
    assert torch.eq(t1, t2).all()
    print("passed assertEqual")

In [None]:
torch.manual_seed(42)

In [5]:
# d_head corresponds to d_k in the paper
def scaled_dot_product_attention(value, key, query, dropout=0.1):
    """
    Shape:
    - query: `(..., T, N * H, E / H)`
    - key: `(..., S, N * H, E / H)`
    - value: `(..., S, N * H, E /H)`
    
    where E = d_model, E/H = d_head
    """
    assertEqual(value, v)
    assertEqual(key, k)
    assertEqual(query, q)
    
    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "The d_head of query, key, value must be equal."
    S, T, N_H, d_head = key.shape[-3], query.shape[-3], query.shape[-2], query.shape[-1]
    
    query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)

    # calculates attention weights
    query = query * (float(d_head) ** -0.5)     
    attention_weights = torch.matmul(query, key.transpose(-2,-1))
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=dropout)
    assert attention_weights.shape == (N_H, T, S), "attention_weights should be shape (N * H, T, S)"

    attention_output = torch.matmul(attention_weights, value)

    return attention_output, attention_weights

In [6]:
attn_output, attn_weights = scaled_dot_product_attention(v, k, q)

passed assertEqual
passed assertEqual
passed assertEqual


In [7]:
print(expected_attn_weights)

tensor([[[7.3709e-03, 1.0526e-02, 3.1706e-01,  ..., 5.3871e-02,
          1.4450e-01, 9.8899e-03],
         [5.2474e-03, 7.2790e-04, 2.1213e-03,  ..., 2.7348e-03,
          3.3701e-03, 2.1130e-01],
         [7.9056e-02, 6.6726e-02, 1.3913e-02,  ..., 2.6752e-02,
          0.0000e+00, 5.7255e-02],
         ...,
         [3.2884e-02, 1.1404e-01, 1.7159e-01,  ..., 1.3734e-01,
          1.1767e-01, 0.0000e+00],
         [3.5993e-03, 3.5142e-03, 1.2023e-02,  ..., 4.5520e-02,
          1.6123e-02, 5.1409e-02],
         [4.2352e-02, 2.2713e-02, 4.4686e-02,  ..., 0.0000e+00,
          0.0000e+00, 7.4061e-02]],

        [[1.8031e-03, 6.8657e-02, 2.0406e-01,  ..., 4.0332e-02,
          2.1923e-02, 4.4464e-03],
         [1.4574e-02, 7.4297e-02, 2.1377e-01,  ..., 5.6560e-02,
          1.1533e-02, 2.0428e-02],
         [0.0000e+00, 4.8970e-02, 6.1474e-02,  ..., 6.9460e-02,
          2.2376e-02, 6.1622e-02],
         ...,
         [3.2817e-02, 3.7500e-02, 1.7899e-02,  ..., 2.8935e-02,
          3.188

In [8]:
print(attn_weights)

tensor([[[7.3709e-03, 1.0526e-02, 3.1706e-01,  ..., 5.3871e-02,
          1.4450e-01, 9.8899e-03],
         [5.2474e-03, 7.2790e-04, 2.1213e-03,  ..., 2.7348e-03,
          3.3701e-03, 2.1130e-01],
         [7.9056e-02, 6.6726e-02, 1.3913e-02,  ..., 0.0000e+00,
          1.9131e-02, 5.7255e-02],
         ...,
         [3.2884e-02, 1.1404e-01, 1.7159e-01,  ..., 1.3734e-01,
          1.1767e-01, 7.4021e-03],
         [3.5993e-03, 3.5142e-03, 1.2023e-02,  ..., 4.5520e-02,
          1.6123e-02, 5.1409e-02],
         [4.2352e-02, 2.2713e-02, 4.4686e-02,  ..., 2.2976e-02,
          4.1240e-02, 0.0000e+00]],

        [[1.8031e-03, 6.8657e-02, 2.0406e-01,  ..., 0.0000e+00,
          2.1923e-02, 0.0000e+00],
         [1.4574e-02, 7.4297e-02, 2.1377e-01,  ..., 5.6560e-02,
          1.1533e-02, 2.0428e-02],
         [2.4224e-02, 4.8970e-02, 6.1474e-02,  ..., 6.9460e-02,
          2.2376e-02, 6.1622e-02],
         ...,
         [3.2817e-02, 3.7500e-02, 1.7899e-02,  ..., 2.8935e-02,
          3.188

In [10]:
assert attn_weights.shape == expected_attn_weights.shape
assertEqual(attn_weights, expected_attn_weights)

AssertionError: 

In [5]:
print(attn_output.shape, attn_weights.shape)

torch.Size([256, 25, 3]) torch.Size([256, 25, 21])


In [6]:
print("expected_attn_output: ", expected_attn_output)
print("actual_attn_output: ", attn_output)

expected_attn_output:  tensor([[[-1.7754e-01,  1.2924e-01,  1.5124e-01],
         [-5.5349e-01,  9.9234e-02, -1.0563e-01],
         [-9.8998e-01,  4.5181e-03, -6.7981e-01],
         ...,
         [-8.1611e-01,  8.0464e-01,  1.4902e+00],
         [ 1.2824e+00,  1.1947e+00,  3.6544e-01],
         [ 1.7662e-01, -8.4852e-01, -2.9297e-01]],

        [[ 7.8802e-01, -8.1057e-01,  4.0439e-01],
         [-1.0824e+00, -9.2882e-01, -2.4140e-01],
         [-2.5309e+00, -1.5416e+00, -4.3882e-01],
         ...,
         [-4.3796e-01,  8.1887e-01, -2.4539e-01],
         [-1.0819e+00, -2.4269e-01,  1.4391e-01],
         [ 3.8632e-01,  2.7089e-01, -3.7836e-01]],

        [[ 5.2584e-01,  2.4597e-01,  6.9702e-01],
         [-3.9866e-01,  9.3980e-01,  2.9205e-01],
         [-1.1475e+00, -8.7214e-01, -4.5471e-01],
         ...,
         [-8.9404e-01,  1.0199e+00,  1.6401e+00],
         [ 2.2478e-01, -5.7274e-01, -2.8966e-01],
         [ 7.0868e-01,  6.0373e-01, -8.6544e-02]],

        ...,

        [[ 9.86

In [7]:
assert(torch.eq(attn_output,expected_attn_output).all())

RuntimeError: The size of tensor a (25) must match the size of tensor b (256) at non-singleton dimension 1