In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys, os, re, csv, codecs, numpy as np, pandas as pd
from fastai import *        # Quick accesss to most common functionality
# from fastai.text import *   # Quick accesss to NLP functionality
import html
from pathlib import Path

In [3]:
import torch.nn as nn

In [4]:
from fastai.text.models.transformer import tfmer_lm_config, GeLU, init_transformer
from fastai.text.models.awd_lstm import RNNDropout
from fastai.text.learner import LanguageLearner

In [5]:
PATH=Path.home()/'data/wikitext-2-raw'

In [6]:
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.layers import *

In [7]:
bptt = 256
data = load_data(PATH, bs=8, bptt=bptt)

In [8]:
# model = get_language_model(TransformerXL, len(data.vocab.itos))

In [9]:
class PositionalEncoding(nn.Module):
    "Encode the position with a sinusoid."
    def __init__(self, d:int):
        super().__init__()
        self.register_buffer('freq', 1 / (10000 ** (torch.arange(0., d, 2.)/d)))
    
    def forward(self, pos:Tensor):
        inp = torch.ger(pos, self.freq)
        enc = torch.cat([inp.sin(), inp.cos()], dim=-1)
        return enc

In [10]:
import pdb

In [11]:

class MultiHeadAttention(nn.Module):
    "MutiHeadAttention."
    
    def __init__(self, n_heads:int, d_model:int, d_head:int=None, resid_p:float=0.1, attn_p:float=0., bias:bool=True,
                 scale:bool=True, sigmoid=False):
        super().__init__()
        d_head = ifnone(d_head, d_model//n_heads)
        self.n_heads,self.d_head,self.scale = n_heads,d_head,scale
        self.attention = nn.Linear(d_model, 3 * n_heads * d_head, bias=bias)
        self.out = nn.Linear(n_heads * d_head, d_model, bias=bias)
        self.drop_att,self.drop_res = nn.Dropout(attn_p),nn.Dropout(resid_p)
        self.ln = nn.LayerNorm(d_model)
        self.sigmoid = sigmoid
        
    def forward(self, x:Tensor, mask:Tensor=None, **kwargs):
        return self.ln(x + self.drop_res(self.out(self._apply_attention(x, mask=mask, **kwargs))))
    
    def _apply_attention(self, x:Tensor, mask:Tensor=None):
        bs,x_len = x.size(0),x.size(1)
        wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        wq,wk,wv = wq.permute(0, 2, 1, 3),wk.permute(0, 2, 3, 1),wv.permute(0, 2, 1, 3)
        attn_score = torch.matmul(wq, wk)
        if self.scale: attn_score = attn_score.div_(self.d_head ** 0.5)
        if mask is not None: 
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
            
        if self.sigmoid: attn_prob = self.drop_att(torch.sigmoid(attn_score))
        else: attn_prob = self.drop_att(F.softmax(attn_score, dim=-1))
            
#         pdb.set_trace()
        attn_vec = torch.matmul(attn_prob, wv)
        return attn_vec.permute(0, 2, 1, 3).contiguous().contiguous().view(bs, x_len, -1)
        
    def _attention_einsum(self, x, mask=None):
        # Permute and matmul is a little bit faster but this implementation is more readable
        bs,x_len = x.size(0),x.size(1)
        wq,wk,wv = torch.chunk(self.attention(x), 3, dim=-1)
        wq,wk,wv = map(lambda x:x.view(bs, x.size(1), self.n_heads, self.d_head), (wq,wk,wv))
        attn_score = torch.einsum('bind,bjnd->bijn', (wq, wk))
        if self.scale: attn_score = attn_score.mul_(1/(self.d_head ** 0.5))
        if mask is not None: 
            attn_score = attn_score.float().masked_fill(mask, -float('inf')).type_as(attn_score)
        attn_prob = self.drop_att(F.softmax(attn_score, dim=2))
        attn_vec = torch.einsum('bijn,bjnd->bind', (attn_prob, wv))
        return attn_vec.contiguous().view(bs, x_len, -1)


In [12]:

class LinearDecoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        raw_outputs, outputs = input
        decoded = self.decoder(outputs[-1])
        return decoded, raw_outputs, outputs


In [13]:

# def feed_forward(d_model:int, d_ff:int, ff_p:float=0., act=nn.ReLU, double_drop:bool=True):
def feed_forward(d_model:int, d_inner:int, ff_p:float=0.1):
    layers = [
        nn.Linear(d_model, d_inner), 
        GeLU(),
        nn.Linear(d_inner, d_model), 
        nn.Dropout(ff_p), 
        MergeLayer(),
        nn.LayerNorm(d_model)
    ]
    return SequentialEx(*layers)


In [14]:
class DecoderLayer(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, bias:bool=True):
        super().__init__()
        self.mhra = MultiHeadAttention(n_heads, d_model, d_head, bias=bias)
        self.ff   = feed_forward(d_model, d_inner)
    
    def forward(self, x:Tensor, mask:Tensor=None, **kwargs): 
        attn = self.mhra(x, mask=mask, **kwargs)
        res = self.ff(attn)
        return res

### Components

In [15]:
bs,bptt,d_model = 4, 64, 128
# d1 = DecoderLayer(n_heads=4, d_model=d_model, d_head=32, d_inner=512)
mhra = MultiHeadAttention(n_heads=4, d_model=d_model, d_head=32)
ff   = feed_forward(d_model=d_model, d_inner=512)

mask = torch.triu(torch.ones(bptt, bptt), diagonal=1).byte()[None,None]
xb = torch.ones(bs,bptt,d_model)

In [16]:
mask

tensor([[[[0, 1, 1,  ..., 1, 1, 1],
          [0, 0, 1,  ..., 1, 1, 1],
          [0, 0, 0,  ..., 1, 1, 1],
          ...,
          [0, 0, 0,  ..., 0, 1, 1],
          [0, 0, 0,  ..., 0, 0, 1],
          [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8)

In [17]:
x_attn = mhra(xb, mask=mask)
x_ff = ff(x_attn)
x_attn.shape, x_ff.shape

(torch.Size([4, 64, 128]), torch.Size([4, 64, 128]))

In [18]:
# x_attn

## Transformer

In [19]:

class Transformer(nn.Module):
    "Transformer model: https://arxiv.org/abs/1706.03762."
    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 embed_p:float=0.1, learned_pos_enc:bool=False, mask=True, **kwargs):
        super().__init__()
        self.mask = mask
        self.encoder = nn.Embedding(vocab_sz, d_model)
#         self.pos_enc = nn.Embedding(ctx_len, d_model) if learned_pos_enc else PositionalEncoding(d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.drop_emb = nn.Dropout(embed_p)
        self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner) for k in range(n_layers)])
        
#         self.layers[0].mhra.scale = False
#         self.layers[0].mhra.sigmoid = True
#         self.layers[1].mhra.scale = False
#         self.layers[1].mhra.sigmoid = True
        
        self.layers[2].mhra.scale = False
        self.layers[2].mhra.sigmoid = True
        self.layers[3].mhra.scale = False
        self.layers[3].mhra.sigmoid = True
    
    def reset(self): pass
    
    def forward(self, x):
        bs, x_len = x.size()
        pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype).float()
        inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) #.mul_(self.d_model ** 0.5)
        mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None] if self.mask else None
        #[None,:,:None] for einsum implementation of attention
        for layer in self.layers: inp = layer(inp, mask=mask)
        return ([inp],[inp]) #For the LinearDecoder

