In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys, os, re, csv, codecs, numpy as np, pandas as pd
from fastai import *        # Quick accesss to most common functionality
# from fastai.text import *   # Quick accesss to NLP functionality
import html
from pathlib import Path

In [3]:
import torch.nn as nn

In [4]:
from fastai.text.models.transformer import tfmer_lm_config, GeLU, init_transformer, MultiHeadAttention, PositionalEncoding
from fastai.text.models.awd_lstm import RNNDropout
from fastai.text.learner import LanguageLearner

In [5]:
PATH=Path.home()/'data/wikitext-2-raw'

In [6]:
from fastai.basic_data import *
from fastai.torch_core import *
from fastai.layers import *

In [7]:
bptt = 256
data = load_data(PATH, bs=8, bptt=bptt)

In [8]:
MASK = 'xxmask'
vocab = data.vocab
vocab.itos.append(MASK)

In [9]:
word_range = (0, len(data.vocab.itos))
from fastai.text.transform import *
pad_idx = data.vocab.stoi[PAD]
mask_idx = data.vocab.stoi[MASK]
def bert_tfm(b, word_range=word_range, pad_idx=pad_idx, 
             mask_idx=mask_idx, p=0.2):
    # p = replacement probability
    x_lm,y_lm = b
    x_msk,y_msk = x_lm.clone(),x_lm.clone() # x, x
#     x,y = x.clone(),y.clone()
    rand = torch.rand(x_msk.shape, device=x_lm.device)
    y_msk[rand > p] = pad_idx
    x_msk[rand <= (p*.8)] = mask_idx # 80% = mask
    wrong_word = (rand > (p*.8)) & (rand <= (p*.9)) # 10% = wrong word
    x_msk[wrong_word] = torch.randint(*word_range, [wrong_word.sum().item()], device=x_lm.device)
    return x_msk, y_msk

In [10]:
data.train_dl.tfms = [bert_tfm]
data.valid_dl.tfms = [bert_tfm]

In [11]:

def acc_ignore_pad(input:Tensor, targ:Tensor, pad_idx=pad_idx)->Rank0Tensor:
    n = targ.shape[0]
    input = input.argmax(dim=-1).view(n,-1)
    targ = targ.view(n,-1)
    mask = targ != pad_idx
    return (input[mask]==targ[mask]).float().mean()

def bert_acc(input:Tensor, b_t:Tensor, lm_t:Tensor)->Rank0Tensor:
    return acc_ignore_pad(input, b_t)

In [12]:

class LinearDecoder(nn.Module):
    "To go on top of a RNNCore module and create a Language Model."
    initrange=0.1

    def __init__(self, n_out:int, n_hid:int, tie_encoder:nn.Module=None, bias:bool=True):
        super().__init__()
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
        if bias: self.decoder.bias.data.zero_()
        if tie_encoder: self.decoder.weight = tie_encoder.weight

    def forward(self, input:Tuple[Tensor,Tensor])->Tuple[Tensor,Tensor,Tensor]:
        raw_outputs, outputs = input
        decoded = self.decoder(outputs[-1])
        return decoded, raw_outputs, outputs


In [13]:
class DecoderLayer(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, n_heads:int, d_model:int, d_head:int, d_inner:int, bias:bool=True):
        super().__init__()
        self.mhra = MultiHeadAttention(n_heads, d_model, d_head, bias=bias)
        self.ff   = feed_forward(d_model, d_inner)
    
    def forward(self, x:Tensor, mask:Tensor=None, **kwargs): 
        attn = self.mhra(x, mask=mask, **kwargs)
        res = self.ff(attn)
        return res

In [14]:

# def feed_forward(d_model:int, d_ff:int, ff_p:float=0., act=nn.ReLU, double_drop:bool=True):
def feed_forward(d_model:int, d_inner:int, ff_p:float=0.1):
    layers = [
        nn.Linear(d_model, d_inner), 
        GeLU(),
        nn.Linear(d_inner, d_model), 
        nn.Dropout(ff_p), 
        MergeLayer(),
        nn.LayerNorm(d_model)
    ]
    return SequentialEx(*layers)


