In [1]:
import torch
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torch import nn
import time
import pandas as pd
from torch.utils.data.dataset import random_split

In [2]:
train_iter = AG_NEWS(split='train')
df = pd.DataFrame(list(train_iter), columns=['class', 'text'])
df

Unnamed: 0,class,text
0,3,Wall St. Bears Claw Back Into the Black (Reute...
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new..."
...,...,...
119995,1,Pakistan's Musharraf Says Won't Quit as Army C...
119996,2,Renteria signing a top-shelf deal Red Sox gene...
119997,2,Saban not going to Dolphins yet The Miami Dolp...
119998,2,Today's NFL games PITTSBURGH at NY GIANTS Time...


In [3]:
df_class = df.iloc[:, 0]
df_class.unique()  # 4分类

array([3, 4, 2, 1], dtype=int64)

In [4]:
train_iter = AG_NEWS(split='train')

# – the name of tokenizer function. If None, it returns split() function, which splits the string sentence by space. If basic_english, it returns _basic_english_normalize() function, which normalize the string first and split by space.
tokenizer = get_tokenizer(tokenizer='basic_english')


def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)  # 分词


vocab = build_vocab_from_iterator(yield_tokens(train_iter))  # Build a Vocab from an iterator.
vocab.insert_token("<unk>", 0)
vocab.set_default_index(0)

In [5]:
text_pipeline = lambda x: [vocab([i])[0] for i in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1  # 使分类标签从0开始

In [6]:
# he vocabulary block converts a list of tokens into integers.
text_pipeline('here is the an example')

[475, 21, 2, 30, 5297]

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
def collate_batch(batch):
    label_list = []  # 分类标签
    text_list = []
    offsets = [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)


train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

In [10]:
for m, n, z in dataloader:
    print(m)
    print(n)
    print(z)
    break

tensor([2, 2, 2, 2, 2, 2, 2, 2], device='cuda:0')
tensor([  431,   425,     1,  1605, 14838,   113,    66,     2,   848,    13,
           27,    14,    27,    15, 50725,     3,   431,   374,    16,     9,
        67507,     6, 52258,     3,    42,  4009,   783,   325,     1, 15874,
         1072,   854,  1310,  4250,    13,    27,    14,    27,    15,   929,
          797,   320, 15874,    98,     3, 27657,    28,     5,  4459,    11,
          564, 52790,     8, 80617,  2125,     7,     2,   525,   241,     3,
           28,  3890, 82814,  6574,    10,   206,   359,     6,     2,   126,
            1,    58,     8,   347,  4582,   151,    16,   738,    13,    27,
           14,    27,    15,  2384,   452,    92,  2059, 27360,     2,   347,
            8,     2,   738,    11,   271,    42,   240, 51953,    38,     2,
          294,   126,   112,    85,   220,     2,  7856,     6, 40066, 15380,
            1,    70,  7376,    58,  1810,    29,   905,   537,  2846,    13,
           27,

In [11]:
class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [12]:
train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

In [13]:
# Hyperparameters
EPOCHS = 10  # epoch
LR = 5  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

In [14]:
def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predicted_label = model(text, offsets)
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc / total_count))
            total_acc, total_count = 0, 0


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [15]:
def to_map_style_dataset(iter_data):
    r"""Convert iterable-style dataset to map-style dataset.
    """

    class _MapStyleDataset(torch.utils.data.Dataset):

        def __init__(self, iter_data):
            # TODO Avoid list issue #1296
            self._data = list(iter_data)

        def __len__(self):
            return len(self._data)

        def __getitem__(self, idx):
            return self._data[idx]

    return _MapStyleDataset(iter_data)

In [16]:
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)  # 训练数据集
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)  # 验证数据集
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)  # 测试数据集

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

| epoch   1 |   500/ 1782 batches | accuracy    0.687
| epoch   1 |  1000/ 1782 batches | accuracy    0.853
| epoch   1 |  1500/ 1782 batches | accuracy    0.876
-----------------------------------------------------------
| end of epoch   1 | time: 21.80s | valid accuracy    0.884 
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.897
| epoch   2 |  1000/ 1782 batches | accuracy    0.902
| epoch   2 |  1500/ 1782 batches | accuracy    0.901
-----------------------------------------------------------
| end of epoch   2 | time: 18.92s | valid accuracy    0.893 
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.916
| epoch   3 |  1000/ 1782 batches | accuracy    0.913
| epoch   3 |  1500/ 1782 batches | accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time: 18.50s | valid accuracy    0.904 
-------------------------------

In [17]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.909


In [18]:
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}


def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1


ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to("cpu")

print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

This is a Sports news
