# BLSTM处理变长序列

假设我们正在使用BLSTM模型处理句子分类的任务，BLSTM最后一个时刻的输出作为句子的表示。

例如以下实例：

    sentences = [['nice', 'day'], ['I', 'like', 'to', 'eat', 'apple'], ['can', 'a', 'can']]

在实际操作时，考虑到计算性能，会将其padding成统一的长度（通常是一个batch中的最大长度）：

    sentences = [
        ['nice', 'day', '_PAD', '_PAD', '_PAD'],
        ['I', 'like', 'to', 'eat', 'apple'],
        ['can', 'a', 'can', '_PAD', '_PAD']
    ]

考虑序列`['nice', 'day', '_PAD', '_PAD', '_PAD']`，对于正向LSTM，我们仅需要在`day`处的`hidden state`；对于反向LSTM，仅需要从`day`编码到`nice`，`_PAD`处的值并不需要计算。

PyTorch通过`torch.nn.utils.rnn.PackedSequence`类，以及以下两个函数处理上述变长序列问题：
  - `torch.nn.utils.rnn.pack_padded_sequence`
  - `torch.nn.utils.rnn.pad_packed_sequence`

接下来，我们通过一组模拟数据，介绍PyTorch中双向LSTM处理变长序列的方法，并对其正确性作了检验。

## 引入相关包

首先引入需要的包，并设置相关参数：

In [2]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

input_size = 64
hidden_size = 100
voc_size = 10

## 初始化BLSTM

接着初始化BLSTM，这里不使用bias是为了方便权重的初始化，同时为了方便演示，将`batch_first`设为`True`。

In [3]:
lstm = nn.LSTM(
    input_size=input_size, hidden_size=hidden_size, num_layers=1,
    bias=False, batch_first=True, bidirectional=True)

为了验证PyTorch计算双向lstm输出的正确性，我们将正、方向LSTM的权重设为相同值：

In [4]:
weight_i = getattr(lstm, 'weight_ih_l{0}'.format(0))  # 正向: (W_ii|W_if|W_ig|W_io)
weight_i_r = getattr(lstm, 'weight_ih_l{0}_reverse'.format(0))  # 反向: (W_ii|W_if|W_ig|W_io)
weight_i_r.data.copy_(weight_i.data)

weight_h = getattr(lstm, 'weight_hh_l{0}'.format(0))  # 正向: (W_hi|W_hf|W_ig|W_ho)
weight_h_r = getattr(lstm, 'weight_hh_l{0}_reverse'.format(0))  # 反向: (W_hi|W_hf|W_hg|W_ho)
weight_h_r.data.copy_(weight_h.data)

print('Initialization is done!')

Initialization is done!


这样，若某个序列的正、反向输入相同，则正、反方向最后一个时刻的输出应该一致；若反向LSTM计算了`_PAD`值，则输出结果会不一致。

## 构建模拟数据

In [31]:
# 设置模拟输入数据
sentences = [['nice', 'day'], ['I', 'like', 'to', 'eat', 'apple'], ['can', 'a', 'can']]
test_sent_idx = 2  # 即['can', 'a', 'can']在sentences中的下标

# 构建alphabet
alphabet = {}
index = 1
for sentence in sentences:
    for word in sentence:
        if word not in alphabet:
            alphabet[word] = index
            index += 1
lengths = [len(s) for s in sentences]
max_len = max(lengths)
batch_size = len(sentences)

inputs = np.zeros((batch_size, max_len), dtype='int32')
for i, sentence in enumerate(sentences):
    ids = list(map(lambda w: alphabet[w], sentence))
    inputs[i, :lengths[i]] = ids
inputs = torch.LongTensor(inputs)
lengths = torch.LongTensor(lengths)

# 按句子实际长度降序排序
lengths, indices = torch.sort(lengths, descending=True)
inputs = inputs[indices]
_, indices_recover = torch.sort(indices)  # 用于还原为原序列

# 设置embedding层，其中padding_idx表示padding值的编号
embedding = nn.Embedding(voc_size, input_size, padding_idx=0)
inputs = embedding(inputs)
print(inputs.size())  # [3, 5, 64]

inputs_packed = pack_padded_sequence(inputs, lengths, True)
print('inputs_packed.data.size: {0}'.format(inputs_packed.data.size()))
print('batch_sizes: {0}'.format(inputs_packed.batch_sizes))

torch.Size([3, 5, 64])
inputs_packed.data.size: torch.Size([10, 64])
batch_sizes: tensor([ 3,  3,  2,  1,  1])


