## [0416] KoElectra를 이용한 첫 베이스라인 가다듬기

- **add_special_tokens 함수**: Entity 위치 정보를 활용해 [ENT],[/ENT] entity special token 추가

In [1]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece
!pip install transformers==3
!pip install torch



In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import pandas as pd
import numpy as np
import re
import tarfile
import pickle as pickle
from tqdm import tqdm
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split

# Using KoELECTRA Model
from transformers import ElectraModel, ElectraTokenizer, ElectraForSequenceClassification

# Added by Me
import os
from tqdm.notebook import tqdm
from ohsuz.utils import *
from ohsuz.loss import *
from ohsuz.config import *
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR

There are 1 GPU(s) available.
We will use the GPU: Tesla V100-PCIE-32GB
There are 1 GPU(s) available.
We will use the GPU: Tesla V100-PCIE-32GB


In [3]:
max_len = 128
batch_size = 16
warmup_ratio = 0.01
epochs = 10
max_grad_norm = 1
log_interval = 50
lr = 5e-5

### 1. Dataset & DataLoader 준비

**add_entity_tokens**
- input
    - entity token을 추가할 문장
    - 첫 번째 entity 시작, 끝 index
    - 두 번째 entity 시작, 끝 index
- output
    - 해당하는 index에 entity token이 추가된 문장

**make_embedding_layer**
- input
    - 문장의 input_ids
- output
    - entity에 해당하는 token이면 1, 아니면 0으로 나타내는 tensor

In [4]:
def make_embedding_layer(input_ids, tokenizer):
    flag = False
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    special_tokens = special_tokens_dict['additional_special_tokens']
    is_entity_layer = []

    for i, token in enumerate(tokens):
        if token in special_tokens:
            if flag == False:
                flag = True
            else:
                flag = False
        else:
            if flag == True:
                is_entity_layer.append(1)
                continue
        is_entity_layer.append(0)

    is_entity_layer = torch.tensor(is_entity_layer)
    return is_entity_layer

In [5]:
def add_entity_tokens(sentence, a1, a2, b1, b2):
    new_sentence = None
    special_tokens = special_tokens_dict['additional_special_tokens']
    
    if a1 > b1: # b1 먼저
        new_sentence = sentence[:b1] + special_tokens[2] + sentence[b1:b2+1] + special_tokens[3] + sentence[b2+1:a1] + special_tokens[0] + sentence[a1:a2+1] + special_tokens[1] + sentence[a2+1:]
        #new_sentence = sentence[:b1] + "$" + sentence[b1:b2+1] + "$" + sentence[b2+1:a1] + "#" + sentence[a1:a2+1] + "#" + sentence[a2+1:]
    else: # a1 먼저
        new_sentence = sentence[:a1] + special_tokens[0] + sentence[a1:a2+1] + special_tokens[1] + sentence[a2+1:b1] + special_tokens[2] + sentence[b1:b2+1] + special_tokens[3] + sentence[b2+1:]
    return new_sentence

In [6]:
def load_data(dataset_dir, add_entity=True):
    with open('/opt/ml/input/data/label_type.pkl', 'rb') as f:
        label_type = pickle.load(f)
    dataset = pd.read_csv(dataset_dir, delimiter='\t', header=None)
    dataset = preprocessing_dataset(dataset, label_type, add_entity)
    return dataset


def preprocessing_dataset(dataset, label_type, add_entity):
    label = []
    sentences = None
    for i in dataset[8]:
        if i == 'blind':
            label.append(100)
        else:
            label.append(label_type[i])
    
    if add_entity:
        ### 이 부분을 더 효율적으로 고치려면???
        sentences = [add_entity_tokens(dataset[1][i], dataset[3][i], dataset[4][i], dataset[6][i], dataset[7][i]) for i in range(len(dataset))]
    else:
        sentences = dataset[1]

    out_dataset = pd.DataFrame({'sentence':sentences,'entity_01':dataset[2],'entity_02':dataset[5],'label':label,})
    return out_dataset

**handle_UNK의 option: REMOVE, ADD, REPLACE**