In [15]:
class DownsampleLayer(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_model:int, n_heads:int, d_head:int, ff_p:float=0.1, downsample=True):
        super().__init__()
        
        self.mhra = MultiHeadAttention(n_heads, d_model, d_head, bias=True)
        
        d_inner = d_model*4
        if downsample:
            d_out = d_model * 2
            self.downblock = nn.Conv1d(d_model, d_out, (2), stride=2)
        else:
            self.downblock = None
            d_out = d_model
            
        self.ln1 = nn.Linear(d_out, d_inner)
        self.act = GeLU()
        self.ln2 = nn.Linear(d_inner, d_out)
        self.drop = nn.Dropout(ff_p)
        self.norm = nn.LayerNorm(d_out)

    
    def forward(self, x:Tensor):
        x_attn = self.mhra(x)
        x = x_attn
        
        if self.downblock:
            x_p = x.permute(0, 2, 1)
            x_d = self.downblock(x_p) # bptt x emb x bptt
            x = x_d.permute(0, 2, 1)
        
        x1 = self.ln2(self.act(self.ln1(x)))
        
        x2 = x + x1
        
        return self.norm(self.drop(x2))

In [16]:
class UpsampleLayer(nn.Module):
    "Basic block of a Transformer model."
    #Can't use Sequential directly cause more than one input...
    def __init__(self, d_model:int, n_heads:int, d_head:int, ff_p:float=0.1, upsample=True):
        super().__init__()
        
        self.mhra = MultiHeadAttention(n_heads, d_model, d_head, bias=True)
        
        d_inner = d_model*4
        if upsample:
            d_out = d_model // 2
            self.upblock = nn.ConvTranspose1d(d_model, d_out, (2), stride=2)
        else:
            d_out = d_model
            self.upblock = None
            
        self.ln1 = nn.Linear(d_out, d_inner)
        self.act = GeLU()
        self.ln2 = nn.Linear(d_inner, d_out)
        self.drop = nn.Dropout(ff_p)
        self.norm = nn.LayerNorm(d_out)

    
    def forward(self, x:Tensor, x_skip:Tensor):
        x_attn = self.mhra(x)
        x = x_attn
        
        if self.upblock:
            x_p = x.permute(0, 2, 1)
            x_u = self.upblock(x_p) # bptt x emb x bptt
            x = x_u.permute(0, 2, 1)
            
        x1 = self.ln2(self.act(self.ln1(x)))
        x2 = self.norm(self.drop(x1))
        return x2 + x_skip

## Testing decoder layer

In [17]:
bs,bptt,d_model = 4, 64, 128
# d1 = DecoderLayer(n_heads=4, d_model=d_model, d_head=32, d_inner=512)
d1 = DownsampleLayer(n_heads=4, d_model=d_model, d_head=32, downsample=True)
mask = torch.triu(torch.ones(bptt, bptt), diagonal=1).byte()[None,None]
xb = torch.ones(bs,bptt,d_model)

In [18]:
xb.shape, d1(xb).shape

(torch.Size([4, 64, 128]), torch.Size([4, 32, 256]))

In [19]:
bs,bptt,d_model = 4, 64, 128
# d1 = DecoderLayer(n_heads=4, d_model=d_model, d_head=32, d_inner=512)
u1 = UpsampleLayer(n_heads=4, d_model=d_model, d_head=32, upsample=True)
xb = torch.ones(bs,bptt,d_model)
x_skip = torch.ones(bs,bptt*2,d_model//2)

In [20]:
xb.shape, x_skip.shape

(torch.Size([4, 64, 128]), torch.Size([4, 128, 64]))

In [21]:
xb.permute(0,2,1).shape, u1.upblock(xb.permute(0,2,1)).shape

(torch.Size([4, 128, 64]), torch.Size([4, 64, 128]))

In [22]:
u1.upblock(xb.permute(0,2,1)).permute(0,2,1).shape

torch.Size([4, 128, 64])

In [23]:
xb.shape, u1(xb, x_skip).shape

(torch.Size([4, 64, 128]), torch.Size([4, 128, 64]))

Reshape and half

In [24]:
a = torch.arange(40).view(2, 4, 5); a

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])

In [25]:
a.view(2, 2, 10)

tensor([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]]])

In [26]:
d = nn.Linear(d_model*2, d_model)

Conv1d example -

