In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from copy import deepcopy
import random
import math
import time
from tqdm.notebook import tqdm

from typing import Tuple

from dataset import iterator

In [3]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [4]:
dataset = iterator.BertIterator(seq_len=200)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

____

# Input Embedding

In [5]:
class InputEmbedding(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.tok_emb = nn.Embedding(vocab_size, d_model, padding_idx=dataset.prep.pad_id)
        self.pos_emb = nn.Embedding(seq_length, d_model, padding_idx=dataset.prep.pad_id)
        self.seg_emb = nn.Embedding(2, d_model, padding_idx=dataset.prep.pad_id)
        
    def generate_enc_mask_m(self, src) :       
        mask_m = (src != 1).unsqueeze(1).unsqueeze(2)
        return mask_m
    
    def forward(self, txt, seg) : 
        emb = self.tok_emb(txt)
        pos = torch.arange(0, emb.shape[1]).unsqueeze(0).repeat(emb.shape[0], 1).to(emb.device)
        summed = emb / math.sqrt(self.d_model) + self.pos_emb(pos) + self.seg_emb(seg)
        return summed

# Scaled Dot-Product Attention

In [6]:
class ScaledDotProductAttention(nn.Module) : 
    def __init__(self, d_model) : 
        super().__init__()
        self.d_model = d_model
        self.fc = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask) :         
        score = torch.matmul(q, k.permute(0,1,3,2).contiguous())/math.sqrt(d_model)
        score = score.masked_fill(mask, -1e10)
        scaled_score = torch.softmax(score, dim=-1)
        
        attention = torch.matmul(scaled_score, v).permute(0,2,3,1).contiguous()
        attention = attention.view(attention.shape[0], attention.shape[1], self.d_model)
        
        return self.fc(attention)

___

# Multi-Head Attention

In [7]:
class MultiHeadAttention(nn.Module) : 
    def __init__(self, d_model, seq_length, n_head) : 
        super().__init__()
        assert d_model % n_head == 0, f"n_head({n_head}) does not divide d_model({d_model})"

        self.n_div_head = d_model//n_head
        self.d_model = d_model
        self.seq_len = seq_length
        self.n_head = n_head

        self.Q = nn.Linear(d_model,  d_model)
        self.K = nn.Linear(d_model,  d_model)
        self.V = nn.Linear(d_model,  d_model)
        
    def div_and_sort_for_multiheads(self, projected, seq_len) : 
        div = projected.view(projected.shape[0], self.n_head, seq_len, self.n_div_head)
        return div
    
    def forward(self, emb, enc_inputs=None) :
        q = self.div_and_sort_for_multiheads(self.Q(emb), self.seq_len)
    
        if enc_inputs is not None : # enc-dec attention
            seq_len = enc_inputs.shape[1] # takes target sequence length for k and v
            k = self.div_and_sort_for_multiheads(self.K(enc_inputs), seq_len)
            v = self.div_and_sort_for_multiheads(self.V(enc_inputs), seq_len)
        else : # self-attention
            k = self.div_and_sort_for_multiheads(self.K(emb), self.seq_len)
            v = self.div_and_sort_for_multiheads(self.V(emb), self.seq_len)

        return q,k,v

# Post-process the sub-layer
- layer normalization
- residual conection
- residual dropout

In [8]:
class PostProcessing(nn.Module) : 
    def __init__(self, d_model, p=0.1) : 
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p)
        
    def forward(self, emb, attn) : 
        return self.ln(emb+self.dropout(attn))

# Position-wise FFN

In [9]:
class PositionwiseFFN(nn.Module) : 
    def __init__(self, d_model, d_ff) : 
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x) : 
        return self.fc2(torch.relu(self.fc1(x)))

# Encoder

In [10]:
class EncoderLayer(nn.Module) : 
    def __init__(self, vocab_size, seq_length, d_model, d_ff, n_head, dropout_p) : 
        super().__init__()
        
        self.ma = MultiHeadAttention(d_model, seq_length, n_head).to(device)
        self.sdp = ScaledDotProductAttention(d_model)
        
        self.pp1 = PostProcessing(d_model, dropout_p)
        self.pp2 = PostProcessing(d_model, dropout_p)
        
        self.positionwise_ffn = PositionwiseFFN(d_model, d_ff)
            
    def forward(self, emb, mask_m) :

        q,k,v = self.ma(emb)    
        attn = self.sdp(q,k,v, mask=mask_m)
        
        attn = self.pp1(emb, attn)
        z = self.positionwise_ffn(attn)

        return self.pp2(attn, z)

# BERT

In [11]:
class BERT(nn.Module) : 
    def __init__(self,
                 vocab_size,
                 seq_length,
                 d_model,
                 d_ff,
                 n_head,
                 dropout_p,
                 n_enc_layer) : 
        
        super().__init__()
        
        self.embber = InputEmbedding(vocab_size, seq_length, d_model)
        
        enc = EncoderLayer(vocab_size, seq_length, d_model, d_ff, n_head, dropout_p)
        
        self.enc = nn.ModuleList([deepcopy(enc) for _ in range(n_enc_layer)])
        
        self.mlm_fc = nn.Linear(d_model, vocab_size)
        self.nsp_fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, txt, seg) : 
        
        emb = self.embber(txt, seg)
        mask_m = self.embber.generate_enc_mask_m(txt)
        
        for enc_layer in self.enc : 
            emb = enc_layer(emb, mask_m)
        
        return emb#, self.mlm_fc(emb), self.nsp_fc(emb[:,0,:])

In [13]:
d_model = 128
d_ff = 256
n_head = 8
batch_size = 16
vocab_size = len(dataset.prep.lemma_dict)
dropout_p = 0.1
n_enc_layer = 1
seq_length = dataset.seq_len

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
model = BERT(vocab_size,
             seq_length,
             d_model,
             d_ff,
             n_head,
             dropout_p,
             n_enc_layer).to(device)

In [15]:
for data in dataloader : 
    break

In [16]:
emb = model(data['text'].to(device), data['seg'].to(device))

In [19]:
model.mlm_fc

Linear(in_features=128, out_features=809316, bias=True)

In [18]:
model.mlm_fc(emb)

RuntimeError: CUDA out of memory. Tried to allocate 38.59 GiB (GPU 0; 7.93 GiB total capacity; 1.40 GiB already allocated; 5.76 GiB free; 1.43 GiB reserved in total by PyTorch)