In [1]:
import os
import random
import time
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F

from torchcrf import CRF

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from transformers import ElectraModel, ElectraTokenizer

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

2023-03-20 10:18:01.308135: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Label Dicts

In [2]:
labels = ['[PAD]', 'E_B', 'E_I', 'O']
num_labels = len(labels)
id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in id2label.items()}

## Load BERT and Tokenizer

In [3]:
bert = ElectraModel.from_pretrained('monologg/koelectra-base-v3-discriminator', num_labels=4)
tokenizer = ElectraTokenizer.from_pretrained('tokenizer')

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Load Data and Preprocess for Training

In [4]:
data = pd.read_pickle('data/preprocessed.pkl')

In [5]:
data['tokens'] = data.tokens.apply(lambda x: ['[CLS]'] + x + ['[SEP]'])
data['labels'] = data.labels.apply(lambda x: ['O'] + x + ['O'])

In [6]:
data.tokens.apply(lambda x: len(x)).max()

233

In [7]:
max_len = 256

In [8]:
data['tokens'] = data.tokens.apply(lambda x: x + ['[PAD]'] * (max_len - len(x)))
data['labels'] = data.labels.apply(lambda x: x + ['[PAD]'] * (max_len - len(x)))

In [9]:
tokens_lst = data.tokens.to_list()
labels_lst = data.labels.to_list()

In [10]:
X_train, X_eval, y_train, y_eval = train_test_split(tokens_lst, 
                                                    labels_lst, 
                                                    test_size=0.2, shuffle=True, random_state=42)

In [11]:
train_data = []
for tokens, labels in zip(X_train, y_train):
    length = tokens.index('[PAD]')
    mask = [1] * length + [0] * (max_len - length)

    label_ids = []
    for label in labels:
        label_ids.append(label2id[label])
        
    train_data.append([tokenizer.convert_tokens_to_ids(tokens), mask, label_ids])

In [12]:
eval_data = []
for tokens, labels in zip(X_eval, y_eval):
    length = tokens.index('[PAD]')
    mask = [1] * length + [0] * (max_len - length)
    
    label_ids = []
    for label in labels:
        label_ids.append(label2id[label])
        
    eval_data.append([tokenizer.convert_tokens_to_ids(tokens), mask, label_ids])

In [13]:
# idx = random.randrange(0, len(train_data) - 1)
# for x, xm, y in zip(train_data[idx][0], train_data[idx][1], train_data[idx][2]):
#     print(x, xm, y)

# idx = random.randrange(0, len(eval_data) - 1)
# for x, xm, y in zip(eval_data[idx][0], eval_data[idx][1], eval_data[idx][2]):
#     print(x, xm, y)

## HP Config

In [14]:
batch_size = 16
LEARNING_RATE = 5e-5
N_EPOCHS = 15

## Create Dataset and Generate Dataloader

In [15]:
class TaggerDataset(Dataset): 
    def __init__(self, data):
        self.data = data
    
    def __len__(self): 
        return len(self.data)

    def __getitem__(self, idx):
        input_ids = self.data[idx][0]
        mask = self.data[idx][1]
        label_ids = self.data[idx][2]
        return (torch.LongTensor(input_ids), torch.LongTensor(mask), torch.LongTensor(label_ids))

In [16]:
train_dataset = TaggerDataset(train_data)
eval_dataset = TaggerDataset(eval_data)

In [17]:
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
eval_loader = DataLoader(eval_dataset, batch_size = batch_size, shuffle = True)

## Instantiate Model

In [18]:
class BERT_BiLSTM_CRF(nn.Module):
    
    def __init__(self, bert, config, need_birnn=False, rnn_dim=128):
        super(BERT_BiLSTM_CRF, self).__init__()
        
        self.num_tags = config.num_labels
        self.bert = bert
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size
        self.need_birnn = need_birnn

        # if False, no use of BiLSTM
        if need_birnn:
            self.birnn = nn.LSTM(config.hidden_size, rnn_dim, num_layers=1, bidirectional=True, batch_first=True)
            out_dim = rnn_dim*2
        
        self.hidden2tag = nn.Linear(out_dim, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)
    
    def predict(self, input_ids, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        return self.crf.decode(emissions, input_mask.byte())

    def forward(self, input_ids, tags, input_mask=None):
        emissions = self.tag_outputs(input_ids, input_mask)
        loss = -1*self.crf(emissions, tags.long(), input_mask.byte()) # negative log likelihood loss , reduction='mean' default 'sum'
        return loss.unsqueeze(0)

    def tag_outputs(self, input_ids, input_mask=None):
        outputs = self.bert(input_ids, attention_mask=input_mask)
        sequence_output = outputs[0]
        
        if self.need_birnn:
            self.birnn.flatten_parameters()
            sequence_output, _ = self.birnn(sequence_output)

        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)
        return emissions

In [19]:
bert.resize_token_embeddings(len(tokenizer))
config = bert.config
print(config.num_labels)
print(config.hidden_dropout_prob)
print(config.hidden_size)

model = BERT_BiLSTM_CRF(bert, config, need_birnn=True, rnn_dim=128)

4
0.1
768


## Optimizer / Criterion / Scheduler

In [20]:
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=0, factor=0.7, min_lr=0)

## DataParallel

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NGPU = torch.cuda.device_count()
if NGPU > 1:
    model = torch.nn.DataParallel(model, device_ids=list(range(NGPU)))
    # model = torch.nn.DataParallel(model, device_ids=[0,1])
    # torch.multiprocessing.set_start_method('spawn', force=True)
