In [None]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from dotted_dict import DottedDict

### Resources
- https://github.com/codertimo/BERT-pytorch/
- https://nlp.seas.harvard.edu/2018/04/03/attention.html
- https://jalammar.github.io/illustrated-transformer/
- https://neptune.ai/blog/how-to-code-bert-using-pytorch-tutorial
- https://arxiv.org/abs/1810.04805
- https://neptune.ai/blog/unmasking-bert-transformer-model-performance

### Config

In [None]:
config = DottedDict()
config.batch_size = 4     
config.pred_min = 2      # min number of masked tokens [MSK]
config.pred_max = 4      # max number of masked tokens
config.pred_freq = 0.15  # number of mask tokens = pred_freq * d_l
config.d_model = 8       # embed. dimension of tokens and positions
config.d_k = 5           
config.d_q = 5
config.d_v = 8
config.d_ff = 4 * config.d_model
config.n_heads = 3       # number of attention heads
config.d_l = 10          # number of tokens in sentence

### Preprocessing

In [None]:
corpus = (
    "baabaac",
    "aababc",
    "bcaaaa",
    "aac",
    "bbbaabbaa",
    "bbbbbbabc",
    "ababc",
    "babc",
    "bcaaca",
    "aabbaaac",
)
corpus_vocabuary = ('a', 'b', 'c')

In [None]:
# Special Toks:
#   [CLS]: Required to remove the 0 from positional meaning & use it to represent a sentence
#   [PAD]: Required to standardize sequence length for batch processing
#   [MSK]: Required to mask out the target prediction tokens in the input

spec_tok_dict = {'[PAD]': 0, '[MSK]': 1, '[CLS]': 2}
spec_idx_dict = {idx: word for word, idx in spec_tok_dict.items()}
idx_dict = {}
tok_dict = {}
tok_list = []

# word dict
for tok, idx in spec_tok_dict.items():
    tok_dict[tok] = idx
for idx, tok in enumerate(corpus_vocabuary):
    tok_dict[tok] = idx + len(spec_tok_dict)

for tok, idx in tok_dict.items():
    idx_dict[idx] = tok
    tok_list.append(tok)

d_vocabulary = len(tok_list)

In [None]:
len(tok_list)

In [None]:
print(spec_tok_dict)
print(spec_idx_dict)
print(idx_dict)
print(tok_dict)
print(tok_list)

In [None]:
sample_idx = 0
sentence = corpus[sample_idx]

# sentence to idx vectors
# -) tokenize
tokens = list(sentence)

# 2) replace with vocabulary idcs
tok_list = [tok_dict[tok] for tok in tokens]
tok_list = np.array(tok_list)

# 3) calculate the number of predctions
n_preds = int(round(len(tok_list) * config.pred_freq))
n_preds = min(max(config.pred_min, n_preds), config.pred_max)

# 4) create MASKS
mask_idcs = np.random.choice(len(tok_list), size=n_preds, replace=False)
mask_toks = tok_list[mask_idcs]
tok_list[mask_idcs] = tok_dict["[MSK]"]
np.pad(mask_toks, (0, config.pred_max), mode="constant")

# 5) PAD
n_pad = config.d_l - len(tok_list)
tok_list = np.pad(tok_list, (1, n_pad - 1), mode='constant')

# ADD CLS Token to start
tok_list[0] = tok_dict['[CLS]']

In [None]:
print("sentence:  ", corpus[sample_idx])
print("sentence:  ", [idx_dict[idx] for idx in tok_list])
print("mask idcs:  ", mask_idcs)
print("mask toks: ", [idx_dict[idx] for idx in mask_toks])

In [None]:
print(tok_list)
print([idx_dict[idx] for idx in tok_list])
print(mask_idcs)
print([idx_dict[idx] for idx in mask_toks])

