In [52]:
import torch as t
import torch.nn as nn

from fancy_einsum import einsum

In [53]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim: int, max_seq_len: int = 5000):
        super().__init__()
        position = t.arange(max_seq_len).unsqueeze(1)
        denominator = 10000 ** (2 * t.arange(embedding_dim // 2).type(t.float32) / embedding_dim)

        pe = t.zeros(max_seq_len, embedding_dim)
        pe[:, 0::2] = t.sin(position / denominator)
        pe[:, 1::2] = t.cos(position / denominator)
        self.register_buffer('pe', pe)

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, seq_len, embedding_dim)
        '''
        seq_len = x.shape[1]
        return x + self.pe[:seq_len]

In [60]:
def single_head_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of
    the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, seq_len, weights_dim)
    K: shape (batch, seq_len, weights_dim)
    V: shape (batch, seq_len, weights_dim)

    Return: shape (batch, seq_len, weights_dim)
    '''
    scale = t.sqrt(t.tensor(K.shape[-1]).type(t.float32))
    raw_attention_filter = einsum('b Q_seq_len w, b K_seq_len w -> b Q_seq_len K_seq_len', Q, K)
    attention_filter = t.softmax(raw_attention_filter / scale, dim=-1)
    return einsum('b out_seq_len seq_len, b seq_len w -> b out_seq_len w', attention_filter, V)


def single_head_masked_attention(Q: t.Tensor, K: t.Tensor, V: t.Tensor) -> t.Tensor:
    '''
    Should return the results of masked self-attention.

    See "The Decoder Side" section of the Illustrated Transformer for an explanation of masking.

    Q: shape (batch, seq_len, weights_dim)
    K: shape (batch, seq_len, weights_dim)
    V: shape (batch, seq_len, weights_dim)

    Return: shape (batch, seq_len, weights_dim)
    '''
    scale = t.sqrt(t.tensor(K.shape[-1]).type(t.float32))
    raw_attention_filter = einsum('b Q_seq_len w, b K_seq_len w -> b Q_seq_len K_seq_len', Q, K)
    mask_filter = t.triu(t.full_like(raw_attention_filter, -t.inf), 1)
    masked_attention_filter = t.softmax((raw_attention_filter + mask_filter) / scale, dim=-1)
    return einsum('b out_seq_len seq_len, b seq_len w -> b out_seq_len w', masked_attention_filter, V)