In [52]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from utils import *
import nlpcda
from nlpcda.tools.Basetool import Basetool
import jieba
    
import os

from typing import List, Dict

In [53]:
class WordPositionExchange(Basetool):
    '''随机词语交换。'''
    
    def __init__(self, create_num: int = 5, change_rate: float = 0.05, char_gram: int = 3, seed: int = 1):
        super(WordPositionExchange, self).__init__('', create_num, change_rate, seed)
        self.char_gram = char_gram

    def __replace_one(self, one_sentence: str):
        # 变为字 数组
        # sen_chars = list(one_sentence)
        sen_chars = list(jieba.cut(one_sentence, cut_all=False))
        for i in range(len(sen_chars)):
            if self.random.random() < self.change_rate:
                # 非中文字不动！
                if self.__is_chinese(sen_chars[i]) == False:
                    continue
                # 交换位置
                change_i = self.__cpt_exchange_position(sen_chars, i)
                # 进行交换
                sen_chars[i], sen_chars[change_i] = sen_chars[change_i], sen_chars[i]
        return ''.join(sen_chars)

    def __cpt_exchange_position(self, sen_chars: list, position_i):
        # 计算出交换位置
        i = position_i
        j = position_i
        # 从position_i左边，找到第一个不是中文的位置，or 全是中文则不能超过char_gram范围
        while i > 0 and self.__is_chinese(sen_chars[i]) and abs(i - position_i) < self.char_gram:
            i -= 1
        # 从position_i右边，找到第一个不是中文的位置，or 全是中文则不能超过char_gram范围
        while j < len(sen_chars) - 1 and self.__is_chinese(sen_chars[j]) and abs(j - position_i) < self.char_gram:
            j += 1
        # 不是中文导致的推出，需要撤回位置
        if not self.__is_chinese(sen_chars[i]):
            if i < position_i:
                i += 1
        if not self.__is_chinese(sen_chars[j]):
            if j > position_i:
                j -= 1
        return self.random.randint(i, j)

    def __is_chinese(self, a_chr):
        return u'\u4e00' <= a_chr <= u'\u9fff'

    def replace(self, replace_str: str):
        replace_str = replace_str.replace('\n', '').strip()
        sentences = [replace_str]
        t = 0

        while len(sentences) < self.create_num:
            t += 1
            a_sentence = self.__replace_one(replace_str)

            if a_sentence not in sentences:
                sentences.append(a_sentence)
            if t > self.create_num * self.loop_t / self.change_rate:
                break
        return sentences

