## Text sampling

In [1]:
from random import sample, seed
from utils import load_dataset
from tqdm import tqdm

In [2]:
train = load_dataset('train.txt')
seed(175)
len(train)

260000

In [3]:
def text_sampling(out_num):
    return [t for t in sample(train, out_num)]


def saveTextFile(data, filepath):
    f = open(filepath, 'w')
    tmp = "\n{}\t{}\t{}"
    first = data[0]
    f.write(f"{first[0]}\t{first[1]}\t{str(first[-1])}")
    for example in data[1:]:
        f.write(tmp.format(example[0], example[1], str(example[-1])))
    f.close()
    print(filepath + " has been saved!")

In [4]:
ten_k = text_sampling(10000)
fifty_k = text_sampling(50000)
hundred_k = text_sampling(100000)
hund_fifty_k = text_sampling(150000)

In [5]:
saveTextFile(ten_k, '../data/aug_texts/train_10k.txt')
saveTextFile(fifty_k, '../data/aug_texts/train_50k.txt')
saveTextFile(hundred_k, '../data/aug_texts/train_100k.txt')
saveTextFile(hund_fifty_k, '../data/aug_texts/train_150k.txt')
saveTextFile(train, '../data/aug_texts/train_full.txt')

../data/aug_texts/train_10k.txt has been saved!
../data/aug_texts/train_50k.txt has been saved!
../data/aug_texts/train_100k.txt has been saved!
../data/aug_texts/train_150k.txt has been saved!
../data/aug_texts/train_full.txt has been saved!


## Text Augmentation

For every text pair, we will augment both texts and do cross pairing.



### DA models combined

This is for two reasons: (1) efficiency; (2) more controlled, making sure that the augmented texts are sampled from the same pool.

In [23]:
from ngramLM_en import NgramLM
from reda_en import REDA
from itertools import groupby
from random import sample


lm = NgramLM()


