In [1]:
# For using Mecab
# !curl -s https://raw.githubusercontent.com/teddylee777/machine-learning/master/99-Misc/01-Colab/mecab-colab.sh | bash
# !pip install -U "jpype1<1.1"

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import random
import torch
import numpy as np

def fix_seed(seed: int) -> None:
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
fix_seed(1004)

In [3]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
import re
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
from itertools import combinations
from deep_utils import stratify_train_test_split_multi_label
import copy


class T5NERDataset(Dataset):
    def __init__(self, args, data_name, tokenizer, kfold_idx = None):
        self.args = args
        self.data_name = data_name
        self.max_len = args.max_len
        self.batch_size = args.batch_size
        self.tokenizer = tokenizer
        self.eng_to_kor = {'PS': '사람', 'LC': '위치', 'OG': '기관', 'DT': '날짜', 'TI': '시간', 'QT': '수량'}
        
        if data_name in ['train', 'val']:
            data_path = os.path.join(args.data_path, 'klue_ner_train_80.txt')
        else:
            data_path = os.path.join(args.data_path, 'klue_ner_test_20.txt')
        f = open(data_path)
        self.raw_data = f.readlines()
        self.raw_df = self._prepare_df(self.raw_data)
        self.original_df = copy.deepcopy(self.raw_df)
        # ----- Train : Val Split ------- #
        self.raw_df["id"] = self.raw_df.index
        self.raw_df["y"] = self.raw_df["tags"].apply(lambda x : convert_tags_to_vector(x))
        y = np.array([np.array(ls) for ls in self.raw_df["y"]])
        train_X, test_X, train_y, test_y = stratify_train_test_split_multi_label(self.raw_df["id"], y, test_size=args.val_ratio)
        if data_name == 'train':
            self.raw_df = self.raw_df.loc[train_X]
        if data_name == 'val':
            self.raw_df = self.raw_df.loc[test_X]
        # -------------------------------- #
        
        self._set_input_text()
        print(len(self.raw_df), len(self.input_text))
            
    def __getitem__(self, idx):
        return self._preprocess(self.input_text[idx])

    def __len__(self):
        return len(self.raw_data)

    def _preprocess(self, input_text) :
        encoder_text, decoder_target = input_text
        encoder_inputs = self.tokenizer(encoder_text, max_length = self.args.max_len, padding = "max_length", truncation = True, return_tensors = 'pt')
        decoder_inputs = self.tokenizer(decoder_target, max_length = self.args.max_len, padding = "max_length", truncation = True, return_tensors = 'pt')
        
        src_ids, src_mask = encoder_inputs["input_ids"], encoder_inputs["attention_mask"]
        tgt_ids = decoder_inputs["input_ids"]
        return {'src_ids': src_ids.squeeze(), 'src_mask': src_mask.squeeze(), 'tgt_ids': tgt_ids.squeeze()}

    def _get_special_tokens(self, args) :
        self.sep = self.tokenizer.sep_token
        self.eos = self.tokenizer.eos_token
        self.input_pad_id = self.tokenizer.pad_token_id
        self.target_pad_id = -1004
    
    def _set_input_text(self) :
        ner_inputs, ner_outputs = self._get_ner_inputs()
        if self.data_name in  ['val', 'test']:
            self.input_text = [(enc, dec) for enc, dec in zip(ner_inputs, ner_outputs)]
            return
        ee_inputs, ee_outputs = self._get_ee_inputs()
        et_inputs, et_outputs = self._get_et_inputs()
        pos_inputs, pos_outputs = self._get_pos_inputs()
        hmn_inputs, hmn_outputs = self._get_hmn_inputs()
        inputs, outputs = ner_inputs+ee_inputs+et_inputs+hmn_inputs, ner_outputs+ee_outputs+et_outputs+hmn_outputs
        self.input_text = [(enc, dec) for enc, dec in zip(inputs, outputs)]
    
    def _preprocess_text(self, line):
        entities, tags = [], []
        l = re.findall(r'<[%-=+,#/\?:^.@*\"※~ㆍ!』‘|\(\)\[\]`\'…》\”\“\’·\s0-9ㄱ-ㅣ가-힣A-Za-z]+:[A-Za-z]+>', line)

        for label in l:
            entity, tag = label.replace('<', '').replace('>','').split(':')
            entities.append(entity)
            tags.append(tag)
            line = line.replace('\n', '').replace(label, entity)
    
        return line, entities, tags

    def _prepare_df(self, data):
        preprocessed_text, entities, tags, counts = [], [], [], []

        for line in data:
            line, entity, tag = self._preprocess_text(line)
            preprocessed_text.append(line)
            entities.append(entity)
            tags.append(tag)
            counts.append(len(entity))

        df = pd.DataFrame({"text": data, 'preprocessed_text': preprocessed_text, 'entities': entities, 'tags': tags, 'cnt': counts})
        return df
    
    def _get_ner_inputs(self):
        inputs, outputs = [], []
        entities, tags = list(self.raw_df['entities']), list(self.raw_df['tags'])
        instruction = ' Instruction: Input Sentence에서 찾을 수 있는 모든 Entity 및 그들의 Entity type을 출력하세요.'+\
                        ' 가능한 Entity type은 다음과 같습니다: 사람, 위치, 기관, 날짜, 시간, 수량'

        for i, text in enumerate(list(self.raw_df['preprocessed_text'])):
            input_sentence = 'Sentence: ' + text + instruction
            output_sentence = ''
            kor_tags = [self.eng_to_kor[tag] for tag in tags[i]]
            for idx, entity in enumerate(entities[i]):
                output_sentence += (entity+get_josa(entity, 'ent')+' ')
                if idx == len(entities[i])-1:
                    output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'다.')
                else:
                    output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'고, ')
            inputs.append(input_sentence)
            outputs.append(output_sentence)
        return inputs, outputs
    
    def _get_ee_inputs(self):
        inputs, outputs = [], []
        entities, tags = list(self.raw_df['entities']), list(self.raw_df['tags'])
        instruction = ' Instruction: Input Sentence에서 Entity에 해당하는 단어를 모두 출력하세요.'

        for i, text in enumerate(list(self.raw_df['preprocessed_text'])):
            input_sentence = 'Sentence: ' + text + instruction
            output_sentence = ', '.join(entities[i]) + '.'
            inputs.append(input_sentence)
            outputs.append(output_sentence)
        return inputs, outputs
    
    def _get_et_inputs(self):
        inputs, outputs = [], []
        entities, tags = list(self.raw_df['entities']), list(self.raw_df['tags'])

        for i, text in enumerate(list(self.raw_df['preprocessed_text'])):
            input_entities = ', '.join(entities[i])
            input_sentence = 'Sentence: ' + text + ' Instruction: Input Sentence에서 <' + input_entities + \
                                '>의 Entity type을 출력하세요. 가능한 Entity type은 다음과 같습니다: 사람, 위치, 기관, 날짜, 시간, 수량'
            output_sentence = ''
            kor_tags = [self.eng_to_kor[tag] for tag in tags[i]]
            for idx, entity in enumerate(entities[i]):
                output_sentence += (entity+get_josa(entity, 'ent')+' ')
                if idx == len(entities[i])-1:
                    output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'다.')
                else:
                    output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'고, ')
            inputs.append(input_sentence)
            outputs.append(output_sentence)
        return inputs, outputs
    
    def _get_hmn_inputs(self):
        inputs, outputs = [], []
        texts, entities, tags = list(self.original_df['preprocessed_text']), list(self.original_df['entities']), list(self.original_df['tags'])
        instruction = 'Instruction: Sentence1과 Sentence2에서 Entity에 해당하는 단어와 그들의 Entity type을 출력하세요.'
        ent_dict = defaultdict(list)
        for i, entity in enumerate(entities):
            for idx, ent in enumerate(entity):
                ent_dict[ent].append((tags[i][idx], texts[i]))
                
        ent_dict_del_1 = defaultdict(list)
        for key in ent_dict.keys():
            if len(ent_dict[key]) != 1:
                ent_dict_del_1[key] = ent_dict[key]
                
        for key in ent_dict_del_1.keys():
            sents = []
            for item1, item2 in list(combinations(ent_dict_del_1[key], 2)):
                if item1[0] != item2[0]:
                    if item1[1] not in sents and item2[1] not in sents:
                        sents.append(item1[1])
                        inputs.append(f'Sentence1: {item1[1]} Sentence2: {item2[1]} {instruction}')
                        kor_tags = [self.eng_to_kor[tag] for tag in [item1[0], item2[0]]]
                        output_sentence = ''
                        for idx in range(2):
                            output_sentence += (key+get_josa(key, 'ent')+' ')
                            if idx == 1:
                                output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'다.')
                            else:
                                output_sentence += (kor_tags[idx]+get_josa(kor_tags[idx], 'tag')+'고, ')
                        outputs.append(output_sentence)
        return inputs, outputs
        

    def _get_pos_inputs(self):
        inputs, outputs = [], []
        entities, pos_tags = list(self.raw_df['entities']), self.get_pos_tags()
        
        instruction = ' Instruction: Input Sentence에서 찾을 수 있는 모든 Entity 및 그들의 품사를 출력하세요.'+\
                        ' 가능한 품사는 다음과 같습니다: 일반명사, 고유명사, 단위명사, 수사, 해당없음'

        for i, text in enumerate(list(self.raw_df['preprocessed_text'])):
            input_sentence = 'Sentence: ' + text + instruction
            output_sentence = ''
            pos_tag = pos_tags[i]
            for idx, entity in enumerate(entities[i]):
                output_sentence += (entity+get_josa(entity, 'ent')+' ')
                if idx == len(entities[i])-1:
                    output_sentence += (pos_tag[idx]+get_josa(pos_tag[idx], 'tag')+'다.')
                else:
                    output_sentence += (pos_tag[idx]+get_josa(pos_tag[idx], 'tag')+'고, ')
            inputs.append(input_sentence)
            outputs.append(output_sentence)
        
        return inputs, outputs
    
    def get_pos_tags(self):
        from konlpy.tag import Mecab
        mecab = Mecab()
        NE_pos_dict = {'NNG':'일반명사', 'NNP':'고유명사', 'NNBC': '단위명사', 'NR': '수사'}
        total_pos_of_entities = []
        for ent in list(self.raw_df['entities']):
            pos_of_entities = []
            for e in ent:
                pos = [pos for tok, pos in mecab.pos(e)]
                candidates = Counter(pos).most_common()
                others = True
                if 'NR' in pos:  
                    pos_of_entities.append(NE_pos_dict['NR'])
                    others = False
                elif 'SN' in pos or 'NNBC' in pos:
                    pos_of_entities.append(NE_pos_dict['NNBC'])
                    others = False
                else:
                    for candidate, _ in candidates:
                        if candidate in list(NE_pos_dict.keys()):
                            pos_of_entities.append(NE_pos_dict[candidate])
                            others = False
                            break
                if others == True:
                    pos_of_entities.append('해당없음')
            total_pos_of_entities.append(pos_of_entities)
        return total_pos_of_entities


