In [1]:
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils import rnn
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_x = [torch.tensor([1, 2, 3, 4, 5, 6, 7]),
           torch.tensor([2, 3, 4, 5, 6, 7]),
           torch.tensor([3, 4, 5, 6, 7]),
           torch.tensor([4, 5, 6, 7]),
           torch.tensor([5, 6, 7]),
           torch.tensor([6, 7]),
           torch.tensor([7])]

In [3]:
def collate_fn(train_data):
    train_data.sort(key=lambda data: len(data), reverse=True)
    data_length = [len(data) for data in train_data]
    train_data = rnn.pad_sequence(train_data, batch_first=True, padding_value=0)
    return train_data, data_length

In [4]:
class MyData(Dataset):
    def __init__(self, train_x):
        self.train_x = train_x

    def __len__(self):
        return len(self.train_x)

    def __getitem__(self, item):
        return self.train_x[item].type(torch.float32)
        
train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)

In [5]:
for data, length in train_dataloader:
    print(data)
    print(length)

tensor([[1., 2., 3., 4., 5., 6., 7.],
        [2., 3., 4., 5., 6., 7., 0.]])
[7, 6]
tensor([[3., 4., 5., 6., 7.],
        [4., 5., 6., 7., 0.]])
[5, 4]
tensor([[5., 6., 7.],
        [6., 7., 0.]])
[3, 2]
tensor([[7.]])
[1]


In [6]:
for data, length in train_dataloader:
    data = rnn.pack_padded_sequence(data, length, batch_first=True)
    print(data)

PackedSequence(data=tensor([1., 2., 2., 3., 3., 4., 4., 5., 5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 2, 2, 2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([3., 4., 4., 5., 5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([5., 6., 6., 7., 7.]), batch_sizes=tensor([2, 2, 1]), sorted_indices=None, unsorted_indices=None)
PackedSequence(data=tensor([7.]), batch_sizes=tensor([1]), sorted_indices=None, unsorted_indices=None)


In [7]:
net = torch.nn.LSTM(1, 5, batch_first=True)

In [8]:
def collate_fn(train_data):
    train_data.sort(key=lambda data: len(data), reverse=True)
    data_length = [len(data) for data in train_data]
    train_data = rnn.pad_sequence(train_data, batch_first=True, padding_value=0)
    return train_data.unsqueeze(-1), data_length  # 对train_data增加了一维数据

In [9]:
net

LSTM(1, 5, batch_first=True)

In [10]:
net.all_weights[0][0].dtype

torch.float32

In [12]:
train_data = MyData(train_x)
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=collate_fn)

flag = 0
for data, length in train_dataloader:
    data = rnn.pack_padded_sequence(data, length, batch_first=True)
    # print(data.data.shape)
    output, hidden = net(data)
    if flag == 0:
        output, out_len = rnn.pad_packed_sequence(output, batch_first=True)
        print(output)
        print(output.shape)
        flag = 1

tensor([[[ 4.6671e-02,  1.1022e-01,  4.1483e-02, -2.7838e-02, -2.2109e-01],
         [ 9.1933e-02,  2.5600e-01,  1.0339e-01, -1.1843e-04, -4.0936e-01],
         [ 1.1548e-01,  3.5329e-01,  1.1301e-01,  4.9575e-02, -5.3543e-01],
         [ 1.1759e-01,  4.0443e-01,  8.6840e-02,  9.6579e-02, -6.2591e-01],
         [ 1.0640e-01,  4.3535e-01,  5.7716e-02,  1.2702e-01, -6.9590e-01],
         [ 8.9529e-02,  4.5920e-01,  3.6772e-02,  1.3763e-01, -7.4985e-01],
         [ 7.1837e-02,  4.7978e-01,  2.3306e-02,  1.3230e-01, -7.9073e-01]],

        [[ 7.8933e-02,  2.0046e-01,  7.6023e-02,  2.6744e-03, -3.1016e-01],
         [ 1.1227e-01,  3.3568e-01,  1.0565e-01,  4.5169e-02, -4.9430e-01],
         [ 1.1674e-01,  3.9973e-01,  8.6080e-02,  9.2831e-02, -6.0613e-01],
         [ 1.0612e-01,  4.3384e-01,  5.7799e-02,  1.2490e-01, -6.8535e-01],
         [ 8.9417e-02,  4.5855e-01,  3.6828e-02,  1.3659e-01, -7.4397e-01],
         [ 7.1787e-02,  4.7945e-01,  2.3325e-02,  1.3184e-01, -7.8733e-01],
         [

In [None]:
output.shape

AttributeError: 'PackedSequence' object has no attribute 'shape'