In [70]:
class DataAugmentation:
    def __init__(self, configs:Dict[str, dict]) -> None:
        self.entity_swap, self.random_del, self.random_swap = None, None, None

        if 'random_entity' in configs.keys():
            self.entity_swap_p = configs['random_entity'].pop('prop')
            self.entity_swap = nlpcda.Similarword(**(configs['random_entity']))
        if 'random_delete_char' in configs.keys():
            self.random_del_p = configs['random_delete_char'].pop('prop')
            self.random_del = nlpcda.RandomDeleteChar(**(configs['random_delete_char']))
        if 'random_swap' in configs.keys():
            self.random_swap_p = configs['random_swap'].pop('prop')
            self.random_swap = nlpcda.Similarword(**(configs['random_swap']))
        if 'random_swap_order' in configs.keys():
            self.random_swap_order_p = configs['random_swap_order'].pop('prop')
            self.random_swap_order = WordPositionExchange(**(configs['random_swap_order']))

    def aug(self, df_full:pd.DataFrame, permute=True, seed=1024) -> pd.DataFrame:
        df_pos = df_full[df_full.label == 1]
        df_neg = df_full[df_full.label == 0]

        L = len(df_pos)
        if self.entity_swap:
            df_pos_aug = self.aug_single(df_pos, L, self.entity_swap_p, self.entity_swap)
        # if self.random_swap_order:
        #     df_pos_aug = self.aug_single(df_pos_aug, L, self.random_swap_order_p, self.random_swap_order)
        if self.random_del:
            df_pos_aug = self.aug_single(df_pos_aug, L, self.random_del_p, self.random_del)
        if self.random_swap:
            df_pos_aug = self.aug_single(df_pos_aug, L, self.random_swap_p, self.random_swap)

        df_neg_aug = self.split_long_sentence(df_neg, label=0)
        if self.random_swap_order:
            df_neg_aug = self.aug_single(df_neg_aug, len(df_neg_aug), self.random_swap_order_p, self.random_swap_order, new_label=1)
        augmented_df = pd.concat((df_neg_aug, df_pos_aug))
        # ----------------------------------------------
        # with pd.option_context('display.max_rows', None, 'display.max_columns', None, ):
        #     pd.options.display.max_colwidth = 100
        #     display(augmented_df[-200:])
        # ----------------------------------------------
        if permute:
            augmented_df = augmented_df.sample(frac=1, random_state=seed).reset_index(drop=True)
        return augmented_df

    def aug_single(self, df:pd.DataFrame, L:int, p:float, tool, new_label=None) -> pd.DataFrame:
        """input L: original df length. Avoid augmentation on newly constructed data. """
        idx = np.random.choice(range(L), size=int(L*p))
        slice_df = df.iloc[idx]
        transformed_slice_df = self.get_transformed_df(slice_df, tool, new_label)
        augmented_df = pd.concat((df, transformed_slice_df))
        return augmented_df

    def text_seq_transform(self, tool, texts:List[str]) -> List[str]:
        out = []
        for text in texts:
            transformed_text = tool.replace(text)[-1]
            if transformed_text != text:
                out.append(transformed_text)
        return np.array(out)
    
    def get_transformed_df(self, slice_df:pd.DataFrame, tool, new_label) -> pd.DataFrame:
        label, text = slice_df[['label', 'text']].values.T
        transformed_text = self.text_seq_transform(tool, text)
        transformed_slice_df = pd.DataFrame({'text':transformed_text, 'label':label[0]})
        if new_label:
            transformed_slice_df['label'] = new_label
        return transformed_slice_df
    
    def split_long_sentence(self, slice_df:pd.DataFrame, label=0, punctuations=['。', '！', '？']) -> pd.DataFrame:
        def flatten(l):
            return [item for sublist in l for item in sublist]
        
        texts = slice_df.text.values

        outputs = []
        for text in texts:
            sentences = []
            for p in punctuations:
                sentences = [s for s in text.split(p) if s]
                sentences = [s+p if s[-1] not in punctuations else s for s in sentences]
                if len(sentences) > 1:
                    break
            outputs.extend(sentences)
            out_df = pd.DataFrame(data={'label':label, 'text':outputs})
        return out_df

In [71]:
np.random.seed(1024)
rnd_idx = np.random.choice(range(1, 43001), size=10000)

train_df = pd.read_csv('../data/data-org/train.csv', sep='\t')
train_df_slice = pd.read_csv('../data/data-org/train.csv', sep='\t').iloc[rnd_idx]

train_df.drop(columns=['id'], inplace=True)
train_df_slice.drop(columns=['id'], inplace=True)

In [72]:
entities_file = os.path.join("D:\Apps\Anaconda3\envs\general-torch\Lib\site-packages", "nlpcda\data\entities.txt")

da_configs = {
    'random_entity':{
        'base_file':entities_file, 
        'create_num':2, 
        'change_rate':0.75, 
        'seed':1024, 
        'prop':0.3,  
    }, 
    'random_delete_char':{
        'create_num':2, 
        'change_rate':0.05, 
        'seed':1024, 
        'prop':0.1, 
    }, 
    'random_swap':{
        'create_num':2, 
        'change_rate':0.2, 
        'seed':1024, 
        'prop':0.2, 
    }, 
    'random_swap_order':{
        'create_num':2,
        'char_gram':5,  
        'change_rate':0.05, 
        'seed':1024, 
        'prop':0.5, 
    }
}

da = DataAugmentation(da_configs)

load :D:\Apps\Anaconda3\envs\general-torch\Lib\site-packages\nlpcda\data\entities.txt done
load :d:\Apps\Anaconda3\envs\general-torch\lib\site-packages\nlpcda\data\同义词.txt done


