<a href="https://colab.research.google.com/github/manders2/NodeRed_AWS_Rekognition/blob/master/Bidirectional_encoder_representation_from_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This code is possible because of [Tae-Hwan Jung](https://github.com/graykode). I have just broken down the code and added few things here and here for better understanding.


In [None]:
import math
import re
from random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from pprint import pprint

In [None]:
text = (
        'Hello, how are you? I am Romeo.\n'
        'Hello, Romeo My name is Juliet. Nice to meet you.\n'
        'Nice meet you too. How are you today?\n'
        'Great. My baseball team won the competition.\n'
        'Oh Congratulations, Juliet\n'
        'Thanks you Romeo'
    )

In [None]:
sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n')  # filter '.', ',', '?', '!'
word_list = list(set(" ".join(sentences).split()))
word_dict = {'[PAD]': 0, '[CLS]': 1, '[SEP]': 2, '[MASK]': 3}


for i, w in enumerate(word_list):
    word_dict[w] = i + 4
number_dict = {i: w for i, w in enumerate(word_dict)}
vocab_size = len(word_dict)

token_list = list()
for sentence in sentences:
    arr = [word_dict[s] for s in sentence.split()]
    token_list.append(arr)

In [None]:
token_list

[[19, 11, 15, 6, 28, 22, 16],
 [19, 16, 26, 8, 17, 14, 24, 5, 10, 6],
 [24, 10, 6, 18, 11, 15, 6, 9],
 [12, 26, 4, 21, 20, 25, 23],
 [7, 13, 14],
 [27, 6, 16]]

In [None]:
maxlen = 30 # maximum of length
batch_size = 6
max_pred = 5  # max tokens of prediction
n_layers = 6 # number of Encoder of Encoder Layer
n_heads = 12 # number of heads in Multi-Head Attention
d_model = 768 # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

In [None]:

def make_batch():
    batch = []
    positive = negative = 0
    while positive != batch_size/2 or negative != batch_size/2:
        tokens_a_index, tokens_b_index= randrange(len(sentences)), randrange(len(sentences))
        tokens_a, tokens_b= token_list[tokens_a_index], token_list[tokens_b_index]

        input_ids = [word_dict['[CLS]']] + tokens_a + [word_dict['[SEP]']] + tokens_b + [word_dict['[SEP]']]

        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)

        #MASK LM
        n_pred =  min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) # 15 % of tokens in one sentence

        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word_dict['[CLS]'] and token != word_dict['[SEP]']]
        shuffle(cand_maked_pos)
        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word_dict['[MASK]'] # make mask
            elif random() < 0.5:  # 10%
                index = randint(0, vocab_size - 1) # random index in vocabulary
                input_ids[pos] = word_dict[number_dict[index]] # replace

        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)

    #     # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)

        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext
            positive += 1
        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:
            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext
            negative += 1
    return batch



In [None]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

In [None]:
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

In [None]:
batch = make_batch()
pprint(batch)

