In [6]:
import numpy as np
import torch
import torch.nn as nn

# Create learnable positional encoder
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, max_length, embedding_dim):
        super(LearnablePositionalEncoding, self).__init__()
        self.positional_encoding = nn.Parameter(torch.randn(max_length, embedding_dim))
    
    def forward(self, x):
        x += self.positional_encoding
        return x

def positional_encoding(max_length, embedding_dim):
    return LearnablePositionalEncoding(max_length, embedding_dim)

# Create a dummy dataset
input_dim = 10  
sequence_length = 100  
num_samples = 1000  

# Generate random input data
input_data = np.random.randn(num_samples, sequence_length, input_dim)

# Apply positional encoding
max_length = sequence_length
embedding_dim = input_dim
pos_encoding = positional_encoding(max_length, embedding_dim)
output_data = pos_encoding(torch.Tensor(input_data))

# Print output shape: (num_samples, sequence_length, input_dim)
print(output_data.shape)  

# Print positional encoding tensor
positional_encoding_tensor = pos_encoding.positional_encoding

print(positional_encoding_tensor.shape)
print(positional_encoding_tensor)

torch.Size([1000, 100, 10])
torch.Size([100, 10])
Parameter containing:
tensor([[ 8.7260e-01, -4.1444e-01,  3.5632e-02, -5.8786e-01,  9.5883e-01,
         -8.1268e-01, -4.8118e-01,  4.8895e-01,  1.9211e+00, -3.8902e-01],
        [-6.4208e-01,  3.3416e-01,  5.1781e-01, -9.9827e-01,  9.3226e-01,
          1.7098e+00,  6.1400e-01, -1.3124e+00, -1.6880e+00,  2.9578e-01],
        [-1.2924e+00,  5.4199e-01, -7.5654e-01, -1.0622e+00, -4.9414e-01,
         -5.3186e-01, -3.7727e-01,  8.1856e-01,  7.8862e-02,  3.4074e-01],
        [ 3.3164e+00,  7.3494e-01, -1.8412e+00,  4.9381e-01,  1.4820e-01,
         -5.0241e-01, -1.1961e+00, -3.5141e-01,  1.4096e-01,  1.5781e-01],
        [ 8.6595e-02, -7.3166e-01, -1.4279e+00,  1.2642e+00, -1.8747e+00,
         -1.3450e-01, -8.1296e-01,  3.8277e-01,  4.4957e-01,  1.9979e+00],
        [ 1.3757e-01,  4.3348e-01, -4.7114e-01,  9.2125e-01, -5.1609e-01,
         -1.5222e-01,  1.5776e+00, -4.2137e-01, -7.1084e-01, -1.0344e-01],
        [-2.8797e-01,  1.4580e+00,