In [None]:
import math
import re
from random import *
# import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# #to reproduce results
# SEED=1234

# random.seed(SEED)
# np.random.seed(SEED)
# torch.manual_seed(SEED)
# torch.cuda.manual_seed(SEED)
# torch.backends.cudnn.deterministic = True

In [None]:
def make_batch():
  batch=[]
  positive = negative = 0
  #when count of positive and negative sentences = batch_size, dont proceed
  while positive != batch_size/2 or negative != batch_size/2 :
    #selects random tokens_a_index and tokens_b_index from [0,len(sentences))
    tokens_a_index,tokens_b_index = randrange(len(sentences)),randrange(len(sentences))
    #token_list is list of sentence tokens
    tokens_a, tokens_b =token_list[tokens_a_index],token_list[tokens_b_index]
    input_ids = [word_dict['[CLS]']] +tokens_a +[word_dict['[SEP]']] +tokens_b +[word_dict['[SEP]']]
    segment_ids= [0] * (1+len(tokens_a)+1) + [1] * (len(tokens_b) +1)

    #n_pred - max mask we can apply
    n_pred = min(max_pred, max(1, int(round(len(input_ids)*0.15))))

    sent12 = [i for i,token in enumerate(input_ids) if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
    shuffle(sent12)
    masked_tokens,masked_pos=[],[]
    for pos in sent12[:n_pred]:
      masked_pos.append(pos)
      masked_tokens.append(input_ids[pos])
      if random()<0.8:
        input_ids[pos] = word_dict['[MASK]']
      elif random()<0.9:
        index=randrange(vocab_size)
        input_ids[pos] = index  
        """change"""
      else:
        pass #mask token will remains unchanged
      
    #in paper it is mentioned that model is trained for max seq len of 128 for 90% and 512 for 10%
    n_pad=max_len - len(input_ids)
    input_ids.extend([0]*n_pad)
    segment_ids.extend([0]*n_pad)

    #this step is not clear
    if max_pred > n_pred:
      n_pad = max_pred - n_pred
      masked_tokens.extend([0] * n_pad)
      masked_pos.extend([0] * n_pad)

    if tokens_a_index+1 == tokens_b_index and positive < batch_size/2:
      batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True])
      positive+=1

    elif tokens_a_index+1 != tokens_b_index and negative < batch_size/2:
      batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False])
      negative+=1
    
  return batch

In [None]:
def get_attn_pad_mask(seq_q,seq_k):
  #here len_q = len_k = len_sent 
  batch_size, len_q= seq_q.size()
  batch_size, len_k= seq_k.size()
  pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  #batch_size x 1 x len_k(=len_q)
  # mask=0=true and othertoken=1=false
  return pad_attn_mask.expand(batch_size,len_q,len_k) 

def gelu(x):
    return x*0.5*(1.0+torch.erf(x/math.sqrt(2.0)))
    #Computes the error function of each element. The error function:
    #erf(x) = (2/sqrt(pi))int(e^(-t^2)dt)

class Embedding(nn.Module):
  def __init__(self):
    super(Embedding, self).__init__()
    self.tok_embed=nn.Embedding(vocab_size,h_dim)
    self.pos_embed=nn.Embedding(max_len,h_dim)
    self.seg_embed=nn.Embedding(n_segments,h_dim)
    self.norm=nn.LayerNorm(h_dim)

  def forward(self,x,seg):
    seq_len= x.size(1)
    pos = torch.arange(seq_len, dtype=torch.long)
    pos = pos.unsqueeze(0).expand_as(x)
    embedding = self.tok_embed(x)+self.pos_embed(pos)+self.seg_embed(seg)
    #Layer Normalization of sum(embedding) 
    return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
  def __init__(self):
    super(ScaledDotProductAttention,self).__init__()

  def forward(self, Q, K, V, attn_mask):
    #Q=[batch_size,n_heads,q_len,d_k] K=[batch_size,n_heads,k_len,d_k]
    scores=torch.matmul(Q,K.permute(0,1,3,2))/np.sqrt(d_k)
    scores.masked_fill_(attn_mask, -1e9) #Fills elements of scores with -1e9 where mask is one.
    attn=torch.softmax(scores,dim=-1)
    context=torch.matmul(attn,V)
    return context, attn

