In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
from KoBERT.kobert.pytorch_kobert import get_pytorch_kobert_model
from kobert.utils import get_tokenizer
import numpy as np
from tqdm import tqdm
import pandas as pd
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import AdamW
from seqeval.metrics import f1_score
from models.layers.linears import PoolerEndLogits, PoolerStartLogits
from losses.focal_loss import FocalLoss
from losses.label_smoothing import LabelSmoothingCrossEntropy
from transformers import get_linear_schedule_with_warmup
from processors.utils_ner import bert_extract_item as bert_extract_item_pred
from metrics.ner_metrics import SpanEntityScore
from tools.common import logger

In [2]:
from processors.utils_ner import DataProcessor, get_entities
from processors.ner_span import InputExample, InputFeature
from torch.utils.data import TensorDataset
from processors.ner_span import convert_examples_to_features, CnerProcessor
from seqeval.metrics import precision_score, recall_score, f1_score, accuracy_score, classification_report
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import logging
import os
import copy
import json
from torchcrf import CRF

In [3]:
  label_list =['O','PDT','MOV','TRV']
  id2label = {i: label for i, label in enumerate(label_list)}
  metric = SpanEntityScore(id2label)

  start_answer = tf.math.argmax(start_pred,-1) # b * seq
  end_answer = tf.math.argmax(end_pred,-1)

  active_loss = attention_mask==1

  for i in range(active_loss.shape[0]):
    active_start_pred = tf.boolean_mask(start_answer[i], active_loss[i])
    active_end_pred = tf.boolean_mask(end_answer[i], active_loss[i])
    R = bert_extract_item(active_start_pred, active_end_pred)
    T = bert_extract_item(start_label[i], end_label[i])
    metric.update(true_subject=T, pred_subject=R)
  eval_info, entity_info = metric.result()

dd


In [4]:
epochs = 5
num = 4
max_len = 256

In [5]:
bertmodel, vocab = get_pytorch_kobert_model() # KoBERT 모델 불러오기
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model
using cached model
using cached model


In [6]:
class BertSpanForNer(nn.Module):
    def __init__(self, bert, hidden_size=768, num_classes=num, dr_rate=0.3, params=None):
        super(BertSpanForNer, self).__init__()
        self.soft_label = True
        self.num_labels = num_classes
        # loss_type = ['lsr', 'focal', 'ce']
        self.loss_type = 'ce'
        self.bert = bert
        self.dropout = nn.Dropout(dr_rate)
        self.start_fc = PoolerStartLogits(hidden_size, self.num_labels)
        if self.soft_label:
            self.end_fc = PoolerEndLogits(hidden_size + self.num_labels, self.num_labels)
        else:
            self.end_fc = PoolerEndLogits(hidden_size + 1, self.num_labels)

    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()
        
    def forward(self, input_ids, valid_length, token_type_ids=None, start_positions=None, end_positions=None):
        attention_mask = self.gen_attention_mask(input_ids, valid_length)
        outputs = self.bert(input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask.float().to(input_ids.device))
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        start_logits = self.start_fc(sequence_output)
        
        if start_positions is not None and self.training:
            if self.soft_label:
                batch_size = input_ids.size(0)
                seq_len = input_ids.size(1)
                label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
                label_logits.zero_()
                label_logits = label_logits.to(input_ids.device)
                label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
            else:
                label_logits = start_positions.unsqueeze(2).float()
        else:
            label_logits = F.softmax(start_logits, -1)
            if not self.soft_label:
                label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
                
        end_logits = self.end_fc(sequence_output, label_logits)
        outputs = (start_logits, end_logits,) + outputs[2:]

        if start_positions is not None and end_positions is not None:
            assert self.loss_type in ['lsr', 'focal', 'ce']
            if self.loss_type =='lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'focal':
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            start_logits = start_logits.view(-1, self.num_labels)
            end_logits = end_logits.view(-1, self.num_labels)
            active_loss = attention_mask.view(-1) == 1
            active_start_logits = start_logits[active_loss]
            active_end_logits = end_logits[active_loss]

            active_start_labels = start_positions.view(-1)[active_loss]
            active_end_labels = end_positions.view(-1)[active_loss]

            start_loss = loss_fct(active_start_logits, active_start_labels)
            end_loss = loss_fct(active_end_logits, active_end_labels)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

        return outputs

In [7]:
label_dict = {'O': 0,
 'UNK':1,
 'PDT-B': 2,
 'PDT-I': 3,
 'MOV-B': 4,
 'MOV-I': 5,
 'TRV-B': 6,
 'TRV-I': 7}

