### Setup GPU
We only need one GPU for training and inference. 

In [1]:
%env CUDA_VISIBLE_DEVICES=0
%load_ext autoreload
%autoreload 2

env: CUDA_VISIBLE_DEVICES=0


### Imports

In [2]:
import os, regex, math, copy, time, sys
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pickle
import spacy
from collections import defaultdict
import subprocess
import random

In [3]:
if not os.path.isfile('./utils.py'):
    print("Downloading utils.py...")
    url = "https://raw.githubusercontent.com/bsantraigi/Transformer-XS/master/utils.py"
    import subprocess
    subprocess.run(["wget", url])
else:
    print("Found utils.py...")

Found utils.py...


In [4]:
from utils import Lang

In [5]:
print(f"# Using pytorch v{torch.__version__}")

# Using pytorch v1.1.0


In [6]:
MAX_LEN = 100

### Check CUDA
Checking is cuda is available. I haven't dared to trained this on CPU. Even just using GPU also takes quite a lot of time to train well. In case cuda isn't detected in your system, you might not have a GPU or have the CPU variant of pytorch installed. You can always run this on Google Colab. 

In [7]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"# Using device: {device}")

# Using device: cuda


### Download Wikitext-103 dataset

In [8]:
!mkdir data

mkdir: cannot create directory ‘data’: File exists


In [9]:
if not os.path.isdir("data/wikitext-103/"):
    print("Downloading data...")
    subprocess.run("wget -c https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip -P data".split())
    print("Unzipping data...")
    subprocess.run(["unzip", "data/wikitext-103-v1.zip", "-d", "data/"])
    print("Done...")
else:
    print("Found data...")

Found data...


### Data Preprocessing
Main target here is to create the VOCAB

In [10]:
en = spacy.load('en_core_web_sm')

In [11]:
data_path = 'data/wikitext-103/'

In [12]:
train_lines = 1801350
test_lines = 4358
valid_lines = 3760

In [13]:
vocab = defaultdict(int)
split = 'train'
L = eval(f'{split}_lines')

### Vocab Creation (spaCy)

In [14]:
# with open(data_path + f'wiki.{split}.tokens') as f:
#     _progress = 0
#     buffer = []
#     for line in tqdm_notebook(f, total=L):
#         # _progress += 1
#         line = buffer.append(line.strip())
#         # print(f'{_progress/L*100:2.2F}', end='\r')
#         if len(buffer) > 40000:
#             buffer = ' '.join(buffer)
#             tokens = list(en.tokenizer(buffer.lower()))
#             buffer = []
#             for w in tokens:
#                 vocab[w.text] += 1
    
#     # One last time to clean the buffer
#     buffer = ' '.join(buffer)
#     tokens = list(en.tokenizer(buffer.lower()))
#     buffer = []
#     for w in tokens:
#         vocab[w.text] += 1

### Create or Load Vocab
The following step will take some time, upto 10 mins. The spacy tokenizer is not as fast. But this is a one time process. Once the vocab file is created, you can just load from there.

OR

Loads a saved vocab class object with word2index and index2word functions. 

#### Creation of vocab might take upto 5 min

In [15]:
%%time
lang_file = "./models/wiki103.large.lang"
if not os.path.isfile(lang_file):
    print("Creating vocab file...")
    en_lang = Lang('wiki')
    en_lang.buildLang(open(data_path + f'wiki.{split}.tokens'), num_lines=train_lines, care_for_newline=True)
    with open(lang_file, 'wb') as f:
        pickle.dump(en_lang, f)
else:
    print("Loading vocab file...")
    en_lang = pickle.load(open('./models/wiki103.large.lang', 'rb'))

Loading vocab file...
CPU times: user 187 ms, sys: 75 ms, total: 262 ms
Wall time: 307 ms


#### Limit vocab size
We only consider a vocab size of 40000 for now. This version of model is based on English words seen in training dataset. To decrease number of <unk> in dataset, I kept the vocab size a bit large. If we use bpe, the vocab size can be decreased while keeping better coverage.

In [16]:
en_lang.limitVocab(40000)

### Embedder
Holds the word embedding matrix. 

In [17]:
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embed(x)

### Positional Encoder
Transformer doesn't have any sequential notion in it's architecture by default. So, it can only realize it's input as a bag of tokens. So, we need to explicitly provide positional information through the token embedding itself.