PackedSequence包含两个值，分别是`data`和`batch_sizes`。其中`data`根据`lengths`参数(即序列的实际长度)，记录了`inputs`中的tensor；`batch_sizes`长度等于实际长度的最大值，第`i`个值记录了第`i`时刻输入的batch size大小。

## 计算LSTM输出

In [32]:
lstm_output, lstm_hidden = lstm(inputs_packed)
lstm_hidden, lstm_cell_state = lstm_hidden[0], lstm_hidden[1]
lstm_hidden = lstm_hidden.transpose(0, 1)
print(lstm_hidden.size())  # torch.Size([batch_size, 2, hidden_size])

torch.Size([3, 2, 100])


`lstm_hidden`和`lstm_cell_state`分别记录了正、反向最后一个时刻的`hidden state`和`cell state`。

In [34]:
# lstm_hidden还原为原来的顺序
lstm_hidden_recover = lstm_hidden[indices_recover]
print(lstm_hidden_recover[test_sent_idx])  # [2, 100]

tensor([[-0.3368, -0.0240,  0.1080, -0.1191,  0.0403, -0.2525, -0.1497,
         -0.1395, -0.0139, -0.0174,  0.0452, -0.0374,  0.0151, -0.0664,
          0.1031, -0.2861,  0.0777,  0.0341,  0.2238, -0.0606,  0.1187,
         -0.0912, -0.1714, -0.2521, -0.0669, -0.0877, -0.0448,  0.1882,
          0.2066, -0.0237,  0.2246, -0.0791,  0.1213, -0.0267,  0.0449,
         -0.1667, -0.2509, -0.0138,  0.1489,  0.0431, -0.1908,  0.0314,
          0.0225,  0.0544,  0.0571,  0.0266, -0.1087,  0.0524, -0.0433,
         -0.0243,  0.0668,  0.0219, -0.1165, -0.1460, -0.1115, -0.0206,
          0.1466,  0.0416, -0.2200,  0.0825, -0.0216, -0.1909, -0.0777,
          0.2464,  0.1675, -0.0917,  0.0665, -0.0184,  0.0786,  0.0604,
          0.0275, -0.0962, -0.0110,  0.0853,  0.1375,  0.0879,  0.0941,
          0.0529, -0.0430,  0.1885, -0.0014, -0.1174,  0.0093,  0.2019,
         -0.1814,  0.0357,  0.0561,  0.1409,  0.0529,  0.1508, -0.0638,
         -0.3775, -0.0084,  0.0304, -0.1446, -0.0993, -0.1134, -

可以看出正、反向的输出一致。

`lstm_output`记录的是BLSTM每个时刻的输出，根据句子实际长度，也可以取出在最后一个时刻的输出：

In [33]:
# 还原为原形状
lstm_output_pad, lengths = pad_packed_sequence(lstm_output, batch_first=True)
print(lstm_output_pad.size())  # size=[4, 9, 200]

# sent_idx=1句子正反向lstm输出
hidden_last = lstm_output_pad[sent_idx][lengths[sent_idx]-1][:hidden_size]  # 正向
hidden_last_r = lstm_output_pad[sent_idx][0][hidden_size:]  # 反向
print(hidden_last)
print(hidden_last_r)

torch.Size([3, 5, 200])
tensor([-0.3368, -0.0240,  0.1080, -0.1191,  0.0403, -0.2525, -0.1497,
        -0.1395, -0.0139, -0.0174,  0.0452, -0.0374,  0.0151, -0.0664,
         0.1031, -0.2861,  0.0777,  0.0341,  0.2238, -0.0606,  0.1187,
        -0.0912, -0.1714, -0.2521, -0.0669, -0.0877, -0.0448,  0.1882,
         0.2066, -0.0237,  0.2246, -0.0791,  0.1213, -0.0267,  0.0449,
        -0.1667, -0.2509, -0.0138,  0.1489,  0.0431, -0.1908,  0.0314,
         0.0225,  0.0544,  0.0571,  0.0266, -0.1087,  0.0524, -0.0433,
        -0.0243,  0.0668,  0.0219, -0.1165, -0.1460, -0.1115, -0.0206,
         0.1466,  0.0416, -0.2200,  0.0825, -0.0216, -0.1909, -0.0777,
         0.2464,  0.1675, -0.0917,  0.0665, -0.0184,  0.0786,  0.0604,
         0.0275, -0.0962, -0.0110,  0.0853,  0.1375,  0.0879,  0.0941,
         0.0529, -0.0430,  0.1885, -0.0014, -0.1174,  0.0093,  0.2019,
        -0.1814,  0.0357,  0.0561,  0.1409,  0.0529,  0.1508, -0.0638,
        -0.3775, -0.0084,  0.0304, -0.1446, -0.0993, 

取出的值与`lstm_hidden`的值相等。