[[[1,
   7,
   13,
   14,
   2,
   19,
   11,
   15,
   6,
   3,
   22,
   16,
   2,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [15, 28, 0, 0, 0],
  [7, 9, 0, 0, 0],
  False],
 [[1,
   19,
   11,
   15,
   3,
   28,
   21,
   3,
   2,
   24,
   10,
   6,
   18,
   11,
   15,
   6,
   9,
   2,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [22, 6, 16, 0, 0],
  [6, 4, 7, 0, 0],
  False],
 [[1,
   19,
   16,
   26,
   8,
   3,
   14,
   24,
   3,
   10,
   6,
   2,
   7,
   13,
   14,
   2,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
 

In [None]:
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

In [None]:
get_attn_pad_mask(input_ids, input_ids)[0][0], input_ids[0]

(tensor([False, False, False, False, False, False, False, False, False, False,
         False, False, False,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True]),
 tensor([ 1,  7, 13, 14,  2, 19, 11, 15,  6,  3, 22, 16,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]))

In [None]:
print(input_ids)

tensor([[ 1,  7, 13, 14,  2, 19, 11, 15,  6,  3, 22, 16,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 19, 11, 15,  3, 28, 21,  3,  2, 24, 10,  6, 18, 11, 15,  6,  9,  2,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 19, 16, 26,  8,  3, 14, 24,  3, 10,  6,  2,  7, 13, 14,  2,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  3, 26,  4, 21, 20, 25,  3,  2,  7, 13, 14,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 12, 26,  4, 21,  3, 25, 23,  2,  7, 13,  3,  2,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 19, 16, 26,  8, 17,  3, 24,  5, 10,  3,  2, 24, 10,  6, 18, 11, 15,
          6,  9,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0]])


In [None]:
"""Defines the BERT Model Architecture, creates a dense vector (768 deep) for each of the 3 inputs (toek, position, segment). 'forward' then combines the vectors and normalses them """
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

In [None]:
"""
The code you've selected defines the Embedding class, which is a crucial part of the BERT model's input layer. This class is responsible for creating the initial embeddings for the input tokens, incorporating information about the token itself, its position in the sequence, and the segment it belongs to (e.g., Sentence A or Sentence B). Let's break it down:

__init__(self): This is the constructor for the Embedding class. It initializes four main components:

self.tok_embed = nn.Embedding(vocab_size, d_model): This creates an embedding layer for the tokens themselves. It maps each unique word/token in your vocabulary (vocab_size) to a dense vector of size d_model (768 in this case). This vector captures the semantic meaning of the word.
self.pos_embed = nn.Embedding(maxlen, d_model): This creates a positional embedding layer. Since Transformers process words in parallel and don't inherently know the order of words, a positional embedding is added. It maps each possible position up to maxlen (30 in this case) to a d_model-sized vector, providing the model with sequence order information.
self.seg_embed = nn.Embedding(n_segments, d_model): This creates a segment embedding layer. In tasks like Next Sentence Prediction (NSP), BERT processes two sentences simultaneously. This embedding differentiates between tokens belonging to the first segment (Sentence A, represented by 0) and the second segment (Sentence B, represented by 1), with n_segments being 2.
self.norm = nn.LayerNorm(d_model): This applies Layer Normalization to the combined embeddings. Layer Normalization helps stabilize the training process and improves performance by normalizing the sum of the embeddings across the feature dimension.

forward(self, x, seg): This method defines how the input x (token IDs) and seg (segment IDs) are processed to produce the final embedding. Here's what happens:

seq_len = x.size(1): It gets the length of the input sequence.
pos = torch.arange(seq_len, dtype=torch.long): It generates a tensor representing the positions [0, 1, 2, ..., seq_len-1].
pos = pos.unsqueeze(0).expand_as(x): This reshapes the position tensor to match the shape of the input token IDs x, allowing for element-wise addition.
embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg): This is the core step where the three types of embeddings (token, position, and segment) are retrieved and summed together. This sum represents a rich embedding for each token, containing its semantic meaning, its position, and its sentence context.
return self.norm(embedding): Finally, the combined embedding is passed through the LayerNorm layer before being returned. This normalized embedding is then fed into the subsequent layers of the BERT model.
"""

In [None]:
"""The creates the 'Z' sum for each input / Z matrix (= the output of the self attention layer)  """
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

In [None]:
"""
The ScaledDotProductAttention class is a fundamental building block of the Transformer architecture, responsible for calculating the attention weights and context vectors. Here's a breakdown of its operation:

Purpose: It computes scaled dot-product attention, which determines how much 'attention' a model should pay to different parts of the input sequence when processing each element.

forward(self, Q, K, V, attn_mask) method: This method takes four inputs:

Q (Query): Represents the current element(s) for which we want to compute attention. Its shape is [batch_size, ..., len_q, d_k] (where len_q is the sequence length of the query, and d_k is the dimension of the keys).
K (Key): Represents all elements in the sequence that the query might attend to. Its shape is [batch_size, ..., len_k, d_k].
V (Value): Contains the information associated with each key. Its shape is [batch_size, ..., len_k, d_v].
attn_mask: A boolean mask used to prevent attention to certain positions (e.g., padding tokens). Values of True in the mask indicate positions to be ignored.
Steps involved in the forward pass:

Calculate Raw Attention Scores:

scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
This is the core of dot-product attention. It computes the dot product between the Query (Q) and the Key (K.transpose(-1, -2)). The transpose operation aligns the K matrix for the dot product, effectively calculating the similarity between each query and all keys.
The result is then divided by np.sqrt(d_k). This scaling factor prevents the dot products from growing too large, which can lead to vanishing or exploding gradients during training, especially with large d_k values. It helps stabilize the softmax function.
The scores tensor will have a shape like [batch_size, ..., len_q, len_k], where each element (i, j) indicates the similarity between query i and key j.
Apply Attention Masking:

scores.masked_fill_(attn_mask, -1e9)
The attn_mask is applied to the scores. For any position where attn_mask is True (meaning that position should be ignored), the corresponding score is set to a very large negative number (e.g., -1e9).
This ensures that when the softmax function is applied next, these masked positions will have an output probability close to zero, effectively preventing the model from attending to them.
Compute Attention Weights:

attn = nn.Softmax(dim=-1)(scores)
A Softmax function is applied along the last dimension (dim=-1) of the scores tensor. This normalizes the scores into a probability distribution, where each value attn[i, j] represents the weight or importance that query i assigns to key j (and its corresponding value).
The sum of these weights for each query across all keys will be 1.
Compute Context Vector:

context = torch.matmul(attn, V)
Finally, the attention weights (attn) are multiplied by the Value (V) matrix. This operation creates a weighted sum of the Value vectors, where the weights are the attention probabilities calculated in the previous step.
The resulting context vector is an aggregated representation of the input sequence, selectively focusing on the most relevant information (weighted by attn) for each query.
Output: The method returns two tensors:

context: The weighted sum of the values, representing the attended-to information.
attn: The attention weights, indicating how much each query focused on each key.
"""

In [None]:
emb = Embedding()
embeds = emb(input_ids, segment_ids)

attenM = get_attn_pad_mask(input_ids, input_ids)

SDPA= ScaledDotProductAttention()(embeds, embeds, embeds, attenM)

S, A = SDPA

#print('Masks',masks[0][0])
print()
print('Scores: ', S[0][0],'\n\nAttention M: ', A[0][0])


Scores:  tensor([ 8.7460e-01,  4.6691e-01,  6.3216e-01,  1.1800e+00, -1.3627e+00,
         4.8756e-01,  1.8669e+00, -1.1932e-01,  5.3279e-01,  2.9444e-01,
         1.9735e-01, -5.5001e-01, -1.0140e+00,  7.8670e-01,  1.0887e+00,
         6.3057e-04,  9.9289e-01, -2.9609e-01,  6.0724e-01,  2.6625e-02,
        -3.3870e-02, -2.4476e+00,  1.1672e+00, -1.8267e+00, -1.3979e+00,
        -6.8144e-01,  3.8619e-01, -5.1123e-01, -7.6187e-01,  9.3967e-01,
        -9.7210e-02, -5.0097e-01, -1.0900e+00, -1.5892e+00, -9.8001e-01,
         3.8063e-01,  1.6488e+00,  4.4615e-01, -1.8010e+00, -4.5524e-01,
         2.6158e-01,  5.4329e-01,  1.5783e+00, -3.7724e-01, -1.2882e+00,
         3.7595e-01,  1.2723e+00, -5.8365e-01,  6.9351e-01,  8.7077e-01,
        -1.0934e+00,  2.4805e-01,  3.8487e-01,  1.9047e+00, -1.4751e+00,
        -6.4955e-02,  7.4973e-01,  2.1981e+00,  8.2631e-01,  1.1614e+00,
        -1.3962e+00,  7.4448e-01,  6.3684e-01, -2.9518e-02, -4.7870e-02,
         1.9326e+00, -9.1461e-01,  2.1574

In [None]:
print((SDPA))

(tensor([[[ 0.8746,  0.4669,  0.6322,  ...,  1.0056,  0.8513, -0.4588],
         [ 0.1904,  0.6533,  0.3322,  ...,  1.5567,  0.6592, -0.4647],
         [-0.0363, -1.4184, -0.6248,  ...,  1.2556,  2.0644, -0.6692],
         ...,
         [-0.1348, -1.2966, -0.6218,  ...,  1.1660,  1.9410, -0.7025],
         [-0.6662, -0.8761, -0.6396,  ...,  0.7468,  1.3374, -0.9273],
         [-0.7515,  0.1862, -0.6686,  ...,  0.0298,  1.2525, -0.7221]],

        [[ 0.8746,  0.4669,  0.6322,  ...,  1.0056,  0.8513, -0.4588],
         [ 0.0573,  1.6328, -0.2580,  ...,  0.5064,  1.0939, -0.4329],
         [-0.3442, -0.7339, -0.2056,  ...,  0.6509,  1.6837, -0.8786],
         ...,
         [-0.6747, -0.8419, -0.5177,  ...,  0.7436,  0.3194, -1.1956],
         [-0.1746, -1.3277, -1.7037,  ...,  0.7070,  0.1554, -1.3366],
         [-0.3774, -0.9328, -1.0695,  ...,  0.6860,  0.2872, -1.2324]],

        [[ 0.8746,  0.4669,  0.6322,  ...,  1.0056,  0.8513, -0.4588],
         [ 0.0573,  1.6328, -0.2580,  ...,  

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(n_heads * d_v, d_model)(context)
        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]


In [None]:
"""
The MultiHeadAttention class is a critical component in Transformer models, including BERT. It allows the model to jointly attend to information from different representation subspaces at different positions. Here's a detailed explanation:

__init__(self): This is the constructor for the MultiHeadAttention class. It initializes three linear projection layers:

self.W_Q = nn.Linear(d_model, d_k * n_heads): This layer projects the input query (Q) into n_heads separate 'query' vectors. The output dimension is d_k * n_heads, effectively preparing the input for multiple attention heads, each with a dimension of d_k.
self.W_K = nn.Linear(d_model, d_k * n_heads): Similarly, this layer projects the input key (K) into n_heads separate 'key' vectors.
self.W_V = nn.Linear(d_model, d_v * n_heads): This layer projects the input value (V) into n_heads separate 'value' vectors, each with a dimension of d_v.
forward(self, Q, K, V, attn_mask): This method defines the forward pass of the multi-head attention mechanism:

Store Residual Connection and Batch Size: residual, batch_size = Q, Q.size(0)

It stores the original input Q for a residual connection later (a common practice in Transformers to help with training deep networks).
It also extracts the batch_size from Q.
Linear Projections and Reshaping (Splitting into Heads):

q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)
First, the Q, K, and V inputs (each of shape [batch_size, len_seq, d_model]) are passed through their respective linear projection layers (W_Q, W_K, W_V). The output now has a dimension of d_k * n_heads (or d_v * n_heads for V).
Then, view is used to reshape these outputs to explicitly separate the n_heads. For example, (batch_size, len_seq, d_k * n_heads) becomes (batch_size, len_seq, n_heads, d_k).
Finally, transpose(1,2) rearranges the dimensions to (batch_size, n_heads, len_seq, d_k) (or d_v for v_s). This arranges the tensors so that each 'head' can process its own Q, K, V independently.
Prepare Attention Mask for Multi-Head: attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

The attn_mask (which is typically [batch_size, len_q, len_k]) is expanded to include the n_heads dimension, making it [batch_size, n_heads, len_q, len_k]. This ensures the same masking pattern is applied to all attention heads.
Apply Scaled Dot-Product Attention: context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)

The projected q_s, k_s, v_s, and the expanded attn_mask are passed to the ScaledDotProductAttention module (which you defined earlier). This performs the core attention calculation for each head independently, resulting in context vectors and attention weights attn for each head.
Concatenate and Final Linear Projection: context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)

output = nn.Linear(n_heads * d_v, d_model)(context)
The context from all heads (currently [batch_size, n_heads, len_q, d_v]) are transposed back and then viewed (concatenated) to combine the outputs from all heads back into a single tensor of shape [batch_size, len_q, n_heads * d_v].
This concatenated output then undergoes a final linear projection (nn.Linear(n_heads * d_v, d_model)) to bring the dimension back to d_model.
Add Residual Connection and Layer Normalization: return nn.LayerNorm(d_model)(output + residual), attn

The output from the linear projection is added to the original residual (input Q).
Finally, Layer Normalization (nn.LayerNorm(d_model)) is applied to this sum.
This entire process allows the model to capture diverse relationships within the sequence by attending to different parts of the input through multiple heads, and then consolidating these perspectives into a single, richer representation.
"""

In [None]:
emb = Embedding()
embeds = emb(input_ids, segment_ids)

attenM = get_attn_pad_mask(input_ids, input_ids)

MHA= MultiHeadAttention()(embeds, embeds, embeds, attenM)

Output, A = MHA

A[0][0]

In [None]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(gelu(self.fc1(x)))


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

In [None]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ1 = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        # it will be decided by first token(CLS)
        h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled) # [batch_size, 2]

        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        # get masked position from final output of transformer.
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked = self.norm(self.activ2(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_clsf

In [None]:
model = BERT()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

batch = make_batch()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(*batch))

for epoch in range(10):
    optimizer.zero_grad()
    logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
    loss_lm = criterion(logits_lm.transpose(1, 2), masked_tokens) # for masked LM
    loss_lm = (loss_lm.float()).mean()
    loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
    loss = loss_lm + loss_clsf
    if (epoch + 1) % 10 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

In [None]:
# Predict mask tokens ans isNext
input_ids, segment_ids, masked_tokens, masked_pos, isNext = map(torch.LongTensor, zip(batch[0]))
print(text)
print([number_dict[w.item()] for w in input_ids[0] if number_dict[w.item()] != '[PAD]'])

logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
print('masked tokens list : ',[pos.item() for pos in masked_tokens[0] if pos.item() != 0])
print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])

logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
print('isNext : ', True if isNext else False)
print('predict isNext : ',True if logits_clsf else False)