model = model.to(device)

## Functions

In [22]:
def categorical_accuracy(preds, y, tag_pad_idx):
    non_pad_elements = torch.nonzero((y != tag_pad_idx))
    correct = preds[non_pad_elements].eq(y[non_pad_elements])
    return correct.sum() / torch.FloatTensor([y[non_pad_elements].shape[0]]).to(device)

In [23]:
def categorical_f1(preds, y, tag_pad_idx):
    non_pad_elements = torch.nonzero((y != tag_pad_idx))
    preds_no_pad = preds[non_pad_elements].squeeze(1).detach().cpu()
    y_no_pad = y[non_pad_elements].detach().cpu()
    
    f1_macro = f1_score(y_no_pad, preds_no_pad, average='macro')
    f1_micro = f1_score(y_no_pad, preds_no_pad, average='micro')    
    
    return f1_macro, f1_micro

In [24]:
def train(model, iterator, optimizer, tag_pad_idx):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    predictions_set = None
    tags_set = None
    for batch in iterator:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        tags = batch[2].to(device)

        loss = model(input_ids, tags, attention_mask).mean() / batch_size
        
        predictions = model.module.predict(input_ids, attention_mask)
        predictions = list(map(lambda x: x + [0 for _ in range(max_len - len(x))], predictions))
        predictions = torch.LongTensor(predictions).to(device)
        predictions = predictions.view(-1)
        tags = tags.view(-1)
        if predictions_set == None:
            predictions_set = predictions
            tags_set = tags
        else:
            predictions_set = torch.cat([predictions_set, predictions], dim=0)
            tags_set = torch.cat([tags_set, tags], dim=0)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    f1_macro, f1_micro = categorical_f1(predictions_set, tags_set, tag_pad_idx)
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator), f1_macro, f1_micro

In [25]:
def evaluate(model, iterator, tag_pad_idx):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    predictions_set = None
    tags_set = None
    with torch.no_grad():
        for batch in iterator:
            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            tags = batch[2].to(device)
            
            loss = model(input_ids, tags, attention_mask).mean() / batch_size
            
            predictions = model.module.predict(input_ids, attention_mask)
            predictions = list(map(lambda x: x + [0 for _ in range(max_len - len(x))], predictions))
            predictions = torch.LongTensor(predictions).to(device)
            predictions = predictions.view(-1)
            tags = tags.view(-1)
            if predictions_set == None:
                predictions_set = predictions
                tags_set = tags
            else:
                predictions_set = torch.cat([predictions_set, predictions], dim=0)
                tags_set = torch.cat([tags_set, tags], dim=0)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        f1_macro, f1_micro = categorical_f1(predictions_set, tags_set, tag_pad_idx)
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator), f1_macro, f1_micro

In [26]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

## Train

In [27]:
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc, train_f1_mac, train_f1_mic = train(model, train_loader, optimizer, 0)
    valid_loss, valid_acc, valid_f1_mac, valid_f1_mic = evaluate(model, eval_loader, 0)
    
    cur_lr = scheduler.optimizer.state_dict()['param_groups'][0]['lr']
    scheduler.step(valid_loss)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'event_tagger.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'Learning Rate: {cur_lr}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%', end=' | ')
    print(f'Train F1 Mac: {train_f1_mac*100:.2f}% | Train F1 Mic: {train_f1_mic*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%', end=' |  ')
    print(f'Val. F1 Mac: {valid_f1_mac*100:.2f}% |  Val. F1 Mic: {valid_f1_mic*100:.2f}%', end='\n\n')

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


Epoch: 01 | Epoch Time: 5m 12s
Learning Rate: 5e-05
	Train Loss: 6.470 | Train Acc: 80.35% | Train F1 Mac: 38.09% | Train F1 Mic: 80.47%
	 Val. Loss: 3.103 |  Val. Acc: 90.97% |  Val. F1 Mac: 82.79% |  Val. F1 Mic: 90.96%

Epoch: 02 | Epoch Time: 4m 13s
Learning Rate: 5e-05
	Train Loss: 2.416 | Train Acc: 92.88% | Train F1 Mac: 86.46% | Train F1 Mic: 92.84%
	 Val. Loss: 2.346 |  Val. Acc: 92.98% |  Val. F1 Mac: 86.91% |  Val. F1 Mic: 92.98%

Epoch: 03 | Epoch Time: 4m 21s
Learning Rate: 5e-05
	Train Loss: 1.571 | Train Acc: 95.43% | Train F1 Mac: 91.15% | Train F1 Mic: 95.43%
	 Val. Loss: 2.384 |  Val. Acc: 92.96% |  Val. F1 Mac: 87.16% |  Val. F1 Mic: 92.92%

Epoch: 04 | Epoch Time: 4m 25s
Learning Rate: 3.5e-05
	Train Loss: 1.043 | Train Acc: 97.05% | Train F1 Mac: 94.15% | Train F1 Mic: 97.05%
	 Val. Loss: 2.251 |  Val. Acc: 93.81% |  Val. F1 Mac: 87.84% |  Val. F1 Mic: 93.65%

Epoch: 05 | Epoch Time: 4m 20s
Learning Rate: 3.5e-05
	Train Loss: 0.797 | Train Acc: 97.91% | Train F1 Ma

In [28]:
# torch.save(model.state_dict(), 'last.pt')