In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn

import matplotlib.pyplot as plt

In [None]:
class LearnablePositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_len, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer("positions", torch.arange(max_len).unsqueeze(0))  # Shape: (1, max_len)

    def forward(self, x):
        # x has shape (batch_size, seq_len, d_model)
        # it is the token embeddings
        pos_emb = self.position_embeddings(self.positions[:, :x.shape[1]])
        return self.dropout(x + pos_emb)

- A learnable positional encoding allows the model to learn the best positional representations during training
- Dropout helps prevent overfitting by randomly zeroing out some position encodings and improves generalization to unseen sequences.

**Side note**: We use self.register_buffer() to create a buffer that is not a learnable parameter. The call in \_\_init\_\_() registers *positions* as a buffer, meaning:
- It persists in the model but is not a learnable parameter.
- It moves automatically with the model (e.g., to GPU).
- It is saved and loaded with state_dict().

In [4]:
dummy_input = torch.tensor(
    [[[0.2, 0.4, 0.1, 0.3],
      [0.5, 0.2, 0.7, 0.9],
      [0.8, 0.6, 0.4, 0.2]]]
)
dummy_input

tensor([[[0.2000, 0.4000, 0.1000, 0.3000],
         [0.5000, 0.2000, 0.7000, 0.9000],
         [0.8000, 0.6000, 0.4000, 0.2000]]])

In [6]:
# create the positional encoding

d_model = dummy_input.shape[2]  # Embedding dimension (4)
seq_len = dummy_input.shape[1]  # Length of sequence (3)

pos_encoder = LearnablePositionalEncoding(d_model)

In [8]:
output = pos_encoder(dummy_input)

print(f"Learnable Positional Encoding with Dropout (seq_len={seq_len}, d_model={d_model}):")
print(output)

Learnable Positional Encoding with Dropout (seq_len=3, d_model=4):
tensor([[[ 0.5879, -0.1287, -0.4993,  2.0910],
         [ 2.4000,  2.2207,  1.0777, -0.7921],
         [ 2.0809,  0.7444, -1.3603, -0.6124]]], grad_fn=<MulBackward0>)
