# BLSTM处理变长序列

假设我们正在使用BLSTM模型处理句子分类的任务，BLSTM最后一个时刻的输出作为句子的表示。

例如以下实例：

    sentences = [['nice', 'day'], ['I', 'like', 'to', 'eat', 'apple'], ['can', 'a', 'can']]

在实际操作时，考虑到计算性能，会将其padding成统一的长度（通常是一个batch中的最大长度）：

    sentences = [
        ['nice', 'day', '_PAD', '_PAD', '_PAD'],
        ['I', 'like', 'to', 'eat', 'apple'],
        ['can', 'a', 'can', '_PAD', '_PAD']
    ]

考虑序列`['nice', 'day', '_PAD', '_PAD', '_PAD']`，对于正向LSTM，我们仅需要在`day`处的`hidden state`；对于反向LSTM，仅需要从`day`编码到`nice`，`_PAD`处的值并不需要计算。

PyTorch通过`torch.nn.utils.rnn.PackedSequence`类，以及以下两个函数处理上述变长序列问题：
  - `torch.nn.utils.rnn.pack_padded_sequence`
  - `torch.nn.utils.rnn.pad_packed_sequence`

接下来，我们通过一组模拟数据，介绍PyTorch中双向LSTM处理变长序列的方法，并对其正确性作了检验。

## 引入相关包

首先引入需要的包，并设置相关参数：

In [42]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

input_size = 64
hidden_size = 100

## 初始化BLSTM

接着初始化BLSTM，这里不使用bias是为了简化权重的初始化，同时为了方便演示，将`batch_first`设为`True`。

In [43]:
lstm = nn.LSTM(
    input_size=input_size, hidden_size=hidden_size, num_layers=1,
    bias=False, batch_first=True, bidirectional=True)

为了验证PyTorch计算双向lstm输出的正确性，我们将正、方向LSTM的权重设为相同值：

In [44]:
weight_i = getattr(lstm, 'weight_ih_l{0}'.format(0))  # 正向: (W_ii|W_if|W_ig|W_io)
weight_i_r = getattr(lstm, 'weight_ih_l{0}_reverse'.format(0))  # 反向: (W_ii|W_if|W_ig|W_io)
weight_i_r.data.copy_(weight_i.data)

weight_h = getattr(lstm, 'weight_hh_l{0}'.format(0))  # 正向: (W_hi|W_hf|W_ig|W_ho)
weight_h_r = getattr(lstm, 'weight_hh_l{0}_reverse'.format(0))  # 反向: (W_hi|W_hf|W_hg|W_ho)
weight_h_r.data.copy_(weight_h.data)

print('Initialization is done!')

Initialization is done!


这样，若某个序列的正、反向输入相同，则正、反方向最后一个时刻的输出应该一致；若反向LSTM计算了`_PAD`值，则输出结果会不一致。

## 构建模拟数据

In [49]:
# 设置模拟输入数据
sentences = [['nice', 'day'], ['I', 'like', 'to', 'eat', 'apple'], ['can', 'a', 'can']]
test_sent_idx = 2  # 即['can', 'a', 'can']在sentences中的下标

# 构建alphabet
alphabet = {}
index = 1
for sentence in sentences:
    for word in sentence:
        if word not in alphabet:
            alphabet[word] = index
            index += 1
voc_size = len(alphabet) + 1

lengths = [len(s) for s in sentences]
max_len = max(lengths)
batch_size = len(sentences)

inputs = np.zeros((batch_size, max_len), dtype='int32')
for i, sentence in enumerate(sentences):
    ids = list(map(lambda w: alphabet[w], sentence))
    inputs[i, :lengths[i]] = ids
inputs = torch.LongTensor(inputs)
lengths = torch.LongTensor(lengths)

# 按句子实际长度降序排序
lengths, indices = torch.sort(lengths, descending=True)
inputs = inputs[indices]

# 设置embedding层，其中padding_idx表示padding值的编号
embedding = nn.Embedding(voc_size, input_size, padding_idx=0)
inputs = embedding(inputs)
print(inputs.size())  # [3, 5, 64]

inputs_packed = pack_padded_sequence(inputs, lengths, True)
print('inputs_packed.data.size: {0}'.format(inputs_packed.data.size()))
print('batch_sizes: {0}'.format(inputs_packed.batch_sizes))

torch.Size([3, 5, 64])
inputs_packed.data.size: torch.Size([10, 64])
batch_sizes: tensor([ 3,  3,  2,  1,  1])


PackedSequence包含两个值，分别是`data`和`batch_sizes`。其中`data`根据`lengths`参数(即序列的实际长度)，记录了`inputs`中的tensor；`batch_sizes`长度等于实际长度的最大值，第`i`个值记录了第`i`时刻输入的batch size大小。

## 计算LSTM输出

