In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys, os, re, csv, codecs, numpy as np, pandas as pd
from fastai import *        # Quick accesss to most common functionality
# from fastai.text import *   # Quick accesss to NLP functionality
import html
from pathlib import Path

In [3]:
import torch.nn as nn

In [4]:
from fastai.text.models.transformer import tfmer_lm_config, GeLU, init_transformer
from fastai.text.models.awd_lstm import RNNDropout
from fastai.text.learner import LanguageLearner

In [5]:
PATH=Path.home()/'data/wikitext-2-raw'

In [6]:
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.layers import *

In [7]:
from fastai.text.models.transformer import *

## Bert task - dataloading

In [8]:
bptt = 512
data = load_data(PATH, bs=16, bptt=bptt)

In [9]:
MASK = 'xxmask'

vocab = data.vocab

vocab.itos.append(MASK)

In [10]:
data.show_batch()

idx,text
0,"the map varies depending on an individual player 's approach : when one option is selected , the other is sealed off to the player . xxmaj outside missions , the player characters rest in a camp , where units can be customized and character growth occurs . xxmaj alongside the main story missions are character - specific sub missions relating to different squad members . xxmaj after the game"
1,"xxmaj kettle xxmaj hole xxmaj pond and to the southwest of xxmaj secret xxmaj lake in a heavily forested region . xxmaj after xxmaj secret xxmaj lake , the highway curves to the north , crossing xxmaj oak xxmaj hill xxmaj road at another at - grade intersection . xxbos xxmaj shortly after the intersection with xxmaj oak xxmaj hill xxmaj road , xxmaj route 4 transitions from a divided"
2,"hamas was seeking a diplomatic settlement with xxmaj israel . xxmaj he also condemned what he said was xxmaj israel 's refusal "" to abide by international law [ and ] to abide by the opinion of the international community "" to settle the conflict . xxbos "" i was of course happy to meet the xxmaj hizbullah people , because it is a point of view that is rarely"
3,"june demonstrated a number of the major problems inherent in the xxmaj french and xxmaj british navies at the start of the xxmaj revolutionary xxmaj wars . xxmaj both admirals were faced with disobedience from their captains , along with ill - discipline and poor training among their shorthanded crews , and they failed to control their fleets effectively during the height of the combat . xxbos = = xxmaj"
4,"of the xxup rok 3rd xxmaj division . xxmaj the xxmaj south xxmaj korean forces engaged the 766th forces around the village 's middle school with small - arms fire until noon . xxmaj at that point , xxmaj north xxmaj korean armored vehicles moved in to reinforce the 766th troops and drove the xxmaj south xxmaj koreans out of the village . xxbos xxmaj the village was strategically important"


In [11]:
data.one_batch()

(tensor([[ 1127,    15,     9,  ...,  2980,     5,  4042],
         [ 3366,    15,   217,  ...,    12,  1653,    72],
         [12753,    11,     5,  ..., 21301,   406,  1252],
         ...,
         [   10,     9,  1798,  ...,    15,     9,  1658],
         [   24,   540,    10,  ..., 14090,    11,     5],
         [   12,     5,  2994,  ...,   618,  9654,    26]]),
 tensor([[   15,     9,  2618,  ...,     5,  4042,    12],
         [   15,   217,  8701,  ...,  1653,    72,  1402],
         [   11,     5,    46,  ...,   406,  1252,    16],
         ...,
         [    9,  1798,    12,  ...,     9,  1658,    14],
         [  540,    10,    23,  ...,    11,     5,    35],
         [    5,  2994, 13419,  ...,  9654,    26,     5]]))

