In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [2]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

In [3]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [5]:
import io
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

In [17]:
# 从wikitext中获取训练数据，并针对每个训练样本进行tokenize
# tokenize过程中使用Counter()计数，以便决定样本的词汇量(vocab)
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for line in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter)

def data_process(raw_text_iter):
  data = [torch.tensor([vocab[token] for token in tokenizer(item)],
                       dtype=torch.long) for item in raw_text_iter]
  return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [18]:
bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [19]:
ntokens = len(vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [284]:
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    # print(len(data_source))
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
            # print(total_loss)
    return total_loss / (len(data_source) - 1)

In [285]:
best_val_loss = float("inf")
epochs = 3 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 11.51 | loss  5.47 | ppl   237.26
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 11.07 | loss  5.50 | ppl   244.50
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 10.81 | loss  5.29 | ppl   199.31
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 10.81 | loss  5.37 | ppl   215.65
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 10.84 | loss  5.33 | ppl   206.00
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 10.85 | loss  5.38 | ppl   216.27
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 10.85 | loss  5.40 | ppl   220.81
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 10.84 | loss  5.43 | ppl   228.66
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 10.85 | loss  5.38 | ppl   217.50
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 10.86 | loss  5.39 | ppl   219.81
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 10.89 | loss  5.26 | ppl   192.58
| epoch   1 |  2400/ 

In [286]:
torch.save(best_model.state_dict(), "../models/Transformer.model")
print("Saved PyTorch Model State to Transformer.model")

Saved PyTorch Model State to Transformer.model


In [287]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.48 | test ppl   239.22


In [288]:
def indicesToSentence(indices, vocab):
    sentence = ""
    for word in [vocab.itos[index] for index in indices]:
        sentence += word + " "
    return sentence

In [99]:
train_iter = WikiText2(split='train')

In [113]:
text = next(train_iter)
print(text)

In [209]:
s_onehot = torch.tensor([vocab[token] for token in tokenizer(text)])
print(s_onehot)

tensor([20001,    84,  3850,    89,     0,  3870,    22,   781, 28781,     3,
         6183,     4,  3850,     5,     2,  5024,    89,    21,     3,  1838,
         1019,     8,    15,  3850,  3870,   882,   630,   977,     3,    24,
            9,  5791,   300,    13,   576,   233,    68,   453,    20, 13723,
            6,   758,     4,  2501,    18,     2,  1768,  5638,     4,   156,
            7,   247,   355,     7,   977,     3,    25,    24,     2,   238,
           68,     7,     2,  3850,    94,     4,     0,     2,   157,  4420,
            5,  5791,     6,   730,    13,    59,  2097,    15,    44,  7076,
            3,     2,   334,  1086,  3219,     8,     2,    38,    68,     6,
         1695,     2, 11220,     3,     9, 19699,   314,  1064,  2083,     2,
         1703,     5, 19010,    57,     2,    96, 25358,   108,    53,  1939,
         1645,   289,   599,     6,    35, 13621,   121,     2,  2322,  1064,
            0, 14743,     4])


In [210]:
indicesToSentence(s_onehot, vocab)

'senjō no valkyria 3 <unk> chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the battlefield 3 ) , commonly referred to as valkyria chronicles iii outside japan , is a tactical role @-@ playing video game developed by sega and media . vision for the playstation portable . released in january 2011 in japan , it is the third game in the valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the nameless , a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit <unk> raven . '

In [197]:
data, targets = get_batch(s_onehot, 0)

In [198]:
indicesToSentence(data, vocab)

'senjō no valkyria 3 <unk> chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the battlefield 3 ) , commonly referred to as valkyria chronicles iii outside japan , is a tactical role @-@ playing '

In [204]:
indicesToSentence(targets, vocab)

'no valkyria 3 <unk> chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the battlefield 3 ) , commonly referred to as valkyria chronicles iii outside japan , is a tactical role @-@ playing video '

In [205]:
src_mask = model.generate_square_subsequent_mask(bptt).to(device)

In [206]:
output = best_model(data.to(device), model.generate_square_subsequent_mask(bptt).to(device))
output_flat = output.view(-1, ntokens)

In [208]:
print(output.shape, output_flat.shape)
print(data.shape, targets.shape)
print(len(data) * criterion(output_flat, targets).item())

torch.Size([35, 35, 28783]) torch.Size([1225, 28783])
torch.Size([35]) torch.Size([35])


ValueError: Expected input batch_size (1225) to match target batch_size (35).

In [211]:
# evaluate transformer

In [307]:
train_iter, val_iter, test_iter = WikiText2()

In [308]:
test_data = data_process(test_iter)
test_data = batchify(test_data, 2)
print(test_data.shape)

torch.Size([120929, 2])


In [309]:
total_loss = 0
src_mask = best_model.generate_square_subsequent_mask(bptt).to(device)
data, targets = get_batch(test_data, 0)
output = best_model(data.to(device), src_mask)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()

In [310]:
print(data.shape, output.shape, targets.shape, output_flat.shape)

torch.Size([35, 2]) torch.Size([35, 2, 28783]) torch.Size([70]) torch.Size([70, 28783])


In [311]:
total_loss

184.7235369682312

In [312]:
indicesToSentence([torch.argmax(dist).item() for dist in output_flat], vocab)

'= a lester , , , = she lester had , not a be average like writer in , she and <unk> series , <unk> the , , , it the show was effects been children <unk> <unk> star short <unk> of <unk> of in the the years show , network and of for show she , show the to , the the , episode in named show by , '

In [313]:
indicesToSentence(targets, vocab)

'robert plainly <unk> silly = that robert she <unk> did is not an look english injured film . , nick television <unk> and of theatre <unk> actor gave . the he special had four a and guest a @-@ half starring out role of on five the stars television , series writing the that bill the in power 2000 of . emotion this saves was the followed day by again '

In [314]:
indicesToSentence(data, vocab)

TypeError: only integer tensors of a single element can be converted to an index