In [48]:
import os
from tempfile import TemporaryDirectory

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [39]:
def data_process(raw_text_iter):
    """Convert raw text into a flat Tensor.
    """
    data = [
        torch.tensor(vocab(tokenizer(item)), dtype=torch.long)
        for item in raw_text_iter
    ]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def batchify(data, bsz, device):
    """Divides data into 'bsz' separate sequences & removes extra elements
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data.to(device)

In [49]:
class TransformerModel(nn.Module):
    
    def __init__(self, n_tokens, n_heads, d_model, d_hid, n_layers, dropout=0.5):
        super().__init__()
        self.model_type == "Transformer"
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layer = TransformerEncoderLayer(d_model, n_heads, d_hid, dropout)
        self.transformer_encoder = TransformerEncoderLayer(encoder_layer, n_layers)
        self.encoder = nn.Embedding(n_tokens, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, n_tokens)
    
    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
            
class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        pass

In [14]:
train_iter = WikiText2(split="train")
tokenizer = get_tokenizer("basic_english")
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [18]:
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 20
eval_batch_size = 10

In [43]:
train_data = batchify(train_data, batch_size, device)
val_data = batchify(val_data, eval_batch_size, device)
test_data = batchify(test_data, eval_batch_size, device)

In [47]:
test_data.size()

torch.Size([24185, 10])