[toc]

# Pytorch Linear

## nn.Linear

In [3]:
import torch

batch_size = 10
input_dim = 5
output_dim = 3
x = torch.randn(batch_size, input_dim)
linear = torch.nn.Linear(input_dim, output_dim) 
linear(x).shape # [batch_size, input_dim]

torch.Size([10, 3])

需要注意的是，权重的形状是 [output_dim, input_dim]

做的乘法是

$$
y = x A^T + b
$$

In [9]:
linear.weight.shape # b

torch.Size([3, 5])

### x 可以是多维的。此时返回的维度为 [*, output_dim]

从三维为例。假设 x.shape = [batch_size, seq_len, input_dim]

In [36]:
import torch

batch_size = 2
seq_len = 4
input_dim = 5
output_dim = 3
x = torch.randn(batch_size, seq_len, input_dim)
linear = torch.nn.Linear(input_dim, output_dim) 
y1 = linear(x)
print(y1.shape) # [batch_size, seq_len, input_dim]

torch.Size([2, 4, 3])


当 x 是多维的时候乘法是如何进行的呢？实际上，是和 x 的每一个子矩阵都乘以相同的权重。在实现上，可以讲所有子矩阵堆叠起来，转换为普通的矩阵乘法。

![](https://gitee.com/EdwardElric_1683260718/picture_bed/raw/master/img/20200907182258.png)

因此，我们可以尝试模拟 nn.Linear 的操作。

In [37]:
weight = linear.weight
bias = linear.bias

y2 = torch.matmul(x.view(-1, input_dim), weight.transpose(0, 1)).view(batch_size, -1, output_dim) + bias

可以看到结果是一样的。

In [40]:
print(torch.all(y1 == y2))

tensor(True)


# References
1. [Linear — PyTorch 1.6.0 documentation](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html?highlight=nn%20linear#torch.nn.Linear)