class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()
    #d_q=d_k
    self.W_Q = nn.Linear(h_dim, d_k * n_heads)
    self.W_K = nn.Linear(h_dim,d_k * n_heads)
    self.W_V = nn.Linear(h_dim, d_v *n_heads)
    self.fc_o = nn.Linear(n_heads * d_v, h_dim)
    self.norm=nn.LayerNorm(h_dim)

  def forward(self,Q,K,V, attn_mask):
    batch_size=Q.shape[0]

    q_s=self.W_Q(Q).view(batch_size, -1, n_heads, d_k).permute(0,2,1,3) #q_s=[batch_size, n_heads,q_len,d_k]
    k_s=self.W_K(K).view(batch_size, -1, n_heads, d_k).permute(0,2,1,3) #k_s=[batch_size,n_heads,k_len,d_k]
    v_s=self.W_V(V).view(batch_size, -1, n_heads, d_v).permute(0,2,1,3) #v_s=[batch_size,n_heads,v_len,d_v]

    attn_mask = attn_mask.unsqueeze(1).repeat(1,n_heads,1,1)
    #so dim 1 gets repeated by n_heads and other dims (0,2,3) are not repeated as repeat factor=1
    # attn_mask : [batch_size x n_heads x len_q x len_k]

    ScaledDotProductAttentionInstant = ScaledDotProductAttention()
    context,attn = ScaledDotProductAttentionInstant(q_s,k_s,v_s,attn_mask)
    #context=[batch_size,n_heads,q_len,d_v]
    context = context.permute(0,2,1,3).contiguous()
    #context=[batch_size,q_len,n_heads,d_v]
    context = context.view(batch_size, -1 , n_heads*d_v)
    #context=[batch_size,q_len,n_heads*d_v]
    output = self.fc_o(context)
    #output=[batch_size,q_len,h_dim]
    return self.norm(output+Q), attn

class PoswiseFeedForwardNet(nn.Module):
  def __init__(self):
      super(PoswiseFeedForwardNet,self).__init__()
      self.fc1 = nn.Linear(h_dim,d_ff)
      self.fc2 = nn.Linear(d_ff, h_dim)

  def forward(self, x):
      return self.fc2(gelu(self.fc1(x)))

class EncoderLayer(nn.Module):
  def __init__(self):
      super(EncoderLayer, self).__init__()
      self.enc_self_attn = MultiHeadAttention()
      self.pos_ffn = PoswiseFeedForwardNet()

  def forward(self,enc_inputs,enc_self_attn_mask):
      enc_outputs , attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
      enc_outputs = self.pos_ffn(enc_outputs)
      return enc_outputs, attn

class BERT(nn.Module):
  def __init__(self):
    super(BERT,self).__init__()
    self.embedding = Embedding()
    embed_weight=self.embedding.tok_embed.weight
    self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
    self.fc = nn.Linear(h_dim,h_dim)
    self.linear = nn.Linear(h_dim,h_dim)
    self.activ1=nn.Tanh()
    self.norm=nn.LayerNorm(h_dim)
    self.classifier=nn.Linear(h_dim,2)
    n_vocab,n_dim = embed_weight.size()
    self.decoder=nn.Linear(n_dim,n_vocab,bias=False)
    self.decoder.weight = embed_weight
    self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

  def forward(self,input_ids,segment_ids,masked_pos):
    #shape(input_ids) = shape(segment_ids) = shape(masked_pos) =(batchsize, sent_length)
    output = Embedding()(input_ids, segment_ids)
    enc_self_attn_mask = get_attn_pad_mask(input_ids,input_ids)
    for layer in self.layers:
      output, enc_self_attn =layer(output, enc_self_attn_mask) #output=[batch_size,sent_len,h_dim]
      
    
    #extracting cls token
    h_pooled = self.activ1(self.fc(output[:,0])) #output[:,0]=[batch_size,h_dim]
    """2"""
    logits_clsf = self.classifier(h_pooled)

    #basically None adds newaxis at that position i.e. 3rd dim is introduced here.
    """print('output',masked_pos )"""
    #masked_pos =[batch_size,mask_len] mask_len<5 due to considered conditions
    masked_pos = masked_pos[:,:,None].expand(-1,-1,output.size(-1)) #masked_pos =[batch_size,mask_len,h_dim] this is done to select embeddings of mask tokens
    """print('output_size',masked_pos)"""
    
    # get embeddings of only mask tokens from final output of transformer
    h_masked = torch.gather(output,1,masked_pos)
    
    h_masked = self.norm(gelu(self.linear(h_masked)))
    logits_lm = self.decoder(h_masked) + self.decoder_bias
    #logits_lm=[batch_size,mask_len,vocab_size
    return logits_lm, logits_clsf


