
https://pytorch.org/tutorials/beginner/transformer_tutorial.html



In [1]:
from dlcliche.notebook import *
from dlcliche.image import *
from dlcliche.torch_utils import *

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        encoder_norm = nn.LayerNorm(nhid)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers, norm=encoder_norm)
        self.embedding = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.embedding(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return output

In [2]:
import torchtext
from torchtext.data.utils import get_tokenizer
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    data = TEXT.numericalize([data.examples[0].text])
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)

In [3]:
import time


def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target


def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    ntokens = len(TEXT.vocab.stoi)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        # prevent explosion.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    ntokens = len(TEXT.vocab.stoi)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output = eval_model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

bptt = 35
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float("inf")
epochs = 50 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model
        print('Saved best model at', epoch)

    scheduler.step()

| epoch   1 |   200/ 2981 batches | lr 5.00 | ms/batch 10.75 | loss  8.13 | ppl  3392.76
| epoch   1 |   400/ 2981 batches | lr 5.00 | ms/batch 10.49 | loss  6.85 | ppl   941.97
| epoch   1 |   600/ 2981 batches | lr 5.00 | ms/batch 10.50 | loss  6.41 | ppl   605.46
| epoch   1 |   800/ 2981 batches | lr 5.00 | ms/batch 10.52 | loss  6.25 | ppl   519.53
| epoch   1 |  1000/ 2981 batches | lr 5.00 | ms/batch 10.50 | loss  6.13 | ppl   460.33
| epoch   1 |  1200/ 2981 batches | lr 5.00 | ms/batch 10.52 | loss  6.10 | ppl   447.37
| epoch   1 |  1400/ 2981 batches | lr 5.00 | ms/batch 10.51 | loss  6.05 | ppl   423.94
| epoch   1 |  1600/ 2981 batches | lr 5.00 | ms/batch 10.52 | loss  6.05 | ppl   426.11
| epoch   1 |  1800/ 2981 batches | lr 5.00 | ms/batch 10.52 | loss  5.97 | ppl   389.66
| epoch   1 |  2000/ 2981 batches | lr 5.00 | ms/batch 10.52 | loss  5.97 | ppl   390.18
| epoch   1 |  2200/ 2981 batches | lr 5.00 | ms/batch 10.55 | loss  5.86 | ppl   349.64
| epoch   1 |  2400/ 

| epoch   6 |  1600/ 2981 batches | lr 3.68 | ms/batch 10.74 | loss  5.13 | ppl   168.27
| epoch   6 |  1800/ 2981 batches | lr 3.68 | ms/batch 10.75 | loss  5.08 | ppl   161.28
| epoch   6 |  2000/ 2981 batches | lr 3.68 | ms/batch 10.74 | loss  5.10 | ppl   163.63
| epoch   6 |  2200/ 2981 batches | lr 3.68 | ms/batch 10.76 | loss  4.97 | ppl   144.10
| epoch   6 |  2400/ 2981 batches | lr 3.68 | ms/batch 10.75 | loss  5.05 | ppl   155.48
| epoch   6 |  2600/ 2981 batches | lr 3.68 | ms/batch 10.74 | loss  5.07 | ppl   159.70
| epoch   6 |  2800/ 2981 batches | lr 3.68 | ms/batch 10.76 | loss  5.02 | ppl   151.37
-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 33.39s | valid loss  5.41 | valid ppl   223.48
-----------------------------------------------------------------------------------------
Saved best model at 6
| epoch   7 |   200/ 2981 batches | lr 3.49 | ms/batch 10.81 | loss  5.05 | ppl   156.44
| epoch   7 

