In [1]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from utils import *
import nlpcda
import os

from typing import List, Dict

Simbert不能正常使用，除非你安装：bert4keras、tensorflow ，为了安装快捷，没有默认安装.... No module named 'bert4keras'


In [2]:
class DataAugmentation:
    def __init__(self, configs:Dict[str, dict]) -> None:
        self.entity_swap, self.random_del = None, None

        if 'random_entity' in configs.keys():
            self.entity_swap_p = configs['random_entity'].pop('prop')
            self.entity_swap = nlpcda.Similarword(**(configs['random_entity']))
        if 'random_delete_char' in configs.keys():
            self.random_del_p = configs['random_delete_char'].pop('prop')
            self.random_del = nlpcda.RandomDeleteChar(**(configs['random_delete_char']))
    
    def aug(self, df_full:pd.DataFrame, permute=True, seed=1024) -> pd.DataFrame:
        df = df_full[df_full.label == 1]
        df_neg = df_full[df_full.label == 0]

        L = len(df)
        if self.entity_swap:
            augmented_df = self.aug_single(df, L, self.entity_swap_p, self.entity_swap)
        if self.random_del:
            augmented_df = self.aug_single(augmented_df, L, self.random_del_p, self.random_del)

        augmented_df = pd.concat((df_neg, augmented_df))
        if permute:
            augmented_df = augmented_df.sample(frac=1, random_state=seed).reset_index(drop=True)
        return augmented_df

    def aug_single(self, df:pd.DataFrame, L:int, p:float, tool) -> pd.DataFrame:
        """input L: original df length. Avoid augmentation on newly constructed data. """
        idx = np.random.choice(range(L), size=int(L*p))
        slice_df = df.iloc[idx]
        label, text = slice_df[['label', 'text']].values.T
        transformed_slice_df = self.get_transformed_df(slice_df, tool)
        augmented_df = pd.concat((df, transformed_slice_df))
        return augmented_df

    def text_seq_transform(self, tool, texts:List[str]) -> List[str]:
        out = []
        for text in texts:
            transformed_text = tool.replace(text)[-1]
            out.append(transformed_text)
        return np.array(out)
    
    def get_transformed_df(self, slice_df:pd.DataFrame, tool) -> pd.DataFrame:
        label, text = slice_df[['label', 'text']].values.T
        transformed_text = self.text_seq_transform(tool, text)
        # display(text[:10])
        # display(transformed_text[:10])
        transformed_slice_df = slice_df.drop(columns=['text']).copy(deep=True)
        transformed_slice_df['text'] = transformed_text
        return transformed_slice_df
        

In [7]:
np.random.seed(1024)
rnd_idx = np.random.choice(range(1, 43001), size=10000)

train_df = pd.read_csv('../data-org/train.csv', sep='\t')
train_df_slice = pd.read_csv('../data-org/train.csv', sep='\t').iloc[rnd_idx]

train_df.drop(columns=['id'], inplace=True)
train_df_slice.drop(columns=['id'], inplace=True)

In [8]:
entities_file = os.path.join("D:\Apps\Anaconda3\envs\general-torch\Lib\site-packages", "nlpcda\data\entities.txt")

da_configs = {
    'random_entity':{
        'base_file':entities_file, 
        'create_num':2, 
        'change_rate':0.75, 
        'seed':1024, 
        'prop':0.5,  
    }, 
    'random_delete_char':{
        'create_num':2, 
        'change_rate':0.05, 
        'seed':1024, 
        'prop':0.3, 
    }
}

da = DataAugmentation(da_configs)
train_df_aug = da.aug(train_df)
train_df_slice_aug = da.aug(train_df_slice)

ntf()

load :D:\Apps\Anaconda3\envs\general-torch\Lib\site-packages\nlpcda\data\entities.txt done


In [9]:
# train_df_aug.to_csv('../data-aug-mini/train.csv', sep='\t')
train_df_aug.to_csv('../data-aug/train.csv', sep='\t')
train_df_slice_aug.to_csv('../data-aug-mini/train.csv', sep='\t')

In [11]:
i = 37223
train_df_aug.iloc[i:i+50]

Unnamed: 0,label,text
37223,1,我国成功发射的首颗X射线空间天文卫星“慧眼”，将显著增加我国大型科学卫星研制水平，实现我国在...
37224,1,星期天，杨惠妍把自己的房间打扫得干干净净，整整齐齐。
37225,0,邓稼先是中国几千年传统文化孕育出来的有最高奉献精神的儿子。
37226,1,随着综合国力的提商，特别是科学技术和造船工业能力的提高，才使一个国家有能力建造。从这个意义上...
37227,1,无人驾驶技术的日渐成熟，差不多45%到53%上下的服务员对全自动驾驶汽车用于港珠澳大桥穆朗玛...
37228,1,毒品流向分散是我市毒品犯罪的一大特点。既有大量毒品从境外及外省市流入广州，又有相当数量的毒品...
37229,1,这次演讲比赛规定每一位选手的演讲时间最多不能超过30分钟左右。
37230,1,“成青快速铁路”的正式运行，大大缩短了成都至青城山的时间，激发了成都市民去亲近自然的热情，也...
37231,1,何刚回答说，万安公墓奥运会是13亿宽窄巷子人民的盛会，也是全世界各国人民的盛会。办好本届中国...
37232,1,《舌尖上的中国2》摄制组历时半年，遍访各地美食，足迹踏遍全国各个省份以及部分海外城市，走访拍...


In [34]:
len(train_df_aug[train_df_aug.label == 0]) / len(train_df_aug[train_df_aug.label == 1])

0.24324095620901207

# Generate entity vocab

In [None]:
def postprocess_ds(outputs:List[List[dict]]):
    entity_vocab = {}
    for output in outputs:
        if output:
            sentence_vocab = postprocess_sentence(output)
            for k, v in sentence_vocab.items():
                if k in entity_vocab.keys():
                    entity_vocab[k].extend(v)
                else:
                    entity_vocab[k] = v
    return entity_vocab


def postprocess_sentence(ner_outputs:List[dict]):
    entity_vocab = {}
    if ner_outputs == []:
        return

    current = ''
    for out in ner_outputs:
        if out['entity'][0] == 'B':
            if current:
                if category in entity_vocab.keys() and current not in entity_vocab[category]:
                    entity_vocab[category].append(current)
                else:
                    entity_vocab[category] = [current]
                current = ''
            category = out['entity'][2:]
            current += out['word']
        if out['entity'][0] == 'I':
            if not current:
                continue
            current += out['word']
    if current:
        if category in entity_vocab.keys() and current not in entity_vocab[category]:
            entity_vocab[category].append(current)
        else:
            entity_vocab[category] = [current]
    return entity_vocab