In [1]:
import torch
import math
import torch.nn as nn

# Parameters
d_model = 512  # Example d_model
seq_len = 10   # Example seq_len
dropout_rate = 0.1  # Example dropout rate

# Initialize the components
dropout = nn.Dropout(dropout_rate)

# Step 1: Create positional encoding matrix
pe = torch.zeros(seq_len, d_model)
print("Initial positional encoding (pe) shape:", pe.shape)
print(pe)

# Step 2: Create position tensor
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)  # (seq_len, 1)
print("\nPosition tensor shape:", position.shape)
print(position)

# Step 3: Calculate div_term
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model / 2)
print("\nDiv term shape:", div_term.shape)
print(div_term)

# Step 4: Apply sine and cosine
pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices
print("\nUpdated positional encoding (pe) shape after sine and cosine:")
print(pe)

# Step 5: Add batch dimension to the positional encoding
pe = pe.unsqueeze(0)  # (1, seq_len, d_model)
print("\nPositional encoding with batch dimension added (pe) shape:", pe.shape)
print(pe)

# Step 6: Simulate input tensor (x) for embeddings
x = torch.randn(2, seq_len, d_model)  # Example input tensor with shape (batch_size, seq_len, d_model)
print("\nInput tensor (x) shape:", x.shape)
print(x)

# Step 7: Add positional encoding (no gradient tracking)
x = x + (pe[:, :x.shape[1], :]).requires_grad_(False)  # (batch, seq_len, d_model)
print("\nTensor after adding positional encoding:")
print(x)

# Step 8: A


Initial positional encoding (pe) shape: torch.Size([10, 512])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

Position tensor shape: torch.Size([10, 1])
tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.],
        [9.]])

Div term shape: torch.Size([256])
tensor([1.0000e+00, 9.6466e-01, 9.3057e-01, 8.9769e-01, 8.6596e-01, 8.3536e-01,
        8.0584e-01, 7.7737e-01, 7.4989e-01, 7.2339e-01, 6.9783e-01, 6.7317e-01,
        6.4938e-01, 6.2643e-01, 6.0430e-01, 5.8294e-01, 5.6234e-01, 5.4247e-01,
        5.2330e-01, 5.0481e-01, 4.8697e-01, 4.6976e-01, 4.5316e-01, 4.3714e-01,
        4.2170e-01, 4.0679e-01, 3.9242e-01, 3.7855e-01, 3.6517e-01, 3.5227e-01,
        3.3982e-01, 3.2781e-01, 3.1623e-01, 3.0505e-01, 2.9427e-