In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys, os, re, csv, codecs, numpy as np, pandas as pd
from fastai import *        # Quick accesss to most common functionality
# from fastai.text import *   # Quick accesss to NLP functionality
import html
from pathlib import Path

In [3]:
import torch.nn as nn

In [4]:
from fastai.text.models.transformer import *
from fastai.text.models.transformer import init_transformer
from fastai.text.models.awd_lstm import RNNDropout, LinearDecoder
from fastai.text.learner import LanguageLearner

In [5]:
PATH=Path.home()/'data/wikitext-2-raw'

In [6]:
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.layers import *

In [7]:
bptt = 512
data = load_data(PATH, bs=16, bptt=bptt)

In [8]:
MASK = 'xxmask'
vocab = data.vocab
vocab.itos.append(MASK)

In [9]:
word_range = (0, len(data.vocab.itos))
from fastai.text.transform import *
pad_idx = data.vocab.stoi[PAD]
mask_idx = data.vocab.stoi[MASK]
def bert_tfm(b, word_range=word_range, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    x_msk,y_msk = x_lm.clone(),x_lm.clone() # x, x
    rand = torch.rand(x_msk.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    x_msk[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x_msk[wrong_word] = torch.randint(*word_range, [wrong_word.sum().item()], device=x_lm.device)
    return x_msk, y_msk

In [10]:
data.train_dl.tfms = [bert_tfm]
data.valid_dl.tfms = [bert_tfm]

In [11]:

def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx=pad_idx)->Rank0Tensor:
    n = targ.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targ = targ.view(n,-1)
    mask = targ != pad_idx
    return (input[mask]==targ[mask]).float().mean()

def bert_acc(input:Tensor, b_t:Tensor)->Rank0Tensor:
    return acc_ignore_pad(input, b_t)

In [12]:

class Decoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, output_p:float=0.0, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        self.output_dp = RNNDropout(output_p)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        output = self.output_dp(input)
        decoded = self.decoder(output)
        return decoded

In [13]:
class TransformerBlock(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_attn:int, n_heads:int, d_head:int, d_out:int=None, bias:bool=True, ff_p=0.0):
        super().__init__()
        if d_out is None: d_out = d_attn
        d_inner = d_out*4
        self.mhra = MultiHeadRelativeAttention(n_heads, d_attn, d_head, bias=bias)
        
#         self.mhra = MultiHeadAttention(n_heads, d_attn, d_head, bias=bias)
        self.pos_enc = PositionalEncoding(d_attn)
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        
        nn.init.normal_(self.u, 0., 0.02)
        nn.init.normal_(self.v, 0., 0.02)
        
        self.ln = nn.Sequential(*[
            nn.Linear(d_out, d_inner), 
            GeLU(),
            nn.Linear(d_inner, d_out),
            nn.Dropout(ff_p), 
        ])
        self.norm = nn.LayerNorm(d_out)
        
    def rel_attn(self, x:Tensor, mask:Tensor=None):
#         return self.mhra(x)
        pos = torch.arange(x.shape[1]-1, -1, -1, device=x.device, dtype=self.u.dtype)
        pos_enc = self.pos_enc(pos)
        return self.mhra(x, mask=mask, r=pos_enc, u=self.u, v=self.v)
    
    def forward(self, x:Tensor, mask:Tensor=None, **kwargs): 
        attn = self.rel_attn(x, mask=mask)
        res = self.ln(attn) + attn
        return self.norm(res)

In [61]:
class TransformerBlock2(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_attn:int, n_heads:int, d_head:int, d_out:int=None, bias:bool=False, ff_p=0.0):
        super().__init__()
        if d_out is None: d_out = d_attn
        d_inner = d_out*4
        self.mhra = MultiHeadRelativeAttention(n_heads, d_attn, d_head, bias=bias)
#         self.mhra = MultiHeadAttention(n_heads, d_attn, d_head, bias=bias)

        self.ln = nn.Sequential(*[
            nn.Linear(d_out, d_inner), 
            GeLU(),
            nn.Linear(d_inner, d_out),
            nn.Dropout(ff_p), 
        ])
        self.norm = nn.LayerNorm(d_out)
        
    def forward(self, x:Tensor, **kwargs): 
        attn = self.mhra(x, **kwargs)
        res = self.ln(attn) + attn
        return self.norm(res)

In [14]:
class DownsampleLayer(TransformerBlock):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_model:int, n_heads:int, d_head:int):
        d_out = d_model * 2
        super().__init__(d_model, n_heads, d_head, d_out)
        self.downblock = nn.Conv1d(d_model, d_out, (2), stride=2)
            
    def forward(self, x:Tensor):
        x_attn = self.rel_attn(x)
        
        x_d = self.downblock(x_attn.permute(0, 2, 1)).permute(0, 2, 1) # bptt x emb x bptt
        
        x1 = self.ln(x_d)
        
        return self.norm(x_attn + x1)

In [15]:
class UpsampleLayer(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_model:int, n_heads:int, d_head:int):
        d_out = d_model // 2
        super().__init__(d_model, n_heads, d_head, d_out)
        self.upblock = nn.ConvTranspose1d(d_model, d_out, (2), stride=2)

    
    def forward(self, x:Tensor, x_skip:Tensor, r=None):
        x_attn = self.rel_attn(x)
        
        x_u = self.upblock(x_attn.permute(0, 2, 1)).permute(0, 2, 1) # bptt x emb x bptt
        
        x1 = self.ln(x_u)
        
        return self.norm(x1 + x_skip)

## Transformer

In [16]:
class TransformerConv(nn.Module):
    "Transformer model: https://arxiv.org/abs/1706.03762."
    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, 
                 embed_p:float=0.0, **kwargs):
        super().__init__()
        self.encoder = nn.Embedding(vocab_sz, d_model)
        self.drop_emb = nn.Dropout(embed_p)
        # See https://arxiv.org/abs/1711.09160
        with torch.no_grad(): trunc_normal_(self.encoder.weight, std=0.01)
            
