In [22]:
import torch
from torch.nn.utils.rnn import pad_sequence
import spacy
from torchtext.datasets import IMDB
from torchtext.vocab import build_vocab_from_iterator
import torch.utils.data as Data

In [23]:
test_iter = IMDB(split='test')
spacy_en = spacy.load('en_core_web_sm')


def yield_tokens(data_iter):
    for _, text in data_iter:
        yield [tok.text for tok in spacy_en.tokenizer(text)]  # 分词


vocab = build_vocab_from_iterator(yield_tokens(test_iter))
vocab.insert_token("<unk>", 0)
vocab.insert_token("<pad>", 1)
vocab.insert_token("<SOS>", 2)
vocab.insert_token("<EOS>", 3)
vocab.set_default_index(0)

In [24]:
# 文本内容转换为数字
text_transform = lambda x: [vocab['<SOS>']] + [vocab[token] for token in
                                               [tok.text for tok in spacy_en.tokenizer(x)]] + [vocab['<EOS>']]

# 文本标签转换为数字
label_transform = lambda x: 1.0 if x == 'pos' else 0.0


def collate_batch(batch):
    """
    对文本标签和文本内容进行处理使之可以用于pack_padded_sequence操作
    Parameters
    ---------
    batch : 每个batch数据

    Returns
    label_tensor : 每个batch数据文本标签的处理输出
    text_pad : 每个batch数据文本内容的处理输出
    lengths : 每个batch数据文本内容的长度
    -------
    """
    label_list, text_list, lengths = [], [], []
    for (_label, _text) in batch:
        label_list.append(label_transform(_label))
        processed_text = torch.tensor(text_transform(_text))
        lengths.append(len(processed_text))
        text_list.append(processed_text)
    label_tensor = torch.tensor(label_list)
    text_pad = pad_sequence(text_list, batch_first=True, padding_value=0)
    lengths = torch.tensor(lengths)
    return label_tensor, text_pad, lengths

In [25]:
test_iter = IMDB(split='test')
test_dataloader = Data.DataLoader(test_iter, batch_size=128, shuffle=False, collate_fn=collate_batch)

In [26]:
for i in test_dataloader:
    print(i)
    break

(tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.]), tensor([[   2,   13,  139,  ...,    0,    0,    0],
        [   2, 5632,    4,  ...,    0,    0,    0],
        [   2,  117,    7,  ...,    0,    0,    0],
        ...,
        [   2,   57,  215,  ...,    0,    0,    0],
        [   2,   69,   11,  ...,    0,    0,    0],
        [   2, 1655,  209,  ...,    0,    0,    0]]), tensor([ 305,  246,  147,  423,  153,  208,  333,  198,  170,  190,  152,  221,
         139,  472,  152,  224