In [27]:
m = nn.Conv1d(128, 128, kernel_size=(3), stride=2, padding=3//2)
input = torch.randn(4, 128, 64)
output = m(input)
input.shape, output.shape

(torch.Size([4, 128, 64]), torch.Size([4, 128, 32]))

In [28]:
m = nn.Conv1d(128, 128, kernel_size=2, stride=2, padding=0)
input = torch.randn(4, 128, 64)
output = m(input)
input.shape, output.shape

(torch.Size([4, 128, 64]), torch.Size([4, 128, 32]))

## Conv1d Transpose Up

In [29]:
m = nn.ConvTranspose1d(5, 5, (9), stride=1, padding=0)
input = torch.randn(1, 5, 8)
output = m(input)
input.shape, output.shape

(torch.Size([1, 5, 8]), torch.Size([1, 5, 16]))

In [30]:
m = nn.ConvTranspose1d(128, 128, (2), stride=2)
input = torch.randn(4, 128, 32)
output = m(input)

In [31]:
input.shape, output.shape

(torch.Size([4, 128, 32]), torch.Size([4, 128, 64]))

### Components

In [32]:
bs,bptt,d_model = 4, 64, 128
# d1 = DecoderLayer(n_heads=4, d_model=d_model, d_head=32, d_inner=512)
mhra = MultiHeadAttention(n_heads=4, d_model=d_model, d_head=32)
ff   = feed_forward(d_model=d_model, d_inner=512)

mask = torch.triu(torch.ones(bptt, bptt), diagonal=1).byte()[None,None]
xb = torch.ones(bs,bptt,d_model)

In [33]:
mask

tensor([[[[0, 1, 1,  ..., 1, 1, 1],
          [0, 0, 1,  ..., 1, 1, 1],
          [0, 0, 0,  ..., 1, 1, 1],
          ...,
          [0, 0, 0,  ..., 0, 1, 1],
          [0, 0, 0,  ..., 0, 0, 1],
          [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8)

In [34]:
x_attn = mhra(xb, mask=mask)
x_ff = ff(x_attn)
x_attn.shape, x_ff.shape

(torch.Size([4, 64, 128]), torch.Size([4, 64, 128]))

## Transformer

In [35]:
class Transformer(nn.Module):
    "Transformer model: https://arxiv.org/abs/1706.03762."
    def __init__(self, vocab_sz:int, ctx_len:int, n_layers:int, n_heads:int, d_model:int, d_head:int, d_inner:int, 
                 embed_p:float=0.1, **kwargs):
        super().__init__()
        self.encoder = nn.Embedding(vocab_sz, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.drop_emb = nn.Dropout(embed_p)
#         self.layers = nn.ModuleList([DecoderLayer(n_heads, d_model, d_head, d_inner) for k in range(n_layers)])
        self.a1 = DownsampleLayer(d_model, n_heads, d_head)
        self.a2 = DownsampleLayer(d_model*2, n_heads, d_head)
        self.a3 = DownsampleLayer(d_model*4, n_heads, d_head, downsample=False)
        self.a4 = UpsampleLayer(d_model*4, n_heads, d_head)
        self.a5 = UpsampleLayer(d_model*2, n_heads, d_head)
    
    def reset(self): pass
    
    def forward(self, x):
        bs, x_len = x.size()
        pos = torch.arange(0, x_len, device=x.device, dtype=x.dtype).float()
        inp = self.drop_emb(self.encoder(x) + self.pos_enc(pos)[None]) #.mul_(self.d_model ** 0.5)
#         mask = torch.triu(x.new_ones(x_len, x_len), diagonal=1).byte()[None,None]
        #[None,:,:None] for einsum implementation of attention
#         for layer in self.layers: inp = layer(inp, mask=mask)
#         print('Inp:', inp.shape)
        x1 = self.a1(inp)
#         print('x1:', x1.shape)
        x2 = self.a2(x1)
#         print('x2:', x2.shape)
        x3 = self.a3(x2)
        x4 = self.a4(x3, x1)
        x5 = self.a5(x4, inp)
        
#         print(x1.shape, x2.shape, x3.shape, x4.shape)
        return ([x5],[x5]) #For the LinearDecoder

In [36]:
vocab_sz = len(data.vocab.itos)
# config = tfmer_lm_config.copy(); config
config = {
    'ctx_len': bptt,
    'n_layers': 4,
    'n_heads': 4,
    'd_model': 128,
    'd_head': 32,
    'd_inner': 512,
}

In [37]:
d_model = 16
d_head = 4
n_heads = 4
emb = nn.Embedding(vocab_sz, d_model) 
a1 = DownsampleLayer(d_model, n_heads, d_head) 
a2 = DownsampleLayer(d_model*2, n_heads, d_head)
a3 = DownsampleLayer(d_model*4, n_heads, d_head, downsample=False)
a4 = UpsampleLayer(d_model*4, n_heads, d_head)
a5 = UpsampleLayer(d_model*2, n_heads, d_head)

In [38]:
xb,yb = data.one_batch(cpu=True)

In [39]:
inp = emb(xb); inp.shape

torch.Size([8, 256, 16])

In [40]:
x1 = a1(inp); x1.shape

torch.Size([8, 128, 32])

In [41]:
x2 = a2(x1); x2.shape

torch.Size([8, 64, 64])

In [42]:
x3 = a3(x2); x3.shape

torch.Size([8, 64, 64])

In [43]:
x4 = a4(x3, x1); x4.shape

torch.Size([8, 128, 32])

In [44]:
encoder = Transformer(vocab_sz, **config)
decoder = LinearDecoder(vocab_sz, config['d_model'], tie_encoder=encoder.encoder, bias=False)
model = nn.Sequential(encoder, decoder)
model.reset = lambda: True
model.apply(init_transformer)

Sequential(
  (0): Transformer(
    (encoder): Embedding(39881, 128)
    (pos_enc): PositionalEncoding()
    (drop_emb): Dropout(p=0.1)
    (a1): DownsampleLayer(
      (mhra): MultiHeadAttention(
        (attention): Linear(in_features=128, out_features=384, bias=True)
        (out): Linear(in_features=128, out_features=128, bias=True)
        (drop_att): Dropout(p=0.0)
        (drop_res): Dropout(p=0.0)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (downblock): Conv1d(128, 256, kernel_size=(2,), stride=(2,))
      (ln1): Linear(in_features=256, out_features=512, bias=True)
      (act): GeLU()
      (ln2): Linear(in_features=512, out_features=256, bias=True)
      (drop): Dropout(p=0.1)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (a2): DownsampleLayer(
      (mhra): MultiHeadAttention(
        (attention): Linear(in_features=256, out_features=384, bias=True)
        (out): Linear(in_features=128, out_features=256, b

In [45]:
model(xb)

(tensor([[[-0.1492, -0.1778,  0.1127,  ...,  0.4905,  0.2945,  0.2801],
          [ 0.2332,  0.1924, -0.0607,  ...,  0.5226, -0.0131,  0.1304],
          [-0.0346, -0.5041,  0.3645,  ..., -0.0707, -0.1599,  0.1048],
          ...,
          [-0.2043,  0.1740, -0.1442,  ..., -0.1785,  0.2258, -0.3501],
          [-0.2488, -0.2424,  0.1111,  ..., -0.4120, -0.2610, -0.3826],
          [-0.0871,  0.2557, -0.2549,  ...,  0.1698,  0.2015, -0.5001]],
 
         [[ 0.4238, -0.3886,  0.3914,  ...,  0.4161, -0.0464,  0.2260],
          [ 0.3174,  0.1404, -0.1269,  ...,  0.4123,  0.2232,  0.5896],
          [-0.0503, -0.6630,  0.4550,  ..., -0.1371, -0.0544,  0.3313],
          ...,
          [-0.4075,  0.1626, -0.3683,  ..., -0.4111,  0.4170, -0.4185],
          [ 0.4170,  0.3523,  0.2773,  ..., -0.4189,  0.0222, -0.1661],
          [-0.1041,  0.0805, -0.0236,  ...,  0.1391,  0.4103, -0.4348]],
 
         [[ 0.4861, -0.1117,  0.4745,  ...,  0.3369,  0.0152,  0.3588],
          [ 0.1431, -0.0468,

In [46]:
learn = LanguageLearner(data, model, loss_func=CrossEntropyFlat(ignore_index=pad_idx))

In [None]:
learn.metrics = [bert_acc]

In [47]:
learn.fit_one_cycle(1, 1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,6.366689,6.109301,0.031494,01:02
