In [25]:
import torch
from torch import nn
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

In [26]:
train_dataset, test_dataset = AG_NEWS(root='NLP/dataset/IMDB',
                                                split=('train', 'test'))

In [27]:
tokenizer = get_tokenizer('basic_english')

def yield_tokens(dataset):
    for label, text in train_dataset:
        yield [token for token in tokenizer(text)]
    
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab["<unk>"])

In [28]:
def collate_fn(data):
    labels, texts = [], []
    for label, text in data:
        labels.append(int(label) - 1)
        text = torch.tensor(vocab(tokenizer(text)), dtype=torch.long)
        texts.append(text)
    texts = pad_sequence(texts, batch_first=True, padding_value=vocab['<pad>'])
    labels = torch.tensor(labels, dtype=torch.long)
    return texts, labels

In [29]:
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True, drop_last=True, collate_fn=collate_fn)

In [30]:
class GlobalMaxPool1d(nn.Module):
    def __init__(self):
        super(GlobalMaxPool1d, self).__init__()
    def forward(self, x):
        return F.max_pool1d(x, kernel_size=x.shape[2])  # shape: (batch_size, channel, 1)


In [31]:

class Model(nn.Module):
    def __init__(self, vocab_size, emb_dim, kernel_sizes, num_channels, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(sum(num_channels), num_classes)
        # 时序最大池化层没有权重，所以可以共用一个实例
        self.pool = GlobalMaxPool1d()
        self.convs = nn.ModuleList()  # 创建多个一维卷积层
        for c, k in zip(num_channels, kernel_sizes):
            self.convs.append(nn.Conv1d(in_channels = emb_dim, 
                                        out_channels = c, 
                                        kernel_size = k))

    def forward(self, sentence):
        embeds = self.embedding(sentence)
        embeds = embeds.permute(0, 2, 1)
        # 对于每个一维卷积层，在时序最大池化后会得到一个形状为(批量大小, 通道大小, 1)的
        # Tensor。使用flatten函数去掉最后一维，然后在通道维上连结
        encoding = torch.cat([self.pool(F.relu(conv(embeds))).squeeze(-1) for conv in self.convs], dim=1)
        # 应用丢弃法后使用全连接层得到输出
        outputs = self.fc(self.dropout(encoding))
        return outputs

In [32]:
num_classes  = len(set([label for (label, text) in train_dataset]))
vocab_size = len(vocab)
embedding_dim, kernel_sizes, num_channels = 100, [3, 4, 5], [100, 100, 100]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = Model(vocab_size, embedding_dim, kernel_sizes, num_channels, num_classes).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
epochs = 5

Model(
  (embedding): Embedding(95812, 100)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=300, out_features=4, bias=True)
  (pool): GlobalMaxPool1d()
  (convs): ModuleList(
    (0): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
    (1): Conv1d(100, 100, kernel_size=(4,), stride=(1,))
    (2): Conv1d(100, 100, kernel_size=(5,), stride=(1,))
  )
)


In [33]:
def train():
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for text, label in train_dataloader:
            text = text.to(device)
            label = label.to(device)
            out = model(text)
            loss = criterion(out, label)
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('epoch:{}, loss:{}'.format(epoch + 1, epoch_loss / len(list(train_dataloader))))

In [34]:
train()

epoch:1, loss:0.9190634152828119
epoch:2, loss:0.48686470091342926
epoch:3, loss:0.36750345161327946
epoch:4, loss:0.2996078131035862
epoch:5, loss:0.25466879068786263


In [35]:
def test():
    model.eval()
    epoch_loss = 0
    total, correct = 0, 0
    with torch.no_grad():
        for text, label in test_dataloader:
            text = text.to(device)
            label = label.to(device)
            out = model(text)
            loss = criterion(out, label)
            epoch_loss += loss.item()

            out = out.argmax(dim=-1)
            correct += (out == label).sum()
            total += len(label)
            
        print('loss:{}, acc:{}'.format(epoch_loss / len(list(test_dataloader)), correct / total))

In [36]:
test()

loss:0.32150405367552226, acc:0.9010857939720154