In [7]:
class KoElecDataset(Dataset):
    def __init__(self, tsv_file, add_entity=True, handle_UNK='REMOVE'):
        self.dataset = load_data(tsv_file, add_entity)
        self.dataset['sentence'] = self.dataset['entity_01'] + ' [SEP] ' + self.dataset['entity_02'] + ' [SEP] ' + self.dataset['sentence']
        self.sentences = list(self.dataset['sentence'])
        self.labels = list(self.dataset['label'])
        self.tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator")
        self.tokenizer.add_special_tokens(special_tokens_dict)
        self.handle_UNK = handle_UNK
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        sentence, label = self.sentences[idx], self.labels[idx]
        inputs = self.tokenizer(
            sentence,
            return_tensors='pt',
            truncation=True,
            max_length=190,
            pad_to_max_length=True,
            add_special_tokens=True
        )
            
        input_ids = inputs['input_ids'][0]
        is_embedding_layer = make_embedding_layer(input_ids, self.tokenizer)
        attention_mask = inputs['attention_mask'][0] + is_embedding_layer
        
        return input_ids, attention_mask, label

In [8]:
dataset = KoElecDataset(os.path.join(train_dir, 'train.tsv'))
test_dataset = KoElecDataset(os.path.join(test_dir, 'test.tsv'))

**Train, Valid set 8 : 2 로 분리**

In [9]:
# train_dataset, val_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

In [10]:
# train_dataset.__getitem__(0)

In [11]:
#train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=5)
#val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=5)
train_loader = DataLoader(dataset, batch_size=batch_size, num_workers=5)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=5)

In [12]:
next(iter(train_loader))

[tensor([[    2, 10086,  4239,  ...,     0,     0,     0],
         [    2,  6802,     3,  ...,     0,     0,     0],
         [    2,  6755,  7325,  ...,     0,     0,     0],
         ...,
         [    2, 19729,  3210,  ...,     0,     0,     0],
         [    2, 16156,    21,  ...,     0,     0,     0],
         [    2,  2075,  4265,  ...,     0,     0,     0]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 tensor([17,  0,  6,  2,  8,  0, 17,  3, 10,  0,  4,  0, 16,  4,  0,  0])]

### 2. Model 준비

In [13]:
model = ElectraForSequenceClassification.from_pretrained("monologg/koelectra-base-v3-discriminator", num_labels=42).to(device)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'electra.embeddings.position_ids']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at monologg/koelectra-base-v3-discri

In [15]:
model.resize_token_embeddings(35004) # 내가 임의로 입력함 나중에 수정

Embedding(35004, 768)

### 3. Train

In [16]:
optimizer = AdamW(model.parameters(), lr=lr)
ls_loss = LabelSmoothingLoss()
cels_loss = CELSLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6) # 익효님꺼로 파라미터 변경
# scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=0.)

In [17]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [18]:
best_acc = 0.0

for epoch in range(epochs):
    train_acc = 0.0
    test_acc = 0.0
    
    model.train()
    
    for batch_id, (input_ids_batch, attention_masks_batch, y_batch) in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        y_batch = y_batch.to(device)
        # 생각해보니까 내가 추가적으로 만든 embedding layer를 입력으로 주려면 모델 내부 구조를 바꿔야 되지 않나...?
        # 우선 attention mask에 더해서 입력으로 줘보자
        y_pred = model(input_ids_batch.to(device), attention_mask=attention_masks_batch.to(device))[0]
        loss = cels_loss(y_pred, y_batch)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_acc += calc_accuracy(y_pred, y_batch)
        if batch_id % log_interval == 0:
            print(f"epoch {epoch+1} batch id {batch_id+1} loss {loss.data.cpu().numpy()} train acc {train_acc / (batch_id+1)}")

    train_acc = train_acc / (batch_id+1)
    print(f"epoch {epoch+1} train acc {train_acc}")
 
    """
    model.eval()
    for batch_id, (input_ids_batch, attention_masks_batch, y_batch) in tqdm(enumerate(val_loader)):
        y_batch = y_batch.to(device)
        y_pred = model(input_ids_batch.to(device), attention_mask=attention_masks_batch.to(device))[0]
        test_acc += calc_accuracy(y_pred, y_batch)
        
    print(f"epoch {epoch+1} test acc {test_acc / (batch_id+1)}")
    
    if test_acc >= best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), "/opt/ml/models/0416_koelectra_1.pt")
    """
    
    if train_acc >= best_acc:
        best_acc = train_acc
        torch.save(model.state_dict(), "/opt/ml/models/0416_koelectra_1.pt")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 1 batch id 1 loss 7.421921730041504 train acc 0.0625
