In [41]:
from data_loader import AugmentDataSet
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForMaskedLM
import tqdm
import torch
from torch.utils.data import DataLoader
import copy
from typing import Union

In [42]:
input_file = "../data/train.csv"

input_df = pd.read_csv(input_file)
text_list = input_df["text"].tolist()

In [43]:
tokenizer = AutoTokenizer.from_pretrained("monologg/koelectra-base-v3-generator")

In [44]:
# train_df = pd.read_csv(input_file)
# train_df, valid_df = train_test_split(train_df, test_size=0.9, random_state=42)

# train_df.to_csv("../data/train_10.csv", index=False)
# valid_df.to_csv("../data/valid_90.csv", index=False)

In [45]:
dataset = AugmentDataSet(text_list, tokenizer)

100%|██████████| 7000/7000 [00:01<00:00, 5620.52it/s]


In [46]:
print(dataset[0])

(tensor([    2, 22780,  4097,  9971,  7445,  7697,  4219, 22780, 24304, 10643,
        28911,  4025,     3,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [47]:
def mask_tokens(tokenizer, input_ids:torch.Tensor, mlm_prob:float=0.15, do_rep_random:bool=True):
    '''
        Copied from huggingface/transformers/data/data_collator - torch.mask_tokens()
        Prepare masked tokens inputs/labels for masked language modeling
        if do_rep_random is True:
            80% MASK, 10% random, 10% original
        else:
            100% MASK
    '''
    labels = input_ids.clone()

    probability_matrix = torch.full(labels.shape, mlm_prob)
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value = 0.0)
    if tokenizer._pad_token is not None:
        padding_mask = labels.eq(tokenizer.pad_token_id)
        probability_matrix.masked_fill_(padding_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100 # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
    mask_rep_prob = 0.8
    if not do_rep_random:
        mask_rep_prob = 1.0
    
    indices_replaced = torch.bernoulli(torch.full(labels.shape, mask_rep_prob)).bool() & masked_indices
    input_ids[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    if do_rep_random:
        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]

    return input_ids, labels

def candidate_filtering(tokenizer:AutoTokenizer,
                        input_ids:list,
                        idx:int,
                        org:int,
                        candidates:Union[list, torch.Tensor]) -> int:
    '''
    후보 필터링 조건에 만족하는 최적의 후보 선택
    1. 원래 토큰과 후보 토큰이 같은 타입(is_same_token_type 참고)
    2. 현 위치 앞 혹은 뒤에 동일한 토큰이 있지 않음
    '''

    org_token = tokenizer.convert_ids_to_tokens([org])[0]
    candidate_tokens = tokenizer.convert_ids_to_tokens(candidates.cpu().tolist())

    for rank, token in enumerate(candidate_tokens):
        if org_token!=token and is_same_token_type(org_token, token):
            if input_ids[idx-1]==candidates[rank] or input_ids[idx+1]==candidate_tokens[rank]:
                continue
            return candidates[rank]

    return org

def is_same_token_type(org_token:str, candidate:str) -> bool:
    '''
    후보 필터링 조건을 만족하는지 확인
    - 후보와 원 토큰의 타입을 문장부호와 일반 토큰으로 나누어 같은 타입에 속하는지 확인
    '''
    res = False
    if org_token[0]=="#" and org_token[2:].isalpha()==candidate.isalpha():
        res = True
    elif candidate[0]=="#" and org_token.isalpha()==candidate[2:].isalpha():
        res = True
    elif candidate[0]=="#" and org_token[0]=="#" and org_token[2:].isalpha()==candidate[2:].isalpha():
        res = True
    elif org_token.isalpha()==candidate.isalpha() and (candidate[0]!="#" and org_token[0]!="#"):
        res = True

    return res

In [48]:
def batch_augment(model:AutoModelForMaskedLM,
                tokenizer:AutoTokenizer,
                dataset:torch.utils.data.Dataset,
                k, threshold, mlm_prob, batch_size) -> str:
    '''
    배치 단위의 문장에 랜덤으로 마스킹을 적용하여 새로운 문장 배치를 생성(증강)

    args:
        model(AutoModelForMaskedLM)
        tokenizer(AutoTokenizer)
        dataset(torch.utils.data.Dataset)
        dev(str or torch.device)
        args(argparse.Namespace)
            - k(int, default=5)
            - threshold(float, default=0.95)
           -  mlm_prob(float, default=0.15)
        
    return:
        (list) : 증강한 문장들의 리스트
    '''
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()

    augmented_res = []
    dataloader = DataLoader(dataset, batch_size = batch_size)
    for batch in tqdm.tqdm(dataloader):
        #########################################################
        # 인풋 문장에 랜덤으로 마스킹 적용
        input_ids, attention_masks = batch[0], batch[1]
        masked_input_ids, _ = mask_tokens(tokenizer, input_ids, mlm_prob, do_rep_random=False)

        masked_input_ids = masked_input_ids.to(dev)
        attention_masks = attention_masks.to(dev)
        labels = input_ids
        #########################################################

        with torch.no_grad():
            output = model(masked_input_ids, attention_mask = attention_masks)
            logits1 = output["logits"]

        #########################################################
        # 배치 내의 문장 별로 후보 필터링을 적용하고, 결과를 토대로 새로운 문장 생성
        augmented1 = []
        for sent_no in range(len(masked_input_ids)):
            copied = copy.deepcopy(input_ids.cpu().tolist()[sent_no])

            for i in range(len(masked_input_ids[sent_no])):
                if masked_input_ids[sent_no][i] == tokenizer.pad_token_id:
                    break

                if masked_input_ids[sent_no][i] == tokenizer.mask_token_id:
                    org_token = labels.cpu().tolist()[sent_no][i]
                    prob = logits1[sent_no][i].softmax(dim=0)
                    probability, candidates = prob.topk(k)
                    if probability[0]<threshold:
                        res = candidate_filtering(tokenizer, copied, i, org_token, candidates)
                    else:
                        res = candidates[0]
                    copied[i] = res

            copied = tokenizer.decode(copied, skip_special_tokens=True)
            augmented1.append(copied)
        #########################################################
        augmented_res.extend(augmented1)

    return augmented_res

In [49]:
model = AutoModelForMaskedLM.from_pretrained("monologg/koelectra-base-v3-generator")
model.to("cuda" if torch.cuda.is_available() else "cpu")

ElectraForMaskedLM(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(35000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (embeddings_project): Linear(in_features=768, out_features=256, bias=True)
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=256, 

In [50]:
augmented = batch_augment(model, tokenizer, dataset, 5, 0.95, 0.15, 1)

100%|██████████| 7000/7000 [01:15<00:00, 92.17it/s]


In [52]:
print(input_df)

                       ID                              text  target  \
0     ynat-v1_train_00000         개포2단지 분양 앞두고 개포지구 재건축 불붙어       1   
1     ynat-v1_train_00001         삼성전자 KBIS 2018서 셰프컬렉션 선보여       0   
2     ynat-v1_train_00002           LG G6 사면 BO 이어폰이 단돈 5천원       0   
3     ynat-v1_train_00003            신간 블록체인혁명 2030·남자의 고독사       3   
4     ynat-v1_train_00004    이스라엘 정보당국 팔레스타인인 50명 테러 혐의로 체포       4   
...                   ...                               ...     ...   
6995  ynat-v1_train_06995    힐만 SK 감독 고통스럽지만 내 상황 솔직히 알려야 해       5   
6996  ynat-v1_train_06996    정의장 사드 국회동의 사안 아니라 쳐도 충분히 협의해야       6   
6997  ynat-v1_train_06997          정치권 엘시티 수사 돌발변수에 촉각…왜 지금       6   
6998  ynat-v1_train_06998   문 대통령 1987 관람…깜짝 방문에 객석 환호·박수종합       6   
6999  ynat-v1_train_06999  120년 전 대한제국으로…가을밤 정동에서 시간 여행 떠나다       3   

                                                    url                  date  
0     https://news.naver.com/main/read.nhn?mode=LS2D...  2016.03.16

In [53]:
aug_id_prefix = "aug_"
aug_url = "mlm_augment"
aug_date = "20240130"

augmented_df = pd.DataFrame({"id": [aug_id_prefix+str(i) for i in range(len(augmented))],
                            "text": augmented,
                            "target": input_df["target"].tolist(),
                            "url": [aug_url for i in range(len(augmented))],
                            "date": [aug_date for i in range(len(augmented))]})

In [54]:
input_df.head()

Unnamed: 0,ID,text,target,url,date
0,ynat-v1_train_00000,개포2단지 분양 앞두고 개포지구 재건축 불붙어,1,https://news.naver.com/main/read.nhn?mode=LS2D...,2016.03.16. 오전 11:37
1,ynat-v1_train_00001,삼성전자 KBIS 2018서 셰프컬렉션 선보여,0,https://news.naver.com/main/read.nhn?mode=LS2D...,2018.01.10. 오전 8:33
2,ynat-v1_train_00002,LG G6 사면 BO 이어폰이 단돈 5천원,0,https://news.naver.com/main/read.nhn?mode=LS2D...,2017.04.30. 오전 10:00
3,ynat-v1_train_00003,신간 블록체인혁명 2030·남자의 고독사,3,https://news.naver.com/main/read.nhn?mode=LS2D...,2019.06.13. 오전 11:49
4,ynat-v1_train_00004,이스라엘 정보당국 팔레스타인인 50명 테러 혐의로 체포,4,https://news.naver.com/main/read.nhn?mode=LS2D...,2019.12.18. 오후 11:15


In [55]:
augmented_df.head()

Unnamed: 0,id,text,target,url,date
0,aug_0,##2단지 분양 앞두고 개포 · 재건축 불붙어,1,mlm_augment,20240130
1,aug_1,삼성전자 KBIS 2018서 셰프컬렉션 선보여,0,mlm_augment,20240130
2,aug_2,LG G6 사면 BO 이어폰이 단 5천원,0,mlm_augment,20240130
3,aug_3,"② 블록체인혁명 2030, 남자의 고독사",3,mlm_augment,20240130
4,aug_4,"이스라엘 정보, 팔레스타인인 50명 테러 혐의로 체포",4,mlm_augment,20240130


In [57]:
concat_df = pd.concat([input_df, augmented_df], axis=0)
concat_df.to_csv("../data/mlm_aug_train.csv", index=False)