# 导入库

In [1]:
import os
import time
import random
from collections import Counter

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchtext.datasets import SogouNews, AG_NEWS
import torchtext.vocab as Vocab
from torchtext.data.utils import get_tokenizer
from sklearn.model_selection import train_test_split

from utils import get_vocab, get_tokenized
from utils import evaluate, epoch_time
from utils import train as trainer

# 加载AG_NEWS数据集

In [2]:
train, test = AG_NEWS(
    root='./datasets', 
    split=('train','test')
)

In [3]:
len(train), len(test)

(120000, 7600)

In [4]:
train, test = list(train), list(test)

In [5]:
train, valid = train_test_split(train, test_size=0.2)

In [6]:
len(train), len(valid), len(test)

(96000, 24000, 7600)

In [7]:
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [8]:
vocab = get_vocab(train)

In [9]:
train[0]

(1,
 "Guantanamo Prisoner Goes Before Tribunal (AP) AP - A U.S. military panel heard the case Wednesday of a Guantanamo Bay prisoner accused of fighting for Afghanistan's ousted Taliban regime, as a U.S. judge ordered the government to release records of alleged prisoner abuse at the American base.")

In [10]:
def preprocess(data, vocab, max_l=500):
    
    def pad(x):
        return x[:max_l] if len(x) > max_l else x + [0] * (max_l-len(x))
    tokenized_data = get_tokenized(data)
    features = torch.LongTensor(
        [pad(vocab.lookup_indices(words)) for words in tokenized_data]
    )
    labels = torch.tensor(
        [score-1 for (score, _) in data],
        dtype = torch.int64
    )
    return features, labels

In [11]:
class NewsDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]
    
    def __len__(self):
        return self.features.size(0)

In [12]:
max_length = 50
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
train_set = NewsDataset(*preprocess(train, vocab, max_length))
valid_set = NewsDataset(*preprocess(valid, vocab, max_length))
test_set = NewsDataset(*preprocess(test, vocab, max_length))

In [14]:
counter = dict(Counter([label.item() for label in train_set.labels]))
weights = 1./torch.tensor(
    [counter[i] for i in range(4)], 
    dtype=torch.float
)

In [15]:
samples_weights = weights[train_set.labels]
sampler = WeightedRandomSampler(
    weights=samples_weights, 
    num_samples = len(samples_weights),
    replacement=True
)

In [35]:
train_iter = DataLoader(
    train_set, 
    batch_size=batch_size, 
    sampler=sampler
)
valid_iter = DataLoader(
    valid_set, 
    batch_size=batch_size
)
test_iter = DataLoader(
    test_set,
    batch_size=batch_size
)

# 加载词向量

In [17]:
cache_dir = "./datasets/glove"
glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=cache_dir)

In [18]:
def load_pretrained_embedding(words, pretrained_vocab):
    embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0])
    oov_count = 0
    for i, word in  enumerate(words):
        try:
            idx = pretrained_vocab.stoi[word]
            embed[i, :] = pretrained_vocab.vectors[idx]
        except KeyError:
            oov_count += 1
    if oov_count > 0:
        print('There are %d oov words.' % oov_count)
    return embed

In [19]:
glove_100 = load_pretrained_embedding(vocab.get_itos(), glove_vocab)

There are 16180 oov words.


# 设计模型

In [30]:
class TextCNN(nn.Module):
    def __init__(self, V, E, kernels, channels, O, weights, dropout=0.5):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(V, E)
        self.constant_embedding = nn.Embedding.from_pretrained(
            embeddings=weights,
            freeze=True,
            padding_idx=0
        )
        self.convs = nn.ModuleList()
        for kernel_size, channel_size in zip(kernels, channels):
            self.convs.append(
                nn.Conv1d(
                    in_channels=2*E,
                    out_channels=channel_size,
                    kernel_size=kernel_size
                )
            )
        self.pool = nn.AdaptiveMaxPool1d(1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(sum(channels), O)
    
    def forward(self, X):
        X = torch.cat(
            (self.embedding(X), self.constant_embedding(X)),
            dim=2
        )
        X = X.permute(0, 2, 1)
        X = torch.cat(
            [self.pool(torch.relu(conv(X))).squeeze(-1) for conv in self.convs],
            dim=1
        )
        X = self.fc(self.dropout(X))
        return X

# 初始化参数

In [31]:
Vocab_length = len(vocab)
Embedding_dim = 100
Output_dim = 4
lr = 1e-3
Epochs = 5
dropout = 0.5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [32]:
embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]
model = TextCNN(Vocab_length, embed_size, kernel_sizes, nums_channels,
                Output_dim, glove_100, dropout
               )
model = model.to(device)

In [33]:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
loss = nn.CrossEntropyLoss()

# 训练模型

In [36]:
best_valid_loss = float('inf')

for epoch in range(Epochs):
    start_time = time.time()
    train_loss, train_acc = trainer(
        model, 
        train_iter, 
        optimizer, 
        loss,
        device
    )
    valid_loss, valid_acc = evaluate(model, test_iter, loss,device)
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './models/rnn-best-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 11s
	Train Loss: 0.278 | Train Acc: 90.55%
	 Val. Loss: 0.305 |  Val. Acc: 90.48%
Epoch: 02 | Epoch Time: 0m 11s
	Train Loss: 0.218 | Train Acc: 92.76%
	 Val. Loss: 0.312 |  Val. Acc: 90.61%
Epoch: 03 | Epoch Time: 0m 11s
	Train Loss: 0.185 | Train Acc: 93.98%
	 Val. Loss: 0.318 |  Val. Acc: 90.76%
Epoch: 04 | Epoch Time: 0m 11s
	Train Loss: 0.151 | Train Acc: 95.06%
	 Val. Loss: 0.339 |  Val. Acc: 90.85%
Epoch: 05 | Epoch Time: 0m 11s
	Train Loss: 0.131 | Train Acc: 95.75%
	 Val. Loss: 0.359 |  Val. Acc: 91.24%