epoch 1 batch id 51 loss 4.53863525390625 train acc 0.4693627450980392
epoch 1 batch id 101 loss 4.202578544616699 train acc 0.48205445544554454
epoch 1 batch id 151 loss 5.92665433883667 train acc 0.4830298013245033
epoch 1 batch id 201 loss 4.1099958419799805 train acc 0.4779228855721393
epoch 1 batch id 251 loss 3.4751670360565186 train acc 0.48705179282868527
epoch 1 batch id 301 loss 3.8756747245788574 train acc 0.4925249169435216
epoch 1 batch id 351 loss 2.4690780639648438 train acc 0.4976851851851852
epoch 1 batch id 401 loss 2.436473846435547 train acc 0.506857855361596
epoch 1 batch id 451 loss 2.8510470390319824 train acc 0.5192627494456763
epoch 1 batch id 501 loss 3.9574692249298096 train acc 0.5259481037924152
epoch 1 batch id 551 loss 2.321153163909912 train acc 0.5362976406533575

epoch 1 train acc 0.5385213143872114


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 2 batch id 1 loss 3.862265110015869 train acc 0.5
epoch 2 batch id 51 loss 3.675706624984741 train acc 0.6029411764705882
epoch 2 batch id 101 loss 2.862703323364258 train acc 0.6175742574257426
epoch 2 batch id 151 loss 3.450895071029663 train acc 0.6258278145695364
epoch 2 batch id 201 loss 3.6509718894958496 train acc 0.6305970149253731
epoch 2 batch id 251 loss 1.5159938335418701 train acc 0.6451693227091634
epoch 2 batch id 301 loss 2.681750774383545 train acc 0.6536544850498339
epoch 2 batch id 351 loss 1.4422131776809692 train acc 0.6582977207977208
epoch 2 batch id 401 loss 1.1767903566360474 train acc 0.6658354114713217
epoch 2 batch id 451 loss 2.254039764404297 train acc 0.6690687361419069
epoch 2 batch id 501 loss 2.7259836196899414 train acc 0.6724051896207585
epoch 2 batch id 551 loss 1.6676735877990723 train acc 0.6767241379310345

epoch 2 train acc 0.6779529307282416


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 3 batch id 1 loss 2.533492088317871 train acc 0.625
epoch 3 batch id 51 loss 2.5197136402130127 train acc 0.6985294117647058
epoch 3 batch id 101 loss 2.36226224899292 train acc 0.719059405940594
epoch 3 batch id 151 loss 1.9987423419952393 train acc 0.7239238410596026
epoch 3 batch id 201 loss 1.6767646074295044 train acc 0.7328980099502488
epoch 3 batch id 251 loss 0.8858680725097656 train acc 0.7427788844621513
epoch 3 batch id 301 loss 1.6513774394989014 train acc 0.7437707641196013
epoch 3 batch id 351 loss 0.4920813739299774 train acc 0.7491096866096866
epoch 3 batch id 401 loss 0.77606201171875 train acc 0.7563902743142145
epoch 3 batch id 451 loss 1.4938093423843384 train acc 0.7595620842572062
epoch 3 batch id 501 loss 3.432462692260742 train acc 0.7625998003992016
epoch 3 batch id 551 loss 0.6788500547409058 train acc 0.7662205081669692