def get_josa(s, s_type): 
    NO_JONGSUNG = 'ᴕ'
    CHOSUNGS = ['ㄱ', 'ㄲ', 'ㄴ', 'ㄷ', 'ㄸ', 'ㄹ', 'ㅁ', 'ㅂ', 'ㅃ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅉ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ']
    JOONGSUNGS = ['ㅏ', 'ㅐ', 'ㅑ', 'ㅒ', 'ㅓ', 'ㅔ', 'ㅕ', 'ㅖ', 'ㅗ', 'ㅘ', 'ㅙ', 'ㅚ', 'ㅛ', 'ㅜ', 'ㅝ', 'ㅞ', 'ㅟ', 'ㅠ', 'ㅡ', 'ㅢ', 'ㅣ']
    JONGSUNGS = [NO_JONGSUNG,  'ㄱ', 'ㄲ', 'ㄳ', 'ㄴ', 'ㄵ', 'ㄶ', 'ㄷ', 'ㄹ', 'ㄺ', 'ㄻ', 'ㄼ', 'ㄽ', 'ㄾ', 'ㄿ', 'ㅀ', 'ㅁ', 'ㅂ', 'ㅄ', 'ㅅ', 'ㅆ', 'ㅇ', 'ㅈ', 'ㅊ', 'ㅋ', 'ㅌ', 'ㅍ', 'ㅎ']

    N_CHOSUNGS, N_JOONGSUNGS, N_JONGSUNGS = 19, 21, 28
    FIRST_HANGUL, LAST_HANGUL = 0xAC00, 0xD7A3 #'가', '힣'
    
    result = []
    for c in s:
        if ord(c) < FIRST_HANGUL or ord(c) > LAST_HANGUL: # if a character is a hangul
            result.append(c)
        else:            
            code = ord(c) - FIRST_HANGUL
            jongsung_index = code % N_JONGSUNGS
            code //= N_JONGSUNGS
            joongsung_index = code % N_JOONGSUNGS
            code //= N_JOONGSUNGS
            chosung_index = code

            result.append(CHOSUNGS[chosung_index])
            result.append(JOONGSUNGS[joongsung_index])
            result.append(JONGSUNGS[jongsung_index])
    jaso_str = ''.join(result)
    if s_type == 'ent':
        josa = '는' if jaso_str[-1] == 'ᴕ' else '은'  # 종성 없으면 '는' 있으면 '은'
    if s_type == 'tag':
        josa = '' if jaso_str[-1] == 'ᴕ' else '이'  # 종성 없으면 '' 있으면 '이' 
    return josa


