In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
from torchtext.datasets import IMDB
# pip install torchtext 安装指令
from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import sys
import os
import logging
logging.basicConfig(
    level=logging.WARN,
    stream=sys.stdout,
    format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)

In [5]:

# step2 构建IMDB DataLoader

BATCH_SIZE = 64

def yield_tokens(train_data_iter, tokenizer):
    for i, sample in enumerate(train_data_iter):
        label, comment = sample
        yield tokenizer(comment)

train_data_iter = IMDB(root='.data', split='train') # Dataset类型的对象
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(yield_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)
print(f"单词表大小: {len(vocab)}")

单词表大小: 13351


In [15]:

def collate_fn(batch):
    """ 对DataLoader所生成的mini-batch进行后处理 """
    target = []
    token_index = []
    max_length = 0
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)

        token_index.append(vocab(tokens))
        if len(tokens) > max_length:
            max_length = len(tokens)

        target.append(label)

    token_index = [index + [0]*(max_length-len(index)) for index in token_index]
    return (torch.tensor(target).to(torch.int64), torch.tensor(token_index).to(torch.int32))


In [18]:
labels = {}
for label, text in train_data_iter:
    if label not in labels:
        labels[label] = 1
    else:
        labels[label] += 1
print(f"labels: {labels}")

labels: {1: 12500, 2: 12500}


In [16]:
train_data_iter = IMDB(root='.data', split='train') # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)



In [13]:
print(len(train_data_loader))

391


In [17]:
i = 0
for batch in train_data_loader:
    if i > 10:
        break
    i += 1
    print(batch)

(tensor([2, 1, 2, 1, 1, 2, 2, 1, 1, 2, 2, 2, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 1,
        2, 1, 1, 2, 1, 1, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2,
        1, 1, 2, 1, 2, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2]), tensor([[  12,  437,    7,  ...,    0,    0,    0],
        [ 379,   28,   63,  ...,    0,    0,    0],
        [ 412,   12,  192,  ...,    0,    0,    0],
        ...,
        [  59,   12,   16,  ...,    0,    0,    0],
        [1181,   42,    5,  ...,    0,    0,    0],
        [  59,   12,   16,  ...,    0,    0,    0]], dtype=torch.int32))
(tensor([1, 2, 1, 1, 1, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
        1, 1, 1, 1, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 1, 2, 2, 1, 1, 1, 1, 1, 2,
        1, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([[  12,   48,  380,  ...,    0,    0,    0],
        [  12,  124,   36,  ...,    0,    0,    0],
        [ 522,  275,   19,  ...,    0,    0,    0],
        ...,
        [  12,   71,   44,  ...,    0,    0,   