In [18]:
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len = MAX_LEN):
        super().__init__()
        self.d_model = d_model
        
        # create constant 'pe' matrix with values dependant on 
        # pos and i
        pe = torch.zeros(max_seq_len, d_model)
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe[pos, i] = \
                math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = \
                math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
                
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
 
    
    def forward(self, x):
        # make embeddings relatively larger
        x = x * math.sqrt(self.d_model)
        #add constant to embedding
        seq_len = x.size(1)
        x = x + Variable(self.pe[:,:seq_len], \
        requires_grad=False).cuda()
        return x

### Multi-Head Attention
This is core part of the transformer architecture. A single layer of Multi-Head Attention applies self-attention to all of it's inputs. The input of this operation is a bag of k tokens (each with it's representation of query, key and value) and output is updated representation of the k tokens again. Based on the query representation of every token, one first decide weights (or attention) for key representation of all other tokens. The updated output representation of the query token is constructed by taking linear combination of value representation of tokens using the weights calculated.

In this implementation, we only look behind the current location by masking the indices ahead. This is because we want to predict the next word conditioned on the context behind. 

I plan to add the permutation language model functionality based on XLNet to allow learning bidirectional features. 

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        
        bs = q.size(0)
        
        # perform linear operation and split into h heads
        
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        # transpose to get dimensions bs * h * sl * d_model
       
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        
        # calculate attention using function we will define next
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
                    .view(bs, -1, self.d_model)
        
        output = self.out(concat)
    
        return output

In [20]:
def attention(q, k, v, d_k, mask=None, dropout=None):
    # q, k, v : shape(bs, heads, L_max, d_k)
    # scores: matmul [shape(bs,heads,L_max,d_k), shape(bs,heads,d_k,L_max)] -> shape(bs,heads,L_max,L_max)
    # scores x v : shape(bs,heads,L_max,L_max) X shape(bs,heads,L_max,d_k) -> shape(bs,heads,L_max,d_k)
    
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    if mask is not None:
        # print(f"### Shape of pre-softmax logits: {scores.shape}")
        # mask = mask.unsqueeze(1)
        # print(f"### Shape of mask: {mask.shape}")
        scores = scores.masked_fill(mask == 0, -1e9)
    scores = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, v)
    return output

### FeedForward
A simple feed forward network with one hidden layer. Input and output dimensions are d_model and hidden layer size is d_ff (=2048 by default).

In [21]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout = 0.1):
        super().__init__() 
        # We set d_ff as a default to 2048
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

### Layer Norm
Following the trend in various papers, we also apply Layer Norm after every Multi Head attention and Feed Forward layers.

In [22]:
class Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-6):
        super().__init__()
    
        self.size = d_model
        # create two learnable parameters to calibrate normalisation
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \
        / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

### Encoder Layer
Puts together a single layer of the Encoder. This applies [LayerNorm -> Multi-Head Attn -> LayerNorm -> Feed Forward] to the input.

In [23]:
# build a decoder layer with two multi-head attention layers and
# one feed-forward layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        super().__init__()
        self.norm_1 = Norm(d_model)
        self.norm_2 = Norm(d_model)
        
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
        self.attn_1 = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)
        self.ext_mask = None
    def forward(self, x, trg_mask, m=None):
        x2 = self.norm_1(x)
        state_2_save = x2 # We will return what was input to MHA as memory
        if m is None:
            keys = x2
            values = x2
        else:
            keys = torch.cat((x2, m.detach()), dim=1)
            values = keys
            if self.ext_mask is None:
                # Define once
                self.ext_mask = torch.tensor(np.ones(
                        trg_mask.shape), device=device)
            trg_mask = torch.cat((self.ext_mask, trg_mask), dim=-1)

        x = x + self.dropout_1(self.attn_1(x2, keys, values, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))
        return x, state_2_save.detach()

# We can then build a convenient cloning function that can generate multiple layers:
def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

### Encoder
Puts together the whole network by stacking 
- Word Embedding Matrix
- Positional Encoder
- N multihead attention layers

In [24]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads):
        super().__init__()
        self.N = N
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = get_clones(EncoderLayer(d_model, heads), N)
        self.norm = Norm(d_model)
        
        self.memory = []
    def forward(self, trg, trg_mask):
        new_memory = []
        x = self.embed(trg)
        x = self.pe(x)
        
        for i in range(self.N):
            if len(self.memory) > 0:
                x, icy_state = self.layers[i](x, trg_mask, self.memory[i])
            else:
                x, icy_state = self.layers[i](x, trg_mask)
            new_memory.append(icy_state)
        
        self._reset_memory()
            
        self.memory = [l for l in new_memory]
        return self.norm(x)
    
    def _reset_memory(self):
        # Free up old memory
        while len(self.memory) > 0:
            _r = self.memory.pop(0)
            del _r