In [8]:
train_list_csv = pd.read_csv('train_data.tsv',delimiter='\t')
test_list_csv = pd.read_csv('test_data.tsv',delimiter='\t')

train_list_csv = train_list_csv.dropna(axis=0)
test_list_csv = test_list_csv.dropna(axis=0)

train_list_csv = train_list_csv.reset_index(drop=True)
test_list_csv = test_list_csv.reset_index(drop=True)

for i in range(len(train_list_csv['label'])):
    train_list_csv['label'][i] = train_list_csv['label'][i][1:-1].replace('\'','').replace(' ','').split(",")
for i in range(len(test_list_csv['label'])):
    test_list_csv['label'][i] = test_list_csv['label'][i][1:-1].replace('\'','').replace(' ','').split(",")


for i in range(len(train_list_csv['label'])):
    text_split = train_list_csv['text'][i].split()
    ex_label = []
    for j in range(len(train_list_csv['label'][i])):
        ex_label.append(label_dict[train_list_csv['label'][i][j]])
    train_list_csv['label'][i]=ex_label

for i in range(len(test_list_csv['label'])):
    text_split = test_list_csv['text'][i].split()
    ex_label = []
    for j in range(len(test_list_csv['label'][i])):
        ex_label.append(label_dict[test_list_csv['label'][i][j]])
    test_list_csv['label'][i]=ex_label

tr_tag = train_list_csv['label']
tr_sent = train_list_csv['text']
ts_tag = test_list_csv['label']
ts_sent = test_list_csv['text']

len(tr_tag), len(tr_sent), len(ts_tag), len(ts_sent)

(18390, 18390, 4599, 4599)

In [9]:
for i in range(len(tr_tag)):
    if len(tr_tag[i]) > max_len or len(tr_sent[i]) > max_len: # 문장의 길이가 512가 넘는 문장 제거
        del tr_sent[i]
        del tr_tag[i]

for i in range(len(ts_tag)):
    if len(ts_tag[i]) > max_len or len(ts_sent[i]) > max_len: # 문장의 길이가 512가 넘는 문장 제거
        del ts_sent[i]
        del ts_tag[i]

tr_sent = tr_sent.reset_index(drop=True)
tr_tag = tr_tag.reset_index(drop=True)

ts_sent = ts_sent.reset_index(drop=True)
ts_tag = ts_tag.reset_index(drop=True)

len(tr_tag), len(tr_sent), len(ts_tag), len(ts_sent)

(18029, 18029, 4514, 4514)

In [10]:
def make_label_token(sent, label):
    label_list = []
    sent_split = sent.split()
    for i in range(len(sent_split)):
        sent_tok = tok(sent_split[i])
        if label[i] !=0 and label[i]%2==0:
            label_list.append(label[i])
            for j in sent_tok[1:]:
                label_list.append(label[i]+1)
        else: 
            for j in sent_tok:
                label_list.append(label[i])
    return label_list

In [11]:
tr_label = []
ts_label = []

for i,j in zip(tr_sent, tr_tag):
    tr_label.append(make_label_token(i,j))

for i,j in zip(ts_sent, ts_tag):
    ts_label.append(make_label_token(i,j))

len(tr_label), len(tr_sent), len(ts_label), len(ts_sent)

(18029, 18029, 4514, 4514)

In [12]:
def start_end_pos(label):
    new = []
    ex = []
    result = []
    for i in range(len(label)):
        if label[i]%2==0 and label[i]!=0:
            if len(new) != 0:
                ex.append(new)
            new = []
            new.extend([int(label[i]/2),i])
        elif label[i]%2==1:
            new.append(i) 
            
    ex.append(new) 
    for i in ex:
        result.append([i[1],i[-1], i[0]])
    return result

def start_end_make(label):

    start_label = [0]*len(label)
    end_label = [0]*len(label)

    if 2 in label or 4 in label or 6 in label:
        result = start_end_pos (label)
        for start, end, tag in result:
            start_label[start] =  tag
            end_label[end] = tag

    
    start_label.insert(0,0)
    end_label.insert(0,0)
    return [start_label, end_label]

In [13]:
for i in range(len(tr_label)):
    tr_label[i] = start_end_make(tr_label[i])

for i in range(len(ts_label)):
    ts_label[i] = start_end_make(ts_label[i])

In [14]:
for i in range(len(tr_label)):
    padding_length = max_len - len(tr_label[i][0])
    tr_label[i][0] = np.array(tr_label[i][0] + ([0] * padding_length))
    tr_label[i][0] =tr_label[i][0].astype(np.int64) 
    tr_label[i][1] = np.array(tr_label[i][1] + ([0] * padding_length))
    tr_label[i][1] =tr_label[i][1].astype(np.int64) 