In [None]:
def get_batch(sentences):
    all_toks_list = []
    all_mask_idcs = []
    all_mask_toks = []
    for sentence in sentences:
        tokens = list(sentence)

        # 2) replace with vocabulary idcs
        tok_list = [tok_dict[tok] for tok in tokens]
        tok_list = np.array(tok_list)

        # 3) calculate the number of predctions
        n_preds = int(round(len(tok_list) * config.pred_freq))
        n_preds = min(max(config.pred_min, n_preds), config.pred_max)

        # 4) create MASKS
        # UNDERCOMPLETE: The whole process is explained here:
        # https://neptune.ai/blog/unmasking-bert-transformer-model-performance
        mask_idcs = np.random.choice(len(tok_list), size=n_preds, replace=False)
        mask_toks = tok_list[mask_idcs]
        tok_list[mask_idcs] = tok_dict["[MSK]"]
        # add 1 to mask idxs since we added one position in the front
        mask_idcs += 1
        mask_idcs = np.pad(mask_idcs, (0, config.pred_max), mode="constant")

        # 5) PAD
        n_pad = config.d_l - len(tok_list)
        tok_list = np.pad(tok_list, (1, n_pad - 1), mode='constant')

        # ADD CLS Token to start
        tok_list[0] = tok_dict['[CLS]']
        
        all_toks_list.append(tok_list)
        all_mask_idcs.append(mask_idcs)
        all_mask_toks.append(mask_toks)
        
    return all_toks_list, all_mask_idcs, all_mask_toks

In [None]:
batch = get_batch(corpus[:config.batch_size])
all_toks_list, all_mask_idcs, all_mask_toks  = map(torch.LongTensor, batch)

In [None]:
all_toks_list

In [None]:
all_mask_idcs

In [None]:
all_mask_toks

In [None]:
for toks in all_toks_list:
    print([idx_dict[idx.item()] for idx in toks])

In [None]:
corpus[:config.batch_size]

In [None]:
for mt in all_mask_toks:
    print([idx_dict[idx.item()] for idx in mt])

## Embedding

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_vocabulary, d_model, d_l):
        super(Embedding, self).__init__()
        self.d_vocabulary = d_vocabulary
        self.d_model = d_model
        #
        self.tok_emb = nn.Embedding(d_vocabulary, d_model)  # token embedding
        self.pos_emb = nn.Embedding(d_l, d_model)  # position embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_emb(x) + self.pos_emb(pos)
        return self.norm(embedding)

In [None]:
model_emb = Embedding(d_vocabulary, config.d_model, config.d_l)

In [None]:
emb = model_emb(all_toks_list)
assert emb.shape == torch.Size((config.batch_size, config.d_l, config.d_model))

### Attention Mask

