In [1]:
import torch
from torch import nn
from torch.nn import functional as F

## embedding 的计算过程

- embedding module 的前向过程其实是一个索引（查表）的过程
    - 表的形式是一个 matrix（embedding.weight, learnable parameters）
        - matrix.shape: (v, h)
            - v：vocabulary size
            - h：hidden dimension
    - 具体索引的过程，是通过 one hot + 矩阵乘法的形式实现的；
    - input.shape: (b, s)
        - b：batch size
        - s：seq len
    - embedding(input)
        - (b, s) ==> (b, s, h)
        - (b, s) 和 (v, h) ==>? (b, s, h)
            - (b, s) 经过 one hot => (b, s, v)
            - (b, s, v) @ (v, h) ==> (b, s, h)

### 简单前向

In [2]:
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)

In [3]:
embedding.weight

Parameter containing:
tensor([[ 0.8376,  0.6068,  1.7555],
        [ 0.4941,  0.1717, -0.2396],
        [-1.8685,  1.2610, -0.5606],
        [ 0.8324,  1.0663,  1.2586],
        [-0.7126, -0.8973, -2.2054],
        [ 0.7383,  0.2399,  0.1330],
        [-1.3319, -0.5330,  0.9591],
        [ 0.7808, -0.2259,  0.1930],
        [ 1.1298,  0.1678,  1.1490],
        [-0.6612, -0.9927, -0.4817]], requires_grad=True)

In [4]:
# a batch of 2 samples of 4 indices each
# b, s, 
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
print(input.dtype)

torch.int64


In [5]:
# (b, s, ) => (b, s, h)
embedding(input)

tensor([[[ 0.4941,  0.1717, -0.2396],
         [-1.8685,  1.2610, -0.5606],
         [-0.7126, -0.8973, -2.2054],
         [ 0.7383,  0.2399,  0.1330]],

        [[-0.7126, -0.8973, -2.2054],
         [ 0.8324,  1.0663,  1.2586],
         [-1.8685,  1.2610, -0.5606],
         [-0.6612, -0.9927, -0.4817]]], grad_fn=<EmbeddingBackward0>)

### one-hot 矩阵乘法

In [6]:
# num_classes == vocab size
# (b, s) => (b, s, v)
input_onehot = F.one_hot(input, num_classes=10)
print(input_onehot.shape)
input_onehot

torch.Size([2, 4, 10])


tensor([[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])

In [9]:
print(embedding.weight.dtype)
print(embedding.weight.shape)

torch.float32
torch.Size([10, 3])


In [10]:
# input_onehot.shape: (b, s, v)
# embedding.weight.shape: (v, h)
# (b, s, h)
torch.matmul(input_onehot.type(torch.float32), embedding.weight)

tensor([[[ 0.4941,  0.1717, -0.2396],
         [-1.8685,  1.2610, -0.5606],
         [-0.7126, -0.8973, -2.2054],
         [ 0.7383,  0.2399,  0.1330]],

        [[-0.7126, -0.8973, -2.2054],
         [ 0.8324,  1.0663,  1.2586],
         [-1.8685,  1.2610, -0.5606],
         [-0.6612, -0.9927, -0.4817]]], grad_fn=<UnsafeViewBackward0>)

## max_norm

- https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
- When max_norm is not None, Embedding’s forward method will modify the weight tensor in-place. Since tensors needed for gradient computations cannot be modified in-place, performing a differentiable operation on Embedding.weight before calling Embedding’s forward method requires cloning Embedding.weight when max_norm is not None. For example:

### 不设置 max_norm

In [11]:
# max_norm == True ==> max_norm == 1
embedding = nn.Embedding(3, 5,)
print(embedding.weight.mean())
print(embedding.weight.std())
print(embedding.weight)
torch.norm(embedding.weight, dim=1)

tensor(0.0240, grad_fn=<MeanBackward0>)
tensor(0.6895, grad_fn=<StdBackward0>)
Parameter containing:
tensor([[ 0.7605,  0.3294, -0.3857,  0.6366,  1.1262],
        [-0.3753, -0.5345, -0.0094,  0.3832, -0.1250],
        [-1.0760,  0.1818,  1.0244, -0.5830, -0.9929]], requires_grad=True)


tensor([1.5840, 0.7675, 1.8884], grad_fn=<LinalgVectorNormBackward0>)

In [12]:
inputs = torch.tensor([0, 1, 2])
print(inputs.shape)
outputs = embedding(inputs)
outputs

torch.Size([3])


tensor([[ 0.7605,  0.3294, -0.3857,  0.6366,  1.1262],
        [-0.3753, -0.5345, -0.0094,  0.3832, -0.1250],
        [-1.0760,  0.1818,  1.0244, -0.5830, -0.9929]],
       grad_fn=<EmbeddingBackward0>)

In [13]:
torch.norm(embedding.weight, dim=1)

tensor([1.5840, 0.7675, 1.8884], grad_fn=<LinalgVectorNormBackward0>)

### max_norm=True

In [14]:
# max_norm == True ==> max_norm == 1
embedding = nn.Embedding(3, 5, max_norm=True)
print(embedding.weight.mean())
print(embedding.weight.std())
print(embedding.weight)
torch.norm(embedding.weight, dim=1)

tensor(0.1015, grad_fn=<MeanBackward0>)
tensor(1.2156, grad_fn=<StdBackward0>)
Parameter containing:
tensor([[ 0.7330, -0.2748,  0.5157, -2.5154,  1.3099],
        [ 0.4129, -0.4855,  2.0899, -0.1314,  1.9055],
        [-0.6868, -0.3076, -0.8595, -1.1350,  0.9513]], requires_grad=True)


tensor([2.9869, 2.9021, 1.8704], grad_fn=<LinalgVectorNormBackward0>)

In [15]:
inputs = torch.tensor([0, 1, 2])
print(inputs.shape)
outputs = embedding(inputs)
outputs

torch.Size([3])


tensor([[ 0.2454, -0.0920,  0.1727, -0.8421,  0.4385],
        [ 0.1423, -0.1673,  0.7201, -0.0453,  0.6566],
        [-0.3672, -0.1644, -0.4596, -0.6068,  0.5086]],
       grad_fn=<EmbeddingBackward0>)

In [16]:
torch.norm(outputs, dim=1)

tensor([1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)

In [17]:
embedding.weight

Parameter containing:
tensor([[ 0.2454, -0.0920,  0.1727, -0.8421,  0.4385],
        [ 0.1423, -0.1673,  0.7201, -0.0453,  0.6566],
        [-0.3672, -0.1644, -0.4596, -0.6068,  0.5086]], requires_grad=True)

In [18]:
torch.norm(embedding.weight, dim=1)

tensor([1.0000, 1.0000, 1.0000], grad_fn=<LinalgVectorNormBackward0>)