In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from transformers import BertTokenizerFast

import glob
import pandas as pd
from sklearn.utils import shuffle
from tqdm.notebook import tqdm

fpath = './data/tokenizer_model'
tokenizer = BertTokenizerFast.from_pretrained(fpath,
                                              strip_accents=False,
                                              lowercase=False)

vocab_dim     = len(tokenizer.vocab)
seq_len       = 256
embedding_dim = 512
device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size    = 256

In [10]:
# https://github.com/GyuminJack/torchstudy/blob/main/06Jun/NER/src/data.py

import linecache

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from torch.nn.utils.rnn import pad_sequence

from sklearn.metrics import accuracy_score, f1_score

def load_tokenizer(tokenizer_path):
    loaded_tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path, strip_accents=False, lowercase=False)  # Must be False if cased model  # 로드
    return loaded_tokenizer

class KlueDataset_NER(Dataset):
    def __init__(self, vocab_txt_path, txt_path, *args, **kwargs):
        self.tokenizer = load_tokenizer(vocab_txt_path)
        self.max_seq_len = 256
        self.txt_path = txt_path
        
        self.cls_token_id  = self.tokenizer.cls_token_id
        self.sep_token_id  = self.tokenizer.sep_token_id
        self.pad_token_id  = self.tokenizer.pad_token_id
        
        self.bio_dict = {
                        '[PAD]' : 0,
                        'B-DT': 1,
                        'B-LC': 2,
                        'B-OG': 3,
                        'B-PS': 4,
                        'B-QT': 5,
                        'B-TI': 6,
                        'I-DT': 7,
                        'I-LC': 8,
                        'I-OG': 9,
                        'I-PS': 10,
                        'I-QT': 11,
                        'I-TI': 12,
                        'O': 13
                        }
        self.reverse_bio_dict = {v:k for k, v in self.bio_dict.items()}
        with open(self.txt_path, "r") as f:
            self._total_data = len(f.readlines())

    def __len__(self):
        return self._total_data

    def __getitem__(self, idx):
        raw_ko = linecache.getline(self.txt_path, idx + 1).strip()
        text, bio_string = raw_ko.split("\t")
        bio_tensor = [self.bio_dict[i] for i in bio_string.split(",")]
        
        sent = self.tokenizer.encode(text)[1:-1]
        pad_length = self.max_seq_len - len(sent)
        
        train = torch.tensor([self.cls_token_id] + sent + [self.sep_token_id] + [self.pad_token_id] * pad_length).long().contiguous()
        target = torch.tensor(bio_tensor + [self.pad_token_id] * pad_length).long().contiguous()
        
        segment_embedding = torch.zeros(target.size(0)).long()
        
        return train, target, segment_embedding

In [11]:
# https://inhyeokyoo.github.io/project/nlp/bert-issue/

import torch
import torch.nn as nn

class BERT(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, pad_token_id):
        super(BERT, self).__init__()
        self.pad_token_id  = pad_token_id
        self.nhead         = 8
        self.embedding     = BERTEmbedding(vocab_dim, seq_len, embedding_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=self.nhead, batch_first=True)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        
        
    def forward(self, data, segment_embedding):
        pad_mask  = BERT.get_attn_pad_mask(data, data, self.pad_token_id).repeat(self.nhead, 1, 1)
        embedding = self.embedding(data, segment_embedding)
        output    = self.encoder_block(embedding, pad_mask) 
        
        return output
    
    @staticmethod
    def get_attn_pad_mask(seq_q, seq_k, i_pad):
        batch_size, len_q = seq_q.size()
        batch_size, len_k = seq_k.size()
        pad_attn_mask = seq_k.data.eq(i_pad)
        pad_attn_mask= pad_attn_mask.unsqueeze(1).expand(batch_size, len_q, len_k)
        
        return pad_attn_mask