epoch 3 train acc 0.7679840142095915


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 4 batch id 1 loss 1.570448875427246 train acc 0.75
epoch 4 batch id 51 loss 2.403407096862793 train acc 0.7720588235294118
epoch 4 batch id 101 loss 1.1024820804595947 train acc 0.7951732673267327
epoch 4 batch id 151 loss 1.3048317432403564 train acc 0.8054635761589404
epoch 4 batch id 201 loss 1.5080339908599854 train acc 0.8106343283582089
epoch 4 batch id 251 loss 0.6145176887512207 train acc 0.8187250996015937
epoch 4 batch id 301 loss 1.8409743309020996 train acc 0.8108388704318937
epoch 4 batch id 351 loss 0.5965124368667603 train acc 0.8167735042735043
epoch 4 batch id 401 loss 0.6687719821929932 train acc 0.8206047381546134
epoch 4 batch id 451 loss 1.1265949010849 train acc 0.8235864745011087
epoch 4 batch id 501 loss 2.4620442390441895 train acc 0.8263473053892215
epoch 4 batch id 551 loss 0.39743393659591675 train acc 0.8291742286751361

epoch 4 train acc 0.830595026642984


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 5 batch id 1 loss 1.0896188020706177 train acc 0.8125
epoch 5 batch id 51 loss 1.7890093326568604 train acc 0.8419117647058824
epoch 5 batch id 101 loss 0.8221771717071533 train acc 0.8508663366336634
epoch 5 batch id 151 loss 0.9588762521743774 train acc 0.8497516556291391
epoch 5 batch id 201 loss 1.6271013021469116 train acc 0.8526119402985075
epoch 5 batch id 251 loss 0.6431394815444946 train acc 0.8578187250996016
epoch 5 batch id 301 loss 1.026449203491211 train acc 0.8592192691029901
epoch 5 batch id 351 loss 0.6144285798072815 train acc 0.8628917378917379
epoch 5 batch id 401 loss 0.9620805978775024 train acc 0.864713216957606
epoch 5 batch id 451 loss 1.6454386711120605 train acc 0.8621119733924612
epoch 5 batch id 501 loss 2.1259446144104004 train acc 0.8614021956087824
epoch 5 batch id 551 loss 0.34556692838668823 train acc 0.8637704174228675

epoch 5 train acc 0.8644538188277087


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 6 batch id 1 loss 0.9451676607131958 train acc 0.875
epoch 6 batch id 51 loss 0.8313095569610596 train acc 0.8774509803921569
epoch 6 batch id 101 loss 0.6323969960212708 train acc 0.880569306930693
epoch 6 batch id 151 loss 0.6782960891723633 train acc 0.882864238410596
epoch 6 batch id 201 loss 0.6996987462043762 train acc 0.8883706467661692
epoch 6 batch id 251 loss 0.9633588790893555 train acc 0.8919322709163346
epoch 6 batch id 301 loss 0.9347851276397705 train acc 0.8920265780730897
epoch 6 batch id 351 loss 0.1512310951948166 train acc 0.8945868945868946
epoch 6 batch id 401 loss 0.5121674537658691 train acc 0.895573566084788
epoch 6 batch id 451 loss 1.190389633178711 train acc 0.8957871396895787
epoch 6 batch id 501 loss 1.6487905979156494 train acc 0.8958333333333334
epoch 6 batch id 551 loss 0.3489508032798767 train acc 0.8975725952813067

epoch 6 train acc 0.8983126110124334


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 7 batch id 1 loss 0.7067004442214966 train acc 0.875
epoch 7 batch id 51 loss 0.8358064889907837 train acc 0.9007352941176471
epoch 7 batch id 101 loss 0.9635580778121948 train acc 0.8966584158415841
epoch 7 batch id 151 loss 0.8095109462738037 train acc 0.8936258278145696
epoch 7 batch id 201 loss 0.6405416131019592 train acc 0.898320895522388
epoch 7 batch id 251 loss 0.6395898461341858 train acc 0.9001494023904383
epoch 7 batch id 301 loss 0.5316004753112793 train acc 0.9013704318936877
epoch 7 batch id 351 loss 0.3235449194908142 train acc 0.9050925925925926
epoch 7 batch id 401 loss 0.3003685474395752 train acc 0.9067955112219451
epoch 7 batch id 451 loss 0.6906238794326782 train acc 0.9070121951219512
epoch 7 batch id 501 loss 1.1221625804901123 train acc 0.9089321357285429
epoch 7 batch id 551 loss 0.13454334437847137 train acc 0.9090290381125227

