In [2]:
import torch
import torch.nn as nn

vocab_dim = 63
seq_len = 100
d_model = 128
dim_feedforward = 512
dropout_rate = 0.1
pad_token_id = 3
nhead = 8
num_layers = 8
use_RNN = False


In [6]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, d_model, dropout_rate):
        super(BERTEmbedding, self).__init__()
        self.vocab_dim = vocab_dim
        self.seq_len = seq_len
        self.d_model = d_model
        self.dropout_rate = dropout_rate
        
        # vocab --> embedding
        self.token_embedding = nn.Embedding(self.vocab_dim, self.d_model) 
        self.token_dropout = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.d_model)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data):
        device = data.get_device()
        
        token_embedding = self.token_embedding(data)
        token_embedding = self.token_dropout(token_embedding)
        
        positional_encoding = torch.arange(start=0, end=self.seq_len, step=1).long()
        positional_encoding = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        return token_embedding + positional_embedding
    

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, d_model, dim_feedforward, pad_token_id, nhead, num_layers):
        super(BERT, self).__init__()
        self.pad_token_id = pad_token_id

        self.embedding = BERTEmbedding(vocab_dim, seq_len, d_model, dropout_rate=0.1)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, dim_feedforward=dim_feedforward, nhead=nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        
        
    def forward(self, data):
        pad_mask = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.num_head, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask = pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask
    
    
class MLMHead(nn.Module):
    def __init__(self, bert, d_model, output_dim, use_RNN=False):
        super(MLMHead, self).__init__()
        self.bert = bert
        self.use_RNN = use_RNN
        self.fc = nn.Linear(d_model, output_dim)
        
        if self.use_RNN:
            self.rnn  = nn.GRU(d_model, d_model)
        
    
    def forward(self, x):
        output = self.bert(x)

        if self.use_RNN:
            output, hidden = self.rnn(output)

        output = self.fc(output)
        
        return output

In [7]:
bert_base = BERT(vocab_dim, seq_len, d_model, dim_feedforward, pad_token_id, nhead, num_layers)
mlm_head = MLMHead(bert_base, d_model, vocab_dim, use_RNN)

mlm_head

MLMHead(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token_embedding): Embedding(63, 128)
      (token_dropout): Dropout(p=0.1, inplace=False)
      (positional_embedding): Embedding(100, 128)
      (positional_dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder_layer): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (linear1): Linear(in_features=128, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=512, out_features=128, bias=True)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (encoder_block): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
       