In [12]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_dim, seq_len, embedding_dim, dropout_rate=0.1, device=device):
        super(BERTEmbedding, self).__init__()
        self.seq_len       = seq_len
        self.vocab_dim     = vocab_dim
        self.embedding_dim = embedding_dim
        self.dropout_rate  = dropout_rate
        
        # vocab --> embedding
        self.token_embedding      = nn.Embedding(self.vocab_dim, self.embedding_dim) 
        self.token_dropout        = nn.Dropout(self.dropout_rate)    
        
        # seq len --> embedding
        self.positional_embedding = nn.Embedding(self.seq_len, self.embedding_dim)
        self.positional_dropout   = nn.Dropout(self.dropout_rate) 
        
        # segment (0, 1) --> embedding
        self.segment_embedding    = nn.Embedding(2, self.embedding_dim)
        self.segment_dropout      = nn.Dropout(self.dropout_rate) 
        
        
    def forward(self, data, segment_embedding):
        token_embedding      = self.token_embedding(data)
        token_embedding      = self.token_dropout(token_embedding)
        
        positional_encoding  = torch.arange(start=0, end=self.seq_len, step=1).long()
        # data의 device 정보 가져와서 처리
        positional_encoding  = positional_encoding.unsqueeze(0).expand(data.size()).to(device)
        positional_embedding = self.positional_embedding(positional_encoding)
        positional_embedding = self.positional_dropout(positional_embedding)
        
        segment_embedding    = self.segment_embedding(segment_embedding)
        segment_embedding    = self.segment_dropout(segment_embedding)
        
        return token_embedding + positional_embedding + segment_embedding

In [13]:
from torchcrf import CRF

class BertNER(nn.Module):
    def __init__(self, bert, bert_hidden_size, num_classes, use_LSTM=True):
        super().__init__()
        self.use_LSTM = use_LSTM
        self.bert     = bert
        self.fc       = nn.Linear(bert_hidden_size, num_classes)
        self.dropout  = nn.Dropout(0.1)
        self.lstm     = nn.LSTM(bert_hidden_size, num_classes, batch_first=True)
        self.crf      = CRF(num_tags=num_classes, batch_first=True)

    def forward(self, source, target, segment_embedding):
        source              = self.bert(source, segment_embedding)
        last_encoder_output = self.dropout(source)
        
        if self.use_LSTM:
            last_encoder_output, _ = self.lstm(last_encoder_output)

        emissions              = self.fc(last_encoder_output)
        log_likelihood, output = self.crf(emissions, target), self.crf.decode(emissions)
        
        return log_likelihood, torch.tensor(output)

In [None]:
# https://github.com/long8v/torch_study/blob/master/paper/06_BERT/source/finetune/ner_bert.py

def compute_metrics(y_pred_, y_test_):
    y_pred = torch.Tensor(y_pred_).view(-1).to('cpu')
    y_test = torch.Tensor(y_test_).view(-1).to('cpu')
    
    y_pred = y_pred[y_test != self.pad_idx]
    y_test = y_test[y_test != self.pad_idx]
    
    micro_score = f1_score(y_pred, y_test, average='micro')
    macro_score = f1_score(y_pred, y_test, average=None)
    accuracy_score = accuracy_score(y_pred, y_test)
    
    # macro에서 pad_idx에 대한 값은 평균 구할 때 빼줌
    macro_score = np.mean([score for idx, score in enumerate(macro_score) if idx != self.pad_idx])

    return {'accuracy' : accuracy_score, 'f1_micro': micro_score, 'f1_macro': macro_score}

In [14]:
def train(model, iterator, optimizer, device, clip=1):
    model.train()
    
    epoch_loss = 0
    epoch_acc  = 0
    epoch_f1   = 0
    
    for source, target, segment_embedding in tqdm(iterator, total=len(iterator)):
        optimizer.zero_grad()
        
        source = source.to(device)
        target = target.to(device)
        segment_embedding = segment_embedding.to(device)

        log_likelihood, output = model(source, target, segment_embedding) 
        
        loss = -1 * log_likelihood
        
        loss.backward()
                
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        metrics = compute_metrics(output, target)

        epoch_loss += loss.item()
        epoch_acc  += metrics['accuracy']
        epoch_f1   += metrics['f1_macro']
        
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator), epoch_f1 / len(iterator)


@torch.no_grad()
def evaluate(model, iterator, optimizer, device, clip=1):
    model.eval()
    
    epoch_loss = 0
    epoch_acc  = 0
    epoch_f1   = 0
    
    for source, target, segment_embedding in iterator:
        optimizer.zero_grad()
        
        source = source.to(device)
        target = target.to(device)
        segment_embedding = segment_embedding.to(device)

        log_likelihood, output = model(source, target, segment_embedding) 

        loss = -1 * log_likelihood
        
        metrics = compute_metrics(output, target)

        epoch_loss += loss.item()
        epoch_acc  += metrics['accuracy']
        epoch_f1   += metrics['f1_macro']
        
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator), epoch_f1 / len(iterator)