class AugTextsWithTwoModels(REDA):
    
    def __int__(self, syn_path=None):
        super.__init__(syn_path)
    
    @staticmethod
    def _out_num(edit_num, choice_num=None):
        if choice_num:
            return choice_num
        if edit_num == 1:
            return 20
        if edit_num == 2:
            return 50
        if edit_num == 3:
            return 100
        return edit_num * 50
    
    @staticmethod
    def deduplicate(ori, lst):
        lst.append(ori)
        lst.sort()
        lst = [l for l,_ in groupby(lst)]
        lst.remove(ori)
        return lst
    
    def augment(self, text, replace_rate=0.2, swap_rate=0.2, 
                insert_rate=0.1, delete_rate=0.1, max_mix=None, 
                out_num_each=2, out_str=True):
        
        def _filter(item):
            '''A func to make sure that the data structure is all right as some operation might fail to augment 
            the text (e.g., too short, no synonyms etc.)'''
            if isinstance(item, str):
                return []
            if not out_str and isinstance(item[0], str):
                if ''.join(item) == ''.join(words):
                    return []
                return [item]
            return item
        
        if isinstance(text, str):
            words = self.tokenize(text)
        elif isinstance(text, list):
            words = text
        else:
            raise TypeError("The input text must be either a str or a list")
            
        words_num = len(words)
        replace_num = round(replace_rate * words_num) 
        swap_num = round(swap_rate * words_num) 
        insert_num = round(insert_rate * words_num) 
        delete_num = round(delete_rate * words_num) 
        
        reda_out = []
        ngram_out = []
        _sample = lambda lst, num: sample(lst, num) if len(lst) >= num else lst
        out_num_each_special = out_num_each - 1 if out_num_each > 1 else out_num_each
        
        if replace_num:
            out = _filter(self.replace_syn(words, replace_num, self._out_num(replace_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if swap_num:
            out = _filter(self.swap_words(words, swap_num, self._out_num(swap_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if insert_num:
            out = _filter(self.insert_words(words, insert_num, self._out_num(insert_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        if delete_num:
            out = _filter(self.delete_words(words, delete_num, self._out_num(delete_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
            
        out = _filter(self.mixed_edits(words, 2, 50))
        reda_out.extend(_sample(out, out_num_each_special))
        ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        
        if out_str:
            reda_out = [' '.join(sent) for sent in reda_out]
        # to deduplicate the outputs and ensure that the original text is no returned.
        words = self._out_str(words, out_str)
        
        reda_out = self.deduplicate(words, reda_out)
        ngram_out = self.deduplicate(words, ngram_out)
        return reda_out, ngram_out

In [24]:
def textAugmentation(data, augModel, aug_num_each=2):
    def augment(text_a, text_b, label):
        out_reda = [(text_a, text_b, label)]
        out_ngram = [(text_a, text_b, label)]
        
        aug_reda, aug_ngram = augModel.augment(text_a, out_num_each=aug_num_each)
        out_reda.extend([(t, text_b, label) for t in aug_reda])
        out_ngram.extend([(t, text_b, label) for t in aug_ngram])
        
        aug_reda, aug_ngram = augModel.augment(text_b, out_num_each=aug_num_each)
        out_reda.extend([(text_a, t, label) for t in aug_reda])
        out_ngram.extend([(text_a, t, label) for t in aug_ngram])
        return out_reda, out_ngram
    
    if not isinstance(data[0], (tuple, list,)):
        out_reda, out_ngram = augment(data[0], data[1], data[-1])
        print('Texts augmented.')
        print(f' Before (reda): 1. Now: {len(out_reda)}')
        print(f' Before (ngram): 1. Now: {len(out_ngram)}')
        return out_reda, out_ngram
    
    outputs_reda, outputs_gram = [], []
    for example in tqdm(data):
        out_reda, out_ngram = augment(example[0], example[1], example[-1])
        outputs_reda.extend(out_reda)
        outputs_gram.extend(out_ngram)
    print('Texts augmented.')
    print(f'Before (reda): {len(data)}. Now: {len(outputs_reda)}')
    print(f'Before (ngram): {len(data)}. Now: {len(outputs_gram)}')
    return outputs_reda, outputs_gram

In [25]:
aug = AugTextsWithTwoModels()

### 10,000

In [9]:
ten_k_aug_reda, ten_k_aug_reda_ngram = textAugmentation(ten_k, aug)
saveTextFile(ten_k_aug_reda, '../data/aug_texts/train_10k_aug_reda.txt')
saveTextFile(ten_k_aug_reda_ngram, '../data/aug_texts/train_10k_aug_reda_ngram.txt')

100%|█████████████████████████████████████| 10000/10000 [13:27<00:00, 12.39it/s]


Texts augmented.
Before (reda): 10000. Now: 148341
Before (ngram): 10000. Now: 141604
../data/aug_texts/train_10k_aug_reda.txt has been saved!
../data/aug_texts/train_10k_aug_reda_ngram.txt has been saved!


### 50,000

In [27]:
fifty_k_aug_reda, fifty_k_aug_reda_ngram = textAugmentation(fifty_k, aug, aug_num_each=1)
saveTextFile(fifty_k_aug_reda, '../data/aug_texts/train_50k_aug_reda.txt')
saveTextFile(fifty_k_aug_reda_ngram, '../data/aug_texts/train_50k_aug_reda_ngram.txt')

100%|███████████████████████████████████| 50000/50000 [7:54:09<00:00,  1.76it/s]


Texts augmented.
Before (reda): 50000. Now: 543066
Before (ngram): 50000. Now: 512176
../data/aug_texts/train_50k_aug_reda.txt has been saved!
../data/aug_texts/train_50k_aug_reda_ngram.txt has been saved!


### 100,000

In [28]:
hundred_k_aug_reda, hundred_k_aug_reda_ngram = textAugmentation(hundred_k, aug, aug_num_each=1)
saveTextFile(hundred_k_aug_reda, '../data/aug_texts/train_100k_aug_reda.txt')
saveTextFile(hundred_k_aug_reda_ngram, '../data/aug_texts/train_100k_aug_reda_ngram.txt')

100%|█████████████████████████████████| 100000/100000 [3:09:52<00:00,  8.78it/s]


Texts augmented.
Before (reda): 100000. Now: 1086063
Before (ngram): 100000. Now: 1023777
../data/aug_texts/train_100k_aug_reda.txt has been saved!
../data/aug_texts/train_100k_aug_reda_ngram.txt has been saved!


### 150,000

In [29]:
hund_fifty_k_aug_reda, hund_fifty_k_aug_reda_ngram = textAugmentation(hund_fifty_k, aug, aug_num_each=1)
saveTextFile(hund_fifty_k_aug_reda, '../data/aug_texts/train_150k_aug_reda.txt')
saveTextFile(hund_fifty_k_aug_reda_ngram, '../data/aug_texts/train_150k_aug_reda_ngram.txt')

100%|█████████████████████████████████| 150000/150000 [4:00:35<00:00, 10.39it/s]


Texts augmented.
Before (reda): 150000. Now: 1629178
Before (ngram): 150000. Now: 1536285
../data/aug_texts/train_150k_aug_reda.txt has been saved!
../data/aug_texts/train_150k_aug_reda_ngram.txt has been saved!


### Full

The two augmented train sets for the entire train set were gotten from another running window. The statistics for them are reported below.  

In [31]:
full_aug_reda, full_aug_reda_ngram = textAugmentation(train, aug, aug_num_each=1)
saveTextFile(full_aug_reda, '../data/aug_texts/train_full_aug_reda.txt')
saveTextFile(full_aug_reda_ngram, '../data/aug_texts/train_full_aug_reda_ngram.txt')

In [32]:
print('''Texts augmented.
Before (reda): 260000. Now: 2823733
Before (ngram): 260000. Now: 2662639''')

Texts augmented.
Before (reda): 260000. Now: 2823733
Before (ngram): 260000. Now: 2662639