#         self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner) for k in range(n_layers)])
        self.a1 = DownsampleLayer(d_model, n_heads, d_head)
        self.a2 = DownsampleLayer(d_model*2, n_heads, d_head)
        self.a3 = DownsampleLayer(d_model*4, n_heads, d_head, downsample=False)
        self.a4 = UpsampleLayer(d_model*4, n_heads, d_head)
        self.a5 = UpsampleLayer(d_model*2, n_heads, d_head)
    
    def reset(self): pass
    
    def forward(self, x):
        bs, x_len = x.size()
        inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
#         mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None]
        #[None,:,:None] for einsum implementation of attention
#         for layer in self.layers: inp = layer(inp, mask=mask)
#         print('Inp:', inp.shape)
        x1 = self.a1(inp)
#         print('x1:', x1.shape)
        x2 = self.a2(x1)
#         print('x2:', x2.shape)
        x3 = self.a3(x2)
        x4 = self.a4(x3, x1)
        x5 = self.a5(x4, inp)
        
#         print(x1.shape, x2.shape, x3.shape, x4.shape)
        return x5 #For the LinearDecoder

In [17]:
class TransformerBase(nn.Module):
    "Transformer model: https://arxiv.org/abs/1706.03762."
    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, 
                 embed_p:float=0.0, **kwargs):
        super().__init__()
        self.encoder = nn.Embedding(vocab_sz, d_model)
        # See https://arxiv.org/abs/1711.09160
        with torch.no_grad(): trunc_normal_(self.encoder.weight, std=0.01)
            
        self.drop_emb = nn.Dropout(embed_p)
        self.layers = nn.ModuleList([TransformerBlock(d_model, d_head, n_heads) for k in range(n_layers)])
    
    def reset(self): pass
    
    def forward(self, x):
        bs, x_len = x.size()
        inp = self.drop_emb(self.encoder(x)) #.mul_(self.d_model ** 0.5)
#         mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None]
        for layer in self.layers: inp = layer(inp)
        return inp #For the LinearDecoder

In [18]:
vocab_sz = len(data.vocab.itos)
# config = tfmer_lm_config.copy(); config
config = {
    'ctx_len': bptt,
    'n_layers': 5,
    'n_heads': 4,
    'd_model': 128,
    'd_head': 32,
    'embed_p': 0.0 # Embed p needs to be 0 to match bert-mask baseline
}

In [19]:
# xb,yb = data.one_batch(cpu=False)
# model(xb)

## Baseline Fastai TXL

In [20]:
from fastai.text.learner import get_language_model
encoder = TransformerXL(vocab_sz, d_inner=config['d_model']*4, **config)
decoder = LinearDecoder(vocab_sz, config['d_model'], output_p=0.0, tie_encoder=encoder.encoder)
model = nn.Sequential(encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)

learn = LanguageLearner(data, model, loss_func=CrossEntropyFlat(ignore_index=pad_idx))
learn.metrics = [bert_acc]
# learn.to_fp16();


In [21]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.600806,6.338635,0.141968,00:31
1,6.386663,6.154504,0.143694,00:31


## Baseline Transformer Base

In [20]:
encoder = TransformerBase(vocab_sz, **config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=encoder.encoder)
model = nn.Sequential(encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)
learn = LanguageLearner(data, model, loss_func=CrossEntropyFlat(ignore_index=pad_idx))
learn.callbacks = []
learn.metrics = [bert_acc]
# learn.to_fp16();

In [21]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.611427,6.318287,0.143721,01:18
1,6.526922,6.277451,0.141834,01:18