In [20]:
vocab_sz = len(data.vocab.itos)
# config = tfmer_lm_config.copy(); config
config = {
    'ctx_len': bptt,
    'n_layers': 4,
    'n_heads': 4,
    'd_model': 128,
    'd_head': 32,
    'd_inner': 512,
}

In [21]:
encoder = Transformer(vocab_sz, **config)
decoder = LinearDecoder(vocab_sz, config['d_model'], tie_encoder=encoder.encoder, bias=False)
model = nn.Sequential(encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)

Sequential(
  (0): Transformer(
    (encoder): Embedding(39880, 128)
    (pos_enc): PositionalEncoding()
    (drop_emb): Dropout(p=0.1)
    (layers): ModuleList(
      (0): DecoderLayer(
        (mhra): MultiHeadAttention(
          (attention): Linear(in_features=128, out_features=384, bias=True)
          (out): Linear(in_features=128, out_features=128, bias=True)
          (drop_att): Dropout(p=0.0)
          (drop_res): Dropout(p=0.1)
          (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        )
        (ff): SequentialEx(
          (layers): ModuleList(
            (0): Linear(in_features=128, out_features=512, bias=True)
            (1): GeLU()
            (2): Linear(in_features=512, out_features=128, bias=True)
            (3): Dropout(p=0.1)
            (4): MergeLayer()
            (5): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (1): DecoderLayer(
        (mhra): MultiHeadAttention(
          (attention): Linear

In [22]:
learn = LanguageLearner(data, model)

In [23]:
# Half softmax - first half sigmoid
learn.fit_one_cycle(4, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.55087,6.318259,0.14218,00:44
1,5.978098,5.724125,0.180522,00:44
2,5.771122,5.530916,0.195498,00:44
3,5.699901,5.484363,0.204617,00:44


In [23]:
# Half softmax - first half sigmoid
learn.fit_one_cycle(4, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.179493,5.877147,0.158758,00:44
1,5.764875,5.478712,0.216442,00:44
2,5.542889,5.310651,0.234694,00:44
3,5.456105,5.266305,0.240963,00:44


In [23]:
# With softmax attn
learn.fit_one_cycle(4, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.367531,6.085203,0.149335,00:44
1,5.758986,5.503888,0.213588,00:44
2,5.531856,5.301187,0.235883,00:44
3,5.458888,5.25178,0.241699,00:44


In [22]:
# With sigmoid attn - no scale
learn.fit_one_cycle(4, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.163918,5.881194,0.163716,00:44
1,5.722321,5.491021,0.208866,00:44
2,5.591651,5.395696,0.20763,00:44
3,5.547757,5.366143,0.214875,00:44
