In [2]:
import torch
import torch.nn as nn
import math

In [15]:
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size=100256, d_model=2):
        super(InputEmbedding, self).__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, tokens):
        return self.embedding(tokens) * math.sqrt(self.d_model)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        # Compute positional encodings once in log space
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = self.pe[:, :x.shape[0], :].squeeze(0)
        print(pe)
        x = x + pe
        return self.dropout(x)

In [16]:
ip_embd = InputEmbedding()
ps_embd = PositionalEncoding(d_model=2)

In [17]:
tokens = torch.tensor([0, 1])
embed = ip_embd(tokens)
embed

tensor([[-1.4133, -0.6876],
        [-2.6755,  0.7028]], grad_fn=<MulBackward0>)

In [18]:
embed = ps_embd(embed)
embed

tensor([[0.0000, 1.0000],
        [0.8415, 0.5403]])


tensor([[-1.5704,  0.3471],
        [-2.0378,  1.3812]], grad_fn=<MulBackward0>)