In [8]:
import torch
import torch.nn as nn

batch_size, seq_len, d_model = 64, 10, 3
emb = nn.Embedding(seq_len, d_model)  # 10个词，每个词是3维向量
print(f"嵌入层的形状: {emb.weight.shape}")
print(f"嵌入层的参数: \n{emb.weight.detach().numpy()}")

query = torch.LongTensor([[1, 2, 3], [4, 5, 6]])  # 二维索引指标
output = emb(query)
print(f"嵌入层输出的形状: {output.shape}")
print(f"嵌入层输出的张量: \n{output.detach().numpy()}")

嵌入层的形状: torch.Size([10, 3])
嵌入层的参数: 
[[-0.42180437  2.0353827   1.1638284 ]
 [ 0.27184504  1.4092371  -0.9940103 ]
 [-0.07250201 -1.5029283  -0.6309964 ]
 [-0.942438   -0.7351417   0.22733966]
 [ 1.9784678   0.8207226  -0.56497574]
 [ 1.2907467   0.4594038   1.112774  ]
 [ 1.0533401  -0.5800268   0.0155646 ]
 [ 0.05913044  0.8964288   0.09396027]
 [ 1.2051184  -0.0675792  -0.55703914]
 [-0.17452654 -0.12954919  0.03179752]]
嵌入层输出的形状: torch.Size([2, 3, 3])
嵌入层输出的张量: 
[[[ 0.27184504  1.4092371  -0.9940103 ]
  [-0.07250201 -1.5029283  -0.6309964 ]
  [-0.942438   -0.7351417   0.22733966]]

 [[ 1.9784678   0.8207226  -0.56497574]
  [ 1.2907467   0.4594038   1.112774  ]
  [ 1.0533401  -0.5800268   0.0155646 ]]]


In [6]:
# 一维索引指标
q_in = torch.LongTensor([1, 2, 3])
q_out = emb(q_in)
print(q_out)

tensor([[ 0.5812,  1.5754,  1.6710],
        [-1.9034, -2.0241, -0.2872],
        [ 0.0303,  1.4448, -0.3176]], grad_fn=<EmbeddingBackward0>)


In [18]:
# 仿射变换
affine = nn.Linear(d_model, 2 * d_model, bias=False)  # 输入3维，输出6维
print(f"仿射变换的参数: \n{affine.weight.detach().numpy()}")
print(f"仿射变换的偏置: \n{affine.bias}")
x_in = torch.randn(seq_len, d_model)
x_out = affine(x_in)
print(f"仿射变换的输出: \n{x_out.detach().numpy()}")

仿射变换的参数: 
[[ 0.09462309  0.14278811 -0.13938218]
 [-0.07240725  0.05697745 -0.28411332]
 [-0.34175146  0.2998852  -0.3585627 ]
 [ 0.37633288 -0.45240918 -0.14626577]
 [ 0.37369865  0.5198997  -0.05909109]
 [ 0.09857512 -0.14968792 -0.38849717]]
仿射变换的偏置: 
None
仿射变换的输出: 
[[ 0.00445703  0.26982287  0.70757085 -0.49577072 -0.2815076   0.06644682]
 [-0.3096938  -0.09524789 -1.0542555   1.831219   -1.3346671   0.7598453 ]
 [ 0.10411814  0.22760741  0.12531304  0.4260156  -0.04005839  0.45777068]
 [-0.33386952 -0.42516574 -0.519474   -0.19469433 -0.5449988  -0.5710819 ]
 [ 0.28651223  0.50648236  0.6442523   0.23408063  0.24772479  0.68026894]
 [ 0.17510037  0.40755725  1.3440962  -1.3440317   0.29801944 -0.18057963]
 [-0.09313991 -0.05039303 -0.26036248  0.37601334 -0.33557263  0.12125733]
 [ 0.13883959  0.32704493  0.03142339  0.8973393  -0.15197322  0.79305166]
 [-0.28619874 -0.09947865  0.02819148 -0.23423627 -0.84119153 -0.2251691 ]
 [-0.4055633  -0.24186707 -0.54873335  0.46400195 -1.19

In [26]:
# 手动计算线性变换
x_out_manual = x_in.matmul(affine.weight.t())
torch.equal(x_out, x_out_manual)

True