| epoch  12 |   200/ 2981 batches | lr 2.70 | ms/batch 10.74 | loss  4.70 | ppl   110.04
| epoch  12 |   400/ 2981 batches | lr 2.70 | ms/batch 10.67 | loss  4.73 | ppl   112.86
| epoch  12 |   600/ 2981 batches | lr 2.70 | ms/batch 10.68 | loss  4.54 | ppl    94.01
| epoch  12 |   800/ 2981 batches | lr 2.70 | ms/batch 10.67 | loss  4.61 | ppl   100.45
| epoch  12 |  1000/ 2981 batches | lr 2.70 | ms/batch 10.68 | loss  4.61 | ppl   100.18
| epoch  12 |  1200/ 2981 batches | lr 2.70 | ms/batch 10.67 | loss  4.64 | ppl   103.14
| epoch  12 |  1400/ 2981 batches | lr 2.70 | ms/batch 10.66 | loss  4.66 | ppl   105.19
| epoch  12 |  1600/ 2981 batches | lr 2.70 | ms/batch 10.68 | loss  4.70 | ppl   110.18
| epoch  12 |  1800/ 2981 batches | lr 2.70 | ms/batch 10.66 | loss  4.66 | ppl   105.73
| epoch  12 |  2000/ 2981 batches | lr 2.70 | ms/batch 10.68 | loss  4.68 | ppl   107.76
| epoch  12 |  2200/ 2981 batches | lr 2.70 | ms/batch 10.67 | loss  4.54 | ppl    94.11
| epoch  12 |  2400/ 

| epoch  17 |  1800/ 2981 batches | lr 2.09 | ms/batch 10.69 | loss  4.45 | ppl    85.67
| epoch  17 |  2000/ 2981 batches | lr 2.09 | ms/batch 10.70 | loss  4.47 | ppl    87.12
| epoch  17 |  2200/ 2981 batches | lr 2.09 | ms/batch 10.69 | loss  4.32 | ppl    75.50
| epoch  17 |  2400/ 2981 batches | lr 2.09 | ms/batch 10.68 | loss  4.40 | ppl    81.27
| epoch  17 |  2600/ 2981 batches | lr 2.09 | ms/batch 10.69 | loss  4.42 | ppl    83.34
| epoch  17 |  2800/ 2981 batches | lr 2.09 | ms/batch 10.69 | loss  4.38 | ppl    79.83
-----------------------------------------------------------------------------------------
| end of epoch  17 | time: 33.21s | valid loss  5.47 | valid ppl   236.55
-----------------------------------------------------------------------------------------
| epoch  18 |   200/ 2981 batches | lr 1.99 | ms/batch 10.74 | loss  4.44 | ppl    84.36
| epoch  18 |   400/ 2981 batches | lr 1.99 | ms/batch 10.70 | loss  4.46 | ppl    86.69
| epoch  18 |   600/ 2981 batches 

| epoch  23 |   200/ 2981 batches | lr 1.54 | ms/batch 10.78 | loss  4.29 | ppl    72.72
| epoch  23 |   400/ 2981 batches | lr 1.54 | ms/batch 10.73 | loss  4.31 | ppl    74.35
| epoch  23 |   600/ 2981 batches | lr 1.54 | ms/batch 10.68 | loss  4.15 | ppl    63.14
| epoch  23 |   800/ 2981 batches | lr 1.54 | ms/batch 10.69 | loss  4.21 | ppl    67.39
| epoch  23 |  1000/ 2981 batches | lr 1.54 | ms/batch 10.68 | loss  4.22 | ppl    67.96
| epoch  23 |  1200/ 2981 batches | lr 1.54 | ms/batch 10.68 | loss  4.24 | ppl    69.51
| epoch  23 |  1400/ 2981 batches | lr 1.54 | ms/batch 10.69 | loss  4.25 | ppl    69.87
| epoch  23 |  1600/ 2981 batches | lr 1.54 | ms/batch 10.67 | loss  4.30 | ppl    73.53
| epoch  23 |  1800/ 2981 batches | lr 1.54 | ms/batch 10.69 | loss  4.27 | ppl    71.36
| epoch  23 |  2000/ 2981 batches | lr 1.54 | ms/batch 10.68 | loss  4.29 | ppl    72.79
| epoch  23 |  2200/ 2981 batches | lr 1.54 | ms/batch 10.69 | loss  4.14 | ppl    63.08
| epoch  23 |  2400/ 

