## Text sampling

In [1]:
from random import sample, seed
from utils import load_dataset
from tqdm import tqdm
import os

In [2]:
train = load_dataset('../data/train.tsv')
seed(175)
len(train)

6920

In [3]:
def text_sampling(out_num):
    return [t for t in sample(train, out_num)]


def saveTextFile(data, filepath):
    f = open(filepath, 'w')
    tmp = "\n{}\t{}"
    f.write(str(data[0][1]) + '\t' + data[0][0])
    for example in data[1:]:
        f.write(tmp.format(str(example[1]), example[0]))
    f.close()
    print(filepath + " has been saved!")

In [4]:
five_h = text_sampling(500)
one_k = text_sampling(1000)
two_k = text_sampling(2000)
four_k = text_sampling(4000)

In [5]:
if not os.path.exists('../data/aug_texts/'):
    os.mkdir('../data/aug_texts/')

saveTextFile(five_h, '../data/aug_texts/train_0.5k.txt')
saveTextFile(one_k, '../data/aug_texts/train_1k.txt')
saveTextFile(two_k, '../data/aug_texts/train_2k.txt')
saveTextFile(four_k, '../data/aug_texts/train_4k.txt')
saveTextFile(train, '../data/aug_texts/train_full.txt')

../data/aug_texts/train_0.5k.txt has been saved!
../data/aug_texts/train_1k.txt has been saved!
../data/aug_texts/train_2k.txt has been saved!
../data/aug_texts/train_4k.txt has been saved!
../data/aug_texts/train_full.txt has been saved!


## Text Augmentation

For every text pair, we will augment both texts and do cross pairing.



### DA models combined

This is for two reasons: (1) efficiency; (2) more controlled, making sure that the augmented texts are sampled from the same pool.

In [6]:
from ngramLM_en import NgramLM
from reda_en import REDA
from itertools import groupby
from random import sample


lm = NgramLM()