In [50]:
lstm_output, lstm_hidden = lstm(inputs_packed)
lstm_hidden, lstm_cell_state = lstm_hidden[0], lstm_hidden[1]
lstm_hidden = lstm_hidden.transpose(0, 1)
print(lstm_hidden.size())  # torch.Size([batch_size, 2, hidden_size])

torch.Size([3, 2, 100])


`lstm_hidden`和`lstm_cell_state`分别记录了正、反向最后一个时刻的`hidden state`和`cell state`。

In [51]:
# lstm_hidden还原为原来的顺序
_, indices_recover = torch.sort(indices)
lstm_hidden_recover = lstm_hidden[indices_recover]

# 句子['can', 'a', 'can']正反向lstm最后一个时刻的输出
print(lstm_hidden_recover[test_sent_idx])  # [2, 100]

tensor([[-0.0228,  0.1170, -0.0451, -0.0568,  0.1403,  0.0784, -0.0243,
         -0.1485,  0.0261, -0.1661,  0.0351,  0.0388,  0.0242,  0.1387,
          0.0971,  0.1638,  0.1037, -0.0110, -0.1212, -0.0216,  0.0468,
          0.0197,  0.0597,  0.0338,  0.0693, -0.0474,  0.2046, -0.0400,
          0.1145, -0.0909,  0.0937,  0.1635,  0.2097, -0.0552,  0.0778,
         -0.1668, -0.1519, -0.0804,  0.1745, -0.1755, -0.2348,  0.1344,
         -0.1358,  0.2272,  0.1237,  0.0951,  0.1680,  0.2202, -0.0688,
         -0.0385, -0.0741, -0.0883,  0.0944,  0.2821, -0.0596,  0.0859,
          0.0558,  0.1748, -0.1993, -0.0317, -0.1196, -0.0811, -0.0041,
         -0.0472, -0.0441, -0.0055, -0.1401, -0.0508,  0.1898, -0.0252,
          0.2173,  0.0231,  0.1498, -0.0523,  0.2766,  0.2083, -0.1071,
          0.1980,  0.0665, -0.1066,  0.0773,  0.0345,  0.1416, -0.2942,
          0.1971, -0.0078,  0.0528, -0.2320, -0.1157, -0.1126,  0.2077,
         -0.2817, -0.0535,  0.0617, -0.0845,  0.0861, -0.0674,  

可以看出正、反向的输出一致。

`lstm_output`记录的是BLSTM每个时刻的输出，根据句子实际长度，也可以取出在最后一个时刻的输出：

In [52]:
# 还原为原形状
lstm_output_pad, lengths = pad_packed_sequence(lstm_output, batch_first=True)
lstm_output_pad_recover = lstm_output_pad[indices_recover]
lengths_recover = lengths[indices_recover]
print(lstm_output_pad.size())  # size=[4, 9, 200]

# 句子['can', 'a', 'can']正反向lstm最后一个时刻的输出
hidden_last = lstm_output_pad_recover[test_sent_idx][lengths_recover[test_sent_idx]-1][:hidden_size]  # 正向
hidden_last_r = lstm_output_pad_recover[test_sent_idx][0][hidden_size:]  # 反向
print(hidden_last)
print(hidden_last_r)

torch.Size([3, 5, 200])
tensor([-0.0228,  0.1170, -0.0451, -0.0568,  0.1403,  0.0784, -0.0243,
        -0.1485,  0.0261, -0.1661,  0.0351,  0.0388,  0.0242,  0.1387,
         0.0971,  0.1638,  0.1037, -0.0110, -0.1212, -0.0216,  0.0468,
         0.0197,  0.0597,  0.0338,  0.0693, -0.0474,  0.2046, -0.0400,
         0.1145, -0.0909,  0.0937,  0.1635,  0.2097, -0.0552,  0.0778,
        -0.1668, -0.1519, -0.0804,  0.1745, -0.1755, -0.2348,  0.1344,
        -0.1358,  0.2272,  0.1237,  0.0951,  0.1680,  0.2202, -0.0688,
        -0.0385, -0.0741, -0.0883,  0.0944,  0.2821, -0.0596,  0.0859,
         0.0558,  0.1748, -0.1993, -0.0317, -0.1196, -0.0811, -0.0041,
        -0.0472, -0.0441, -0.0055, -0.1401, -0.0508,  0.1898, -0.0252,
         0.2173,  0.0231,  0.1498, -0.0523,  0.2766,  0.2083, -0.1071,
         0.1980,  0.0665, -0.1066,  0.0773,  0.0345,  0.1416, -0.2942,
         0.1971, -0.0078,  0.0528, -0.2320, -0.1157, -0.1126,  0.2077,
        -0.2817, -0.0535,  0.0617, -0.0845,  0.0861, 

取出的值与`lstm_hidden`的值相等。