In [None]:
#https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.embedding     = BERTEmbedding()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.encoder_block = nn.TransformerEncoder(encoder_layer, num_layer=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [None]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.droupout_rate = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_droupout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long() 
        # Batch 차원 추가
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [None]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert, input_dim, output_dim):
        super(MaskedLanguageModeling, self).__init__()
        self.bert = bert
        self.fc   = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        output = self.bert(x)
        output = self.fc(output)
        
        return output
    
    
class NextSentencePrediction(nn.Module):
    def __init__(self, bert, input_dim, output_dim=2):
        super(NextSentencePrediction, self).__init__()
        self.bert = bert
        self.fc   = nn.Linear(intput_dim, output_dim)
        
    def forward(self, x):
        output = self.bert(x)
        output = self.fc(output)
        
        return output[:, 0, :] # CLS token 