### Transformer
Final wrapper class for Transformer. Nothing but the Encoder layer along with a final linear projection layer, that projects the output representation to log probability of words in vocab. 

In [25]:
class Transformer(nn.Module):
    def __init__(self, trg_vocab, d_model, N, heads):
        super().__init__()
        self.encoder = Encoder(trg_vocab, d_model, N, heads)
        # self.decoder = Decoder(trg_vocab, d_model, N, heads)
        self.out = nn.Linear(d_model, trg_vocab)
    def forward(self, trg, trg_mask, mem=None):
        """
        if no memory passed, a new context will start.
        if memory is passed, will make use of that as key and value
        """
        # e_outputs = self.encoder(src, src_mask)
        d_output = self.encoder(trg, trg_mask)
        output = self.out(d_output)
        return output
        # we don't perform softmax on the output as this will be handled 
        # automatically by our loss function
        
    def reset_memory(self):
        self.encoder._reset_memory()

### WikiDataset
WikiDataset class for fetching samples from wikitext-103 dataset.

In [26]:
class WikiDataset(Dataset):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
    def __init__(self, split, max_len=MAX_LEN):
        super(WikiDataset, self).__init__()
        if split == 'train':
            _file = data_path + '/wiki.train.tokens'
            n_lines = 1801350
        elif split=="valid":
            _file = data_path + '/wiki.valid.tokens'
            n_lines = 3760
        elif split=="test":
            _file = data_path + '/wiki.test.tokens'
            n_lines = 4358
        else:
            raise Exception(f"wrong split: {split}")
        print("File:", _file)
        print("Expected # of lines:", n_lines)
        self.data = []
        with open(_file) as f:
            for line in tqdm_notebook(f, total=n_lines):
                line = line.strip()
                if len(line) > 0:
                    el = en_lang.encodeSentence(line)
                    if len(el) < max_len:
                        el = el + [en_lang.iEOS] + [en_lang.iPAD]*(max_len - len(el) - 1)
                    else:
                        el = el[:(max_len - 1)] + [en_lang.iEOS]
                    self.data.append(el)
    
    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

### Contiguous WikiDataset

In [27]:
class Contiguous_WikiDataset(Dataset):
    """Dataset loader 
    Assumes the files have following structure within a single file. 
    Each page is marked with one and only one `=` sign. Sections/sub-
    sections might be marked with `==` or more of those.
    Use this dataset class for context consistent batch-formation. 
    e.g. as required in Transformer-XL and XLNet
    
    >>>
    = Page Title 1 =
    
    Page content
    
    = Page Title 2 =
    
    Page content
    <<<
    
    :param: self.batch_size: Using with wikidataset class, not a very ideal thing to do. But 
    needed a workaround for now. 
    """
    def __init__(self, split, batch_size, max_len=MAX_LEN):
        super(Contiguous_WikiDataset, self).__init__()
        
        """
        SET PARAMETERs
        """
        self.batch_size = batch_size
        self.split = split
        self.max_len = MAX_LEN
        
        """
        Filenames
        """
        if self.split == 'train':
            self._file = data_path + '/wiki.train.tokens'
            n_lines = 1801350
        elif self.split=="valid":
            self._file = data_path + '/wiki.valid.tokens'
            n_lines = 3760
        elif self.split=="test":
            self._file = data_path + '/wiki.test.tokens'
            n_lines = 4358
        else:
            raise Exception(f"wrong self.split: {self.split}")
        print("File:", self._file)
        # print("Expected # of lines:", n_lines)
        
        ## REGEX
        self.re_page = regex.compile(r"^ = [^=]+? = $", regex.MULTILINE)
        #self.re_page = regex.compile(r"^ = ([^=]+)+? = *$", regex.MULTILINE)
        
        """
        One time. pre-Loads the raw data
        """
        self.load_data()
        
        """
        No shuffle initially:
            This makes sure valid and test dataset is processed in 
            original order
        """
        self.reShuffle(shuffle=False) 
        
    def load_data(self):
        _data = open(self._file).read()
        
        matches = [match.span()[0] for match in self.re_page.finditer(_data)]
        self.N = len(matches)
        print(self.N, 'documents found.')
        
        # Each entry in self.data is a document. Include
        self.data = []
        for start, end in tqdm_notebook(zip(matches, matches[1:] + [len(_data)]), total=self.N, desc="Encoding: "):
            self.data.append(
                en_lang.encodeSentence(_data[start:end])
            )
            
    def reShuffle(self, shuffle=True):
        if shuffle:
            random.shuffle(self.data)
        self.batched_data = []
        next_doc = self.batch_size # next doc to pick if we ran out of text
        current_docs = list(zip(range(self.batch_size), [0]*self.batch_size))
        finished = 0
        while(finished <= self.batch_size - 10): # Batch tends to get sparse near the end. 
            print(f"Status: {len(self.batched_data)} | {finished}/{self.N}", end="\r")
            # print(current_docs)
            finished = 0
            for i, (doc_x, doc_y) in enumerate(current_docs):
                if doc_x != -1:
                    reset = (doc_y+self.max_len) >= len(self.data[doc_x])
                    self.batched_data.append((self.data[doc_x][doc_y:(doc_y+self.max_len)], reset))
                    if reset:
                        if next_doc < self.N:
                            doc_x = next_doc
                            next_doc += 1
                            doc_y = 0
                        else:
                            doc_x = -1
                    else:
                        doc_y += self.max_len
                    
                    current_docs[i] = (doc_x, doc_y)
                else:
                    self.batched_data.append(([], True))
                    finished += 1
            # print('Finished:', finished)
        # Drop the last batch as that's all empty
        self.batched_data = self.batched_data[:-self.batch_size]
        
