In [1]:
!git clone https://github.com/catSirup/KorEDA.git

In [34]:
import re
import pickle
import sys
sys.path.append("/opt/ml/backup")
from dataloader.load_data import *
from typing import Optional
import random

In [2]:
wordnet = {}
with open("./KorEDA/wordnet.pickle", "rb") as f:
    wordnet = pickle.load(f)

print(len(wordnet))

In [20]:
# 한글만 남기고 나머지는 삭제
def get_only_hangul(line):
    parseText= re.compile('/ ^[ㄱ-ㅎㅏ-ㅣ가-힣]*$/').sub('',line)
    return parseText

In [21]:
########################################################################
# Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################

def synonym_replacement(words, n):
    """ 유의어로 교체(Synonym Replacement, SR)
    """
    new_words = words.copy()
    random_word_list = list(set([word for word in words]))
    random.shuffle(random_word_list)
    num_replaced = 0
    for random_word in random_word_list:
        synonyms = get_synonyms(random_word)
        if len(synonyms) >= 1:
            synonym = random.choice(list(synonyms))
            new_words = [synonym if word == random_word else word for word in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    if len(new_words) != 0:
        sentence = ' '.join(new_words)
        new_words = sentence.split(" ")    
    else:
        new_words = ""
        
    return new_words


In [22]:
def get_synonyms(word):
    synomyms = []
    
    try:
        for syn in wordnet[word]:
            for s in syn:
                synomyms.append(s)
    except:
        pass
    
    return synomyms

In [23]:
########################################################################
# Random deletion
# Randomly delete words from the sentence with probability p
########################################################################
def random_deletion(words, p):
    """랜덤 삭제(Random Deletion, RD)
    """
    if len(words) == 1:
        return words
    
    new_words = []
    for word in words:
        r = random.uniform(0, 1)
        if r > p:
            new_words.append(word)
            
    if len(new_words) == 0:
        rand_int = random.randint(0, len(words)-1)
        return [words[rand_int]]
    
    return new_words

In [24]:
########################################################################
# Random swap
# Randomly swap two words in the sentence n times
########################################################################
def random_swap(words, n):
    """랜덤 교체(Random swap, RS)
    """
    new_words = words.copy()
    for _ in range(n):
        new_words = swap_word(new_words)
    return new_words

def swap_word(new_words):
    random_idx_1 = random.randint(0, len(new_words)-1)
    random_idx_2 = random_idx_1
    counter = 0
    
    while random_idx_2 == random_idx_1:
        random_idx_2 = random.randint(0, len(new_words)-1)
        counter += 1
        if counter > 3:
            return new_words
        
    new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1]
    return new_words

In [25]:
########################################################################
# Random insertion
# Randomly insert n words into the sentence
########################################################################
def random_insertion(words, n):
    """랜덤 삽입(Random Insertion, RI)
    """
    new_words = words.copy()
    for _ in range(n):
        add_word(new_words)
        
    return new_words

In [26]:
def add_word(new_words):
    synonyms = []
    counter = 0
    while len(synonyms) < 1:
        if len(new_words) >= 1:
            random_word = new_words[random.randint(0, len(new_words)-1)]
            synonyms = get_synonyms(random_word)
            counter += 1
        else:
            random_word = ""
            
        if counter >= 10:
            return
        
    random_synonym = synonyms[0]
    random_idx = random.randint(0, len(new_words)-1)
    new_words.insert(random_idx, random_synonym)

In [27]:
def EDA(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
    sentence = get_only_hangul(sentence)
    words = sentence.split(' ')
    words = [word for word in words if word is not ""]
    num_words = len(words)
    
    augmented_sentences = []
    num_new_per_technique = int(num_aug/4) + 1
    
    n_sr = max(1, int(alpha_sr*num_words))
    n_ri = max(1, int(alpha_ri*num_words))
    n_rs = max(1, int(alpha_rs*num_words))

    # sr : 유의어로 교체
    for _ in range(num_new_per_technique):
        a_words = synonym_replacement(words, n_sr)
        augmented_sentences.append(' '.join(a_words))

    # ri : 랜덤 삽입
    for _ in range(num_new_per_technique):
        a_words = random_insertion(words, n_ri)
        augmented_sentences.append(' '.join(a_words))

    # rs : 랜덤 교체
    for _ in range(num_new_per_technique):
        a_words = random_swap(words, n_rs)
        augmented_sentences.append(" ".join(a_words))

    # rd : 랜덤 삭제
    for _ in range(num_new_per_technique):
        a_words = random_deletion(words, p_rd)
        augmented_sentences.append(" ".join(a_words))

    augmented_sentences = [get_only_hangul(sentence) for sentence in augmented_sentences]
    random.shuffle(augmented_sentences)

    if num_aug >= 1:
        augmented_sentences = augmented_sentences[:num_aug]
    else:
        keep_prob = num_aug / len(augmented_sentences)
        augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]

    augmented_sentences.append(sentence)

    return augmented_sentences

In [28]:
data = load_data("/opt/ml/input/data/train/train_renew.tsv")

In [32]:
def augment_sentence_EDA(record : dict) -> Optional[dict]:
    res = []
    entity_code01 = "ZQWXEC" # entity가 사라지는 것을 방지
    entity_code02 = "QZWXEC"
    
    sentence = record['sentence']
    sentence_temp = record['sentence']
    entity01 = record['entity_01']
    entity02 = record['entity_02']
    
    sentence = sentence.replace(entity01, entity_code01).replace(entity02, entity_code02)
    
    eda_list = EDA(sentence)
#     print(eda_list)
    
    for eda_result in eda_list:
        eda_result = eda_result.replace(entity_code01, entity01).replace(entity_code02, entity02)
        count = 0
        if eda_result != sentence_temp:
            res.append({
                "sentence": eda_result,
                "entity_01" : entity01,
                "entity_02" : entity02,
                "label" : record["label"]
            })
            break
    
    return res

In [3]:
aug_info = []

for i in tqdm(range(data.shape[0]), desc="Augmenting ..."):
    if data.iloc[i]['label'] != 0:
        aug_info.extend(augment_sentence_EDA(data.iloc[i].to_dict()))

In [4]:
print(len(aug_info))

In [5]:
aug_info[0]

In [47]:
aug_data = data.append(aug_info)

In [48]:
aug_data.to_csv("/opt/ml/input/data/train/aug_EDA_train.tsv", index=False,header = None, sep = "\t")