epoch 7 train acc 0.9096358792184724


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 8 batch id 1 loss 0.5552057027816772 train acc 0.875
epoch 8 batch id 51 loss 0.3900566101074219 train acc 0.8970588235294118
epoch 8 batch id 101 loss 0.3365430235862732 train acc 0.9077970297029703
epoch 8 batch id 151 loss 1.7549057006835938 train acc 0.8981788079470199
epoch 8 batch id 201 loss 1.102691650390625 train acc 0.8958333333333334
epoch 8 batch id 251 loss 2.3128113746643066 train acc 0.8936752988047809
epoch 8 batch id 301 loss 0.5011834502220154 train acc 0.8955564784053156
epoch 8 batch id 351 loss 0.2679084241390228 train acc 0.8992165242165242
epoch 8 batch id 401 loss 0.7614998817443848 train acc 0.90321072319202
epoch 8 batch id 451 loss 0.5857945680618286 train acc 0.9068736141906873
epoch 8 batch id 501 loss 0.9870500564575195 train acc 0.9089321357285429
epoch 8 batch id 551 loss 0.07993300259113312 train acc 0.911978221415608

epoch 8 train acc 0.9128552397868561


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 9 batch id 1 loss 0.3934621810913086 train acc 1.0
epoch 9 batch id 51 loss 0.6492621898651123 train acc 0.9350490196078431
epoch 9 batch id 101 loss 0.1175699383020401 train acc 0.9405940594059405
epoch 9 batch id 151 loss 0.44494080543518066 train acc 0.9412251655629139
epoch 9 batch id 201 loss 0.5461031794548035 train acc 0.9406094527363185
epoch 9 batch id 251 loss 0.22196003794670105 train acc 0.9422310756972112
epoch 9 batch id 301 loss 0.3715178370475769 train acc 0.9424833887043189
epoch 9 batch id 351 loss 0.08383341133594513 train acc 0.9423076923076923
epoch 9 batch id 401 loss 0.5709728598594666 train acc 0.9431109725685786
epoch 9 batch id 451 loss 0.6494795083999634 train acc 0.9440133037694013
epoch 9 batch id 501 loss 0.7883056402206421 train acc 0.9451097804391217
epoch 9 batch id 551 loss 0.11173073947429657 train acc 0.9452132486388385

epoch 9 train acc 0.9456039076376554


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

epoch 10 batch id 1 loss 0.984125554561615 train acc 0.9375
epoch 10 batch id 51 loss 0.16507777571678162 train acc 0.9362745098039216
epoch 10 batch id 101 loss 0.10816280543804169 train acc 0.9436881188118812
epoch 10 batch id 151 loss 0.396756649017334 train acc 0.9466059602649006
epoch 10 batch id 201 loss 0.7647391557693481 train acc 0.9458955223880597
epoch 10 batch id 251 loss 0.159413680434227 train acc 0.9479581673306773
epoch 10 batch id 301 loss 0.27500540018081665 train acc 0.9493355481727574
epoch 10 batch id 351 loss 0.053770072758197784 train acc 0.9513888888888888
epoch 10 batch id 401 loss 0.5284467339515686 train acc 0.9515274314214464
epoch 10 batch id 451 loss 0.3354232609272003 train acc 0.9519124168514412
epoch 10 batch id 501 loss 0.6443986892700195 train acc 0.9518463073852296

epoch 10 train acc 0.9525976909413855


### **6. 예측**

In [21]:
model.load_state_dict(torch.load("/opt/ml/models/0416_koelectra_1.pt"))

model.eval()

predictions = []

for input_ids_batch, attention_masks_batch, y_batch in tqdm(test_loader):
    y_batch = y_batch.to(device)
    y_pred = model(input_ids_batch.to(device), attention_mask=attention_masks_batch.to(device))[0]
    _, predict = torch.max(y_pred, 1)
    predictions.extend(predict.tolist())

HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




In [22]:
submission = pd.DataFrame(predictions, columns=['pred'])
submission.to_csv(os.path.join(submission_dir, '0416_submission_2.csv'), index=False)