In [None]:
if __name__ == '__main__':
  #BERT parameters
  max_len = 30  #maximum length of each sentence
  batch_size = 6
  max_pred = 5
  n_layers = 6
  n_heads=12
  h_dim = 768
  d_ff = 3072  #768 * 4
  d_k = d_v = 64
  n_segments = 2

  text1 = ('Hello, how are you? I am Romeo. \n'
      'Hello, Romeo My name is Julliet. Nice to meet you .\n'
      'Nice meet you too.today?\n'
      'Great. My baseball team won the competition.\n'
      'Oh Congratulations, Juliet\n'
      'Thank you Romeo')
  
  text2=text1.split("''")
  text=text2[0]
  
  sentences = re.sub("[.,!?\\-]",'',text.lower()).split('\n')
  word_list = list(set(" ".join(sentences).split()))
  word_dict = {'[PAD]':0, '[CLS]':1, '[SEP]':2, '[MASK]':3}

  for i,w in enumerate(word_list):
      word_dict[w]=i+4
  number_dict = {i:w for i,w in enumerate (word_dict)}
  vocab_size = len(word_dict)

  token_list = list()
  for sentence in sentences:
      arr=[word_dict[s] for s in sentence.split()]
      token_list.append(arr)

  model=BERT()
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(),lr=0.001)

  batch=make_batch()
  input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

  #training
  for epoch in range(1):
    optimizer.zero_grad()
    logits_lm, logits_clsf = model(input_ids,segment_ids,masked_pos)
    # print(logits_lm)
    loss_lm=criterion(logits_lm.transpose(1,2),masked_tokens) #for masked LM
    print(masked_tokens)
    loss_lm=(loss_lm.float()).mean()
    loss_clsf=criterion(logits_clsf,isNext)# for sentence classification
    loss = loss_lm+loss_clsf
    # if (epoch+1)%10==0:
      # print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()



In [None]:
  # Predict mask tokens ans isNext
  model.eval()
  epoch_loss1=0
  with torch.no_grad():
    for i in range(0,5):
      input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[4]))
      # print(text)
      print([number_dict[int(w)] for w in input_ids[0] if number_dict[int(w)] != '[PAD]'])

      logits_lm1, logits_clsf1 = model(input_ids,segment_ids,masked_pos)
      logits_lm = logits_lm1.data.max(2)[1][0].data.numpy()
      print('masked tokens list:',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])
      print('predicted masked tokens list:',[pos for pos in logits_lm if pos !=0])

      logits_clsf = logits_clsf1.data.max(1)[1].data.numpy()[0]
      print(isNext)
      print(logits_lm1.data.max(2))
      print(masked_tokens)
      print('IsNext: ', True if isNext[0] else False)
      print('predict isNext:',True if logits_clsf else False)

      loss_lm=criterion(logits_lm1.transpose(1,2),masked_tokens) #for masked LM
      loss_lm=(loss_lm.float()).mean()
      loss_clsf=criterion(logits_clsf1,isNext)# for sentence classification
      loss = loss_lm+loss_clsf
      print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
   
