# Understanding the Difference between Embedding Layers and Linear Layers

Both embedding layers and linear (fully connected) layers in PyTorch can be used to perform matrix multiplications, but we use embedding layers specifically for computational efficiency when dealing with one-hot encoded vectors.

In [1]:
import torch
print("PyTorch version:", torch.__version__)

PyTorch version: 2.9.1+cpu


## Using `nn.Embedding`

In [2]:
# Suppose we have 3 training examples which may represent token IDs in a LLM context
idx = torch.tensor([2, 3, 1])

# The number of rows in the embedding matrix can be determined by the maximum index + 1
# E.g., if the max index is 3, we need 4 rows (0, 1, 2, 3)
num_idx = max(idx) + 1

# The desired embedding dimension is a hyperparameter we can choose
out_dim = 5

In [3]:
torch.manual_seed(0)

embedding = torch.nn.Embedding(num_idx, out_dim)

In [4]:
embedding.weight

Parameter containing:
tensor([[-1.1258, -1.1524, -0.2506, -0.4339,  0.5988],
        [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577],
        [ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
        [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537]], requires_grad=True)

Now we can use the embedding layers to obtain the vector representations for the given indices.

In [5]:
# token ID 1
embedding(torch.tensor([1]))  # returns the vector for token ID 1

tensor([[-1.5551, -0.3414,  1.8530,  0.4681, -0.1577]],
       grad_fn=<EmbeddingBackward0>)

In [6]:
# token ID 2
embedding(torch.tensor([2]))  # returns the vector for token ID 2

tensor([[1.4437, 0.2660, 1.3894, 1.5863, 0.9463]],
       grad_fn=<EmbeddingBackward0>)

We can convert all training examples at once:

In [7]:
embedding(idx)

tensor([[ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
        [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537],
        [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577]],
       grad_fn=<EmbeddingBackward0>)

## Using `nn.Linear`

We can achieve the same result using a linear layer `nn.Linear` on a one-hot encoded representation in PyTorch.

First, we need to convert the token IDs into a one-hot representation:

In [8]:
onehot = torch.nn.functional.one_hot(idx)
onehot

tensor([[0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 1, 0, 0]])

Then, we initialize a linear layer which carries out a matrix multiplication $XW^T$:

In [9]:
torch.manual_seed(0)
linear = torch.nn.Linear(num_idx, out_dim, bias=False)
linear.weight

Parameter containing:
tensor([[-0.0037,  0.2682, -0.4115, -0.3680],
        [-0.1926,  0.1341, -0.0099,  0.3964],
        [-0.0444,  0.1323, -0.1511, -0.0983],
        [-0.4777, -0.3311, -0.2061,  0.0185],
        [ 0.1977,  0.3000, -0.3390, -0.2177]], requires_grad=True)

To directly compare it to the `Embedding` layer above, we have to use the same small random weights, so we will reassign them:

In [10]:
linear.weight = torch.nn.Parameter(embedding.weight.T)
linear.weight

Parameter containing:
tensor([[-1.1258, -1.5551,  1.4437, -0.8437],
        [-1.1524, -0.3414,  0.2660,  0.9318],
        [-0.2506,  1.8530,  1.3894,  1.2590],
        [-0.4339,  0.4681,  1.5863,  2.0050],
        [ 0.5988, -0.1577,  0.9463,  0.0537]], requires_grad=True)

Now we can use the linear layer on the one-hot encoded representation of the inputs:

In [11]:
linear(onehot.float())

tensor([[ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
        [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537],
        [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577]], grad_fn=<MmBackward0>)

In [12]:
# compare with embedding output
embedding(idx)

tensor([[ 1.4437,  0.2660,  1.3894,  1.5863,  0.9463],
        [-0.8437,  0.9318,  1.2590,  2.0050,  0.0537],
        [-1.5551, -0.3414,  1.8530,  0.4681, -0.1577]],
       grad_fn=<EmbeddingBackward0>)

This use of the matrix multiplication on one-hot encodings is equivalent to the embedding layer look-up but can be inefficient if we work with large embedding matrices, because there are a lot of wasteful multiplications by zeros.