class AugTextsWithTwoModels(REDA):
    
    def __int__(self, syn_path=None):
        super.__init__(syn_path)
    
    @staticmethod
    def _out_num(edit_num, choice_num=None):
        if choice_num:
            return choice_num
        if edit_num == 1:
            return 20
        if edit_num == 2:
            return 50
        if edit_num == 3:
            return 100
        return edit_num * 50
    
    @staticmethod
    def deduplicate(ori, lst):
        lst.append(ori)
        lst.sort()
        lst = [l for l,_ in groupby(lst)]
        lst.remove(ori)
        return lst
    
    def augment(self, text, replace_rate=0.2, swap_rate=0.2, 
                insert_rate=0.1, delete_rate=0.1, max_mix=None, 
                out_num_each=2, out_str=True):
        
        def _filter(item):
            '''A func to make sure that the data structure is all right as some operation might fail to augment 
            the text (e.g., too short, no synonyms etc.)'''
            if isinstance(item, str):
                return []
            if not out_str and isinstance(item[0], str):
                if ''.join(item) == ''.join(words):
                    return []
                return [item]
            return item
        
        if isinstance(text, str):
            words = self.tokenize(text)
        elif isinstance(text, list):
            words = text
        else:
            raise TypeError("The input text must be either a str or a list")
            
        words_num = len(words)
        replace_num = round(replace_rate * words_num) 
        swap_num = round(swap_rate * words_num) 
        insert_num = round(insert_rate * words_num) 
        delete_num = round(delete_rate * words_num) 
        
        reda_out = []
        ngram_out = []
        _sample = lambda lst, num: sample(lst, num) if len(lst) >= num else lst
        out_num_each_special = out_num_each - 1 if out_num_each > 1 else out_num_each
        
        if replace_num:
            out = _filter(self.replace_syn(words, replace_num, self._out_num(replace_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if swap_num:
            out = _filter(self.swap_words(words, swap_num, self._out_num(swap_num)))
            reda_out.extend(_sample(out, out_num_each))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each, out_str=out_str))
        if insert_num:
            out = _filter(self.insert_words(words, insert_num, self._out_num(insert_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        if delete_num:
            out = _filter(self.delete_words(words, delete_num, self._out_num(delete_num)))
            reda_out.extend(_sample(out, out_num_each_special))
            ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
            
        out = _filter(self.mixed_edits(words, 2, 50))
        reda_out.extend(_sample(out, out_num_each_special))
        ngram_out.extend(lm.pickBestSent(out, out_num=out_num_each_special, out_str=out_str))
        
        if out_str:
            reda_out = [' '.join(sent) for sent in reda_out]
        # to deduplicate the outputs and ensure that the original text is no returned.
        words = self._out_str(words, out_str)
        
        reda_out = self.deduplicate(words, reda_out)
        ngram_out = self.deduplicate(words, ngram_out)
        return reda_out, ngram_out

In [7]:
def textAugmentation(data, augModel, aug_num_each=2):
    def augment(text, label):
        out_reda = [(text, label)]
        out_ngram = [(text, label)]
        
        aug_reda, aug_ngram = augModel.augment(text, out_num_each=aug_num_each)
        out_reda.extend([(t, label) for t in aug_reda])
        out_ngram.extend([(t, label) for t in aug_ngram])
        
        return out_reda, out_ngram
    
    outputs_reda, outputs_gram = [], []
    for example in tqdm(data):
        out_reda, out_ngram = augment(example[0], example[1])
        outputs_reda.extend(out_reda)
        outputs_gram.extend(out_ngram)
    print('Texts augmented.')
    print(f'Before (reda): {len(data)}. Now: {len(outputs_reda)}')
    print(f'Before (ngram): {len(data)}. Now: {len(outputs_gram)}')
    return outputs_reda, outputs_gram

In [8]:
aug = AugTextsWithTwoModels()

### 500

In [9]:
five_h_aug_reda, five_h_aug_reda_ngram = textAugmentation(five_h, aug)
saveTextFile(five_h_aug_reda, '../data/aug_texts/train_0.5k_aug_reda.txt')
saveTextFile(five_h_aug_reda_ngram, '../data/aug_texts/train_0.5k_aug_reda_ngram.txt')

100%|█████████████████████████████████████████| 500/500 [01:11<00:00,  6.95it/s]

Texts augmented.
Before (reda): 500. Now: 3953
Before (ngram): 500. Now: 3936
../data/aug_texts/train_0.5k_aug_reda.txt has been saved!
../data/aug_texts/train_0.5_k_aug_reda_ngram.txt has been saved!





### 1,000

In [10]:
one_k_aug_reda, one_k_aug_reda_ngram = textAugmentation(one_k, aug)
saveTextFile(one_k_aug_reda, '../data/aug_texts/train_1k_aug_reda.txt')
saveTextFile(one_k_aug_reda_ngram, '../data/aug_texts/train_1k_aug_reda_ngram.txt')

100%|███████████████████████████████████████| 1000/1000 [02:40<00:00,  6.24it/s]

Texts augmented.
Before (reda): 1000. Now: 7899
Before (ngram): 1000. Now: 7859
../data/aug_texts/train_1k_aug_reda.txt has been saved!
../data/aug_texts/train_1k_aug_reda_ngram.txt has been saved!





### 2,000

In [11]:
two_k_aug_reda, two_k_aug_reda_ngram = textAugmentation(two_k, aug)
saveTextFile(two_k_aug_reda, '../data/aug_texts/train_2k_aug_reda.txt')
saveTextFile(two_k_aug_reda_ngram, '../data/aug_texts/train_2k_aug_reda_ngram.txt')

100%|███████████████████████████████████████| 2000/2000 [05:24<00:00,  6.17it/s]

Texts augmented.
Before (reda): 2000. Now: 15780
Before (ngram): 2000. Now: 15702
../data/aug_texts/train_2k_aug_reda.txt has been saved!
../data/aug_texts/train_2k_aug_reda_ngram.txt has been saved!





### 4,000

In [12]:
four_k_aug_reda, four_k_aug_reda_ngram = textAugmentation(four_k, aug)
saveTextFile(four_k_aug_reda, '../data/aug_texts/train_4k_aug_reda.txt')
saveTextFile(four_k_aug_reda_ngram, '../data/aug_texts/train_4k_aug_reda_ngram.txt')

100%|███████████████████████████████████████| 4000/4000 [10:01<00:00,  6.65it/s]

Texts augmented.
Before (reda): 4000. Now: 31585
Before (ngram): 4000. Now: 31436
../data/aug_texts/train_4k_aug_reda.txt has been saved!
../data/aug_texts/train_4k_aug_reda_ngram.txt has been saved!





### Full

The two augmented train sets for the entire train set were gotten from another running window. The statistics for them are reported below.  

In [13]:
full_aug_reda, full_aug_reda_ngram = textAugmentation(train, aug, aug_num_each=1)
saveTextFile(full_aug_reda, '../data/aug_texts/train_full_aug_reda.txt')
saveTextFile(full_aug_reda_ngram, '../data/aug_texts/train_full_aug_reda_ngram.txt')

100%|███████████████████████████████████████| 6920/6920 [18:26<00:00,  6.25it/s]

Texts augmented.
Before (reda): 6920. Now: 40883
Before (ngram): 6920. Now: 40713
../data/aug_texts/train_full_aug_reda.txt has been saved!
../data/aug_texts/train_full_aug_reda_ngram.txt has been saved!



