### Coding BERT from scratch again from the d2l book

In [None]:
!pip install setuptools==66
!pip install matplotlib_inline
!pip install d2l==1.0.0b
!pip install pytorch-pretrained-bert

In [None]:
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
class BERTEncoder(nn.Module):
  def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blocks = 2, dropout = 0.2, max_len = 1000):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, num_hiddens)
    self.seg_embedding = nn.Embedding(2, num_hiddens)
    self.pos_embedding = nn.Parameter(torch.rand(1, max_len, num_hiddens))
    self.blocks = nn.Sequential()
    for i in range(num_blocks):
      self.blocks.add_module(f'{i}',
          d2l.TransformerEncoderBlock(
              num_hiddens, ffn_num_hiddens, num_heads, dropout, True))
  
  #tokens shape (batch, max_len)
  def forward(self, tokens, segments, valid_lens):
    X = self.embedding(tokens) + self.seg_embedding(segments)
    #only take up to tokens length
    X = X + self.pos_embedding[:, :X.shape[1], :]
    for block in self.blocks:
      X = block(X, valid_lens)
    
    return X

class MaskLM(nn.Module):
  def __init__(self, vocab_size, num_hiddens, **kwargs):
    super(MaskLM, self).__init__(**kwargs)
    self.mlp = nn.Sequential(
        nn.LazyLinear(num_hiddens),
        nn.ReLU(),
        nn.LayerNorm(num_hiddens),
        nn.LazyLinear(vocab_size),
    )
    
  def forward(self, X, pred_positions):
    #X shape (batch, max_len, # hiddens)
    #pred positions shape (# batch, # preds)
    num_preds = pred_positions.shape[1]
    batch_size = X.shape[0]
    pred_positions = pred_positions.reshape(-1)
    batch_idx = torch.arange(0, batch_size)
    batch_idx = torch.repeat_interleave(batch_idx, num_preds)

    print(pred_positions.shape)
    print(batch_idx.shape)

    masked_X = X[batch_idx,pred_positions]
    #reshape it
    masked_X = masked_X.reshape((batch_size, num_preds, -1))
    mlm_y_hat = self.mlp(masked_X)
    return mlm_y_hat
class NextSentencePred(nn.Module):

  def __init__(self, **kwargs):
    super(NextSentencePred, self).__init__(**kwargs)
    self.output = nn.LazyLinear(2)

  def forward(self, X):
    # X shape (# batch, num_hiddens)
    return self.output(X) 

class BERTModel(nn.Module):
  def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blocks = 2, dropout = 0.2, max_len = 100, **kwargs):
    super(BERTModel, self).__init__(**kwargs)
    self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blocks, dropout, max_len)
    self.hidden = nn.Sequential(
        nn.LazyLinear(num_hiddens),
        nn.Tanh())
    self.mlm = MaskLM(vocab_size, num_hiddens)
    self.nsp = NextSentencePred()
  
  def forward(self, tokens, segments, pred_positions):
    encoded_X = self.encoder(tokens, segments, None)
    if pred_positions == None:
      mlm_Y_pred = None
    else:
      mlm_Y_pred = self.mlm(encoded_X, pred_positions)
    nsp_Y_pred = self.nsp(self.hidden(encoded_X[:,0,:]))

    return encoded_X, mlm_Y_pred, nsp_Y_pred

In [None]:
vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blocks, dropout = 1000, 768, 1024, 4, 2, 0.2
model = BERTModel(vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_blocks, dropout)

batch_size = 2 
token_len = 10
tokens = torch.randint(0, vocab_size, (batch_size, token_len))
print(tokens.shape)
segments = torch.tensor(([0,0,0,0,0,0,1,1,1,1], [0,0,0,0,0,1,1,1,1,1]))
print(segments.shape)
pred_positions = torch.tensor([[1,3,4,5], [3,6,7,1]])
print(pred_positions.shape)

encoded_X, mlm_Y_pred, nsp_Y_pred = model(tokens, segments, pred_positions)

print('encoded_X shape', encoded_X.shape)
print('mlm_Y_pred shape', mlm_Y_pred.shape)
print('nsp_Y_pred shape', nsp_Y_pred.shape)


torch.Size([2, 10])
torch.Size([2, 10])
torch.Size([2, 4])
torch.Size([8])
torch.Size([8])
encoded_X shape torch.Size([2, 10, 768])
mlm_Y_pred shape torch.Size([2, 4, 1000])
nsp_Y_pred shape torch.Size([2, 2])
