In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
class PositionEncoding(nn.Module):
  def __init__(self, seq_length, d_model):
    super(PositionEncoding, self).__init__()
    # output = (seq_length, d_model)
    positions = torch.arange(0, seq_length, dtype = torch.float).unsqueeze(0)
    for dim in range(d_model):
      std_dim = dim//2 * 2
      dim_term = torch.tensor([(1/1e4)** (2*std_dim/d_model)] * seq_length)
      dim_term = torch.sin(torch.mul(positions, dim_term)) if dim % 2 == 0 else torch.cos(torch.mul(positions, dim_term))
      if dim == 0:
        self.myOutput = dim_term
      else:
        self.myOutput = torch.cat((self.myOutput, dim_term), dim = 0)
    self.myOutput = self.myOutput.transpose(-1, -2)

  def forward(self, x):
    return self.myOutput

In [None]:
class BERTEmbedding(nn.Module):
  def __init__(self, vocab_size, seq_length, n_segments, d_model, dropout):
    super(BERTEmbedding, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.embed_segment = nn.Embedding(n_segments, d_model)
    self.pe = PositionEncoding(seq_length, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, seq_input, seg_input):
    x = self.embed_segment(seg_input) + self.embedding(seq_input) + self.pe(seq_input)
    return self.dropout(x)

In [None]:
class BERTModel(nn.Module):
  def __init__(self, vocab_size, seq_length, n_segments, d_model, dropout):
    super(BERTModel, self).__init__()
    self.embed = BERTEmbedding(vocab_size, seq_length, n_segments, d_model, dropout)
    encoder_layer = nn.TransformerEncoderLayer(d_model = d_model,
                                               nhead = 8,
                                               dropout = dropout)
    self.bert = nn.TransformerEncoder(encoder_layer,
                                      num_layers = 6,
                                      mask_check = True)
    self.dropout = nn.Dropout(dropout)

  def forward(self, seq_input, seg_input):
    output = self.embed(seq_input, seg_input)
    output = self.bert(output)
    output = self.dropout(output)

    return output

In [None]:
if __name__ == "__main__":
  VOCAB_SIZE = 20000
  SEQ_LENGTH = 100
  D_MODEL = 512
  DROPOUT = 0.2
  N_SEGMENTS = 3
  BATCH_SIZE = 32

  bert = BERTModel(VOCAB_SIZE,
                   SEQ_LENGTH,
                   N_SEGMENTS,
                   D_MODEL,
                   DROPOUT)
  seq_input = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
  seg_input = torch.randint(0, N_SEGMENTS, (BATCH_SIZE, SEQ_LENGTH))
  print(bert(seq_input, seg_input).size())



torch.Size([32, 100, 512])