In [12]:
word_range = (0, len(data.vocab.itos))
from fastai.text.transform import *
pad_idx = data.vocab.stoi[PAD]
mask_idx = data.vocab.stoi[MASK]
def bert_first_tfm(b, word_range=word_range, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    x_msk,y_msk = x_lm.clone(),x_lm.clone() # x, x
#     x,y = x.clone(),y.clone()
    rand = torch.rand(x_msk.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    x_msk[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x_msk[wrong_word] = torch.randint(*word_range, [wrong_word.sum().item()], device=x_lm.device)
    return x_msk, (y_msk, y_lm)

In [13]:
def lm_first_tfm(b, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    
    y_msk = y_lm.clone() # x, x
    rand = torch.rand(x_lm.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    msk_idxs = rand <= p
    return (x_lm, msk_idxs), (y_msk, y_lm)

In [14]:
def rnd_tfm(b):
    r = random.randint(0, 1)
    if r == 0:
        return lm_first_tfm(b)
    return bert_first_tfm(b)

In [15]:
def val_tfm(b):
    x_lm, y_lm = b
    x_msk, ys = bert_first_tfm(b)
    y_msk, _ = ys
    return (x_msk,x_lm), (y_msk, y_lm)

## Model creation

In [16]:

class Embedder(nn.Module):
    "Embedding + positional encoding + dropout"
    def __init__(self, vocab_sz:int, emb_sz:int, embed_p:float=0., pad_idx=pad_idx):
        super().__init__()
        self.emb_sz = emb_sz
        
#         self.embed = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_idx)        
        self.embed = nn.Embedding(vocab_sz, emb_sz)
        # See https://arxiv.org/abs/1711.09160
        with torch.no_grad(): trunc_normal_(self.embed.weight, std=0.01)
        self.drop = nn.Dropout(embed_p)
    
    def forward(self, inp, pos_forward=False):
        emb = self.drop(self.embed(inp))
        return emb


In [17]:

class Encoder(nn.Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=False, scale:bool=True,
                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
                 learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0, **kwargs):
        super().__init__()
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
        self.init = False
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
                      attn_cls=attn_cls) for k in range(n_layers)])
        
        self.pos_enc = PositionalEncoding(d_model)
        
        nn.init.normal_(self.u, 0., 0.02)
        nn.init.normal_(self.v, 0., 0.02)
    
    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()
    
    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
    
    def forward(self, x):
        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
        if self.mem_len > 0 and not self.init: 
            self.reset()
            self.init = True
        bs,x_len,emb_sz = x.size()
        
        inp = x
        
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len).long(), diagonal=1+m_len).byte()[None,None] if self.mask else None
        
        hids = []
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i] if self.mem_len > 0 else None
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        if self.mem_len > 0 : self._update_mems(hids)
        return core_out

In [18]:
TrainType = Enum('TrainType', 'Default, Separate, LMOnly, BertOnly')

