## Text sampling

In [1]:
from random import sample, seed
from utils import lcqmcLoader

In [2]:
train = lcqmcLoader('train')
seed(175)

In [3]:
def text_sampling(out_num):
    return [t for t in sample(train, out_num)]


def saveTextFile(data, filepath):
    f = open(filepath, 'w')
    f.write("text_a\ttext_b\tlabel")
    tmp = "\n{}\t{}\t{}"
    for example in data:
        f.write(tmp.format(example[0], example[1], example[-1]))
    f.close()
    print(filepath + " has been saved!")

#### Dataset size to sample: 5k, 10k, 50k, 100k, and full set

In [4]:
five_k = text_sampling(5000)
ten_k = text_sampling(10000)
fifty_k = text_sampling(50000)
hundred_k = text_sampling(100000)

In [5]:
saveTextFile(five_k, '../data/aug_texts/train_5k.txt')
saveTextFile(ten_k, '../data/aug_texts/train_10k.txt')
saveTextFile(fifty_k, '../data/aug_texts/train_50k.txt')
saveTextFile(hundred_k, '../data/aug_texts/train_100k.txt')
saveTextFile(train, '../data/aug_texts/train_full.txt')

../data/aug_texts/train_5k.txt has been saved!
../data/aug_texts/train_10k.txt has been saved!
../data/aug_texts/train_50k.txt has been saved!
../data/aug_texts/train_100k.txt has been saved!
../data/aug_texts/train_full.txt has been saved!


## Text Augmentation

For every text pair, we will augment both texts and do cross pairing.

## DA models combined

This is for two reasons: (1) efficiency; (2) more controlled, making sure that the augmented texts are sampled from the same pool.

In [6]:
from ngramLM import NgramLM
from reda import REDA
from itertools import groupby
from random import sample


lm = NgramLM()


