In [1]:
import os
import numpy as np
from tempfile import TemporaryDirectory

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
def data_process(raw_text_iter):
    """Convert raw text into a flat Tensor.
    """
    data = [
        torch.tensor(vocab(tokenizer(item)), dtype=torch.long)
        for item in raw_text_iter
    ]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def batchify(data, bsz, device):
    """Divides data into 'bsz' separate sequences & removes extra elements
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)


def generate_square_subsequent_mask(sz):
    """Generate upper-triangular matrix of ``-inf`` with zeros on the diagonal.
    """
    return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1)


def get_batch(x_src, i, bptt=35):
    """
    x_src is a tensor of shape (full_seq_len, batch_size).
    
    Returns a tuple (data, target) where data has shape (seq_len, batch_size) and
    target has shape (seq_len * batch_size)
    """
    seq_len = min(bptt, len(x_src) - 1 - i)
    data = x_src[i: i + seq_len]
    target = x_src[(i + 1): (i + 1 + seq_len)].reshape(-1)
    return data, target

In [27]:
class TransformerModel(nn.Module):
    
    def __init__(self, n_tokens, d_model, n_heads, d_hid, n_layers, dropout=0.5):
        super().__init__()
        self.model_type = "Transformer"
        self.encoder = nn.Embedding(n_tokens, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = TransformerEncoderLayer(d_model, n_heads, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer, n_layers)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, n_tokens)
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)

    def forward(self, x_src, x_src_mask):
        x_src = self.encoder(x_src) * np.sqrt(self.d_model)
        x_src = self.pos_encoder(x_src)
        x_dest = self.transformer_encoder(x_src, x_src_mask)
        x_dest = self.decoder(x_dest)
        return x_dest


class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        # Odd positions
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        # Even positions
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        """
        x is a tensor of shape (seq_len, batch_size, embedding_dim)
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [15]:
train_iter = WikiText2(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [16]:
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 20
eval_batch_size = 10

In [18]:
train_data = batchify(train_data, batch_size, device)
val_data = batchify(val_data, eval_batch_size, device)
test_data = batchify(test_data, eval_batch_size, device)

In [19]:
train_data.size()

torch.Size([102499, 20])

In [20]:
# Size of vocab
n_tokens = len(vocab)
# Embedding dimension
emsize = 200
# Dimension size of feed forward network in TransformerEncoder
d_hid = 200
# Number of TransformerEncoderLayer in TransformerEncoder
n_layers = 2
# Number of heads in MultiheadAttention
n_head = 2
# Dropout probability (used by all network modules)
dropout = 0.2

In [28]:
model = TransformerModel(n_tokens, emsize, n_head, d_hid, n_layers, dropout).to(device)