#         with open(self._file) as f:
#             for line in tqdm_notebook(f, total=n_lines):
#                 line = line.strip()
#                 if len(line) > 0:
#                     el = en_lang.encodeSentence(line)
#                     if len(el) < self.max_len:
#                         el = el + [en_lang.iEOS] + [en_lang.iPAD]*(self.max_len - len(el) - 1)
#                     else:
#                         el = el[:(self.max_len - 1)] + [en_lang.iEOS]
#                     self.data.append(el)
    
    def __getitem__(self, index):
        return self.batched_data[index]

    def __len__(self):
        return len(self.batched_data)

#### Dataset Split

In [28]:
BATCH_SIZE = 32

In [29]:
%%time
wikiDataset_valid = Contiguous_WikiDataset('valid', BATCH_SIZE)

File: data/wikitext-103//wiki.valid.tokens
60 documents found.


HBox(children=(IntProgress(value=0, description='Encoding: ', max=60, style=ProgressStyle(description_width='i…


Status: 0 | 0/60Status: 32 | 0/60Status: 64 | 0/60Status: 96 | 0/60Status: 128 | 0/60Status: 160 | 0/60Status: 192 | 0/60Status: 224 | 0/60Status: 256 | 0/60Status: 288 | 0/60Status: 320 | 0/60Status: 352 | 0/60Status: 384 | 0/60Status: 416 | 0/60Status: 448 | 0/60Status: 480 | 0/60Status: 512 | 0/60Status: 544 | 0/60Status: 576 | 0/60Status: 608 | 0/60Status: 640 | 0/60Status: 672 | 0/60Status: 704 | 0/60Status: 736 | 0/60Status: 768 | 0/60Status: 800 | 0/60Status: 832 | 0/60Status: 864 | 0/60Status: 896 | 0/60Status: 928 | 0/60Status: 960 | 0/60Status: 992 | 0/60Status: 1024 | 0/60Status: 1056 | 0/60Status: 1088 | 0/60Status: 1120 | 0/60Status: 1152 | 0/60Status: 1184 | 0/60Status: 1216 | 3/60Status: 1248 | 4/60Status: 1280 | 4/60Status: 1312 | 4/60Status: 1344 | 4/60Status: 1376 | 6/60Status: 1408 | 7/60Status: 1440 | 8/60Status: 1472 | 9/60Status: 1504 | 9/60Status: 1536 | 9/60Status: 1568 | 9/60Status: 1600 | 9/60Status: 1632 | 9

In [30]:
%%time
wikiDataset_test = Contiguous_WikiDataset('test', BATCH_SIZE)

File: data/wikitext-103//wiki.test.tokens
64 documents found.


HBox(children=(IntProgress(value=0, description='Encoding: ', max=64, style=ProgressStyle(description_width='i…


Status: 0 | 0/64Status: 32 | 0/64Status: 64 | 0/64Status: 96 | 0/64Status: 128 | 0/64Status: 160 | 0/64Status: 192 | 0/64Status: 224 | 0/64Status: 256 | 0/64Status: 288 | 0/64Status: 320 | 0/64Status: 352 | 0/64Status: 384 | 0/64Status: 416 | 0/64Status: 448 | 0/64Status: 480 | 0/64Status: 512 | 0/64Status: 544 | 0/64Status: 576 | 0/64Status: 608 | 0/64Status: 640 | 0/64Status: 672 | 0/64Status: 704 | 0/64Status: 736 | 0/64Status: 768 | 0/64Status: 800 | 0/64Status: 832 | 0/64Status: 864 | 0/64Status: 896 | 0/64Status: 928 | 0/64Status: 960 | 0/64Status: 992 | 0/64Status: 1024 | 0/64Status: 1056 | 0/64Status: 1088 | 0/64Status: 1120 | 0/64Status: 1152 | 0/64Status: 1184 | 0/64Status: 1216 | 0/64Status: 1248 | 0/64Status: 1280 | 0/64Status: 1312 | 0/64Status: 1344 | 0/64Status: 1376 | 0/64Status: 1408 | 1/64Status: 1440 | 5/64Status: 1472 | 5/64Status: 1504 | 5/64Status: 1536 | 7/64Status: 1568 | 8/64Status: 1600 | 8/64Status: 1632 | 8

In [31]:
%%time
# wikiDataset_train = Contiguous_WikiDataset('train', BATCH_SIZE)
# print("DONE!")

CPU times: user 2 µs, sys: 1 µs, total: 3 µs
Wall time: 5.01 µs


### Model Config
+ d_model: Embedding dim of words
+ heads: Number of heads used for multi-head attention
+ N: Number of MHA layers

In [32]:
# d_model = 512
# heads = 32
# N = 8
d_model = 256
heads = 32
N = 8
_vocab = en_lang.VOCAB_SIZE
model = Transformer(_vocab, d_model, N, heads)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# this code is very important! It initialises the parameters with a
# range of values that stops the signal fading or getting too big.
# See this blog for a mathematical explanation.
optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [33]:
_ = model.cuda()

In [34]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"## Training model with {pytorch_total_params/1000000:0.2F}M trainable parameters.")

## Training model with 31.04M trainable parameters.


In [35]:
# wikiDataloader_train = DataLoader(wikiDataset_train, batch_size=BATCH_SIZE)
wikiDataloader_train = DataLoader(wikiDataset_valid, batch_size=BATCH_SIZE)
wikiDataloader_valid = DataLoader(wikiDataset_valid, batch_size=BATCH_SIZE)
wikiDataloader_test = DataLoader(wikiDataset_test, batch_size=BATCH_SIZE)

In [36]:
print(f"## Steps per epoch {len(wikiDataloader_train.dataset)//wikiDataloader_train.batch_size}")

## Steps per epoch 74


In [37]:
# def train_model():
epochs=14
print_every=50

_ = model.train()
start = time.time()
temp = start

total_loss = 0

for epoch in range(epochs):
    for i, batch in enumerate(wikiDataloader_train):
        reset_keys = batch[1]
        batch = batch[0]
        batch = torch.stack(batch).to(device)
        raise Exception("WAIT!!")
        trg = batch.t()

        # the French sentence we input has all words except
        # the last, as it is using each word to predict the next
        trg_input = trg[:, :-1]

        # the words we are trying to predict
        targets = trg[:, 1:].contiguous().view(-1)

        # create mask to make sure attn reads input only from the left (autoregressive)
        trg_mask = torch.tensor(np.tril(
            np.ones(
                (1, 1, trg_input.shape[1], trg_input.shape[1]))
        ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))

        preds = model(trg_input, trg_mask)

        optim.zero_grad()

        loss = F.cross_entropy(preds.view(-1, preds.size(-1)), targets, ignore_index=en_lang.iPAD)
        loss.backward()
        optim.step()

        total_loss += loss.data.item()
        if (i + 1) % print_every == 0:
            loss_avg = total_loss / print_every
            print("time = %dm, epoch %d, iter = %d, loss = %.3f, PPL = %8.2f, %ds per %d iters" % ((time.time() - start) // 60,
            epoch + 1, i + 1, loss_avg, math.exp(loss_avg), time.time() - temp,
            print_every))
            total_loss = 0
            temp = time.time()
            # raise Exception("STOP")

TypeError: expected Tensor as element 0 in argument 0, but got list

In [39]:
batch[1]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)

In [None]:
len(model.encoder.memory)

In [None]:
raise Exception("Wait")

### Training the model

In [None]:
# train_model()

### Save the model

In [None]:
# torch.save({
#     'epoch': epoch,
#     'iter': i,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optim.state_dict(),
#     'loss': loss
# }, "models/txl_wikitext103.pth")

### Load a saved model

In [None]:
checkpoint = torch.load("models/txl_wikitext103.pth")

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])

### Prediction on Test Set
The following function runs the model on the test or validation set. You can use this function to calculate perplexity on the validation or test set to compare. I didn't bother doing this as I was more interested in the contextual text generation task. That's the next function. 

In [None]:
def sample_sequence():
    model.eval()
    with torch.no_grad():
        start_from = 6
        for i, batch in enumerate(wikiDataloader_valid):
            batch = torch.stack(batch).to(device)
            trg = batch.t()
            zl = None
            trg_input = trg[:, :-1]
            trg_mask_common = torch.tensor(np.tril(
                    np.ones(
                        (1, 1, trg_input.shape[1], trg_input.shape[1]))
                ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))
            
            # the words we are trying to predict
            targets = trg[:, 1:].contiguous().view(-1)
            
            for j in range(start_from, MAX_LEN - 1):
                # Predicting (j+1)th word
                if zl is None:
                    zl = torch.tensor(np.zeros((1, 1, trg_input.shape[1], trg_input.shape[1]))
                                      , device=device).double()
                    zl[..., :j] = 1

                zl[..., :j] = 1

                # create mask to make sure attn reads input only from the left (autoregressive)
                trg_mask =  trg_mask_common * zl

                preds = F.softmax(model(trg_input, trg_mask)[...,j,:], dim=-1)
#                 samples = torch.multinomial(preds, 1)[:,0]
                samples = torch.argmax(preds, 1)
                # samples = samples.view(trg_input.shape[0], -1)

                trg[..., (j+1)] = samples
                print(f"{j}", end="\r")
            return trg

In [None]:
preds = sample_sequence()

In [None]:
en_lang.decodeSentence(preds[16, :].cpu().tolist())

### Text sampler
Finally, the text generator function. This is inspired by the talktotransformer site. I was blown away by that site. Of course, the model here trained is not as good as the fine-tuned GPT-2 model used for talktotransformer, but this gives a good flavour of the task.

In [None]:
def talk_to_me(context, max_len = MAX_LEN):
    model.eval()
    context = torch.tensor(en_lang.encodeSentence(context)).unsqueeze(0).to(device)
    with torch.no_grad():
        start_from = (context.shape[1] - 1)
        # for i, batch in enumerate(wikiDataloader_valid):
        trg_input = context.to(device)
        # trg = batch.t()
        zl = None
        # trg_input = trg[:, :-1]
        trg_input = F.pad(trg_input, (0, MAX_LEN - trg_input.shape[1]), "constant", en_lang.iEOS)
        
        trg_mask_common = torch.tensor(np.tril(
                np.ones(
                    (1, 1, MAX_LEN, MAX_LEN))
            ), device=device) * ((trg_input != en_lang.iPAD).double().unsqueeze(1).unsqueeze(1))

        for j in range(start_from, MAX_LEN - 1):
            # Predicting (j+1)th word
            if zl is None:
                zl = torch.tensor(np.zeros((1, 1, trg_input.shape[1], trg_input.shape[1]))
                                  , device=device).double()
                zl[..., :j] = 1

            zl[..., :j] = 1

            # create mask to make sure attn reads input only from the left (autoregressive)
            trg_mask =  trg_mask_common * zl

            preds = F.softmax(model(trg_input, trg_mask)[...,j,:], dim=-1)
            if np.random.rand() < 0.2:
                samples = torch.multinomial(preds, 1)[:,0]
            else:
                samples = torch.argmax(preds, 1)
            # samples = samples.view(trg_input.shape[0], -1)

            trg_input[..., (j+1)] = samples
            if samples.item() == en_lang.iEOS:
                return trg_input
            print(f"{j}", end="\r")
        return trg_input

In [None]:
query = "Bangalore has the best"
for i in range(5):
    gen_text = talk_to_me(query)
    print(f"Sample {i}: ", ' '.join(en_lang.decodeSentence(gen_text.cpu()[0].numpy().tolist())))