In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import BertTokenizerFast

from tqdm.notebook import tqdm

fpath='data/tokenizer_model'
tokenizer = BertTokenizerFast.from_pretrained(fpath,
                                              strip_accents=False,
                                              lowercase=False)

vocab_dim     = len(tokenizer.vocab)
seq_len       = 256
embedding_dim = 512
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device        = "cpu"
batch_size    = 128

In [2]:
class BERTLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15, NSP_rate=0.5):
        super(BERTLangaugeModelDataset, self).__init__()

        self.data         = data        
        self.tokenizer    = tokenizer
        self.vocab        = tokenizer.vocab
        self.seq_len      = seq_len
        self.masking_rate = masking_rate
        self.NSP_rate     = NSP_rate
        
        self.cls_token_id  = self.tokenizer.cls_token_id
        self.sep_token_id  = self.tokenizer.sep_token_id
        self.pad_token_id  = self.tokenizer.pad_token_id
        self.mask_token_id = self.tokenizer.mask_token_id
        
    def __getitem__(self, sent_1_idx):       
        sent_1 = self.tokenizer.encode(self.data[sent_1_idx])[1:-1]
        sent_2_idx = sent_1_idx + 1
        
        # NSP
        if torch.rand(1) >= self.NSP_rate and sent_2_idx != len(self.data):
            is_next = torch.tensor(1)
        else:
            while sent_2_idx == sent_1_idx + 1:
                sent_2_idx = torch.randint(0, len(self.data), (1,))
            is_next = torch.tensor(0)

        sent_2 = self.tokenizer.encode(self.data[sent_2_idx])[1:-1]
        
        # if length of (sent 1 + sent 2) longer than threshold
        # CLS, SEP 1 and 2
        if len(sent_1) + len(sent_2) >= self.seq_len - 3:
            if len(sent_1) >= self.seq_len -3:
                sent_1 = sent_1[:int(self.seq_len/2)]
                
            sent_2 = sent_2[:self.seq_len - 3 - len(sent_1)]
        
        pad_length = self.seq_len - 3 - len(sent_1) - len(sent_2)
        target = torch.tensor([self.cls_token_id] + sent_1 + [self.sep_token_id] + sent_2 + [self.sep_token_id] + [self.pad_token_id] * pad_length).long().contiguous()        

        sengment_embedding = torch.zeros(target.size(0))
        sengment_embedding[(len(sent_1) + 2):] = 1
        
        # masking
        masked_sent_1, masking_label_sent_1 = self.masking(sent_1)
        masked_sent_2, masking_label_sent_2 = self.masking(sent_2)
        
        masking_label = torch.cat([
            torch.tensor([self.pad_token_id]),
            masking_label_sent_1, 
            torch.tensor([self.sep_token_id]),
            masking_label_sent_2,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ])
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent_1,
            torch.tensor([self.sep_token_id]),
            masked_sent_2,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        return train, target, sengment_embedding, is_next, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab
    
    
    def decode(self, x):
        return self.tokenizer.batch_decode(x)
    
    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label

