In [1]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
def get_word_dict(text):
    sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
    word_list = list(set(" ".join(sentences).split()))
    word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}
    
    for i,w in enumerate(word_list):
        word_dict[w] = i+4
    
    number_dict=zip(word_dict.values(),word_dict.keys())
    
    vocab_size = len(word_dict)
    
    token_list=[[word_dict[word] for word in sentence.split()] for sentence in sentences]
    
    return sentences,word_dict,number_dict,token_list,vocab_size
        
    

In [8]:
def make_batch(sentences,word_dict,number_dict,token_list,batch_size,max_pred,vocab_size,max_len):
    
    """
    max_pred:max tokens of prediction
    """
    batch = []
    
    positive = negative = 0
    
    while positive != batch_size/2 or negative != batch_size/2:
    
        tokens_a_index,token_b_index = randrange(len(sentences)),randrange(len(sentences))
        tokens_a,tokens_b = token_list[tokens_a_index],token_list[token_b_index]
        input_ids = [word_dict["[CLS]"]]+tokens_a + [word_dict["[SEP]"]] + tokens_b + [word_dict["[SEP]"]]
        segment_ids = [0]*(1+len(tokens_a)+1)+[1]*(len(tokens_b)+1)

        #MASK LM

        n_pred = min(max_pred,max(1,int(round(len(input_ids)*15)))) # 15% of tokens in one sentence to mask

        can_masked_pos = [i for i,token in enumerate(input_ids) if token != word_dict['[CLS]' and token != word_dict['[SEP]']]]

        shuffle(can_masked_pos)

        masked_tokens,masked_pos = [],[]

        for pos in can_masked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random()<0.8:
                input_ids[pos] = word_dict["[MASK]"]
            elif 0.8<random()<0.9:
                index = randint(0,vocab_size-1) # random index in vocabulary
                input_ids[pos] = word_dict[num_dict[index]] # replace


        n_pad = maxlen-len(input_ids)
        input_ids.extend([0]*n_pad)
        segment_ids.extend([0]*n_pad)

        if max_pred>n_pred:
            n_pad = max_pred-n_pred
            masked_tokens.extend([0]*n_pad)
            masked_pos.extend([0]*n_pad)
        if tokens_a_index +1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segement_ids,masked_tokens,masked_pos,True]) #isNext
            positive += 1
        elif tokens_a_index+1!= tokens_b_index and negative<batch_size/2:
            batch.append([input_ids,segment_ids,masked_tokens,masked_pos,False]) # NotNext
            negative +=1
    return batch
    
    

In [9]:
def get_attn_pad_mask(seq_q,seq_k):
    
    batch_size ,len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    
    return pad_attn_mask.expand(batch_size,len_q,len_k)

In [10]:
def gelu(x):
    "Implementation of the gelu activation function by Hugging Face"
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


In [11]:
class Embedding(nn.Module):
    
    def __init__(self,vocab_size,model_d,max_len,n_segments):
        super().__init__()
        
        self.tok_embed = nn.Embedding(vocab_size,model_d)
        self.pos_embed = nn.Embedding(max_len,model_d)
        self.seg_embed = nn.Embedding(n_segments,model_d)
        self.norm = nn.LayerNorm(model_d)
    
    def forward(self,x,seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len,dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x)+self.pos_embed(pos)+self.seg_embed(seg)
        return self.norm(embedding)

In [18]:
class ScaleDotProductAttention(nn.Module):
    
    def __init__(self):
        super(ScaleDotProductAttention,self).__init__()
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self,Q,K,V,attn_mask):
        d_k = K.size(-1)
        scores = torch.matmul(Q,K.transpose(-1,-2))/np.sqrt(d_k)
        scores.masked_fill_(attn_mask,-1e9)
        attn = self.sofmax(scores)
        context = torch.matmul(attn,v)
        return context,attn

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self,model_d,d_k,n_heads):
        super().__init__()
        self.W_Q=nn.Linear(model_d,d_k*n_heads)
        self.W_K=nn.Linear(model_d,d_k*n_heads)
        self.W_V=nn.Linear(model_d,d_k*n_heads)
        self.linear = nn.Linear(n_heads*d_k,model_d)
        self.norm = nn.LayerNorm(model_d)
        
        
        
    def forward(self,Q,K,V,attn):
        residual, batch_size = Q,Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
        
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)
        output = self.linear(context)
        return self.norm(output+residual),attn

In [20]:
class PoswiseFeedForwardNet(nn.Module):
    
    def __init__(self,model_d,d_ff):
        super().__init__()
        self.fc1 = nn.Linear(model_d,d_ff)
        self.fc2 = nn.Linear(d_ff,model_d)
    
    
    def forward(self,x):
        return self.fc2(gelu(self.fc1(x)))

In [21]:
class EncoderLayer(nn.Module):
    def __init__(self,model_d,d_k,n_heads,d_ff):
        super().__init__()
        
        self.enc_self_attn = MultiHeadAttention(model_d,d_k,n_heads)
        self.pos_ffn = PowriseFeedForward(model_d,d_ff)
    
    def forward(self,enc_inputs,enc_self_attn_mask):
        enc_outputs,attn = self.enc_self_attn(enc_inputs,enc_inputs,enc_inputs,enc_self_attn_mask)
        enc_ouputs = self.pos_ffn(enc_outputs)
        return enc_outputs,attn

In [23]:
test = torch.Tensor([[[1,2,3],[4,5,6]],
             [[7,8,9],[10,11,12]]])

In [25]:
test[:,0]

tensor([[1., 2., 3.],
        [7., 8., 9.]])

In [27]:
test1=torch.Tensor([[1,2,3],[4,5,6]])
test1

tensor([[1., 2., 3.],
        [4., 5., 6.]])

In [34]:
test1.unsqueeze(-1).expand(-1,-1,10)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]],

        [[4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
         [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
         [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.]]])

In [33]:
test1[:,:,None]

tensor([[[1.],
         [2.],
         [3.]],

        [[4.],
         [5.],
         [6.]]])