In [1]:
import torch

### Understanding the difference between Embedding and Linear Layers
- Embedding Layers in PyTorch accomplish the same as linear layers that perform matrix multiplications; the reason we use embedding layers is computational efficiency
- we will take a look at this relationship step by step

### 1. Using nn.Embedding

In [2]:
# Suppose we have the following 3 examples
# which may represent token IDs in a LLM context
idx = torch.tensor([2, 3, 1])

# The number of rows in the embedding matrix can be determined
# by obtaining the largest token ID + 1
# If the highest token ID is 3, then we want 4 rows, for the possible token IDs 0, 1, 2, 3
num_idx = max(idx)+1

# The desired embedding dimension is a hyperparameter
out_dim = 5

In [3]:
# Let's implement a simple embedding layer
torch.manual_seed(123)

embedding = torch.nn.Embedding(num_idx, out_dim)
embedding.weight

Parameter containing:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  1.5810],
        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015],
        [ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953]], requires_grad=True)

- We can then use the embedding layers to obtain the vector representation of a training example with ID 1:

In [5]:
embedding(torch.tensor([1]))

tensor([[ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]],
       grad_fn=<EmbeddingBackward0>)

In [6]:
idx = torch.tensor([2, 3, 1])
embedding(idx)

tensor([[ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953],
        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]],
       grad_fn=<EmbeddingBackward0>)

### 2. Using nn.Linear
- Now, we will demonstrate that the embedding layer above accomplishes exactly the same as nn.Embedding layer on a one-hot encoded representation in PyTorch
- First, let's convert the token IDs into a one-hot representation

In [7]:
onehot = torch.nn.functional.one_hot(idx)
onehot # remember: 2, 3, 1

tensor([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 1, 0, 0]])

- Next, we initialize a Linear layer, which caries out a matrix multiplication $XW^T$