<a href="https://colab.research.google.com/github/kasakun/CodeBook/blob/master/ml_coding/positional_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

In [None]:
class PositionalEncoding(torch.nn.Module):
  def __init__(self, d_model, max_len=5000):
    super(PositionalEncoding, self).__init__()
    self.P = torch.zeros((1, max_len, d_model))

    # [seq_len, 1]
    position = torch.arange(0, max_len).unsqueeze(1).float()
    div_term = torch.pow(10000, torch.arange(0, d_model, 2).float() / d_model)

    self.P[:, :, 0::2] = torch.sin(position/div_term)
    self.P[:, :, 1::2] = torch.cos(position/div_term)

  def forward(self, x):
    # input x: [batch_size, seq_len, model_d]
    return x + self.P[:, :x.size(1)]

In [None]:
batch_size = 4  # Number of sequences in the batch
seq_len = 10  # Length of the sequence
model_dim = 16  # Dimensionality of the embeddings

x = torch.randn(batch_size, seq_len, model_dim)

pos_encoder = PositionalEncoding(d_model=model_dim)
pos_enc = pos_encoder(x)
print(pos_enc)
print(pos_enc.shape)  # Should output (batch_size, seq_len, model_dim)

tensor([[[-2.0259e+00,  2.4181e-02,  9.6375e-01,  1.7148e+00,  8.5345e-01,
           9.2365e-01,  7.6732e-01,  1.7340e-01,  3.8744e-01,  1.1842e+00,
          -4.6984e-02,  2.2746e+00,  9.0620e-01, -2.8631e-01,  1.8694e-01,
           1.7214e+00],
         [ 9.0147e-01, -1.8694e-01,  2.0037e-01,  3.9865e-01, -2.7431e-01,
           3.7305e-02, -2.8713e-01,  1.0141e+00,  2.5556e-01, -1.4180e-01,
          -7.2546e-01,  1.5222e+00, -4.1565e+00, -2.4114e-01,  5.9824e-01,
           8.1000e-01],
         [ 1.6569e+00, -5.9286e-01, -7.9441e-01,  2.2301e+00,  1.3695e+00,
           1.0594e-01,  1.1546e+00, -2.5739e-01,  3.7875e-01,  1.1542e+00,
           1.3151e+00,  1.2410e+00, -2.0238e+00,  2.1773e+00,  6.4864e-01,
           1.0087e+00],
         [ 6.4358e-01, -1.2382e+00,  1.1689e+00,  7.3964e-01,  1.1317e+00,
           1.8540e+00,  1.6868e+00,  1.5500e+00, -2.5424e-01,  1.4791e+00,
           1.0050e+00,  6.2749e-01, -1.0419e+00, -1.5058e+00, -3.0024e-02,
           2.2134e-01],
    