In [201]:
import torch
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split = 'train')

In [202]:
for index, item in enumerate(train_iter):
    if index < 10:
        print(item)
    else:
        break

(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")
(3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')
(3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.")
(3, 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.')
(3, 'Oil prices soar to all-time record, posing new menace t

In [203]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [204]:
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split = 'train')

In [205]:
def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

In [206]:
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>', '<mask>', '<cls>'])
vocab.set_default_index(vocab['<unk>'])

In [207]:
vocab['<unk>']

0

In [208]:
vocab(['<mask>'])

[1]

In [209]:
vocab(['<cls>', '<mask>'])

[2, 1]

In [210]:
vocab(['here', 'is', 'an', 'example'])

[477, 23, 32, 5299]

In [211]:
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

In [212]:
text_pipeline('here is the an example')

[477, 23, 4, 32, 5299]

In [213]:
label_pipeline('12')

11

In [214]:
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [215]:
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype = torch.int64)
        text_list.append(processed_text)
        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype = torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim = 0)  # 标记序列的开始与结束索引
    
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

In [216]:
test = torch.arange(9)

In [217]:
test

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

In [218]:
test.cumsum(dim=0)

tensor([ 0,  1,  3,  6, 10, 15, 21, 28, 36])

In [219]:
train_iter = AG_NEWS(split = 'train')
dataloader = DataLoader(train_iter, batch_size = 8, shuffle = False, collate_fn = collate_batch)

In [250]:
for index, data in enumerate(dataloader):
    if index > 0:
        break
    labels, texts, offset = data
    print(texts.size(0))
    print(offset.size(0))

224
8


In [221]:
import torch
import torch.nn as nn

In [222]:
class TextClassificationModel(nn.Module):
    
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse = True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [223]:
train_iter = AG_NEWS(split = 'train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

In [224]:
from torchsummary import summary

In [255]:
summary(model, [(224,), (8,)])

ValueError: if input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type <class 'torch.Tensor'>

In [226]:
import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()
    
    for idx, (label, text, offsets) in enumerate(dataloader):
#         print('label:', label)
#         print('text: ', text)
        optimizer.zero_grad()
        predicted_label = model(text, offsets) # predicted_label 每个元素是一个4维向量，分别对应4个类别的概率
#         print('predicted_label: ', predicted_label)
        
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d} / {:5d} batches | accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))
            
            total_acc, total_count = 0, 0
            start_time = time.time()

In [227]:
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

In [228]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

In [229]:
EPOCHS = 10
LR = 2
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
        scheduler.step()
    else:
        total_accu = accu_val
        
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

| epoch   1 |   500 /  1782 batches | accuracy    0.573
| epoch   1 |  1000 /  1782 batches | accuracy    0.789
| epoch   1 |  1500 /  1782 batches | accuracy    0.838
-----------------------------------------------------------
| end of epoch   1 | time:  9.02s | valid accuracy    0.870 
-----------------------------------------------------------
| epoch   2 |   500 /  1782 batches | accuracy    0.872
| epoch   2 |  1000 /  1782 batches | accuracy    0.878
| epoch   2 |  1500 /  1782 batches | accuracy    0.882
-----------------------------------------------------------
| end of epoch   2 | time:  9.00s | valid accuracy    0.888 
-----------------------------------------------------------
| epoch   3 |   500 /  1782 batches | accuracy    0.895
| epoch   3 |  1000 /  1782 batches | accuracy    0.895
| epoch   3 |  1500 /  1782 batches | accuracy    0.902
-----------------------------------------------------------
| end of epoch   3 | time:  9.03s | valid accuracy    0.897 
-------------

In [230]:
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Checking the results of test dataset.
test accuracy    0.901


In [231]:
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

This is a Sports news
