# multi-Head Attention

In [48]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math


import random
from torch.utils.data import Dataset

from tqdm import tqdm

In [49]:
class Attention(nn.Module) :
    
    def forward(self, query, key, value, mask = None, dropout = None) :
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        
        if mask is not None :
            scores = scores.masked_fill(mask == 0, -1e9)
            
        p_attn = F.softmax(scores, dim = 1)
        
        if dropout is not None :
            p_attn = dropout(p_attn)
            
        return torch.matmul(p_attn, value), p_attn

In [50]:
class MultiHeadAttention(nn.Module) :
    
    def __init__(self, h, d_model, dropout = 0.1) :
        super().__init__()
        
        assert d_model % h == 0
        
        self.d_k = d_model // h
        self.h = h
        
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()
        self.dropout = nn.Dropout(p = dropout)
        
    def forward(self, query, key, value, mask = None):
        batch_size = query.size()
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1,2) for l, x in zip(Self.linear_layers, (query, key, value))]
        
        x, attn = self.attention(query, key, value, mask = mask, dropout = self.dropout)
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.h * self.d_k)
        return self.output_linear(x)

# utils.py

In [51]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [54]:
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
        

In [56]:
class GELU(nn.Module):
     def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [57]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()
        
    def forward(self, x):
        return self.w_2(self.dropout(self.activation(self.w_1(X))))

# transformer encoder block

In [58]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadAttention(h = attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff = feed_forward_hidden, dropout = dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout()

# Three Embeddings

In [59]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size = 512):
        super().__init__(vocab_size, embed_size, padding_idx = 0)

In [60]:
class SegmentEmbedding(nn.Embedding):
    def __init__(self, embed_size = 512):
        super().__init__(3, embed_size, padding_idx = 0)

In [61]:
class PositionEmbedding(nn.Module):
    def __init__(self, d_model, max_len = 512):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model).float()
        pe.required_grade = False
        
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return self.pe[:, :x.size(1)]
    

In [62]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, dropout=0.1):
        
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionEmbedding(d_model = self.token.embedding_dim)
        self.segment = SegmentEmbedding(embed_size=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size
        
    def forward(self, sequence, segment_label):
        x = self.token(sequence) + self.position(sequence) + self.segment(segment_label)
        return self.dropout(x)

# BERT model

In [63]:
class BERT(nn.Module):
    def __init__(self, vocab_size, hidden = 768, n_layers = 12, attn_heads = 12, dropout = 0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layer = n_layers
        self.attn_heads = attn_heads
        self.feed_forward_hidden = hidden * 4
        # bert임베딩 = token + segment + position
        self. embedding = BERTEmbedding(vocab_size = vocab_size, embed_size = hidden)
        #transformer 블럭
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]
        )
        
    def forward(self, x, segment_info):
        #attention masking
        mask = (x>0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        
        x = self.embedding(x, segment_info)
        
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
            
        return x

# MLM

In [64]:
class MaskedLanguageModel(nn.Module):
    def __init__(self, hidden, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim = -1)
        
    def forward(self, x):
        return self.softmax(self.linear(x))

# NSP

In [65]:
class NextSentencePrediction(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim = -1)
    
    def forward(self, x):
        return self.softmax(self.linear(x[:,0]))
    

In [66]:
class BERTLM(nn.Module):
    def __init__(self, bert : BERT, vocab_size):
        super().__init__()
        self.bert = bert
        self.next_sentence = NextSentencePrediction(self.bert.hidden)
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)
        
    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.next_sentence(x), self.mask_lm
        

In [68]:

bert = BERT(30000)
bertlm = BERTLM(bert,30000)

In [69]:
print(bertlm)

BERTLM(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(30000, 768, padding_idx=0)
      (position): PositionEmbedding()
      (segment): SegmentEmbedding(3, 768, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_blocks): ModuleList(
      (0-11): 12 x TransformerBlock(
        (attention): MultiHeadAttention(
          (linear_layers): ModuleList(
            (0-2): 3 x Linear(in_features=768, out_features=768, bias=True)
          )
          (output_linear): Linear(in_features=768, out_features=768, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): GELU()
        )
        (input_sublayer): SublayerCon

# 모델학습

In [None]:
class BERTTrainer:
    def __init__(self, bert:BERT, vocab_size : int, train_dataloader : DataLoader, test_dataloader : DataLoader = None, lr : float = 1e-4, betas = (0.9, 0.999), weight_decay : float = 0.01, warmup_steps = 10000, with_cuda : bool = True, cuda_devices = None, log_freq : int = 10):
        cuda_condition = torch.cuda.is_available() and with_cuda
        self. device = torch.device("cuda:0" if cuda_condition else "cpu")
        
        self.bert = bert
        self.model = BERTLM(bert, vacab_size).to(self.device)
        
        self.train_data = train_dataloader
        self.test_data = test_dataloader
        
        self.optim = Adam(self.model.parameters(), lr=lr, betas = betas, weight_decay = weight_decay)
        self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)
        
        self.criterion = nn.NLLoss(ignore_index=0)
        self.log_freq = log_freq
        
        print("Total Parameters : ", sum([p.nelement() for p in self.model.parameters()]))
        
    def train(self, epoch):
        self.iteration(epoch, self.train_data)
    
    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)
    
    def iteration(self, epoch, data_loader, train=True):
        srt_code = "train" if train else "test"
        data_iter = tqdm(enumerate(data_loader),desc="EP_%s:%d" % (str_code, epoch),total=len(data_loader),bar_format="{l_bar}{r_bar}")
        
        avg_loss = 0.0
        total_correct = 0
        total_element = 0
        
        for i in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
            next_loss = self.criterion(text_sent_output, data["is_next"])
            mask_loss = self.criterion(mask_lm_output.transpose(1,2), data["bert_label"])
            loss = next_loss + mask_loss
            
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()
                
            correct = next_sent_output.argmax(dim=1).eq(data["is_next"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["is_next"].nelement()
            
            post_fix = {
                "epoch": epoch,
                "iter" : i,
                "avg_loss" : avg_loss / (i+1),
                "avg_acc" : total_correct / total_element * 100
                "loss" : loss.item()
            }
            
            if i% self.log_freq == 0:
                data_iter.write(str(post_fix))
        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=",
            total_correct * 100.0 / total_element)

        
            