In [1]:
import torch
import torchtext

In [2]:
def assertEqual(t1, t2, message):
    assert torch.eq(t1, t2).all()
    print("Checking: ", message, "     ...OK!")

In [3]:
# d_head corresponds to d_k in the paper
def scaled_dot_product_attention(value, key, query, dropout=0.1):
    """
    Shape:
    - query: `(..., T, N * H, E / H)`
    - key: `(..., S, N * H, E / H)`
    - value: `(..., S, N * H, E /H)`
    
    where E = d_model, E/H = d_head
    """
    assert query.shape[-1] == key.shape[-1] == value.shape[-1], "The d_head of query, key, value must be equal."
    S, T, N_H, d_head = key.shape[-3], query.shape[-3], query.shape[-2], query.shape[-1]
    
    query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)

    # calculates attention weights
    query = query * (float(d_head) ** -0.5)     
    attention_weights = torch.matmul(query, key.transpose(-2,-1))
    attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
    attention_weights = torch.nn.functional.dropout(attention_weights, p=dropout)
    assert attention_weights.shape == (N_H, T, S), "attention_weights should be shape (N * H, T, S)"

    attention_output = torch.matmul(attention_weights, value)

    return attention_output.transpose(-3, -2), attention_weights

In [4]:
def test_sdp():
    SDP = torchtext.nn.ScaledDotProduct(dropout=0.1)
    q = torch.randn(25, 256, 3)
    k = v = torch.randn(21, 256, 3)
    expected_attn_output, expected_attn_weights = SDP(q, k, v)
    torch.manual_seed(42)    
    attn_output, attn_weights = scaled_dot_product_attention(v, k, q) 
    
    assert attn_weights.shape == expected_attn_weights.shape
    assertEqual(attn_weights, expected_attn_weights, "attention_weights are expected?")
    assert attn_output.shape == expected_attn_output.shape, "attn_output.shape is {0} whereas expected_output.shape is {1}".format(attn_output.shape, expected_attn_output.shape) 
    assertEqual(attn_output, expected_attn_output, "attention_output is expected?")
    
test_sdp()

Checking:  attention_weights are expected?      ...OK!
Checking:  attention_output is expected?      ...OK!
