In [17]:
import torch

### Basic Embedding

In [18]:
vocab_size = 32
embedding_size = 2
batch_size = 4
seq_len = 10
input = torch.randint(0, vocab_size, (batch_size, seq_len)) # (batch_size, seq_len) = (4, 10) random integers between 0 and 31 (vocab_size)

In [19]:
input

tensor([[11,  9, 21, 28, 20, 16, 10, 24, 26,  2],
        [18, 23, 14, 10,  8,  2, 17, 14,  0, 13],
        [ 6, 27, 14,  6, 15, 23, 13,  6, 13,  3],
        [23, 10,  6,  5, 13, 28,  5, 30, 20, 25]])

In [20]:
embedding = torch.nn.Embedding(vocab_size, embedding_size)

In [21]:
embeddings = embedding(input)
embeddings.shape # (batch_size, seq_len, embedding_size) = (4, 10, 2)

torch.Size([4, 10, 2])

### Basic Embedding and Learnable Positional Embedding

In [22]:
embedding_layer = torch.nn.Embedding(vocab_size, embedding_size)
positional_embedding_layer = torch.nn.Embedding(seq_len, embedding_size)

input = torch.randint(0, vocab_size, (batch_size, seq_len)) # (batch_size, seq_len) = (4, 10) random integers between 0 and 31 (vocab_size)

embeddings = embedding_layer(input) # (batch_size, seq_len, embedding_size) = (4, 10, 2)
positional_embeddings = positional_embedding_layer(torch.arange(seq_len)) # (seq_len, embedding_size) = (10, 2)
positional_embeddings = positional_embeddings.unsqueeze(0) # (1, seq_len, embedding_size) = (1, 10, 2)
positional_embeddings = positional_embeddings.expand(batch_size, -1, -1) # (batch_size, seq_len, embedding_size) = (4, 10, 2)

embeddings = embeddings + positional_embeddings
embeddings.shape # (batch_size, seq_len, embedding_size) = (4, 10, 2)


torch.Size([4, 10, 2])

In [23]:
positional_embeddings.shape

torch.Size([4, 10, 2])