In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys, os, re, csv, codecs, numpy as np, pandas as pd
from fastai import *        # Quick accesss to most common functionality
# from fastai.text import *   # Quick accesss to NLP functionality
import html
from pathlib import Path

In [3]:
import torch.nn as nn

In [4]:
from fastai.text.models.transformer import tfmer_lm_config, GeLU, init_transformer
from fastai.text.models.awd_lstm import RNNDropout
from fastai.text.learner import LanguageLearner

In [5]:
PATH=Path.home()/'data/wikitext-2-raw'

In [6]:
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.layers import *

In [7]:
from fastai.text.models.transformer import *

## Bert task - dataloading

In [8]:
bptt = 256
data = load_data(PATH, bs=8, bptt=bptt)

In [9]:
MASK = 'xxmask'

In [10]:
vocab = data.vocab

In [11]:
vocab.itos.append(MASK)

In [12]:
data.show_batch()

idx,text
0,"team of writers handled the script . xxmaj the game 's opening theme was sung by xxmaj may ' n . xxbos xxmaj it met with positive sales in xxmaj japan , and was praised by both xxmaj japanese and western critics . xxmaj after release , it received downloadable content , along with an expanded edition in xxmaj november of that year . xxmaj it was also adapted into"
1,", which relentlessly and brutally and xxunk keeps these vicious , murderous wars . xxmaj it is a vandal state . xxmaj there is a xxmaj russian writer who once described vandal states as xxmaj genghis xxmaj khan with a telegraph . xxmaj israel is xxmaj genghis xxmaj khan with a computer . i feel no emotion of affinity with that state . i have some good friends and their"
2,the 5th xxmaj division to envelop xxup un troops and push them back to xxmaj pusan . xxmaj the 766th was not reinforced ; xxmaj north xxmaj korean planners intended it to move unseen around the xxup un lines while the majority of the xxup un and xxmaj north xxmaj korean troops were locked in fighting around xxmaj taegu and the xxmaj xxunk xxmaj bulge . xxbos xxmaj by this
3,"at least two "" facilitating subjects "" ; the average point score per qualification was xxunk , equating to a xxup c- grade , and the average point score per student was xxunk . xxmaj the xxmaj sunday xxmaj times ranked xxmaj carre 's 101st ( 49th amongst state schools ) in the xxmaj midlands and 750th nationally based on a - xxmaj level and xxup gcse performance in 2012"
4,"mfume , the president of the xxmaj national xxmaj association for the xxmaj advancement of xxmaj colored xxmaj people ( xxup naacp ) , would run . xxmaj mfume had previously served on the xxmaj baltimore xxmaj city xxmaj council and in the xxmaj united xxmaj states xxmaj house of xxmaj representatives . xxmaj schmoke called the race "" his to lose "" . xxmaj however , xxmaj mfume lived"


In [13]:
data.one_batch()

(tensor([[    9,  3531,  5611,  ...,  1105,   622,    29],
         [16077,    19,  4075,  ..., 17572,    60,     9],
         [   12,     9,     6,  ...,   344,    12,     5],
         ...,
         [    9, 13634,   749,  ...,    54,    16,     5],
         [ 1228,   367,   785,  ...,    16,  2218,    12],
         [   15,  4549,     9,  ...,    21,  1107,   869]]),
 tensor([[ 3531,  5611,  3398,  ...,   622,    29,     6],
         [   19,  4075,    16,  ...,    60,     9,   465],
         [    9,     6,  6630,  ...,    12,     5,   113],
         ...,
         [13634,   749,    23,  ...,    16,     5,   307],
         [  367,   785,   321,  ...,  2218,    12,  1636],
         [ 4549,     9, 17780,  ...,  1107,   869,    11]]))

