In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import BertTokenizerFast

import glob
import pandas as pd
from sklearn.utils import shuffle
from tqdm.notebook import tqdm

fpath='data/tokenizer_model'
tokenizer = BertTokenizerFast.from_pretrained(fpath,
                                              strip_accents=False,
                                              lowercase=False)

vocab_dim     = len(tokenizer.vocab)
seq_len       = 256
embedding_dim = 512
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size    = 64

In [2]:
class BERTLangaugeModelDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, seq_len=128, masking_rate=0.15, NSP_rate=0.5):
        super(BERTLangaugeModelDataset, self).__init__()

        self.data         = data        
        self.tokenizer    = tokenizer
        self.vocab        = tokenizer.vocab
        self.seq_len      = seq_len
        self.masking_rate = masking_rate
        self.NSP_rate     = NSP_rate
        
        self.cls_token_id  = self.tokenizer.cls_token_id
        self.sep_token_id  = self.tokenizer.sep_token_id
        self.pad_token_id  = self.tokenizer.pad_token_id
        self.mask_token_id = self.tokenizer.mask_token_id
        
    def __getitem__(self, sent_1_idx):       
        sent_1 = self.tokenizer.encode(self.data[sent_1_idx])[1:-1]
        sent_2_idx = sent_1_idx + 1
        
        # NSP
        if torch.rand(1) >= self.NSP_rate and sent_2_idx != len(self.data):
            is_next = torch.tensor(1)
        else:
            while sent_2_idx == sent_1_idx + 1:
                sent_2_idx = torch.randint(0, len(self.data), (1,))
            is_next = torch.tensor(0)

        sent_2 = self.tokenizer.encode(self.data[sent_2_idx])[1:-1]
        
        # if length of (sent 1 + sent 2) longer than threshold
        # CLS, SEP 1 and 2
        if len(sent_1) + len(sent_2) >= self.seq_len - 3:
            if len(sent_1) >= self.seq_len -3:
                sent_1 = sent_1[:int(self.seq_len/2)]
                
            sent_2 = sent_2[:self.seq_len - 3 - len(sent_1)]
        
        pad_length = self.seq_len - 3 - len(sent_1) - len(sent_2)
        target = torch.tensor([self.cls_token_id] + sent_1 + [self.sep_token_id] + sent_2 + [self.sep_token_id] + [self.pad_token_id] * pad_length).long().contiguous()        

        sengment_embedding = torch.zeros(target.size(0))
        sengment_embedding[(len(sent_1) + 2):] = 1
        
        # masking
        masked_sent_1, masking_label_sent_1 = self.masking(sent_1)
        masked_sent_2, masking_label_sent_2 = self.masking(sent_2)
        
        masking_label = torch.cat([
            torch.tensor([self.pad_token_id]),
            masking_label_sent_1, 
            torch.tensor([self.sep_token_id]),
            masking_label_sent_2,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ])
        
        # MLM
        train = torch.cat([
            torch.tensor([self.cls_token_id]), 
            masked_sent_1,
            torch.tensor([self.sep_token_id]),
            masked_sent_2,
            torch.tensor([self.sep_token_id]),
            torch.tensor([self.pad_token_id] * pad_length)
        ]).long().contiguous()
        
        return train, target, sengment_embedding, is_next, masking_label
        
    
    def __len__(self):
        return len(self.data)
    
    
    def __iter__(self):
        for x in self.data:
            yield x
            
    
    def get_vocab(self):
        return self.vocab
    
    
    def decode(self, x):
        return self.tokenizer.batch_decode(x)
    
    
    # TODO mask 안에서 random 으로 바꿔주는 것 추가
    def masking(self, x):
        x = torch.tensor(x).long().contiguous()
        masking_idx   = torch.randperm(x.size()[0])[:round(x.size()[0] * self.masking_rate) + 1]
        masking_label = torch.zeros(x.size()[0])
        masking_label[masking_idx] = 1
        x = x.masked_fill(masking_label.bool(), self.mask_token_id)
        
        return x, masking_label