def convert_tags_to_vector(_tags):
    labels = [0]*6
    for tag in _tags:
        if tag == 'QT':
            labels[0] += 1
        if tag == 'DT':
            labels[1] += 1
        if tag == 'PS':
            labels[2] += 1
        if tag == 'LC':
            labels[3] += 1
        if tag == 'TI':
            labels[4] += 1
        if tag == 'OG':
            labels[5] += 1
    return labels

In [4]:
import argparse

args = argparse.Namespace(
  val_ratio=0.2,
  data_path='data_learn',
  max_len=128,
  batch_size=64
)

In [5]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
model_name = 'model/kt-ulm-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)

In [6]:
train_dataset = T5NERDataset(args=args, data_name='train', tokenizer=tokenizer)

16761 54795


In [7]:
val_dataset = T5NERDataset(args=args, data_name='val', tokenizer=tokenizer)

4041 4041


In [8]:
test_dataset = T5NERDataset(args=args, data_name='test', tokenizer=tokenizer)

5201 5201


In [9]:
train_params = {
    'batch_size': args.batch_size,
    'shuffle': True,
    'num_workers': 2
    }

val_params = {
    'batch_size': args.batch_size,
    'shuffle': False,
    'num_workers': 2
    }

train_loader = DataLoader(train_dataset, **train_params)
val_loader = DataLoader(val_dataset, **val_params)
test_loader = DataLoader(test_dataset, **val_params)