In [14]:
word_range = (0, len(data.vocab.itos))
from fastai.text.transform import *
pad_idx = data.vocab.stoi[PAD]
mask_idx = data.vocab.stoi[MASK]
def bert_first_tfm(b, word_range=word_range, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    x_msk,y_msk = x_lm.clone(),x_lm.clone() # x, x
#     x,y = x.clone(),y.clone()
    rand = torch.rand(x_msk.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    x_msk[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x_msk[wrong_word] = torch.randint(*word_range, [wrong_word.sum().item()], device=x_lm.device)
    return x_msk, (y_msk, y_lm)

In [15]:
def lm_first_tfm(b, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    
    y_msk = y_lm.clone() # x, x
    rand = torch.rand(x_lm.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    msk_idxs = rand <= p
    return (x_lm, msk_idxs), (y_msk, y_lm)

In [16]:
def rnd_tfm(b):
    r = random.randint(0, 1)
    if r == 0:
        return lm_first_tfm(b)
    return bert_first_tfm(b)

In [17]:
def val_tfm(b):
    x_lm, y_lm = b
    x_msk, ys = bert_first_tfm(b)
    y_msk, _ = ys
    return (x_msk,x_lm), (y_msk, y_lm)

## Model creation

In [18]:

class Embedder(nn.Module):
    "Embedding + positional encoding + dropout"
    def __init__(self, vocab_sz:int, emb_sz:int, embed_p:float=0., pad_idx=pad_idx):
        super().__init__()
        self.emb_sz = emb_sz
        
#         self.embed = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_idx)        
        self.embed = nn.Embedding(vocab_sz, emb_sz)
        # See https://arxiv.org/abs/1711.09160
        with torch.no_grad(): trunc_normal_(self.embed.weight, std=0.01)
        self.drop = nn.Dropout(embed_p)
    
    def forward(self, inp, pos_forward=False):
        emb = self.drop(self.embed(inp))
        return emb


In [19]:

class Encoder(nn.Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=False, scale:bool=True,
                 act:Activation=Activation.ReLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
                 learned_pos_enc:bool=False, mask:bool=True, mem_len:int=0, **kwargs):
        super().__init__()
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.mem_len,self.n_layers,self.d_model,self.mask = mem_len,n_layers,d_model,mask
        self.init = False
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
                      ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
                      attn_cls=attn_cls) for k in range(n_layers)])
        
        self.pos_enc = PositionalEncoding(d_model)
    
    def reset(self):
        "Reset the internal memory."
        self.hidden = [next(self.parameters()).data.new(0) for i in range(self.n_layers+1)]

    def _update_mems(self, hids):
        if not getattr(self, 'hidden', False): return None
        assert len(hids) == len(self.hidden), 'len(hids) != len(self.hidden)'
        with torch.no_grad():
            for i in range(len(hids)):
                cat = torch.cat([self.hidden[i], hids[i]], dim=1)
                self.hidden[i] = cat[:,-self.mem_len:].detach()
    
    def select_hidden(self, idxs): self.hidden = [h[idxs] for h in self.hidden]
    
    def forward(self, x):
        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
        if self.mem_len > 0 and not self.init: 
            self.reset()
            self.init = True
        bs,x_len,emb_sz = x.size()
        
        inp = x
        
        m_len = self.hidden[0].size(1) if hasattr(self, 'hidden') and len(self.hidden[0].size()) > 1 else 0
        seq_len = m_len + x_len
        mask = torch.triu(x.new_ones(x_len, seq_len).long(), diagonal=1+m_len).byte()[None,None] if self.mask else None
        
        hids = []
        pos = torch.arange(seq_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        hids.append(inp)
        for i, layer in enumerate(self.layers):
            mem = self.hidden[i] if self.mem_len > 0 else None
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v, mask=mask, mem=mem)
            hids.append(inp)
        core_out = inp[:,-x_len:]
        if self.mem_len > 0 : self._update_mems(hids)
        return core_out

In [20]:
TrainType = Enum('TrainType', 'Default, Separate, LMOnly, BertOnly')

In [21]:
class BertHead(nn.Module):
    def __init__(self, embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.Default):
        super().__init__()
        self.embed = embed
        self.bert_encoder = bert_encoder
        self.nw_encoder = nw_encoder
        self.decoder = decoder
        self.train_type = train_type
        
    
    def forward(self, x, mask_idxs=None):
        x_enc = self.embed(x)
        self.bert_encoder.mask = False
        
        # Baseline
        if self.train_type == TrainType.LMOnly:
            nw_enc = self.nw_encoder(x_enc)
            return self.decoder(nw_enc)
        
        if self.train_type == TrainType.BertOnly:
            bert_enc = self.bert_encoder(x_enc)
            return self.decoder(bert_enc)
        
        # Validation - train separately
        if (self.train_type == TrainType.Separate) or (not self.training):
            bert_enc = self.bert_encoder(x_enc)
            
            x_lm_enc = self.embed(mask_idxs)
            nw_enc = self.nw_encoder(x_lm_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
        
        bert_first = mask_idxs is None # mask idxs tells us which embeddings to mask
        if bert_first:
            self.bert_encoder.mask = True
            bert_enc = self.bert_encoder(x_enc)
            nw_enc = self.nw_encoder(bert_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
        else:
            nw_enc = self.nw_encoder(x_enc)
            nw_enc[mask_idxs] = embed(torch.tensor(mask_idx, device=x.device))
            bert_enc = self.bert_encoder(nw_enc)
            return self.decoder(bert_enc), self.decoder(nw_enc)
    
    "A sequential module that passes the reset call to its children."
    def reset(self):
        for c in self.children():
            if hasattr(c, 'reset'): c.reset()
        

In [22]:

class Decoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, output_p:float, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        output = self.output_dp(input)
        decoded = self.decoder(output)
        return decoded

## Transformer

In [23]:
vocab_sz = len(data.vocab.itos)
# config = tfmer_lm_config.copy(); config
config = {
    'ctx_len': bptt,
    'n_layers': 4,
    'n_heads': 4,
    'd_model': 128,
    'd_head': 32,
    'd_inner': 512,
}

In [24]:
class BertLoss():
    def __init__(self, pad_idx=pad_idx, loss_mult=(1,1)):
        "Loss mult - mask, NextWord, Seq2Seq, NextSent"
        self.index_loss = CrossEntropyFlat(ignore_index=pad_idx)
        self.loss_mult = loss_mult
        
    def __call__(self, input:Tensor, bert_target:Tensor, lm_target:Tensor, **kwargs)->Rank0Tensor:
        x_bert, x_lm = input
        loss_bert = self.index_loss.__call__(x_bert, bert_target, **kwargs) * self.loss_mult[0]
        loss_lm = self.index_loss.__call__(x_lm, lm_target, **kwargs) * self.loss_mult[1]
#         print(loss_bert, loss_lm)
        return loss_bert + loss_lm

In [25]:

def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx=pad_idx)->Rank0Tensor:
    n = targ.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targ = targ.view(n,-1)
    mask = targ != pad_idx
    return (input[mask]==targ[mask]).float().mean()

def bert_acc(input:Tensor, b_t:Tensor, lm_t:Tensor)->Rank0Tensor:
    x_bert, x_lm = input
    return acc_ignore_pad(x_bert, b_t)
def lm_acc(input:Tensor, b_t:Tensor, lm_t:Tensor)->Rank0Tensor:
    x_bert, x_lm = input
    return acc_ignore_pad(x_lm, lm_t)

## Load data

In [26]:
data.train_dl.tfms = [rnd_tfm]
data.valid_dl.tfms = [val_tfm]

In [27]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder)
model = bert_head
model.apply(init_transformer);

In [28]:
learn = LanguageLearner(data, model, loss_func=BertLoss(), clip=0.5)
learn.callbacks = []
learn.metrics=[bert_acc, lm_acc]

In [29]:
learn.callbacks = []

In [30]:
learn.metrics=[bert_acc, lm_acc]

In [31]:
# learn.lr_find(num_it=500)
# learn.recorder.plot()

In [32]:
learn.fit_one_cycle(2, 1e-4)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time


KeyboardInterrupt: 

In [31]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time


KeyboardInterrupt: 

## Baseline - Train Separate

In [26]:
data.train_dl.tfms = [val_tfm]
data.valid_dl.tfms = [val_tfm]

In [27]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.Separate)
model = bert_head
model.apply(init_transformer);

In [28]:
learn = LanguageLearner(data, model, loss_func=BertLoss(), clip=0.5)
learn.callbacks = []
learn.metrics=[bert_acc, lm_acc]

In [29]:
# without clip
learn.fit_one_cycle(1, 1e-3)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time
0,12.532262,12.120117,0.156332,0.151737,01:38


In [40]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time
0,9.751541,9.352708,0.368044,0.268756,01:38
1,8.943863,8.725395,0.408181,0.285668,01:38


Change task

In [30]:
data.train_dl.tfms = [rnd_tfm]
data.valid_dl.tfms = [val_tfm]

In [31]:
learn.model.train_type = TrainType.Default

In [32]:
learn.fit_one_cycle(2, 1e-4)

epoch,train_loss,valid_loss,bert_acc,lm_acc,time
0,12.96056,12.000916,0.15384,0.162767,01:37
1,12.596864,11.907268,0.164497,0.16678,01:38


## Baseline - Next Word only

In [33]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.LMOnly)
model = bert_head
model.apply(init_transformer);

