In [1]:
import torch
from torch import nn
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [2]:
train_dataset, test_dataset = AG_NEWS(root='NLP/dataset/IMDB',
                                                split=('train', 'test'))

In [3]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(dataset):
    for label, text in train_dataset:
        yield [token for token in tokenizer(text)]
    
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab["<unk>"])

In [4]:
def collate_fn(data):
    labels, texts = [], []
    for label, text in data:
        labels.append(int(label) - 1)
        text = torch.tensor(vocab(tokenizer(text)), dtype=torch.long)
        texts.append(text)
    texts = pad_sequence(texts, batch_first=True, padding_value=vocab['<pad>'])
    labels = torch.tensor(labels, dtype=torch.long)
    return texts, labels

In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, drop_last=True, collate_fn=collate_fn)

In [6]:
class Model(nn.Module):
    def __init__(self, vocab_size, emb_dim, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.fc = nn.Linear(emb_dim, num_classes)
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
    
    def forward(self, text):
        # [batch_size, src_len, emb_dim]
        embedded = self.embedding(text)
        embedded = torch.mean(embedded,dim=1,keepdim=False)
        out = self.fc(embedded)
        return out

In [7]:
num_class  = len(set([label for (label, text) in train_dataset]))
print(num_class)
vocab_size = len(vocab)
emb_dim = 64
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
model = Model(vocab_size, emb_dim, num_class).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = nn.CrossEntropyLoss()
epochs = 20

4


In [8]:
def train():
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for text, label in train_dataloader:
            text = text.to(device)
            label = label.to(device)
            out = model(text)
            loss = criterion(out, label)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch:{}, loss:{}'.format(epoch + 1, epoch_loss / len(list(train_dataloader))))

In [9]:
train()

epoch:1, loss:1.1394813102744559
epoch:2, loss:0.6506261299806573
epoch:3, loss:0.43451570359276354
epoch:4, loss:0.3454782765283147
epoch:5, loss:0.3007818692251483
epoch:6, loss:0.2681596132200072
epoch:7, loss:0.246959580030996
epoch:8, loss:0.22800007607219314
epoch:9, loss:0.21403573672093348
epoch:10, loss:0.20116036626626613
epoch:11, loss:0.18931650936508126
epoch:12, loss:0.1797315784935763
epoch:13, loss:0.1698638417518508
epoch:14, loss:0.1612623284631503
epoch:15, loss:0.15381027417597357
epoch:16, loss:0.1467962062290473
epoch:17, loss:0.14031255848530835
epoch:18, loss:0.13381127848537397
epoch:19, loss:0.12776993468340203
epoch:20, loss:0.1220421885580556


In [10]:
def test():
    model.eval()
    epoch_loss = 0
    total, correct = 0, 0
    with torch.no_grad():
        for text, label in test_dataloader:
            text = text.to(device)
            label = label.to(device)
            out = model(text)
            loss = criterion(out, label)
            epoch_loss += loss.item()

            out = out.argmax(dim=-1)
            correct += (out == label).sum()
            total += len(label)
            
        print('loss:{}, acc:{}'.format(epoch_loss / len(list(test_dataloader)), correct / total))

In [11]:
test()

loss:0.25506291172261963, acc:0.9184321761131287