In [10]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

In [11]:
import transformers
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)
optimizer = torch.optim.AdamW(params =  model.parameters(), lr=1e-5)

In [12]:
from tqdm.notebook import tqdm


def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    train_total_loss = 0
    for idx, data in enumerate(tqdm(loader, 0)):
        tgt = data['tgt_ids'].to(device, dtype = torch.long)
        tgt_ids = tgt[:, :-1].contiguous()  # eos token 제외
        lm_labels = tgt[:, 1:].clone().detach()  # start token(</s>) 제외
        lm_labels[tgt[:, 1:] == tokenizer.pad_token_id] = -100
        src_ids = data['src_ids'].to(device, dtype = torch.long)
        src_mask = data['src_mask'].to(device, dtype = torch.long)

        outputs = model(input_ids=src_ids, attention_mask=src_mask, decoder_input_ids=tgt_ids, labels=lm_labels)
        loss = outputs[0]
        train_total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_mean_loss = train_total_loss / (len(loader)*args.batch_size)
    wandb.log({"Epoch": epoch, "Train Loss": train_mean_loss})

In [13]:
def validate(tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(tqdm(loader, 0)):
            tgt_ids = data['tgt_ids'].to(device, dtype = torch.long)
            src_ids = data['src_ids'].to(device, dtype = torch.long)
            src_mask = data['src_mask'].to(device, dtype = torch.long)

            generated_ids = model.generate(
                input_ids = src_ids,
                attention_mask = src_mask, 
                max_length=128, 
                num_beams=3,
                repetition_penalty=2.5,
                length_penalty=1.0, 
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in tgt_ids]

            predictions.extend(preds)
            actuals.extend(target)
            
    return predictions, actuals

In [14]:
import wandb

wandb.init(project='KT', entity="ohsuz", name='InstructionNER')

[34m[1mwandb[0m: Currently logged in as: [33mohsuz[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
from tqdm.notebook import tqdm

predictions_per_epoch = {}
actuals_per_epoch = {}

for epoch in range(50):
    train(epoch, tokenizer, model, device, train_loader, optimizer)
    # if (epoch + 1) >= 30:
    #     if (epoch + 1) % 10 == 0:
    #         predictions, actuals = validate(tokenizer, model, device, val_loader)
    #         predictions_per_epoch[epoch+1] = predictions
    #         actuals_per_epoch[epoch+1] = actuals
    #         epoch_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})
    #         epoch_df.to_csv(f'./baseline_mhn_e{epoch+1}.csv', index=False)
    #         torch.save(model.state_dict(), f'baseline_mhn_e{epoch+1}.pt')
    #         # epoch_df.to_csv(f'./val_e{epoch+1}.csv', index=False)
    #         # torch.save(model.state_dict(), f'model_e{epoch+1}.pt')
    print(f'epoch {epoch+1} done')

  0%|          | 0/326 [00:00<?, ?it/s]

epoch 1 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 2 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 3 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 4 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 5 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 6 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 7 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 8 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 9 done


  0%|          | 0/326 [00:00<?, ?it/s]

epoch 10 done


In [None]:
predictions, actuals = validate(tokenizer, model, device, test_loader)
final_df = pd.DataFrame({'Generated Text':predictions, 'Actual Text':actuals})
final_df.to_csv('./output/submission.csv', index=False)
torch.save(model.state_dict(), f'./output/model.pt')

  0%|          | 0/82 [00:00<?, ?it/s]