In [3]:
# https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 8
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [4]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [5]:
class MaskedLanguageModeling(nn.Module):
    def __init__(self, bert, output_dim):
        super(MaskedLanguageModeling, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
    
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output
    
    
class NextSentencePrediction(nn.Module):
    def __init__(self, bert, output_dim=2):
        super(NextSentencePrediction, self).__init__()
        self.bert = bert
        d_model   = bert.embedding.token_embedding.weight.size(1)
        self.fc   = nn.Linear(d_model, output_dim)
        
    def forward(self, x, segment_embedding):
        output = self.bert(x, segment_embedding)
        output = self.fc(output)
        
        return output[:, 0, :] # CLS token 

In [6]:
def train(mlm_head, nsp_head, iterator, optimizer, criterion, device, clip=1):
    mlm_head.train()
    nsp_head.train()
    
    mlm_epoch_loss = 0
    nsp_epoch_loss = 0
    
    for X, y_mlm, segment_emb, y_nsp, masking_label in tqdm(iterator, total=len(iterator)):
        optimizer.zero_grad()
        
        mlm_output = mlm_head(X.to(device), segment_emb.long().to(device))     
        output_dim = mlm_output.shape[-1]
        
        #         mlm_output = mlm_output.reshape(-1, mlm_output.shape[-1])
        #         mlm_loss   = criterion(mlm_output, y_mlm.to(device).reshape(-1)) # CE
        
        mlm_output = mlm_output[masking_label.bool().to(device)].reshape(-1, output_dim)
        mlm_target = y_mlm.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        mlm_loss   = criterion(mlm_output, mlm_target)
        
        nsp_output = nsp_head(X.to(device), segment_emb.long().to(device))
        nsp_loss   = criterion(nsp_output, y_nsp.to(device)) # no need for reshape target
        
        loss = mlm_loss + nsp_loss
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(mlm_head.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(nsp_head.parameters(), clip)
        
        optimizer.step()

        mlm_epoch_loss += mlm_loss.item()
        nsp_epoch_loss += nsp_loss.item()

    return mlm_epoch_loss / len(iterator), nsp_epoch_loss / len(iterator)


@torch.no_grad()
def evaluate(mlm_head, nsp_head, iterator, optimizer, criterion, device, clip=1):
    mlm_head.eval()
    nsp_head.eval()
    
    mlm_epoch_loss = 0
    nsp_epoch_loss = 0
    
    for X, y_mlm, segment_emb, y_nsp, masking_label in iterator:
        optimizer.zero_grad()
        
        mlm_output = mlm_head(X.to(device), segment_emb.long().to(device))
        output_dim = mlm_output.shape[-1]
        
        #         mlm_output = mlm_output.reshape(-1, mlm_output.shape[-1])
        #         mlm_loss   = criterion(mlm_output, y_mlm.to(device).reshape(-1)) # CE
        
        mlm_output = mlm_output[masking_label.bool().to(device)].reshape(-1, output_dim)
        mlm_target = y_mlm.reshape(-1)[masking_label.reshape(-1).bool()].to(device)
        
        mlm_loss   = criterion(mlm_output, mlm_target)
        
        nsp_output = nsp_head(X.to(device), segment_emb.long().to(device))
        nsp_loss   = criterion(nsp_output, y_nsp.to(device)) # no need for reshape target
        
        mlm_epoch_loss += mlm_loss.item()
        nsp_epoch_loss += nsp_loss.item()

    return mlm_epoch_loss / len(iterator), nsp_epoch_loss / len(iterator)


@torch.no_grad()
def predict(mlm_head, iterator, device, tokenizer):
    mlm_head.eval()
    
    for X, y_mlm, segment_emb, y_nsp, masking_label in iterator:
        output = mlm_head(X.to(device), segment_emb.long().to(device))
    
        target_ = y_mlm.clone().detach().to("cpu")
        output_ = torch.argmax(output.clone().detach().to("cpu"), axis=-1)
        
        target_decode = tokenizer.batch_decode(target_)
        output_decode = tokenizer.batch_decode(output_)

    return target_decode, output_decode

In [7]:
def generate_epoch_prediction_dataloader(data, seq_len, tokenizer, masking_rate, batch_size, collate_fn, shuffle=True, num_workers=5):    
    dataset    = BERTLangaugeModelDataset(data=data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=masking_rate)
    dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn)
    
    return dataloader


def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    
    return torch.utils.data.dataloader.default_collate(batch)

In [8]:
with open("data/petitions.txt", 'r') as f:
    data = f.readlines()
    
proced_data = [line.replace("\n", "") for line in data]

train_data = proced_data[int(len(proced_data) * 0.2):]
test_data  = proced_data[:int(len(proced_data) * 0.2)]

train_dataset = BERTLangaugeModelDataset(data=train_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.4)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

valid_dataset = BERTLangaugeModelDataset(data=test_data, seq_len=seq_len, tokenizer=tokenizer, masking_rate=0.3)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [9]:
bert      = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=0).to(device)
mlm_head  = MaskedLanguageModeling(bert, output_dim=vocab_dim).to(device)
nsp_head  = NextSentencePrediction(bert).to(device)
optimizer = optim.AdamW(list(mlm_head.parameters()) + list(nsp_head.parameters()), lr=1e-4, betas=[0.9, 0.999], weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min')
criterion = nn.CrossEntropyLoss()

  super(AdamW, self).__init__(params, defaults)


In [None]:
start_epoch = 0
if len(glob.glob("output/*.tsv")) != 0:
    print("load pretrained model ... ")
    start_epoch = len(glob.glob("output/*.tsv"))
    checkpoint = torch.load('weights/BERT_LM_best.pt')
    mlm_head.load_state_dict(checkpoint['mlm_head'])
    nsp_head.load_state_dict(checkpoint['nsp_head'])
    optimizer.load_state_dict(checkpoint['optimizer'])

N_EPOCHS  = 1000
PAITIENCE = 30

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

for epoch in range(start_epoch, N_EPOCHS):
    train_mlm_loss, train_nsp_loss = train(mlm_head, nsp_head, train_dataloader, optimizer, criterion, device)
    valid_mlm_loss, valid_nsp_loss = evaluate(mlm_head, nsp_head, valid_dataloader, optimizer, criterion, device)
    
    valid_loss = valid_mlm_loss + valid_nsp_loss
    scheduler.step(valid_loss)
    
    print(f'Epoch: {epoch + 1:04}')
    print(f'Train MLM Loss: {train_mlm_loss:.4f} | Train NSP Loss: {train_nsp_loss:.4f}')
    print(f'Valid MLM Loss: {valid_mlm_loss:.4f} | Valid NSP Loss: {valid_nsp_loss:.4f}')
    
    with open("output/log.txt", "a") as f:
        f.write("epoch: {0:04d} train mlm loss: {1:.4f}, train nsp loss: {2:.4f}, valid mlm loss: {3:.4f}, valid nsp loss: {4:.4f} \n".format(epoch, train_mlm_loss, train_nsp_loss, valid_mlm_loss, valid_nsp_loss))

    if epoch % 1 == 0:
        print("Predictions ...\n")
        sampled_for_prediction = shuffle(test_data, n_samples=20)
        prediction_dataloader  = generate_epoch_prediction_dataloader(
                                                                        sampled_for_prediction, 
                                                                        seq_len=seq_len, 
                                                                        tokenizer=tokenizer, 
                                                                        batch_size=len(sampled_for_prediction), 
                                                                        masking_rate=0.5, 
                                                                        collate_fn=collate_fn
                                                                        )
        output_list, target_list = predict(mlm_head, prediction_dataloader, device, tokenizer)
        prediction_results = pd.DataFrame({"output": output_list, "target": target_list})
        prediction_results.to_csv("output/prediction_results_epoch-{0:04d}.tsv".format(epoch), sep="\t", index=False)            
        
    
    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(
                {
                'mlm_head' : mlm_head.state_dict(),
                'nsp_head' : nsp_head.state_dict(),
                'optimizer': optimizer.state_dict()
                }, 'weights/BERT_LM_best.pt'
            )
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        checkpoint = torch.load('weights/BERT_LM_best.pt')
        mlm_head.load_state_dict(checkpoint['mlm_head'])
        nsp_head.load_state_dict(checkpoint['nsp_head'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        break

  0%|          | 0/1941 [00:00<?, ?it/s]

Epoch: 0001
Train MLM Loss: 6.8506 | Train NSP Loss: 0.7150
Valid MLM Loss: 6.6898 | Valid NSP Loss: 0.6940
Predictions ...



  0%|          | 0/1941 [00:00<?, ?it/s]