for i in range(len(ts_label)):
    padding_length = max_len - len(ts_label[i][0])
    ts_label[i][0] = np.array(ts_label[i][0] + ([0] * padding_length))
    ts_label[i][0] =ts_label[i][0].astype(np.int64) 
    ts_label[i][1] = np.array(ts_label[i][1] + ([0] * padding_length))
    ts_label[i][1] =ts_label[i][1].astype(np.int64) 

len(tr_label), len(tr_sent), len(ts_label), len(ts_sent)

(18029, 18029, 4514, 4514)

In [15]:
class BERTDataset(Dataset):
    def __init__(self, sent, tag, bert_tokenizer, max_len,
                pad, pair):
        transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_len, pad=pad, pair=pair)
        self.sentences = [transform([i]) for i in sent] #문장
        self.labels = tag

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i][0], )+ (self.labels[i][1], ))

    def __len__(self):
        return (len(self.labels))

In [16]:
train_data = pd.DataFrame([tr_sent.tolist(),tr_label]).T
test_data =  pd.DataFrame([ts_sent.tolist(),ts_label]).T

In [17]:
data_train = BERTDataset(tr_sent,tr_label, tok, max_len, True, False)
data_test = BERTDataset(ts_sent,ts_label, tok, max_len, True, False)

In [18]:
train_loader = DataLoader(data_train, batch_size=16, shuffle=True)
test_loader = DataLoader(data_test, batch_size=16, shuffle=False)

In [19]:
device = torch.device('cuda:0')
model = BertSpanForNer(bertmodel)
model.to(device)

