# Positional Encoding

The positional encoding module computes a sequence of embeddings based on the position of an embedding from the input sequence.

In [3]:
import torch
import torch.nn as nn
import math

In [25]:
class PositionalEncoding(nn.Module):
    """
    Standard Sinusoidal Positional Encoding.
    
    wavelength: factor to determine the wavelength in the sinusoidal function.
    """
    def __init__(self, wavelength=10000.):
        super(PositionalEncoding, self).__init__()
        self.wavelength = wavelength

    def forward(self, x):
        """Given a (... x seq_len x embedding_dim) tensor, returns a (seq_len x embedding_dim) tensor."""
        seq_len, embedding_dim = x.shape[-2], x.shape[-1]
        pe = torch.zeros((seq_len, embedding_dim))
        position = torch.arange(seq_len).unsqueeze(1)
        factor = torch.exp(-math.log(self.wavelength) * torch.arange(0, embedding_dim, 2) / embedding_dim)
        pe[:, 0::2] = torch.sin(position * factor)
        pe[:, 1::2] = torch.cos(position * factor)
        return pe

In [28]:
# declare parameters
batch_size = 4
seq_length = 8
embedding_dim = 2

pe = PositionalEncoding()
x = torch.rand((batch_size, seq_length, embedding_dim))
pe(x)

tensor([[ 0.0000,  1.0000],
        [ 0.8415,  0.5403],
        [ 0.9093, -0.4161],
        [ 0.1411, -0.9900],
        [-0.7568, -0.6536],
        [-0.9589,  0.2837],
        [-0.2794,  0.9602],
        [ 0.6570,  0.7539]])

In [29]:
# position-encoded tensor
x += pe(x)
x

tensor([[[ 0.8380,  1.1593],
         [ 0.8669,  0.6860],
         [ 1.8222,  0.1615],
         [ 0.5154, -0.7971],
         [-0.4513, -0.1661],
         [-0.0788,  0.3784],
         [ 0.0910,  1.9162],
         [ 1.6072,  1.2504]],

        [[ 0.8504,  1.0896],
         [ 0.8801,  1.1977],
         [ 1.2070, -0.2350],
         [ 0.6507, -0.7666],
         [-0.4654, -0.1322],
         [-0.6249,  1.1393],
         [-0.2041,  1.9036],
         [ 1.3226,  1.5797]],

        [[ 0.5802,  1.4557],
         [ 1.3878,  1.2987],
         [ 1.4476, -0.4103],
         [ 1.1407, -0.9080],
         [-0.6525, -0.6126],
         [-0.8150,  1.0617],
         [ 0.1906,  1.5095],
         [ 1.1177,  0.9003]],

        [[ 0.6944,  1.0443],
         [ 1.7252,  1.5122],
         [ 1.0109,  0.3975],
         [ 0.5423, -0.6782],
         [-0.7034, -0.1417],
         [-0.8156,  0.8136],
         [-0.2586,  1.9527],
         [ 1.6237,  0.8781]]])