In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random

device = torch.device("cuda:0")

In [2]:
from torch.nn.utils.rnn import pad_sequence, pack_sequence, pack_padded_sequence, pad_packed_sequence

In [3]:
a = torch.ones(14)
b = torch.ones(7)
c = torch.ones(11)
d = torch.ones(9)
data_arr = [a, b, c, d]

padded_seq = pad_sequence(data_arr, batch_first=True)
padded_seq.size()

torch.Size([4, 14])

In [4]:
padded_seq

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])

- padded_seq = [Batch, seq_len, input_dim]

In [5]:
a = torch.ones(14, 300)  #word2vec embedding
b = torch.ones(7, 300)
c = torch.ones(11, 300)
d = torch.ones(9, 300)
data_word2vec = [a,b,c,d]
padded_seq_word2vec = pad_sequence(data_word2vec, batch_first=True)
padded_seq_word2vec.size()

torch.Size([4, 14, 300])

- pack_sequence : zero padding하여 불필요한 zeros가 추가 되는 것이 꺼림찍 할 때 사용할 수 있는 torch 자료구조
- 주어지는 input (list of Tensor)는 길이에 따른 내림차순으로 정렬이 되어있어야 함.

In [6]:
lengths = [len(item) for item in data_arr]
print(lengths)

[14, 7, 11, 9]


In [7]:
# Sort by descending lengths
sorted_idx = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
sorted_data = [data_arr[idx] for idx in sorted_idx]

# Check converted result
for sequence in sorted_data:
    print(sequence)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1.])


In [8]:
packed_seq = pack_sequence(sorted_data)
print(packed_seq)

PackedSequence(data=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1.]), batch_sizes=tensor([4, 4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)


In [9]:
lengths = [len(item) for item in data_word2vec]
print(lengths)

[14, 7, 11, 9]


In [10]:
# Sort by descending lengths
sorted_idx = sorted(range(len(lengths)), key=lengths.__getitem__, reverse=True)
sorted_data_w2v = [data_word2vec[idx] for idx in sorted_idx]

packed_seq_w2v = pack_sequence(sorted_data_w2v)

- nn.RNN
    - input_size: The number of expected features in the input `x`
    - hidden_size: The number of features in the hidden state `h`
    - num_layers: Number of recurrent layers. 
    - nonlinearity: Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``

In [11]:
rnn = nn.RNN(input_size=300, hidden_size=16, batch_first=True)

In [12]:
outputs, state = rnn(padded_seq_word2vec) # seq_length가 모두 동일하게 14로 setting
print(outputs.shape) # shape (batch, seq_len, num_directions * hidden_size)
print(state.shape) # shape (num_layers * num_directions, batch, hidden_size)

torch.Size([4, 14, 16])
torch.Size([1, 4, 16])


In [13]:
outputs, state = rnn(packed_seq_w2v) # seq_length가 모두 [14, 11, 9, 7]로 다름

In [14]:
print(outputs[0].shape) # 14, 11, 9, 7 이 모두 합해져서 [41, 16]으로 나옴
print(outputs[1]) # 이것을 어떻게 다시 정상으로 풀어야할지 힌트

torch.Size([41, 16])
tensor([4, 4, 4, 4, 4, 4, 4, 3, 3, 2, 2, 1, 1, 1])


In [15]:
state.shape # 마지막 state는 따로 unpack 할 필요없이 한번에 나옴 -> 마지막 state만 쓸 때는 이것만

torch.Size([1, 4, 16])

In [16]:
output_unpacked, output_lengths = pad_packed_sequence(outputs, batch_first=True)

In [17]:
output_unpacked.shape

torch.Size([4, 14, 16])

In [18]:
output_lengths

tensor([14, 11,  9,  7])

In [19]:
output_unpacked[2, :, :] # 9개 이후는 zero padding

tensor([[ 1.0000,  0.9871, -0.2388, -0.0806, -0.4983, -0.9996,  0.9938, -0.6307,
         -0.9868, -0.6814,  0.9936, -0.9493,  0.9176,  0.9459,  0.8890,  0.8931],
        [ 1.0000,  0.9836, -0.2090,  0.0439, -0.9148, -0.9993,  0.9932, -0.6368,
         -0.9805, -0.9291,  0.9959, -0.9731,  0.8812,  0.9570,  0.8930,  0.7768],
        [ 1.0000,  0.9829, -0.1926,  0.0745, -0.9257, -0.9992,  0.9928, -0.6471,
         -0.9819, -0.9195,  0.9953, -0.9686,  0.8698,  0.9514,  0.9169,  0.8340],
        [ 1.0000,  0.9827, -0.1981,  0.0958, -0.9243, -0.9992,  0.9928, -0.6563,
         -0.9819, -0.9197,  0.9953, -0.9683,  0.8657,  0.9509,  0.9174,  0.8303],
        [ 1.0000,  0.9826, -0.2028,  0.1013, -0.9240, -0.9992,  0.9927, -0.6519,
         -0.9816, -0.9195,  0.9953, -0.9680,  0.8650,  0.9506,  0.9176,  0.8295],
        [ 1.0000,  0.9826, -0.2054,  0.1011, -0.9236, -0.9993,  0.9928, -0.6506,
         -0.9816, -0.9192,  0.9953, -0.9680,  0.8647,  0.9503,  0.9179,  0.8298],
        [ 1.0000,  0.9