### P24 nn.Embedding前向查表索引过程与one hot关系

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [4]:
torch.manual_seed(0)

<torch._C.Generator at 0x19e182e5c50>

In [5]:
# 词表长度10，每一个词对应3维
embedding = nn.Embedding(10, 3)

In [6]:
embedding.weight

Parameter containing:
tensor([[-1.1258, -1.1524, -0.2506],
        [-0.4339,  0.8487,  0.6920],
        [-0.3160, -2.1152,  0.3223],
        [-1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377, -0.1435],
        [-0.1116, -0.6136,  0.0316],
        [-0.4927,  0.2484,  0.4397],
        [ 0.1124, -0.8411, -2.3160],
        [-0.1023,  0.7924, -0.2897],
        [ 0.0525,  0.5229,  2.3022]], requires_grad=True)

In [7]:
# 2个句子，每个句子4个词
# b = 2, s = 4
input = torch.LongTensor([[1, 2, 4, 5],
                          [4, 3, 2, 9]])

In [8]:
# (b, s, ) => (b, s, h)
embedding(input)

tensor([[[-0.4339,  0.8487,  0.6920],
         [-0.3160, -2.1152,  0.3223],
         [ 0.1198,  1.2377, -0.1435],
         [-0.1116, -0.6136,  0.0316]],

        [[ 0.1198,  1.2377, -0.1435],
         [-1.2633,  0.3500,  0.3081],
         [-0.3160, -2.1152,  0.3223],
         [ 0.0525,  0.5229,  2.3022]]], grad_fn=<EmbeddingBackward0>)

In [10]:
# num_classes == vocab_size
# (b, s) => (b, s, v)
# 以[1, 2, 4, 5]为例，
# - 1的位置为0，其余为0，即 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]，以此类推
input_onehot = F.one_hot(input, num_classes=10)
print(input_onehot.shape)
input_onehot

torch.Size([2, 4, 10])


tensor([[[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])

In [11]:
embedding.weight.shape

torch.Size([10, 3])

In [12]:
# input_onehot.shape: (b, s, v)
# embedding.weight.shape: (v, h)
# (b, s, h)
torch.matmul(input_onehot.type(torch.float32),embedding.weight.type(torch.float32))

tensor([[[-0.4339,  0.8487,  0.6920],
         [-0.3160, -2.1152,  0.3223],
         [ 0.1198,  1.2377, -0.1435],
         [-0.1116, -0.6136,  0.0316]],

        [[ 0.1198,  1.2377, -0.1435],
         [-1.2633,  0.3500,  0.3081],
         [-0.3160, -2.1152,  0.3223],
         [ 0.0525,  0.5229,  2.3022]]], grad_fn=<UnsafeViewBackward0>)