In [77]:
train_df_aug = da.aug(train_df)
train_df_slice_aug = da.aug(train_df_slice)

In [93]:
for path in ['../data/data-aug-large', '../data/data-aug-mini']:
    if not os.path.exists(path):
        os.makedirs(path)
train_df_aug.to_csv('../data/data-aug-large/train.csv', sep='\t')
train_df_slice_aug.to_csv('../data/data-aug-mini/train.csv', sep='\t')

In [92]:
i = 37223
with pd.option_context('display.max_rows', None, 'display.max_columns', None, ):
    pd.options.display.max_colwidth = 100
    display(train_df_aug.iloc[i:i+50])

Unnamed: 0,label,text
37223,1,联系近期发生的美国对台军售、中美贸易纠纷以及奥巴马会见司长等外交摩擦，反映孝贤文化广场海啸后中国经济强劲复苏，但盘锦经济至今没有明显改善，对经十路总统的不满上升。
37224,0,中国青年足球队在1／4决赛中以1:4惨败给沙特队，进军世青赛的希望成了泡影。
37225,1,继“嫦娥1号”取得重大突破之后，中国的深空探测迈入了快速的前所未有的发展时期，“神州八号”已与中国空间站的“天宫一号”完成无人对接任务。
37226,1,对于家庭语言暴力问题，很多家长开端并不在意，以至于油然而生严重后果，才引起了她们的注意，其危害锱铢不亚于粗暴的肢体伤害。
37227,1,推进社会主义新农村建设，必须站在落实科学发展观，构建和谐社会，全面实现小康，进一步重视“三农”问题，把农业放到整个国民经济大格局中统筹谋划。
37228,0,为庆祝建国七十周年，我市开展的“我和我的祖国”快闪录制活动，极大的激发了广大市民的爱国热情。
37229,1,"是否坚持锻炼,是健康的保障。"
37230,1,济南文化西路的慢行一体路使用彩色沥青打造，不但可以提升城市的景观效果，增加现代化都市气息，而且也可以避免普通沥青路面黑色的单调性，提高司机和行人的注意力。
37231,1,2009东感动中国人物李灵说，尽管遇到再大的困难，她都会不改变开始的初衷，不会放弃为孩子建立阅览室的梦想。
37232,1,卡皮奥称，在菲律宾处于中国军事弱势的情况下，菲总统最起码应该像越南一样，向中国表示强烈抗议。


In [95]:
for path in ['../data/data-aug-trunc', '../data/data-aug-trunc-mini']:
    if not os.path.exists(path):
        os.makedirs(path)
train_df_aug[train_df_aug.text.map(len)<62].to_csv('../data/data-aug-trunc/train.csv', sep='\t', encoding='utf-8')
train_df_slice_aug[train_df_slice_aug.text.map(len)<62].to_csv('../data/data-aug-trunc-mini/train.csv', sep='\t', encoding='utf-8')

# Generate entity vocab

In [None]:
def postprocess_ds(outputs:List[List[dict]]):
    entity_vocab = {}
    for output in outputs:
        if output:
            sentence_vocab = postprocess_sentence(output)
            for k, v in sentence_vocab.items():
                if k in entity_vocab.keys():
                    entity_vocab[k].extend(v)
                else:
                    entity_vocab[k] = v
    return entity_vocab


def postprocess_sentence(ner_outputs:List[dict]):
    entity_vocab = {}
    if ner_outputs == []:
        return

    current = ''
    for out in ner_outputs:
        if out['entity'][0] == 'B':
            if current:
                if category in entity_vocab.keys() and current not in entity_vocab[category]:
                    entity_vocab[category].append(current)
                else:
                    entity_vocab[category] = [current]
                current = ''
            category = out['entity'][2:]
            current += out['word']
        if out['entity'][0] == 'I':
            if not current:
                continue
            current += out['word']
    if current:
        if category in entity_vocab.keys() and current not in entity_vocab[category]:
            entity_vocab[category].append(current)
        else:
            entity_vocab[category] = [current]
    return entity_vocab