class AugTextsWithTwoModels(REDA):
    
    def __int__(self, syn_path=None):
        super.__init__(syn_path)
    
    @staticmethod
    def _out_num(edit_num, choice_num=None):
        if choice_num:
            return choice_num
        if edit_num == 1:
            return 20
        if edit_num == 2:
            return 50
        if edit_num == 3:
            return 100
        return edit_num * 50
    
    @staticmethod
    def deduplicate(ori, lst):
        lst.append(ori)
        lst.sort()
        lst = [l for l,_ in groupby(lst)]
        lst.remove(ori)
        return lst
    
    def augment(self, text, replace_rate=0.2, swap_rate=0.2, 
                insert_rate=0.1, delete_rate=0.1, max_mix=None, 
                out_num_each=2, out_str=True):
        
        def _filter(item):
            '''A func to make sure that the data structure is all right as some operation might fail to augment 
            the text (e.g., too short, no synonyms etc.)'''
            if isinstance(item, str):
                return []
            if not out_str and isinstance(item[0], str):
                if len(item) == len(words):
                    for i in range(words_num):
                        if item[i] == words[i]:
                            return []
                return [item]
            return item
        
        if isinstance(text, str):
            words = self.tokenize(text)
        elif isinstance(text, list):
            words = text
        else:
            raise TypeError("The input text must be either a str or a list")
            
        words_num = len(words)
        replace_num = round(replace_rate * words_num) 
        swap_num = round(swap_rate * words_num) 
        insert_num = round(insert_rate * words_num) 
        delete_num = round(delete_rate * words_num) 
        
        reda_out = []
        ngram_out = []
        _sample = lambda lst, num: sample(lst, num) if len(lst) >= num else sample(lst, len(lst))
        out_num_each_special = out_num_each - 1 if out_num_each > 1 else out_num_each
        
        if replace_num:
            out = _filter(self.replace_syn(words, replace_num, self._out_num(replace_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if swap_num:
            out = _filter(self.swap_words(words, swap_num, self._out_num(swap_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if insert_num:
            out = _filter(self.insert_words(words, insert_num, self._out_num(insert_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        if delete_num:
            out = _filter(self.delete_words(words, delete_num, self._out_num(delete_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
            
        out = _filter(self.mixed_edits(words, 2, 50, out_str))
        reda_out.extend(_sample(out, out_num_each_special))
        ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        
        if out_str:
            reda_out = [''.join(sent) for sent in reda_out]
        # to deduplicate the outputs and ensure that the original text is no returned.
        words = self._out_str(words, out_str)
        
        reda_out = self.deduplicate(words, reda_out)
        ngram_out = self.deduplicate(words, ngram_out)
        return reda_out, ngram_out

In [7]:
def textAugmentation(data, augModel, aug_num_each=2):
    def augment(text_a, text_b, label):
        out_reda = [(text_a, text_b, label)]
        out_ngram = [(text_a, text_b, label)]
        
        aug_reda, aug_ngram = augModel.augment(text_a, out_num_each=aug_num_each)
        out_reda.extend([(t, text_b, label) for t in aug_reda])
        out_ngram.extend([(t, text_b, label) for t in aug_ngram])
        
        aug_reda, aug_ngram = augModel.augment(text_b)
        out_reda.extend([(t, text_b, label) for t in aug_reda])
        out_ngram.extend([(t, text_b, label) for t in aug_ngram])
        return out_reda, out_ngram
    
    if not isinstance(data[0], (tuple, list,)):
        out_reda, out_ngram = augment(data[0], data[1], data[-1])
        print('Texts augmented.')
        print(f' Before (reda): 1. Now: {len(out_reda)}')
        print(f' Before (ngram): 1. Now: {len(out_ngram)}')
        return out_reda, out_ngram
    
    outputs_reda, outputs_gram = [], []
    for example in data:
        out_reda, out_ngram = augment(example[0], example[1], example[-1])
        outputs_reda.extend(out_reda)
        outputs_gram.extend(out_ngram)
    print('Texts augmented.')
    print(f'Before (reda): {len(data)}. Now: {len(outputs_reda)}')
    print(f'Before (ngram): {len(data)}. Now: {len(outputs_gram)}')
    return outputs_reda, outputs_gram

In [8]:
aug = AugTextsWithTwoModels()

### 5,000

In [9]:
five_k_aug_reda, five_k_aug_reda_ngram = textAugmentation(five_k, aug)
saveTextFile(five_k_aug_reda, '../data/aug_texts/train_5k_aug_reda.txt')
saveTextFile(five_k_aug_reda_ngram, '../data/aug_texts/train_5k_aug_reda_ngram.txt')

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/w9/d_nplhzj4qx35xxlgljgdtjh0000gn/T/jieba.cache
Loading model cost 0.784 seconds.
Prefix dict has been built successfully.


Texts augmented.
Before (reda): 5000. Now: 66267
Before (ngram): 5000. Now: 64358
../data/aug_texts/train_5k_aug_reda.txt has been saved!
../data/aug_texts/train_5k_aug_reda_ngram.txt has been saved!


### 10,000

In [10]:
ten_k_aug_reda, ten_k_aug_reda_ngram = textAugmentation(ten_k, aug)
saveTextFile(ten_k_aug_reda, '../data/aug_texts/train_10k_aug_reda.txt')
saveTextFile(ten_k_aug_reda_ngram, '../data/aug_texts/train_10k_aug_reda_ngram.txt')

Texts augmented.
Before (reda): 10000. Now: 132513
Before (ngram): 10000. Now: 128649
../data/aug_texts/train_10k_aug_reda.txt has been saved!
../data/aug_texts/train_10k_aug_reda_ngram.txt has been saved!


### 50,000

In [11]:
fifty_k_aug_reda, fifty_k_aug_reda_ngram = textAugmentation(fifty_k, aug, aug_num_each=1)
saveTextFile(fifty_k_aug_reda, '../data/aug_texts/train_50k_aug_reda.txt')
saveTextFile(fifty_k_aug_reda_ngram, '../data/aug_texts/train_50k_aug_reda_ngram.txt')

Texts augmented.
Before (reda): 50000. Now: 563228
Before (ngram): 50000. Now: 544583
../data/aug_texts/train_50k_aug_reda.txt has been saved!
../data/aug_texts/train_50k_aug_reda_ngram.txt has been saved!


### 100,000

In [14]:
hundred_k_aug_reda, hundred_k_aug_reda_ngram = textAugmentation(hundred_k, aug, aug_num_each=1)
saveTextFile(hundred_k_aug_reda, '../data/aug_texts/train_100k_aug_reda.txt')
saveTextFile(hundred_k_aug_reda_ngram, '../data/aug_texts/train_100k_aug_reda_ngram.txt')

Texts augmented.
Before (reda): 100000. Now: 1127535
Before (ngram): 100000. Now: 1089847
../data/aug_texts/train_100k_aug_reda.txt has been saved!
../data/aug_texts/train_100k_aug_reda_ngram.txt has been saved!


### Full

In [13]:
full_aug_reda, full_aug_reda_ngram = textAugmentation(train, aug, aug_num_each=1)
saveTextFile(full_aug_reda, '../data/aug_texts/train_full_aug_reda.txt')
saveTextFile(full_aug_reda_ngram, '../data/aug_texts/train_full_aug_reda_ngram.txt')

Texts augmented.
Before (reda): 238766. Now: 2691917
Before (ngram): 238766. Now: 2601696
../data/aug_texts/train_full_aug_reda.txt has been saved!
../data/aug_texts/train_full_aug_reda_ngram.txt has been saved!
