In [None]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

### Preprocessing

In [None]:
text = (
        'Hello, how are you? I am Romeo.\n'
        'Hello, Romeo My name is Juliet. Nice to meet you.\n'
        'Nice meet you too. How are you today?\n'
        'Great. My baseball team won the competition.\n'
        'Oh Congratulations, Juliet\n'
        'Thanks you Romeo'
    )

corpus = (
    "baabaac",
    "aababc",
    "bcaaaa",
    "aac",
    "bbbaabbaa",
    "bbbbbbabc",
    "ababc",
    "babc",
    "bcaaca",
    "aabbaaac",
)
corpus_vocabuary = ('a', 'b', 'c')

In [None]:
# Special Toks:
#   [CLS]: Required to remove the 0 from positional meaning & use it to represent a sentence
#   [PAD]: Required to standardize sequence length for batch processing
#   [MSK]: Required to mask out the target prediction tokens in the input

spec_tok_dict = {'[PAD]': 0, '[MSK]': 1, '[CLS]': 2}
spec_idx_dict = {idx: word for word, idx in spec_tok_dict.items()}
idx_dict = {}
tok_dict = {}
tok_list = []

# word dict
for tok, idx in spec_tok_dict.items():
    tok_dict[tok] = idx
for idx, tok in enumerate(corpus_vocabuary):
    tok_dict[tok] = idx + len(spec_tok_dict)

for tok, idx in tok_dict.items():
    idx_dict[idx] = tok
    tok_list.append(tok)

In [None]:
print(spec_tok_dict)
print(spec_idx_dict)
print(idx_dict)
print(tok_dict)
print(tok_list)

In [None]:
maxlen = 10    # max sentence length
batch_size = 4
#
pred_min = 2    # min number of predicted tokens
pred_max = 4    # max number of predicted tokens, required for standardized shapes in batch
pred_freq = 0.15 # token prediction freq
#
d_model = 32

In [None]:
sample_idx = 0
sentence = corpus[sample_idx]

# sentence to idx vectors
# -) tokenize
tokens = list(sentence)

# 2) replace with vocabulary idcs
tok_list = [tok_dict[tok] for tok in tokens]
tok_list = np.array(tok_list)

# 3) calculate the number of predctions
n_preds = int(round(len(tok_list) * pred_freq))
n_preds = min(max(pred_min, n_preds), pred_max)

# 4) create MASKS
mask_idcs = np.random.choice(len(tok_list), size=n_preds, replace=False)
mask_toks = tok_list[mask_idcs]
tok_list[mask_idcs] = tok_dict["[MSK]"]
np.pad(mask_toks, (0, pred_max), mode="constant")

# 5) PAD
n_pad = maxlen - len(tok_list)
tok_list = np.pad(tok_list, (1, n_pad - 1), mode='constant')

# ADD CLS Token to start
tok_list[0] = tok_dict['[CLS]']

In [None]:
print("sentence:  ", corpus[sample_idx])
print("sentence:  ", [idx_dict[idx] for idx in tok_list])
print("mask idcs:  ", mask_idcs)
print("mask toks: ", [idx_dict[idx] for idx in mask_toks])

In [None]:
print(tok_list)
print([idx_dict[idx] for idx in tok_list])
print(mask_idcs)
print([idx_dict[idx] for idx in mask_toks])

In [None]:
def get_batch(sentences):
    all_toks_list = []
    all_mask_idcs = []
    all_mask_toks = []
    for sentence in sentences:
        tokens = list(sentence)

        # 2) replace with vocabulary idcs
        tok_list = [tok_dict[tok] for tok in tokens]
        tok_list = np.array(tok_list)

        # 3) calculate the number of predctions
        n_preds = int(round(len(tok_list) * pred_freq))
        n_preds = min(max(pred_min, n_preds), pred_max)

        # 4) create MASKS
        mask_idcs = np.random.choice(len(tok_list), size=n_preds, replace=False)
        mask_toks = tok_list[mask_idcs]
        tok_list[mask_idcs] = tok_dict["[MSK]"]
        # add 1 to mask idxs since we added one position in the front
        mask_idcs += 1
        mask_idcs = np.pad(mask_idcs, (0, pred_max), mode="constant")

        # 5) PAD
        n_pad = maxlen - len(tok_list)
        tok_list = np.pad(tok_list, (1, n_pad - 1), mode='constant')

        # ADD CLS Token to start
        tok_list[0] = tok_dict['[CLS]']
        
        all_toks_list.append(tok_list)
        all_mask_idcs.append(mask_idcs)
        all_mask_toks.append(mask_toks)
        
    return all_toks_list, all_mask_idcs, all_mask_toks

In [None]:
batch = get_batch(corpus[:batch_size])
all_toks_list, all_mask_idcs, all_mask_toks  = map(torch.LongTensor, batch)

In [None]:
all_toks_list

In [None]:
all_mask_idcs

In [None]:
all_mask_toks

In [None]:
for toks in all_toks_list:
    print([idx_dict[idx.item()] for idx in toks])

In [None]:
corpus[:batch_size]

In [None]:
for mt in all_mask_toks:
    print([idx_dict[idx.item()] for idx in mt])

## Embedding

In [None]:
tok_list
mask_idcs
mask_toks

In [None]:
def make_batch():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]

        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]

        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        #MASK LM
        n_pred =  min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence

        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word_dict['[MASK]'] # make mask
            elif random() < 0.5:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word_dict[number_dict[index]] # replace

        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

    #     # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch

        

In [None]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [None]:
 batch = make_batch()

In [None]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

In [None]:
len(input_idcs)

In [None]:
input_ids[0]

In [None]:
[number_dict[idx.item()] for idx in input_ids[0] if idx.item() != 0]

In [None]:
[number_dict[idx.item()] for idx in masked_tokens[0]]

In [None]:
masked_pos[0]

In [None]:
get_attn_pad_mask(input_ids, input_ids)[0][0], input_ids[0]

In [None]:
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

In [None]:
batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

emb = Embedding()
embeds = emb(input_ids, segment_ids)

attenM = get_attn_pad_mask(input_ids, input_ids)

SDPA= ScaledDotProductAttention()(embeds, embeds, embeds, attenM)

S, C, A = SDPA

In [None]:
print('Scores: ', S[0][0],'\n\nAttention M: ', A[0][0])

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]


In [None]:
emb = Embedding()
embeds = emb(input_ids, segment_ids)

attenM = get_attn_pad_mask(input_ids, input_ids)

MHA = MultiHeadAttention()(embeds, embeds, embeds, attenM)
Output, A = MHA

In [None]:
print(Output.shape)
print(A.shape)