BertSpanForNer(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(8002, 768, padding_idx=1)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True

In [20]:
t_total = len(train_loader)
weight_decay = 0.01
warmup_proportion = 0.1
learning_rate = 5e-5
adam_epsilon = 1e-8

In [21]:
no_decay = ["bias", "LayerNorm.weight"]
bert_parameters = model.bert.named_parameters()
start_parameters = model.start_fc.named_parameters()
end_parameters = model.end_fc.named_parameters()
optimizer_grouped_parameters = [
        {"params": [p for n, p in bert_parameters if not any(nd in n for nd in no_decay)],
         "weight_decay": weight_decay, 'lr': learning_rate},
        {"params": [p for n, p in bert_parameters if any(nd in n for nd in no_decay)], "weight_decay": 0.0
            , 'lr': learning_rate},

        {"params": [p for n, p in start_parameters if not any(nd in n for nd in no_decay)],
         "weight_decay": weight_decay, 'lr': 0.001},
        {"params": [p for n, p in start_parameters if any(nd in n for nd in no_decay)], "weight_decay": 0.0
            , 'lr': 0.001},

        {"params": [p for n, p in end_parameters if not any(nd in n for nd in no_decay)],
         "weight_decay": weight_decay, 'lr': 0.001},
        {"params": [p for n, p in end_parameters if any(nd in n for nd in no_decay)], "weight_decay": 0.0
            , 'lr': 0.001},
]

warmup_steps = int(t_total * warmup_proportion)
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

In [22]:
optimizer = AdamW(model.parameters(), lr=learning_rate)

In [23]:
label_list =['O','PDT','MOV','TRV']
id2label = {i: label for i, label in enumerate(label_list)}
metric = SpanEntityScore(id2label)

In [24]:
def bert_extract_item_true(start_logits, end_logits):
    S = []    
    for i, s_l in enumerate(start_logits):
        if s_l == 0:
            continue
        for j, e_l in enumerate(end_logits[i:]):
            if s_l == e_l:
                S.append((s_l, i, i + j))
                break
    return S

def score_result(y_start_pred, y_end_pred, y_start_true, y_end_true, y_len, loss):
    metric = SpanEntityScore(id2label)
        
    for i in range(len(y_start_pred)):
        start_logits = torch.tensor(np.array([y_start_pred[i][:y_len[i]].tolist()]), dtype=torch.long)
        end_logits = torch.tensor(np.array([ y_end_pred[i][:y_len[i]].tolist()]), dtype=torch.long)
        start_pos_batch = np.array(y_start_true[i][:y_len[i]].tolist())[1:-1]
        end_pos_batch = np.array(y_end_true[i][:y_len[i]].tolist())[1:-1]

        R = bert_extract_item_pred(start_logits, end_logits)
        T = bert_extract_item_true(start_pos_batch, end_pos_batch)


        metric.update(true_subject=T, pred_subject=R)

    print("\n")
    eval_info, entity_info = metric.result()

    results = {f'{key}': value for key, value in eval_info.items()}
    results['loss'] = loss
    print("***** Eval results *****")
    info = "-".join([f' {key}: {value:.4f} ' for key, value in results.items()])
    print(info)
    print("***** Entity results *****")
    for key in sorted(entity_info.keys()):
        print("******* %s results ********" % key)
        info = "-".join([f' {key}: {value:.4f} ' for key, value in entity_info[key].items()])
        print(info)
    return results

In [24]:
max = 0.0
model.zero_grad()

for i in range(epochs):
  print("epoch: "+str(i+1)+"/"+str(epochs))
  model.train()
  y_start_pred = []
  y_end_pred = []
  y_start_true = []
  y_end_true = []
  y_len = []
  batch_step = 0


  for input_ids_batch, valid_length_batch, segment_ids_batch, start_pos_batch, end_pos_batch in tqdm(train_loader):
    optimizer.zero_grad()
    loss, start_logits, end_logits = model(input_ids_batch.to(device), valid_length_batch.to(device), segment_ids_batch.to(device), start_pos_batch.to(device), end_pos_batch.to(device))
    loss.backward()

    total_loss = loss.item()
    batch_step +=1
    optimizer.step()
    #scheduler.step()  # Update learning rate schedul
    model.zero_grad()
    torch.cuda.empty_cache()

    y_start_pred.extend(start_logits)
    y_end_pred.extend(end_logits)
    y_start_true.extend(start_pos_batch)
    y_end_true.extend(end_pos_batch)
    y_len.extend(valid_length_batch)

  loss_train = total_loss / batch_step
  
  result_train = score_result(y_start_pred, y_end_pred, y_start_true, y_end_true, y_len, loss_train)

  
torch.save(model,"kobert_span")

epoch: 1/5


100%|██████████| 1127/1127 [07:28<00:00,  2.51it/s]




***** Eval results *****
 acc: 0.6254 - recall: 0.1627 - f1: 0.2582 - loss: 0.0000 
***** Entity results *****
******* MOV results ********
 acc: 0.4787 - recall: 0.0518 - f1: 0.0934 
******* PDT results ********
 acc: 0.6467 - recall: 0.2457 - f1: 0.3561 
******* TRV results ********
 acc: 0.6591 - recall: 0.0208 - f1: 0.0404 


100%|██████████| 283/283 [00:30<00:00,  9.35it/s]




***** Eval results *****
 acc: 0.6337 - recall: 0.4457 - f1: 0.5234 - loss: 0.0004 
***** Entity results *****
******* MOV results ********
 acc: 0.5182 - recall: 0.3950 - f1: 0.4483 
******* PDT results ********
 acc: 0.6984 - recall: 0.5224 - f1: 0.5977 
******* TRV results ********
 acc: 0.4865 - recall: 0.1101 - f1: 0.1796 
new f1 score
epoch: 2/5


100%|██████████| 1127/1127 [07:20<00:00,  2.56it/s]




***** Eval results *****
 acc: 0.6881 - recall: 0.4022 - f1: 0.5076 - loss: 0.0000 
***** Entity results *****
******* MOV results ********
 acc: 0.5571 - recall: 0.2571 - f1: 0.3518 
******* PDT results ********
 acc: 0.7376 - recall: 0.5253 - f1: 0.6136 
******* TRV results ********
 acc: 0.5955 - recall: 0.1342 - f1: 0.2191 


100%|██████████| 283/283 [00:29<00:00,  9.69it/s]




***** Eval results *****
 acc: 0.6370 - recall: 0.5125 - f1: 0.5680 - loss: 0.0003 
***** Entity results *****
******* MOV results ********
 acc: 0.6398 - recall: 0.2742 - f1: 0.3839 
******* PDT results ********
 acc: 0.6433 - recall: 0.6821 - f1: 0.6621 
******* TRV results ********
 acc: 0.5000 - recall: 0.1713 - f1: 0.2551 
new f1 score
epoch: 3/5


100%|██████████| 1127/1127 [07:25<00:00,  2.53it/s]




***** Eval results *****
 acc: 0.7635 - recall: 0.5544 - f1: 0.6424 - loss: 0.0000 
***** Entity results *****
******* MOV results ********
 acc: 0.6645 - recall: 0.4173 - f1: 0.5126 
******* PDT results ********
 acc: 0.8126 - recall: 0.6608 - f1: 0.7289 
******* TRV results ********
 acc: 0.6836 - recall: 0.3582 - f1: 0.4701 


100%|██████████| 283/283 [00:29<00:00,  9.75it/s]




***** Eval results *****
 acc: 0.6895 - recall: 0.4729 - f1: 0.5610 - loss: 0.0003 
***** Entity results *****
******* MOV results ********
 acc: 0.6369 - recall: 0.3145 - f1: 0.4211 
******* PDT results ********
 acc: 0.7391 - recall: 0.5831 - f1: 0.6519 
******* TRV results ********
 acc: 0.4175 - recall: 0.2630 - f1: 0.3227 
epoch: 4/5


100%|██████████| 1127/1127 [07:15<00:00,  2.59it/s]




***** Eval results *****
 acc: 0.8319 - recall: 0.6765 - f1: 0.7462 - loss: 0.0000 
***** Entity results *****
******* MOV results ********
 acc: 0.7587 - recall: 0.5533 - f1: 0.6400 
******* PDT results ********
 acc: 0.8701 - recall: 0.7604 - f1: 0.8115 
******* TRV results ********
 acc: 0.7890 - recall: 0.5664 - f1: 0.6594 


100%|██████████| 283/283 [00:30<00:00,  9.21it/s]




***** Eval results *****
 acc: 0.6394 - recall: 0.5087 - f1: 0.5666 - loss: 0.0003 
***** Entity results *****
******* MOV results ********
 acc: 0.5601 - recall: 0.3845 - f1: 0.4559 
******* PDT results ********
 acc: 0.7186 - recall: 0.6039 - f1: 0.6563 
******* TRV results ********
 acc: 0.3369 - recall: 0.2875 - f1: 0.3102 
epoch: 5/5


100%|██████████| 1127/1127 [07:28<00:00,  2.51it/s]




***** Eval results *****
 acc: 0.8803 - recall: 0.7573 - f1: 0.8142 - loss: 0.0000 
***** Entity results *****
******* MOV results ********
 acc: 0.8244 - recall: 0.6513 - f1: 0.7277 
******* PDT results ********
 acc: 0.9125 - recall: 0.8227 - f1: 0.8653 
******* TRV results ********
 acc: 0.8399 - recall: 0.7006 - f1: 0.7640 


100%|██████████| 283/283 [00:30<00:00,  9.39it/s]




***** Eval results *****
 acc: 0.5740 - recall: 0.5729 - f1: 0.5734 - loss: 0.0002 
***** Entity results *****
******* MOV results ********
 acc: 0.5032 - recall: 0.4516 - f1: 0.4760 
******* PDT results ********
 acc: 0.6551 - recall: 0.6594 - f1: 0.6573 
******* TRV results ********
 acc: 0.3148 - recall: 0.3976 - f1: 0.3514 
new f1 score


In [25]:
model = torch.load("kobert_span")

In [26]:
model.eval()

y_start_pred_test = []
y_end_pred_test = []
y_start_true_test = []
y_end_true_test = []
y_len_test = []
batch_step_test = 0

with torch.no_grad():
    for input_ids_batch_test, valid_length_batch_test, segment_ids_batch_test, start_pos_batch_test, end_pos_batch_test in tqdm(test_loader):
        loss, start_logits_test, end_logits_test = model(input_ids_batch_test.to(device), valid_length_batch_test.to(device), segment_ids_batch_test.to(device), start_pos_batch_test.to(device), end_pos_batch_test.to(device))

        total_loss_test = loss.item()
        batch_step_test +=1

        y_start_pred_test.extend(start_logits_test)
        y_end_pred_test.extend(end_logits_test)
        y_start_true_test.extend(start_pos_batch_test)
        y_end_true_test.extend(end_pos_batch_test)
        y_len_test.extend(valid_length_batch_test)

loss_test = total_loss_test / batch_step_test
score_result(y_start_pred_test, y_end_pred_test, y_start_true_test, y_end_true_test, y_len_test, loss_test)

100%|██████████| 283/283 [00:32<00:00,  8.81it/s]




***** Eval results *****
 acc: 0.6116 - recall: 0.5428 - f1: 0.5752 - loss: 0.0005 
***** Entity results *****
******* MOV results ********
 acc: 0.5033 - recall: 0.4353 - f1: 0.4668 
******* PDT results ********
 acc: 0.6639 - recall: 0.6467 - f1: 0.6552 
******* TRV results ********
 acc: 0.5308 - recall: 0.2110 - f1: 0.3020 


{'acc': 0.6116129032258064,
 'recall': 0.5427998854852563,
 'f1': 0.5751554679205219,
 'loss': 0.0004545488121652772}