# Positional Encoding in Attention is All You Need (AAYN)

#### $PE_{pos, 2i} = \sin\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)$ and $PE_{pos, 2i+1} = \cos\left(\frac{pos}{10000^{\frac{2i}{d_{model}}}}\right)$

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

<img src="../../assets/pos_enc.png" width="700" height="350">

In [1]:
import torch
import torch.nn as nn

In [None]:
class PositionalEncoding(nn.Module):
    """
    simple positional encoding with transformer model (in attention is all you need)

    Args:
        context_size (int): maximum lenght of the input sequence (also known as max_length)
        d_model (int): internal dimension of the model or dimension of embeddings.
        (also known as 'hidden_size')
    """    
    def __init__(self, context_size: int, d_model: int):
        super().__init__()

        pos = torch.arange(0, context_size).unsqueeze(dim=1) # [context_size, 1]
        # dimension indices
        # for d_model=5 -> ii = (0, 2, 4) and ii[:d_model//2] = (0, 2) (see the figure above)
        # this way of implementation, will cover both even and odd values for d_model
        ii = torch.arange(0, (d_model+1)//2)
        div = 10000 ** (2*ii/d_model)
        
        # initialize positional encoding [context_size, d_model]
        self.encoding = torch.zeros(context_size, d_model)
        self.encoding[:, 0::2] = torch.sin(pos / div) # even positions
        self.encoding[:, 1::2] = torch.cos(pos / div[:d_model//2]) # odd  positions

        # Registers positional encoding tensor as part of the module state, but not as a 
        # learnable parameter (i.e., not updated by gradient descent). Positional encodings 
        # in the vanilla “Attention Is All You Need” is not trained.
        # Moreover, when register as buffer, it moves with model and is saved in state_dict
        self.register_buffer('pos_encoding', self.encoding)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        returns positional encoding for a given input tensor x.
        (input tensor x is comming from token embedding layer in the transformer architecture)

        Args:
            x (torch.Tensor): input tensor [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: positional encoding slice of shape [seq_len, d_model], 
            ready to be added to token embeddings.
        """        
        seq_len = x.size(1) # number of tokens in the input sequence
        
        # make sure to use 'pos_encoding' from self.register_buffer 
        # (otherwise we can't use the benefits of self.register_buffer)
        return self.pos_encoding[:seq_len, :]

#### Toy Example

In [15]:
vocab_size = 15
d_model = 11
context_size = 20   # maximum sequence length (that is suppored)
batch_size = 1
seq_len = 5

embedding = nn.Embedding(vocab_size, d_model) # token embedding layer
pos_encoder = PositionalEncoding(context_size=context_size, d_model=d_model) # positional encoding module

In [16]:
# Example: batch of 2 sequences with 7 elements
x = torch.tensor([
    [6, 7, 8, 9, 0, 1, 2],
    [0, 0, 1, 4, 5, 9, 5]
]) # shape [batch_size, seq_len]

x_emb = embedding(x) #[batch_size, seq_len, d_model]
x_pos_embed = pos_encoder(x_emb) #[seq_len, d_model]
x_pos_embed

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,
          0.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  1.8629e-01,  9.8250e-01,  3.5105e-02,
          9.9938e-01,  6.5793e-03,  9.9998e-01,  1.2328e-03,  1.0000e+00,
          2.3101e-04],
        [ 9.0930e-01, -4.1615e-01,  3.6605e-01,  9.3059e-01,  7.0166e-02,
          9.9754e-01,  1.3158e-02,  9.9991e-01,  2.4657e-03,  1.0000e+00,
          4.6203e-04],
        [ 1.4112e-01, -9.8999e-01,  5.3300e-01,  8.4611e-01,  1.0514e-01,
          9.9446e-01,  1.9737e-02,  9.9981e-01,  3.6985e-03,  9.9999e-01,
          6.9304e-04],
        [-7.5680e-01, -6.5364e-01,  6.8129e-01,  7.3201e-01,  1.3999e-01,
          9.9015e-01,  2.6314e-02,  9.9965e-01,  4.9314e-03,  9.9999e-01,
          9.2405e-04],
        [-9.5892e-01,  2.8366e-01,  8.0573e-01,  5.9228e-01,  1.7466e-01,
          9.8463e-01,  3.2891e-02,  9.9946e-01,  6.1642e-03,  9.9998e-0

In [17]:
x_emb.size(), x_pos_embed.size()

(torch.Size([2, 7, 11]), torch.Size([7, 11]))