# 导入库

In [1]:
import time
import random
import torch
from torch import nn, optim
from torchtext.legacy import data
from torchtext.legacy import datasets

In [2]:
SEED = 42
torch.manual_seed(SEED)
# 每次运行网络的时候算法和SEED是固定的，方便复现
torch.backends.cudnn.deterministic = True

#  数据预处理

In [3]:
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)

In [4]:
train, test = datasets.IMDB.splits(TEXT, LABEL)

In [5]:
len(train), len(test)

(25000, 25000)

In [6]:
len(vars(train.examples[0])['text']), vars(train.examples[0])

(165,
 {'text': ['Bromwell',
   'High',
   'is',
   'a',
   'cartoon',
   'comedy',
   '.',
   'It',
   'ran',
   'at',
   'the',
   'same',
   'time',
   'as',
   'some',
   'other',
   'programs',
   'about',
   'school',
   'life',
   ',',
   'such',
   'as',
   '"',
   'Teachers',
   '"',
   '.',
   'My',
   '35',
   'years',
   'in',
   'the',
   'teaching',
   'profession',
   'lead',
   'me',
   'to',
   'believe',
   'that',
   'Bromwell',
   'High',
   "'s",
   'satire',
   'is',
   'much',
   'closer',
   'to',
   'reality',
   'than',
   'is',
   '"',
   'Teachers',
   '"',
   '.',
   'The',
   'scramble',
   'to',
   'survive',
   'financially',
   ',',
   'the',
   'insightful',
   'students',
   'who',
   'can',
   'see',
   'right',
   'through',
   'their',
   'pathetic',
   'teachers',
   "'",
   'pomp',
   ',',
   'the',
   'pettiness',
   'of',
   'the',
   'whole',
   'situation',
   ',',
   'all',
   'remind',
   'me',
   'of',
   'the',
   'schools',
   'I',
   'k

In [7]:
train_data, valid_data = train.split(
    split_ratio=0.8, 
    random_state=random.seed(SEED)
)

In [8]:
len(train_data), len(valid_data), len(test)

(20000, 5000, 25000)

In [9]:
# 构建vocabulary
MAX_VOCAB_SIZE = 25000

TEXT.build_vocab(train_data, max_size=MAX_VOCAB_SIZE)
LABEL.build_vocab(train_data)

In [10]:
len(TEXT.vocab), len(LABEL.vocab)

(25002, 2)

In [11]:
TEXT.vocab.itos[:10]

['<unk>', '<pad>', 'the', ',', '.', 'and', 'a', 'of', 'to', 'is']

In [12]:
# 创建迭代器
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test),
    batch_size = BATCH_SIZE,
    device = device
)

# 创建模型

In [13]:
class RNN(nn.Module):
    def __init__(self, V, E, H, O):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(V, E)
        self.rnn = nn.RNN(E, H)
        self.fc = nn.Linear(H, O)
        
    def forward(self, X):
        X = self.embedding(X)
        output, hidden = self.rnn(X)
        # hidden 是隐藏层的最后一层，也是最后一个时间步的输出
        assert torch.equal(output[-1, :, :], hidden.squeeze(0))
        return self.fc(hidden.squeeze(0))

# 训练模型

In [14]:
Vocab_length = len(TEXT.vocab)
Embedding_dim = 100
Hidden_dim = 256
Output_dim = 1
Learning_rate = 1e-3
Epochs = 5

In [15]:
model = RNN(Vocab_length, Embedding_dim, Hidden_dim, Output_dim)
optimizer = optim.SGD(model.parameters(), lr=Learning_rate)
# 也可以使用crossentropyloss, 二分类任务
criterion = nn.BCEWithLogitsLoss()

In [16]:
model = model.to(device)

In [17]:
criterion = criterion.to(device)

In [18]:
def binary_accuracy(preds, y):
    # sigmoid 转化为0~1之间的实数，相当于概率
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum()/len(correct)
    return acc

In [19]:
def train(model, iterator, optimizer, criterion):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.train()
    for batch in iterator:
        optimizer.zero_grad()
        predictions = model(batch.text).squeeze(1)
        loss = criterion(predictions, batch.label)
        acc = binary_accuracy(predictions, batch.label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc/len(iterator)

In [20]:
def evaluate(model, iterator, criterion):
    epoch_loss = 0.0
    epoch_acc = 0.0
    model.eval()  # 取消dropout，不重新计算batch normalization
    with torch.no_grad():  # 不计算梯度，节省内存和时间
        for batch in iterator:
            predictions = model(batch.text).squeeze(1)
            loss = criterion(predictions, batch.label)
            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc/len(iterator)

In [21]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [22]:
best_valid_loss = float('inf')

for epoch in range(Epochs):
    start_time = time.time()
    train_loss, train_acc = train(
        model, 
        train_iterator, 
        optimizer, 
        criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './models/rnn-best-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

Epoch: 01 | Epoch Time: 0m 19s
	Train Loss: 0.693 | Train Acc: 49.79%
	 Val. Loss: 0.694 |  Val. Acc: 49.11%
Epoch: 02 | Epoch Time: 0m 18s
	Train Loss: 0.693 | Train Acc: 49.95%
	 Val. Loss: 0.694 |  Val. Acc: 48.73%
Epoch: 03 | Epoch Time: 0m 18s
	Train Loss: 0.693 | Train Acc: 49.97%
	 Val. Loss: 0.694 |  Val. Acc: 48.81%
Epoch: 04 | Epoch Time: 0m 18s
	Train Loss: 0.693 | Train Acc: 49.99%
	 Val. Loss: 0.694 |  Val. Acc: 48.99%
Epoch: 05 | Epoch Time: 0m 18s
	Train Loss: 0.693 | Train Acc: 49.48%
	 Val. Loss: 0.694 |  Val. Acc: 48.32%


In [23]:
model.load_state_dict(torch.load('./models/rnn-best-model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

Test Loss: 0.686 | Test Acc: 56.00%