In [21]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.577472,6.347648,0.142532,00:52
1,6.522913,6.308616,0.140825,00:52


In [26]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.491513,6.122727,0.163134,01:06
1,5.781649,5.540959,0.247909,01:06


In [None]:
# learn.lr_find(num_it=300)
# learn.recorder.plot()

## Sanity - Bert Encoder

## Conf transformer

In [62]:

class Embedder(nn.Module):
    "Embedding + positional encoding + dropout"
    def __init__(self, vocab_sz:int, emb_sz:int, embed_p:float=0., pad_idx=pad_idx):
        super().__init__()
        self.emb_sz = emb_sz
        
#         self.embed = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_idx)        
        self.embed = nn.Embedding(vocab_sz, emb_sz)
        # See https://arxiv.org/abs/1711.09160
        with torch.no_grad(): trunc_normal_(self.embed.weight, std=0.01)
        self.drop = nn.Dropout(embed_p)
    
    def forward(self, inp, pos_forward=False):
        emb = self.drop(self.embed(inp))
        return emb


In [63]:

class Encoder(nn.Module):
    "TransformerXL model: https://arxiv.org/abs/1901.02860."
    def __init__(self, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 resid_p:float=0., attn_p:float=0., ff_p:float=0., bias:bool=False, scale:bool=True,
                 act:Activation=Activation.GeLU, double_drop:bool=True, attn_cls:Callable=MultiHeadRelativeAttention,
                 learned_pos_enc:bool=False, mask:bool=True, **kwargs):
        super().__init__()
        self.u = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.v = nn.Parameter(torch.Tensor(n_heads, 1, d_head)) #Remove 1 for einsum implementation of attention
        self.n_layers,self.d_model,self.mask = n_layers,d_model,mask
#         self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner, resid_p=resid_p, attn_p=attn_p,
#                       ff_p=ff_p, bias=bias, scale=scale, act=act, double_drop=double_drop, 
#                       attn_cls=attn_cls) for k in range(n_layers)])
        self.layers = nn.ModuleList([TransformerBlock2(d_model, n_heads, d_head, bias=bias, 
                      ) for k in range(n_layers)])
        
        self.pos_enc = PositionalEncoding(d_model)
        
        nn.init.normal_(self.u, 0., 0.02)
        nn.init.normal_(self.v, 0., 0.02)
    
    def forward(self, x):
        #The hidden state has to be initiliazed in the forward pass for nn.DataParallel
        bs,x_len,emb_sz = x.size()
        
        inp = x
        
        pos = torch.arange(x_len-1, -1, -1, device=inp.device, dtype=inp.dtype)
        pos_enc = self.pos_enc(pos)
        for i, layer in enumerate(self.layers):
            inp = layer(inp, r=pos_enc, u=self.u, v=self.v)
        return inp

In [64]:
embed = Embedder(vocab_sz, config['d_model'])
encoder = Encoder(d_inner=config['d_model']*4, **config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=embed.embed)
model = nn.Sequential(embed, encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)
learn = LanguageLearner(data, model, loss_func=CrossEntropyFlat(ignore_index=pad_idx))
learn.callbacks = []
learn.metrics = [bert_acc]
# learn.to_fp16();

In [65]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.224,5.695016,0.220923,00:32
1,5.153956,4.931585,0.329466,00:32


In [60]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.079302,5.548858,0.251355,00:32
1,5.276506,5.08905,0.307184,00:32


In [52]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.6163,6.340132,0.142449,00:32
1,6.531385,6.31705,0.139587,00:32


In [42]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.584211,6.247587,0.139904,00:32
1,6.103616,5.864691,0.190566,00:32


In [38]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.179782,5.618222,0.236285,00:32
1,5.1272,4.915281,0.332678,00:32


In [31]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time
0,6.046856,5.586289,0.232654,00:33
1,5.476103,5.239614,0.268202,00:33


## Conv transformer

In [None]:
learn.fit_one_cycle(2, 1e-3)

epoch,train_loss,valid_loss,bert_acc,time


In [None]:
learn.fit_one_cycle(4, 1e-2)

In [24]:
encoder = TransformerConv(vocab_sz, **config)
decoder = Decoder(vocab_sz, config['d_model'], tie_encoder=encoder.encoder)
model = nn.Sequential(encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)
learn = LanguageLearner(data, model, loss_func=CrossEntropyFlat(ignore_index=pad_idx))
learn.metrics = [bert_acc]
# learn.to_fp16();

In [25]:
learn.fit_one_cycle(4, 1e-2)

epoch,train_loss,valid_loss,bert_acc,time
0,6.691076,6.477333,0.142961,00:30
1,6.534302,6.275502,0.179262,00:28
2,6.330779,6.108613,0.191137,00:28
3,6.290344,6.085637,0.19094,00:28