In [15]:
vocab_txt_path = "./data/tokenizer_model"

train_path = "./data/klue_ner_processed.train"
train_dataset = KlueDataset_NER(vocab_txt_path, train_path)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_path = "./data/klue_ner_processed.dev"
valid_dataset = KlueDataset_NER(vocab_txt_path, valid_path)
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size)

checkpoint = torch.load('./weights/BERT_LM_best.pt')
bert = BERT(vocab_dim=vocab_dim, seq_len=seq_len, embedding_dim=embedding_dim, pad_token_id=0).to(device)
bert.load_state_dict(checkpoint['bert'])

ner_head = BertNER(bert, embedding_dim, len(train_dataset.bio_dict)).to(device)

optimizer = torch.optim.AdamW(ner_head.parameters(), lr = 0.0001, weight_decay = 0.01)
scheduler = ReduceLROnPlateau(optimizer, 'min')

In [8]:
start_epoch = 0
N_EPOCHS  = 1000
PAITIENCE = 30

n_paitience = 0
best_valid_loss = float('inf')
optimizer.zero_grad()
optimizer.step()

for epoch in range(start_epoch, N_EPOCHS):
    train_loss, train_acc, train_f1 = train(ner_head, train_data_loader, optimizer, device)
    valid_loss, valid_acc, valid_f1 = evaluate(ner_head, valid_data_loader, optimizer, device)
    scheduler.step(valid_loss)

    print(f'Epoch: {epoch + 1:04}')
    print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1: .4f}')
    print(f'Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.4f} | Valid F1: {valid_f1: .4f}')

    with open("output/log_ner.txt", "a") as f:
        f.write("epoch: {0:04d} \ 
                train loss: {1:.4f}, train acc: {2:.4f}, train_f1: {3:.4f}   \
                valid loss: {4:.4f}, valid acc: {5:.4f}, valid_f1: {6: .4f}} \n".format(epoch, train_loss, train_acc, train_f1, valid_loss, valid_acc, valid_f1))


    if n_paitience < PAITIENCE:
        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            torch.save(
                {
                'ner_head' : ner_head.state_dict(),
                'optimizer': optimizer.state_dict()
                }, 'weights/BERT_ner_best.pt'
            )
            n_paitience = 0
        elif best_valid_loss <= valid_loss:
            n_paitience += 1
    else:
        print("Early stop!")
        checkpoint = torch.load('weights/BERT_ner_best.pt')
        ner_head.load_state_dict(checkpoint['ner_head'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        break

  0%|          | 0/83 [00:00<?, ?it/s]

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


Epoch: 0001
Train Loss: 10316.1697 | Train Acc: 0.8016
Valid Loss: 2992.3516 | Valid Acc: 0.8855


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0002
Train Loss: 2398.7176 | Train Acc: 0.9062
Valid Loss: 2395.5992 | Valid Acc: 0.9075


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0003
Train Loss: 1931.0802 | Train Acc: 0.9225
Valid Loss: 2126.0312 | Valid Acc: 0.9158


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0004
Train Loss: 1649.6522 | Train Acc: 0.9327
Valid Loss: 1967.0039 | Valid Acc: 0.9224


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0005
Train Loss: 1454.7549 | Train Acc: 0.9405
Valid Loss: 1887.0495 | Valid Acc: 0.9271


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0006
Train Loss: 1302.2208 | Train Acc: 0.9464
Valid Loss: 1821.0159 | Valid Acc: 0.9295


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0007
Train Loss: 1177.8361 | Train Acc: 0.9510
Valid Loss: 1825.4787 | Valid Acc: 0.9291


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0008
Train Loss: 1075.4959 | Train Acc: 0.9551
Valid Loss: 1776.8044 | Valid Acc: 0.9325


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0009
Train Loss: 987.2065 | Train Acc: 0.9585
Valid Loss: 1767.2415 | Valid Acc: 0.9348


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0010
Train Loss: 913.8428 | Train Acc: 0.9613
Valid Loss: 1757.3239 | Valid Acc: 0.9366


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0011
Train Loss: 836.5510 | Train Acc: 0.9643
Valid Loss: 1752.3896 | Valid Acc: 0.9374


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0012
Train Loss: 776.9342 | Train Acc: 0.9665
Valid Loss: 1768.3807 | Valid Acc: 0.9389


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0013
Train Loss: 724.6691 | Train Acc: 0.9689
Valid Loss: 1764.1499 | Valid Acc: 0.9395


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0014
Train Loss: 667.8331 | Train Acc: 0.9710
Valid Loss: 1792.0772 | Valid Acc: 0.9400


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0015
Train Loss: 618.8147 | Train Acc: 0.9732
Valid Loss: 1773.0927 | Valid Acc: 0.9403


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0016
Train Loss: 577.9724 | Train Acc: 0.9749
Valid Loss: 1835.9813 | Valid Acc: 0.9406


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0017
Train Loss: 541.7198 | Train Acc: 0.9763
Valid Loss: 1870.0714 | Valid Acc: 0.9399


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0018
Train Loss: 508.3576 | Train Acc: 0.9776
Valid Loss: 1847.8342 | Valid Acc: 0.9413


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0019
Train Loss: 484.3505 | Train Acc: 0.9787
Valid Loss: 1904.7685 | Valid Acc: 0.9407


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0020
Train Loss: 452.2369 | Train Acc: 0.9804
Valid Loss: 1892.4936 | Valid Acc: 0.9421


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0021
Train Loss: 423.5917 | Train Acc: 0.9813
Valid Loss: 1962.9563 | Valid Acc: 0.9426


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0022
Train Loss: 410.6223 | Train Acc: 0.9819
Valid Loss: 1902.5098 | Valid Acc: 0.9432


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0023
Train Loss: 370.4296 | Train Acc: 0.9838
Valid Loss: 1968.0586 | Valid Acc: 0.9429


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0024
Train Loss: 359.2646 | Train Acc: 0.9843
Valid Loss: 1978.6798 | Valid Acc: 0.9434


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0025
Train Loss: 352.9422 | Train Acc: 0.9845
Valid Loss: 1988.0291 | Valid Acc: 0.9436


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0026
Train Loss: 344.6962 | Train Acc: 0.9848
Valid Loss: 2008.5164 | Valid Acc: 0.9435


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0027
Train Loss: 348.4822 | Train Acc: 0.9848
Valid Loss: 1997.6107 | Valid Acc: 0.9435


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0028
Train Loss: 345.2984 | Train Acc: 0.9846
Valid Loss: 1995.3747 | Valid Acc: 0.9436


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0029
Train Loss: 336.5549 | Train Acc: 0.9850
Valid Loss: 2010.8467 | Valid Acc: 0.9436


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0030
Train Loss: 327.5774 | Train Acc: 0.9856
Valid Loss: 2016.5985 | Valid Acc: 0.9434


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0031
Train Loss: 330.1552 | Train Acc: 0.9856
Valid Loss: 2026.9738 | Valid Acc: 0.9434


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0032
Train Loss: 326.0173 | Train Acc: 0.9858
Valid Loss: 2054.1785 | Valid Acc: 0.9440


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0033
Train Loss: 327.8226 | Train Acc: 0.9858
Valid Loss: 2043.4496 | Valid Acc: 0.9438


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0034
Train Loss: 322.5582 | Train Acc: 0.9861
Valid Loss: 2038.4425 | Valid Acc: 0.9439


  0%|          | 0/83 [00:00<?, ?it/s]

Epoch: 0035
Train Loss: 318.5146 | Train Acc: 0.9861
Valid Loss: 2038.6860 | Valid Acc: 0.9439


  0%|          | 0/83 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [16]:
bert

BERT(
  (embedding): BERTEmbedding(
    (token_embedding): Embedding(32000, 512)
    (token_dropout): Dropout(p=0.1, inplace=False)
    (positional_embedding): Embedding(256, 512)
    (positional_dropout): Dropout(p=0.1, inplace=False)
    (segment_embedding): Embedding(2, 512)
    (segment_dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
    )
    (linear1): Linear(in_features=512, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=512, bias=True)
    (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (encoder_block): TransformerEncoder(
    (layers): ModuleList(
   