In [35]:
learn = LanguageLearner(data, model)

In [36]:

def base_acc(input:Tensor, t1:Tensor)->Rank0Tensor:
    return acc_ignore_pad(input, t1)
learn.metrics=[base_acc]

In [37]:
learn.callbacks = []

In [38]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,base_acc,time
0,5.700126,5.500681,0.212985,00:46
1,5.48549,5.37044,0.220364,00:46


## Baseline - Bert only

In [38]:
embed = Embedder(vocab_sz, config['d_model'])
bert_encoder = Encoder(**config)
nw_encoder = Encoder(**config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed, bias=False, output_p=False)
bert_head = BertHead(embed, bert_encoder, nw_encoder, decoder, train_type=TrainType.BertOnly)
model = bert_head
model.apply(init_transformer);

In [39]:
loss_func = CrossEntropyFlat(ignore_index=pad_idx)

In [40]:
learn = LanguageLearner(data, model, loss_func=loss_func)

In [41]:
def bert_tfm(b):
    x_msk, (y_msk, y_lm) = bert_first_tfm(b)
    return x_msk, y_msk

data.train_dl.tfms = [bert_tfm]
data.valid_dl.tfms = [bert_tfm]

In [42]:
def base_acc(input:Tensor, t1:Tensor)->Rank0Tensor:
    return acc_ignore_pad(input, t1)
learn.metrics=[base_acc]

In [43]:
learn.callbacks = []

In [44]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,base_acc,time
0,4.897971,4.734547,0.34004,00:46
1,4.4124,4.286451,0.387815,00:45