In [3]:
# https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 8
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [4]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [5]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert, output_dim):
        super(MaskedLanguageModeling, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
    
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output
    
    
class NextSentencePrediction(nn.Module):
    def __init__(self, bert, output_dim=2):
        super(NextSentencePrediction, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
        
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output[:, 0, :] # CLS token 

In [6]:
def train(mlm_head, nsp_head, iterator, optimizer, criterion, device, clip=1):
    mlm_head.train()
    nsp_head.train()
    
    mlm_epoch_loss = 0
    nsp_epoch_loss = 0
    
    for batch, (X, y_mlm, segment_emb, y_nsp, masking_label) in tqdm(enumerate(iterator), total=len(iterator)):
        optimizer.zero_grad()
        
        mlm_output = mlm_head(X.to(device), segment_emb.long().to(device))
        print(f"mlm output : {mlm_output.shape}")
        
        output_dim = mlm_output.shape[-1]
        
        #         mlm_output = mlm_output.reshape(-1, mlm_output.shape[-1])
        #         mlm_loss   = criterion(mlm_output, y_mlm.to(device).reshape(-1)) # CE
        
        print(f"masking label : {masking_label.shape}")
        print(f"masking label : {masking_label}")
        
        mlm_output = mlm_output[masking_label.bool().to(device)].reshape(-1, output_dim)
        print(f"mlm output after masking : {mlm_output.shape}")
        mlm_target = y_mlm.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        print(f"mlm target : {mlm_target.shape}")
        
        break
        
        mlm_loss   = criterion(mlm_output, mlm_target)
        
        nsp_output = nsp_head(X.to(device), segment_emb.long().to(device))
        nsp_loss   = criterion(nsp_output, y_nsp.to(device)) # no need for reshape target
        
        loss = mlm_loss + nsp_loss
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(nsp_head.parameters(), clip)
        
        optimizer.step()

        mlm_epoch_loss += mlm_loss.item()
        nsp_epoch_loss += nsp_loss.item()

    return mlm_epoch_loss / len(iterator), nsp_epoch_loss / len(iterator)


@torch.no_grad()
def evaluate(mlm_head, nsp_head, iterator, optimizer, criterion, device, clip=1):
    mlm_head.eval()
    nsp_head.eval()
    
    mlm_epoch_loss = 0
    nsp_epoch_loss = 0
    
    for batch, (X, y_mlm, segment_emb, y_nsp, masking_label) in enumerate(iterator):
        optimizer.zero_grad()
        
        mlm_output = mlm_head(X.to(device), segment_emb.long().to(device))
        output_dim = mlm_output.shape[-1]
        
        #         mlm_output = mlm_output.reshape(-1, mlm_output.shape[-1])
        #         mlm_loss   = criterion(mlm_output, y_mlm.to(device).reshape(-1)) # CE
        
        mlm_output = mlm_output[masking_label.bool().to(device)].reshape(-1, output_dim)
        mlm_target = y_mlm.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        mlm_loss   = criterion(mlm_output, mlm_target)
        
        nsp_output = nsp_head(X.to(device), segment_emb.long().to(device))
        nsp_loss   = criterion(nsp_output, y_nsp.to(device)) # no need for reshape target
        
        mlm_epoch_loss += mlm_loss.item()
        nsp_epoch_loss += nsp_loss.item()

    return mlm_epoch_loss / len(iterator), nsp_epoch_loss / len(iterator)

In [7]:
with open("data/petitions.txt", 'r') as f:
    data = f.readlines()
    
proced_data = [line.replace("\n", "") for line in data]

train_data = proced_data[int(len(proced_data) * 0.2):]
test_data  = proced_data[:int(len(proced_data) * 0.2)]

train_dataset = BERTLangaugeModelDataset(data=train_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

valid_dataset = BERTLangaugeModelDataset(data=test_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [8]:
bert     = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=0).to(device)
mlm_head = MaskedLanguageModeling(bert, output_dim=vocab_dim).to(device)
nsp_head = NextSentencePrediction(bert).to(device)

In [9]:
optimizer = optim.Adam(list(mlm_head.parameters()) + list(nsp_head.parameters()), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min')
criterion = nn.CrossEntropyLoss()

N_EPOCHS  = 1000
PAITIENCE = 30

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

checkpoint = torch.load('weights/BERT_LM_best.pt')
mlm_head.load_state_dict(checkpoint['mlm_head'])
nsp_head.load_state_dict(checkpoint['nsp_head'])
optimizer.load_state_dict(checkpoint['optimizer'])

for epoch in range(N_EPOCHS):
    train_mlm_loss, train_nsp_loss = train(mlm_head, nsp_head, train_dataloader, optimizer, criterion, device)
    valid_mlm_loss, valid_nsp_loss = evaluate(mlm_head, nsp_head, valid_dataloader, optimizer, criterion, device)
    
    valid_loss = valid_mlm_loss + valid_nsp_loss
    scheduler.step(valid_loss)
    
    print(f'Epoch: {epoch + 1:04}')
    print(f'Train MLM Loss: {train_mlm_loss:.4f} | Train NSP Loss: {train_nsp_loss:.4f}')
    print(f'Valid MLM Loss: {valid_mlm_loss:.4f} | Valid NSP Loss: {valid_nsp_loss:.4f}')

    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(
                {'mlm_head' : mlm_head.state_dict(),'nsp_head' : nsp_head.state_dict(),'optimizer': optimizer.state_dict()}, 
                'weights/BERT_LM_best.pt'
            )
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        checkpoint = torch.load('weights/BERT_LM_best.pt')
        mlm_head.load_state_dict(checkpoint['mlm_head'])
        nsp_head.load_state_dict(checkpoint['nsp_head'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        break

  super(Adam, self).__init__(params, defaults)


  0%|          | 0/970 [00:00<?, ?it/s]

mlm output : torch.Size([128, 256, 32000])
masking label : torch.Size([128, 256])
masking label : tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
mlm output after masking : torch.Size([3068, 32000])
mlm target : torch.Size([3068])
Epoch: 0001
Train MLM Loss: 0.0000 | Train NSP Loss: 0.0000
Valid MLM Loss: 6.6907 | Valid NSP Loss: 0.6932


  0%|          | 0/970 [00:00<?, ?it/s]

mlm output : torch.Size([128, 256, 32000])
masking label : torch.Size([128, 256])
masking label : tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.]])
mlm output after masking : torch.Size([3068, 32000])
mlm target : torch.Size([3068])


KeyboardInterrupt: 

In [16]:
for batch, (X, y_mlm, segment_emb, y_nsp, masking_label) in enumerate(train_dataloader):
    mlm_output = mlm_head(X.to(device), segment_emb.long().to(device))
    print(f"mlm output : {mlm_output.shape}")
    
    output_dim = mlm_output.shape[-1]

    #         mlm_output = mlm_output.reshape(-1, mlm_output.shape[-1])
    #         mlm_loss   = criterion(mlm_output, y_mlm.to(device).reshape(-1)) # CE

    print(f"masking label : {masking_label.shape}")
    print(f"masking label : {masking_label}")

    mlm_output_ = mlm_output[masking_label.bool().to(device)].reshape(-1, output_dim)
    print(f"mlm output after masking : {mlm_output.shape}")
    mlm_target = y_mlm.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
    print(f"mlm target : {mlm_target.shape}")

    break

mlm output : torch.Size([128, 256, 32000])
masking label : torch.Size([128, 256])
masking label : tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 1.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
mlm output after masking : torch.Size([128, 256, 32000])
mlm target : torch.Size([3036])


In [18]:
mlm_output

tensor([[[ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8982,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
         ...,
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220]],

        [[ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         ...,
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8983,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220]],

        [[ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220],
         [-0.6296,  4.9435, -0.6296,  ..., -0

In [17]:
mlm_output_

tensor([[-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
        [-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
        [-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
        ...,
        [-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
        [-0.6296,  4.9435, -0.6296,  ..., -0.5748, -0.5267, -0.5260],
        [ 0.6251, -4.8984,  0.6251,  ...,  0.5706,  0.5228,  0.5220]],
       device='cuda:0', grad_fn=<ViewBackward>)

In [19]:
mlm_output = torch.randn(10, 256, 128)

In [20]:
mlm_output

tensor([[[-1.8854e+00, -4.5087e-01, -3.0679e+00,  ...,  9.5486e-01,
           6.8343e-01,  4.0800e-01],
         [-2.4270e-01, -2.0843e-01, -3.7478e-01,  ...,  1.2955e+00,
          -2.8333e-01,  3.0489e-02],
         [ 7.1423e-02, -1.0986e+00,  1.1390e+00,  ..., -5.7042e-01,
          -8.8368e-01,  6.9217e-01],
         ...,
         [ 3.2071e-01, -2.2002e+00,  8.0698e-02,  ...,  3.2466e-01,
          -9.9349e-01,  5.0202e-01],
         [-1.4008e+00,  5.0895e-01, -9.3579e-02,  ..., -1.1260e+00,
          -4.9518e-01,  9.7975e-01],
         [-3.9151e-01,  6.5946e-01,  9.1997e-01,  ..., -1.8284e+00,
           1.1318e+00,  7.0110e-02]],

        [[ 5.1004e-01,  6.4652e-01,  4.8502e-01,  ...,  8.0034e-01,
          -5.4346e-01, -1.6462e+00],
         [ 6.9855e-02, -9.8620e-02, -6.2837e-01,  ...,  7.2709e-01,
          -3.3640e-01, -1.7753e+00],
         [ 1.3028e+00, -1.6390e+00,  9.5311e-02,  ..., -5.4788e-01,
           3.3195e-02, -1.3426e+00],
         ...,
         [-5.6727e-01,  8

In [33]:
masking_label        = torch.zeros(10, 256)
masking_label[0, 1]  = 1
masking_label[1, 12] = 1
masking_label[2, 13] = 1
masking_label[3, 14] = 1
masking_label[3, 15] = 1
mlm_output_          = mlm_output[masking_label.bool()]

In [29]:
mlm_output.shape

torch.Size([10, 256, 128])

In [30]:
mlm_output_.shape

torch.Size([5, 128])

In [31]:
mlm_output_

tensor([[-2.4270e-01, -2.0843e-01, -3.7478e-01, -1.1358e+00, -2.1946e-01,
          3.3446e-01,  5.7034e-01, -3.8152e-01, -4.6428e-01, -1.5560e+00,
         -1.1907e+00, -1.5742e+00, -1.3044e+00, -2.2895e-01,  8.2193e-01,
         -1.4863e-01, -2.2180e-01, -4.9838e-01, -1.3370e+00,  2.8029e-01,
          2.1031e-01,  3.6513e-03,  3.7757e-01, -9.4559e-01, -4.5971e-01,
          1.5870e+00,  2.1090e+00, -6.2280e-01, -3.3466e-01,  2.6692e-01,
         -6.6542e-01, -5.7952e-01, -1.2446e+00, -9.5962e-01,  2.9304e-01,
          3.9016e-01, -1.9302e+00, -1.7927e-01,  4.0165e-01,  4.2341e-01,
          1.6429e-01, -7.9567e-02, -7.1369e-01,  1.0389e+00,  2.9375e-01,
          5.9075e-01,  2.8927e-01,  8.9392e-01,  5.3192e-01, -3.5230e-01,
          3.0128e-01,  1.8223e+00, -6.6174e-01, -2.4908e-01, -1.7012e+00,
         -3.8633e-01,  5.7321e-01, -1.4606e+00,  1.3026e+00,  7.7289e-02,
          4.5825e-01, -3.4294e-01, -1.7604e-01,  9.6182e-01, -7.6659e-01,
         -1.0352e-01,  1.6431e+00, -2.