In [None]:
def get_attn_pad_mask_orig(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

def get_attn_pad_mask(x):
    mask = x.eq(0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
    return mask

In [None]:
attn_mask = get_attn_pad_mask_orig(all_toks_list, all_toks_list)

In [None]:
print(attn_mask.shape)
print(emb.shape)

### Encoder

#### Attention Module

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, attn_mask):
        # without heads:
        # Q (d_b, d_l, d_k)
        # K (d_b, d_l, d_k)
        # V (d_n, d_l, d_v)
        # attn_mask (d_b, d_l, d_l)
        #
        # scores = (d_b, d_l, d_l)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        
        # context = (d_b, d_l, d_v)
        context = torch.matmul(attn, V)
        return context, attn

In [None]:
print(attn_mask.shape)
attn_mask_headed = attn_mask.unsqueeze(1).repeat(1, config.n_heads, 1, 1)
print(attn_mask_headed.shape)

In [None]:
# without dimension for heads
#
Q = torch.rand((config.batch_size, config.d_l, config.d_k))
K = torch.rand((config.batch_size, config.d_l, config.d_k))
V = torch.rand((config.batch_size, config.d_l, config.d_v))
#
attn_mask = get_attn_pad_mask_orig(all_toks_list, all_toks_list)
#
model_sdpa = ScaledDotProductAttention(config.d_k)
context, attn = model_sdpa(Q, K, V, attn_mask)
#
assert context.shape == torch.Size((config.batch_size, config.d_l, config.d_v))
assert attn.shape == torch.Size((config.batch_size, config.d_l, config.d_l))

In [None]:
# with dimension for heads
#
Q = torch.rand((config.batch_size, config.n_heads, config.d_l, config.d_k))
K = torch.rand((config.batch_size, config.n_heads, config.d_l, config.d_k))
V = torch.rand((config.batch_size, config.n_heads, config.d_l, config.d_v))
#
attn_mask = get_attn_pad_mask_orig(all_toks_list, all_toks_list)
attn_mask_headed = attn_mask.unsqueeze(1).repeat(1, config.n_heads, 1, 1)
#
model_sdpa = ScaledDotProductAttention(config.d_k)
context, attn = model_sdpa(Q, K, V, attn_mask_headed)
#
assert context.shape == torch.Size((config.batch_size, config.n_heads, config.d_l, config.d_v))
assert attn.shape == torch.Size((config.batch_size, config.n_heads, config.d_l, config.d_l))

### Sublayer Residual connection + Norm

In [None]:
class NormedResidualSubLayerConnection(nn.Module):
    def __init__(self, d_model):
        super(NormedResidualSubLayerConnection, self).__init__()
        self.d_model = d_model
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, sublayer):
        return x + sublayer(self.norm(x))

In [None]:
model_res = NormedResidualSubLayerConnection(config.d_model)
sublayer = lambda x: x

In [None]:
print(x.shape)
out = model_res(x, sublayer)
print(out.shape)

#### Multihead Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        super(MultiHeadAttention, self).__init__()
        #
        self.d_k = d_k
        self.d_v = d_v
        self.d_model = d_model
        self.n_heads = n_heads
        #
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
        #
        self.model_sdpa = ScaledDotProductAttention(d_k)
        self.output_linear = nn.Linear(n_heads * d_v, d_v)
        
        # here d_v == d_model
        # rework this module to be independent of the dimensions 
        # or simplify the dimensions
        self.output_norm = nn.LayerNorm(d_v)
        
    def forward(self, x, attn_mask):
        # x         (b, d_l, d_model) = (b, s, m)
        # attn_mask (b, d_l, d_model)
        #
        d_b = x.size(0)
        #
        # (b, s, m) x (h, m, k) -> (b, h, m, k)
        q_s = self.W_Q(x).view(d_b, -1, self.n_heads, self.d_k).transpose(1,2)
        k_s = self.W_K(x).view(d_b, -1, self.n_heads, self.d_k).transpose(1,2)
        
        # (b, s, m) x (h, m, v) -> (b, h, m, k)
        v_s = self.W_V(x).view(d_b, -1, self.n_heads, self.d_v).transpose(1,2)
        
        # (b, l, l) -> (b, h, l, l)
        attn_mask_headed = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        
        context, attn = self.model_sdpa(q_s, k_s, v_s, attn_mask_headed)
        
        # (b, h, l, v) -> (b, l, h * v)
        context = context.transpose(1, 2).contiguous().view(d_b, -1, self.n_heads * self.d_v)
        
        # (b, l, h*v) - > (b, l, v)
        output = self.output_linear(context)
        
        # (b, l, v) -> (b, l, v) where v == d_model right now
        output = self.output_norm(x + output)
        return output, attn

In [None]:
model_mha = MultiHeadAttention(config.d_model, config.d_k, config.d_v, config.n_heads)

In [None]:
output, attn = model_mha(x, attn_mask)

In [None]:
print(output.shape)
print(attn.shape)

### Position Wise Feed Forward

In [None]:
class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """

    def forward(self, x):
        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
        #return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.gelu = GELU()

    def forward(self, x):
        # (b, l, m) -> (b, l, d_ff) -> (b, l, m)
        out = self.fc1(x)
        out = self.gelu(out)
        out = self.fc2(out)
        return out