| epoch  28 |  1800/ 2981 batches | lr 1.19 | ms/batch 10.69 | loss  4.16 | ppl    64.34
| epoch  28 |  2000/ 2981 batches | lr 1.19 | ms/batch 10.67 | loss  4.18 | ppl    65.27
| epoch  28 |  2200/ 2981 batches | lr 1.19 | ms/batch 10.69 | loss  4.04 | ppl    56.62
| epoch  28 |  2400/ 2981 batches | lr 1.19 | ms/batch 10.67 | loss  4.10 | ppl    60.16
| epoch  28 |  2600/ 2981 batches | lr 1.19 | ms/batch 10.68 | loss  4.13 | ppl    62.07
| epoch  28 |  2800/ 2981 batches | lr 1.19 | ms/batch 10.69 | loss  4.10 | ppl    60.08
-----------------------------------------------------------------------------------------
| end of epoch  28 | time: 33.22s | valid loss  5.54 | valid ppl   253.90
-----------------------------------------------------------------------------------------
| epoch  29 |   200/ 2981 batches | lr 1.13 | ms/batch 10.78 | loss  4.16 | ppl    63.83
| epoch  29 |   400/ 2981 batches | lr 1.13 | ms/batch 10.73 | loss  4.18 | ppl    65.46
| epoch  29 |   600/ 2981 batches 

| epoch  34 |   200/ 2981 batches | lr 0.87 | ms/batch 10.80 | loss  4.08 | ppl    59.36
| epoch  34 |   400/ 2981 batches | lr 0.87 | ms/batch 10.73 | loss  4.10 | ppl    60.56
| epoch  34 |   600/ 2981 batches | lr 0.87 | ms/batch 10.74 | loss  3.95 | ppl    51.96
| epoch  34 |   800/ 2981 batches | lr 0.87 | ms/batch 10.72 | loss  4.01 | ppl    55.14
| epoch  34 |  1000/ 2981 batches | lr 0.87 | ms/batch 10.74 | loss  4.03 | ppl    56.24
| epoch  34 |  1200/ 2981 batches | lr 0.87 | ms/batch 10.72 | loss  4.05 | ppl    57.15
| epoch  34 |  1400/ 2981 batches | lr 0.87 | ms/batch 10.72 | loss  4.05 | ppl    57.13
| epoch  34 |  1600/ 2981 batches | lr 0.87 | ms/batch 10.74 | loss  4.09 | ppl    60.01
| epoch  34 |  1800/ 2981 batches | lr 0.87 | ms/batch 10.73 | loss  4.07 | ppl    58.72
| epoch  34 |  2000/ 2981 batches | lr 0.87 | ms/batch 10.74 | loss  4.09 | ppl    59.96
| epoch  34 |  2200/ 2981 batches | lr 0.87 | ms/batch 10.72 | loss  3.94 | ppl    51.50
| epoch  34 |  2400/ 

| epoch  39 |  1800/ 2981 batches | lr 0.68 | ms/batch 10.69 | loss  4.02 | ppl    55.83
| epoch  39 |  2000/ 2981 batches | lr 0.68 | ms/batch 10.70 | loss  4.04 | ppl    56.89
| epoch  39 |  2200/ 2981 batches | lr 0.68 | ms/batch 10.69 | loss  3.89 | ppl    48.97
| epoch  39 |  2400/ 2981 batches | lr 0.68 | ms/batch 10.70 | loss  3.95 | ppl    51.80
| epoch  39 |  2600/ 2981 batches | lr 0.68 | ms/batch 10.69 | loss  3.98 | ppl    53.63
| epoch  39 |  2800/ 2981 batches | lr 0.68 | ms/batch 10.70 | loss  3.95 | ppl    51.95
-----------------------------------------------------------------------------------------
| end of epoch  39 | time: 33.23s | valid loss  5.59 | valid ppl   268.27
-----------------------------------------------------------------------------------------
| epoch  40 |   200/ 2981 batches | lr 0.64 | ms/batch 10.75 | loss  4.01 | ppl    55.27
| epoch  40 |   400/ 2981 batches | lr 0.64 | ms/batch 10.71 | loss  4.04 | ppl    57.01
| epoch  40 |   600/ 2981 batches 