In [19]:
class BertHead(nn.Module):
    def __init__(self, embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.Default):
        super().__init__()
        self.embed = embed
        self.bert_encoder = bert_encoder
        self.nw_encoder = nw_encoder
        self.decoder = decoder
        self.train_type = train_type
        
    
    def forward(self, x, mask_idxs=None):
        x_enc = self.embed(x)
        self.bert_encoder.mask = False
        
        # Baseline
        if self.train_type == TrainType.LMOnly:
            nw_enc = self.nw_encoder(x_enc)
            return self.decoder(nw_enc)
        
        if self.train_type == TrainType.BertOnly:
            bert_enc = self.bert_encoder(x_enc)
            return self.decoder(bert_enc)
        
        # Validation - train separately
        if (self.train_type == TrainType.Separate) or (not self.training):
            bert_enc = self.bert_encoder(x_enc)
            
            x_lm_enc = self.embed(mask_idxs)
            nw_enc = self.nw_encoder(x_lm_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
        
        bert_first = mask_idxs is None # mask idxs tells us which embeddings to mask
        if bert_first:
            self.bert_encoder.mask = True
            bert_enc = self.bert_encoder(x_enc)
            nw_enc = self.nw_encoder(bert_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
        else:
            nw_enc = self.nw_encoder(x_enc)
            nw_enc[mask_idxs] = embed(torch.tensor(mask_idx, device=x.device))
            bert_enc = self.bert_encoder(nw_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
    
    "A sequential module that passes the reset call to its children."
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()
        

In [20]:

class Decoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        output = self.output_dp(input)
        decoded = self.decoder(output)
        return decoded

## Transformer

In [21]:
vocab_sz = len(data.vocab.itos)
# config = tfmer_lm_config.copy(); config
config = {
    'ctx_len': bptt,
    'n_layers': 4,
    'n_heads': 4,
    'd_model': 128,
    'd_head': 32,
    'd_inner': 512,
}

In [22]:
class BertLoss():
    def __init__(self, pad_idx=pad_idx, loss_mult=(1,1)):
        "Loss mult - mask, NextWord, Seq2Seq, NextSent"
        self.index_loss = CrossEntropyFlat(ignore_index=pad_idx)
        self.loss_mult = loss_mult
        
    def __call__(self, input:Tensor, bert_target:Tensor, lm_target:Tensor, **kwargs)->Rank0Tensor:
        x_bert, x_lm = input
        loss_bert = self.index_loss.__call__(x_bert, bert_target, **kwargs) * self.loss_mult[0]
        loss_lm = self.index_loss.__call__(x_lm, lm_target, **kwargs) * self.loss_mult[1]
#         print(loss_bert, loss_lm)
        return loss_bert + loss_lm

In [23]:

def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx=pad_idx)->Rank0Tensor:
    n = targ.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targ = targ.view(n,-1)
    mask = targ != pad_idx
    return (input[mask]==targ[mask]).float().mean()

def bert_acc(input:Tensor, b_t:Tensor, lm_t:Tensor)->Rank0Tensor:
    x_bert, x_lm = input
    return acc_ignore_pad(x_bert, b_t)
def lm_acc(input:Tensor, b_t:Tensor, lm_t:Tensor)->Rank0Tensor:
    x_bert, x_lm = input
    return acc_ignore_pad(x_lm, lm_t)

## Train

In [24]:
data.train_dl.tfms = [rnd_tfm]
data.valid_dl.tfms = [val_tfm]

In [25]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder)
model = bert_head
model.apply(init_transformer);

In [26]:
learn = LanguageLearner(data, model, loss_func=BertLoss(), clip=0.5)
learn.callbacks = []
learn.metrics=[bert_acc, lm_acc]

In [27]:
learn.callbacks = []

In [28]:
learn.metrics=[bert_acc, lm_acc]

In [31]:
# learn.lr_find(num_it=500)
# learn.recorder.plot()

In [29]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time
0,12.452651,11.264405,0.232077,0.218978,00:53
1,10.469805,10.642893,0.271561,0.241542,00:52


## Baseline - Train Separate

In [24]:
data.train_dl.tfms = [val_tfm]
data.valid_dl.tfms = [val_tfm]

In [25]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.Separate)
model = bert_head
model.apply(init_transformer);

In [26]:
learn = LanguageLearner(data, model, loss_func=BertLoss(), clip=0.5)
learn.callbacks = []
learn.metrics=[bert_acc, lm_acc]

In [27]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time
0,11.521917,10.40686,0.314711,0.246377,00:53
1,9.97849,9.652961,0.363171,0.266574,00:53


## Baseline - Next Word only

In [24]:
embed = Embedder(vocab_sz, config['d_model'])
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
model = nn.Sequential(embed, nw_encoder, decoder)
model.apply(init_transformer);

In [25]:
data.train_dl.tfms = []
data.valid_dl.tfms = []

In [26]:
learn = LanguageLearner(data, model)

In [27]:
learn.callbacks = []

In [28]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.386352,5.950587,0.146343,00:26
1,5.591476,5.380808,0.23556,00:26


## Baseline - Bert only

In [24]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
model = nn.Sequential(embed, bert_encoder, decoder)
model.apply(init_transformer);

In [25]:
loss_func = CrossEntropyFlat(ignore_index=pad_idx)

In [26]:
learn = LanguageLearner(data, model, loss_func=loss_func)

In [27]:
def bert_tfm(b):
    x_msk, (y_msk, y_lm) = bert_first_tfm(b)
    return x_msk, y_msk

data.train_dl.tfms = [bert_tfm]
data.valid_dl.tfms = [bert_tfm]

In [28]:
def base_acc(input:Tensor, t1:Tensor)->Rank0Tensor:
    return acc_ignore_pad(input, t1)
learn.metrics=[base_acc]

In [29]:
learn.callbacks = []

In [30]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,base_acc,time
0,5.851985,5.431993,0.246056,00:26
1,5.378727,5.213453,0.268824,00:25
