In [1]:
import time
import random
from collections import Counter
import torch
from torch import nn, optim
import torch.utils.data as Data
from torch.nn.utils.rnn import pad_sequence
import torchtext
from torchtext.datasets import IMDB
import torchtext.vocab as Vocab
from torchtext.data.utils import get_tokenizer

from utils import get_vocab_imdb, get_tokenized_imdb, get_tokenizer
from utils import evaluate, epoch_time
from utils import train as trainer
from utils import preprocess_imdb

In [2]:
SEED = 42
torch.manual_seed(SEED)
# 每次运行网络的时候算法和SEED是固定的，方便复现
torch.backends.cudnn.deterministic = True

## 数据处理

In [3]:
train, test = IMDB(root='./datasets/', split=('train', 'test'))

In [4]:
train, test = list(train), list(test)

In [5]:
len(train), len(test)

(25000, 25000)

In [6]:
batch_size = 64
vocab = get_vocab_imdb(train)

In [7]:
train_set = Data.TensorDataset(*preprocess_imdb(train, vocab))
test_set = Data.TensorDataset(*preprocess_imdb(test, vocab))

In [8]:
train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)
test_iter = Data.DataLoader(test_set, batch_size, shuffle=True)

## 创建模型

### 加载词向量

In [9]:
cache_dir = "./datasets/glove"
glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=cache_dir)

In [10]:
def load_pretrained_embedding(words, pretrained_vocab):
    embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0])
    oov_count = 0
    for i, word in  enumerate(words):
        try:
            idx = pretrained_vocab.stoi[word]
            embed[i, :] = pretrained_vocab.vectors[idx]
        except KeyError:
            oov_count += 1
    if oov_count > 0:
        print('There are %d oov words.' % oov_count)
    return embed

In [11]:
glove_100 = load_pretrained_embedding(vocab.get_itos(), glove_vocab)

There are 14719 oov words.


### Model

In [12]:
class BiRNN(nn.Module):
    def __init__(self, vocab, embed_size, num_hiddens, num_layers):
        super(BiRNN, self).__init__()
        self.embedding = nn.Embedding(len(vocab), embed_size)
        # bidirectional 设为True即可得到双向循环神经网络
        self.encoder = nn.LSTM(input_size = embed_size,
                              hidden_size = num_hiddens,
                              num_layers = num_layers,
                              bidirectional=True)
        # 初识时间步和最终时间步的隐藏状态为全连接层的输入
        self.decoder = nn.Linear(4*num_hiddens, 2)
        
    def forward(self, inputs):
        embeddings = self.embedding(inputs.permute(1, 0))
        outputs, _ = self.encoder(embeddings)
        encoding = torch.cat((outputs[0], outputs[-1]), -1)
        outs = self.decoder(encoding)
        return outs

In [13]:
embed_size, num_hiddens, num_layers = 100, 100, 2
net = BiRNN(vocab, embed_size, num_hiddens, num_layers)

In [14]:
net.embedding.weight.data.copy_(
    load_pretrained_embedding(vocab.get_itos(), glove_vocab)
)
net.embedding.weight.requires_grad = False  # 该嵌入层不需要训练

There are 14719 oov words.


In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(net):,} trainable parameters')

The model has 404,002 trainable parameters


## 训练模型

In [16]:
Vocab_length = len(vocab)
Embedding_dim = 100
Hidden_dim = 256
Output_dim = 2
Learning_rate = 1e-3
Epochs = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
model = net
optimizer = optim.SGD(model.parameters(), lr=Learning_rate)
loss = nn.CrossEntropyLoss()

In [18]:
model = model.to(device)

In [19]:
device

device(type='cuda')

In [20]:
best_valid_loss = float('inf')

for epoch in range(Epochs):
    start_time = time.time()
    train_loss, train_acc = trainer(
        model, 
        train_iter, 
        optimizer, 
        loss,
        device
    )
    valid_loss, valid_acc = evaluate(model, test_iter, loss,device)
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './models/rnn-best-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 1m 3s
	Train Loss: 0.693 | Train Acc: 50.91%
	 Val. Loss: 0.693 |  Val. Acc: 50.17%