| epoch  45 |   200/ 2981 batches | lr 0.50 | ms/batch 10.80 | loss  3.97 | ppl    53.20
| epoch  45 |   400/ 2981 batches | lr 0.50 | ms/batch 10.75 | loss  4.00 | ppl    54.45
| epoch  45 |   600/ 2981 batches | lr 0.50 | ms/batch 10.73 | loss  3.85 | ppl    46.86
| epoch  45 |   800/ 2981 batches | lr 0.50 | ms/batch 10.75 | loss  3.91 | ppl    50.06
| epoch  45 |  1000/ 2981 batches | lr 0.50 | ms/batch 10.74 | loss  3.94 | ppl    51.20
| epoch  45 |  1200/ 2981 batches | lr 0.50 | ms/batch 10.75 | loss  3.95 | ppl    51.85
| epoch  45 |  1400/ 2981 batches | lr 0.50 | ms/batch 10.74 | loss  3.95 | ppl    51.80
| epoch  45 |  1600/ 2981 batches | lr 0.50 | ms/batch 10.74 | loss  3.99 | ppl    54.11
| epoch  45 |  1800/ 2981 batches | lr 0.50 | ms/batch 10.76 | loss  3.98 | ppl    53.67
| epoch  45 |  2000/ 2981 batches | lr 0.50 | ms/batch 10.73 | loss  3.99 | ppl    54.25
| epoch  45 |  2200/ 2981 batches | lr 0.50 | ms/batch 10.75 | loss  3.85 | ppl    46.77
| epoch  45 |  2400/ 

| epoch  50 |  1800/ 2981 batches | lr 0.38 | ms/batch 10.75 | loss  3.96 | ppl    52.29
| epoch  50 |  2000/ 2981 batches | lr 0.38 | ms/batch 10.73 | loss  3.97 | ppl    52.79
| epoch  50 |  2200/ 2981 batches | lr 0.38 | ms/batch 10.75 | loss  3.82 | ppl    45.77
| epoch  50 |  2400/ 2981 batches | lr 0.38 | ms/batch 10.73 | loss  3.87 | ppl    48.10
| epoch  50 |  2600/ 2981 batches | lr 0.38 | ms/batch 10.75 | loss  3.91 | ppl    49.93
| epoch  50 |  2800/ 2981 batches | lr 0.38 | ms/batch 10.73 | loss  3.88 | ppl    48.39
-----------------------------------------------------------------------------------------
| end of epoch  50 | time: 33.35s | valid loss  5.63 | valid ppl   279.85
-----------------------------------------------------------------------------------------


In [4]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.56 | test ppl   259.32


## Generating text

Tried to generate text which follows one of text picked from test data, for N words.

In [5]:
def decompose(text):
    return np.array([' '.join([TEXT.vocab.itos[i] for i in text[:, k]]) for k in range(text.size(1))])

In [15]:
import copy


def generate_text(start_text, N):
    query_text = copy.deepcopy(start_text)
    generated = torch.zeros((N, 1), dtype=torch.long)
    for i in range(N):
        o = best_model(query_text)
        generated[i, 0] = o.argmax(-1)[-1, 0]
        query_text = query_text.roll(-1)
        query_text[-1, 0] = generated[i, 0]
        #print(decompose(ref))
    return generated


for idx in range(3):
    N = 50
    ref = get_batch(test_data, idx)[0][:, :1]
    generated = generate_text(ref, N)
    print(f'[{idx}]')
    print('Starting from:', decompose(ref)[0])
    print('Generated:', decompose(generated)[0])

[0]
Starting from: <eos> = robert <unk> = <eos> <eos> robert <unk> is an english film , television and theatre actor . he had a guest @-@ starring role on the television series the bill in 2000 .
Generated: <eos> = = = = = = = = <eos> <eos> <eos> <eos> = = = = = = = = = = = = = = = = = = = = <eos> <eos> <eos> the first was born in the first son of the first son of aralt
[1]
Starting from: = robert <unk> = <eos> <eos> robert <unk> is an english film , television and theatre actor . he had a guest @-@ starring role on the television series the bill in 2000 . this
Generated: episode was written by david carpenter , who was a guest star game , who was a comedy series of the series of the series of the series of the series of the series of the series . <eos> = = = = = = = = <eos> <eos> <eos>
[2]
Starting from: robert <unk> = <eos> <eos> robert <unk> is an english film , television and theatre actor . he had a guest @-@ starring role on the television series the bill in 2000 . this was
Gener

### Free form

In [27]:
inputs = 'He was born'

tokens = [TEXT.vocab.stoi[w.lower()] for w in inputs.split()]
tensor_tokens = torch.Tensor(tokens).to(torch.long).to(device).unsqueeze(-1)
print(tensor_tokens.shape)
generated = generate_text(tensor_tokens, N)
print('Starting from:', inputs)
print('Generated:', decompose(generated)[0])

torch.Size([3, 1])
Starting from: He was born
Generated: in the <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and <unk> , and
