在某些场景下，我们可能会需要实现等长卷积，等长卷积的实现方式很简单，就是在数据的两端补0。这里，介绍了等长卷积的pytorch实现方式。

## 1. 导入需要的库

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## 2. 模拟数据

In [4]:
batch_x = torch.LongTensor(np.array([
    [1, 2, 3, 4, 5, 0, 0, 0],
    [1, 2, 3, 4, 5, 6, 7, 8],
    [1, 2, 0, 0, 0, 0, 0, 0]]
))
mask = (batch_x > 0).float()
batch_size = 3
max_len = 8

设置embedding层：

In [23]:
alphabet_size = 9
input_dim = 10
embedding = nn.Embedding(alphabet_size, input_dim, padding_idx=0)
batch_x_embed = embedding(batch_x)
print(batch_x_embed.size())  # size=(3, 8, 10)

torch.Size([3, 8, 10])


## 3. 卷积

### 3.1 卷积层

首先设置卷积层：

In [21]:
filter_num = 12
window_size = 3
conv_encoder = nn.Conv2d(
    in_channels=1, out_channels=filter_num, kernel_size=(window_size, input_dim))
print(conv_encoder)

Conv2d(1, 12, kernel_size=(3, 10), stride=(1, 1))


### 3.2 窄卷积

pytorch默认实现的是窄卷积。

In [25]:
batch_x_embed = batch_x_embed.unsqueeze(1)  # size=(3, 1, 8, 10)
conv_output = conv_encoder(batch_x_embed)
print(conv_output.size())  # size=(3, 12, 6, 1)

torch.Size([3, 12, 6, 1])


可以看到，原本的长度为8，经卷积之后变为8-3+1=6。

### 3.3 等长卷积

等长卷积是通过在数据两端补0的方式实现，pytorch提供了一个函数`nn.ZeroPad2d`，函数参数如下：

Args:
    padding (int, tuple): the size of the padding. If is `int`, uses the same
        padding in all boundaries. If a 4-`tuple`, uses (`paddingLeft`, `paddingRight`,
        `paddingTop`, `paddingBottom`)

对于embedding后的句子来说，我们只需要在其`top`和`bottom`补全(window_size-1)//2行0值。

**注**：需要注意的是，若要实现等长卷积，那么卷积核window_size必须设置为奇数。

In [28]:
pad_size = (window_size-1) // 2
pad_op = nn.ZeroPad2d((0, 0, pad_size, pad_size))
batch_x_embed_padded = pad_op(batch_x_embed)
print(batch_x_embed.size())  # size=(3, 1, 8, 10)
print(batch_x_embed_padded.size())  # size=(3, 1, 10, 10)

torch.Size([3, 1, 8, 10])
torch.Size([3, 1, 10, 10])


接着，再对padding后的数据进行卷积：

In [32]:
conv_output = conv_encoder(batch_x_embed_padded)
print(batch_x_embed.size())  # size=(3, 1, 8, 10)
print(conv_output.size())  # size=(3, 12, 6, 1)

torch.Size([3, 1, 8, 10])
torch.Size([3, 12, 8, 1])


可以看出，卷积之后的长度依然是8。