In [8]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=200, device='cpu'):
        super(PositionalEncoding, self).__init__()
        self.pos_table = self._get_sinusoid_encoding_table(max_len, d_model, device)

    def _get_sinusoid_encoding_table(self, max_seq_len, d_hid, device):
        position = torch.arange(0, max_seq_len, device=device).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_hid, 2, device=device).float() * -(torch.log(torch.tensor(10000.0)) / d_hid))
        
        sinusoid_table = torch.zeros(max_seq_len, d_hid, device=device)
        sinusoid_table[:, 0::2] = torch.sin(position * div_term)
        sinusoid_table[:, 1::2] = torch.cos(position * div_term)

        return sinusoid_table.unsqueeze(0)

    def forward(self, x):
        return x + self.pos_table[:, :x.size(1)].clone().detach()

# Initialize Positional Encoding
d_model = 8
max_len = 10
pos_enc = PositionalEncoding(d_model, max_len)

# Generate example data
batch_size = 1
seq_len = 5
x = torch.rand(batch_size, seq_len, d_model)

print("Original Input:\n", x)
print("input shape:", x.shape)

# Apply Positional Encoding
x_encoded = pos_enc(x)

print("\nPositional Encoded Input:\n", x_encoded)
print("output shape", x_encoded.shape)

Original Input:
 tensor([[[0.3850, 0.5385, 0.9373, 0.9378, 0.4698, 0.5875, 0.2967, 0.9103],
         [0.2427, 0.0146, 0.1008, 0.7644, 0.5750, 0.4641, 0.2080, 0.3063],
         [0.5608, 0.6441, 0.6097, 0.7116, 0.2767, 0.1115, 0.8490, 0.5207],
         [0.0299, 0.7309, 0.8054, 0.1917, 0.0828, 0.1358, 0.8224, 0.0253],
         [0.8246, 0.2751, 0.3948, 0.2837, 0.4042, 0.0490, 0.5056, 0.6815]]])
input shape: torch.Size([1, 5, 8])

Positional Encoded Input:
 tensor([[[ 0.3850,  1.5385,  0.9373,  1.9378,  0.4698,  1.5875,  0.2967,
           1.9103],
         [ 1.0842,  0.5549,  0.2006,  1.7595,  0.5850,  1.4640,  0.2090,
           1.3063],
         [ 1.4701,  0.2279,  0.8084,  1.6916,  0.2967,  1.1113,  0.8510,
           1.5207],
         [ 0.1710, -0.2591,  1.1009,  1.1470,  0.1128,  1.1353,  0.8254,
           1.0253],
         [ 0.0678, -0.3786,  0.7842,  1.2048,  0.4441,  1.0482,  0.5096,
           1.6815]]])
output shape torch.Size([1, 5, 8])
