In [1]:
from parrot import Parrot
import pandas as pd
from tqdm import tqdm
import os
import nlpaug.augmenter.word as naw
# Download word2vec model
os.environ["MODEL_DIR"] = '/workspaces/mlops-template-Shunian-Chen/model'
model_dir = os.environ["MODEL_DIR"]
# from nlpaug.util.file.download import DownloadUtil
# DownloadUtil.download_word2vec(dest_dir=model_dir) 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=True)

In [3]:
def paraphrase_aug(data, num):
    aug_text = {'text': []}
    for text in tqdm(data):
        paraphrases = parrot.augment(text,  max_length = 250, max_return_phrases = 5, do_diverse=False, use_gpu=True)
        if paraphrases:
            for para in paraphrases:
                if para[1] != 0:
                    aug_text['text'].append(para[0])
                    num -= 1
    df_aug = pd.DataFrame(aug_text)
    return df_aug, num

def data_augs(data, num, model):
    aug_text = {'text': []}

    if model == 'word_embs':
        aug = naw.WordEmbsAug(model_type='word2vec', model_path=model_dir+'GoogleNews-vectors-negative300.bin',action="substitute", device='cuda')
    elif model == 'contextual':
        aug = naw.ContextualWordEmbsAug(model_path='roberta-base', action="substitute", device='cuda')
    elif model == 'synonym':
        aug = naw.SynonymAug(aug_src='wordnet')

    for text in data.values.tolist():
        augmented_text = aug.augment(str(text))
        aug_text['text'] += augmented_text
        num -= 1
        if num == 0:
            break
    df_aug = pd.DataFrame(aug_text)

    return df_aug, num


In [4]:
origin_data = pd.read_csv("/workspaces/MBTI-Personality-Test/Data/mbti_train.csv")

In [5]:
cnt = origin_data.groupby("label").count().reset_index()
avg = len(origin_data) // 16
print(cnt, avg)

    label   text
0       0   7639
1       1   6279
2       2  55830
3       3  40742
4       4  12630
5       5   9950
6       6  69372
7       7  48976
8       8   3391
9       9   1731
10     10  25672
11     11  26415
12     12   1507
13     13   1631
14     14   7311
15     15   8752 20489


In [6]:
aug = pd.DataFrame(cnt)
aug["to_aug"] = cnt['text'].apply(lambda x: (avg - x)).clip(0)
aug.columns = ["label", "length", "to_aug"]
aug

Unnamed: 0,label,length,to_aug
0,0,7639,12850
1,1,6279,14210
2,2,55830,0
3,3,40742,0
4,4,12630,7859
5,5,9950,10539
6,6,69372,0
7,7,48976,0
8,8,3391,17098
9,9,1731,18758


In [7]:
def augment_data_by_label(data, num, label):
    print("data: ", data.shape)
    augmented_data, num = paraphrase_aug(data, num)
    new_data = pd.concat([pd.DataFrame(data), augmented_data], axis=0)
    print("num after paraphrase: ", num)
    print("augmented_data: ", augmented_data.shape)
    print("new_data: ", new_data.shape)
    if num > 0:
        augmented_data, num = data_augs(new_data, num, 'synonym')
        new_data = pd.concat([new_data, augmented_data], axis=0)
        print("num after synonym: ", num)
        print("augmented_data: ", augmented_data.shape)
    if num > 0:
        augmented_data, num = data_augs(new_data, num, 'contextual')
        new_data = pd.concat([new_data, augmented_data], axis=0)
        print("num after contextual: ", num)
        print("augmented_data: ", augmented_data.shape)
    # if num > 0:
    #     augmented_data, num = data_augs(new_data, num, label, 'word_embs')
    #     new_data = pd.concat([new_data, augmented_data], axis=0)
    print("new_data: ", new_data.shape)
    new_data['label'] = label
    return new_data


In [1]:
a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
a[4:]

[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [8]:
def data_augmentation(data, aug):
    to_aug = aug[aug['to_aug'] > 0]['label'].tolist()
    data_augmented = data
    for label in to_aug[4:]:
        num = aug[aug['label'] == label]['to_aug'].values.tolist()[0]
        print(f"Augmenting label: {label}, num to aug: {num}")

        label_data = data[data['label'] == label].reset_index(drop=True)
        augmented_data = augment_data_by_label(label_data['text'], num, label)
        
        data_augmented = pd.concat([data_augmented, augmented_data], axis=0)
        data_augmented.to_csv(f"mbti_train_aug{label}.csv", index=False)
    return data_augmented

In [9]:
augmented_data = data_augmentation(origin_data, aug)

Augmenting label: 0, num to aug: 12850
data:  (7639,)


100%|██████████| 7639/7639 [1:52:40<00:00,  1.13it/s]  
[nltk_data] Downloading package wordnet to /home/vscode/nltk_data...
[nltk_data] Downloading package omw-1.4 to /home/vscode/nltk_data...
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/vscode/nltk_data...


num after paraphrase:  8977
augmented_data:  (3873, 1)
new_data:  (11512, 1)


[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


num after synonym:  0
augmented_data:  (8977, 1)
new_data:  (20489, 1)
Augmenting label: 1, num to aug: 14210
data:  (6279,)


  5%|▌         | 345/6279 [05:10<1:39:03,  1.00s/it]Bad pipe message: %s [b'*w\xf9\xe8\x19\x92\x84\x03QF\x07\xb9\xda@\xe9\x03R\xea \x0e\xc0\xfc\xdd\x023:"\x99\xfc\xcb\x86\xe5nVx\xd0\xe7\\\xedy\xbd\xfa~\xd7 \xe4\x04\xab7%"\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00']
Bad pipe message: %s [b"\xcb`w\xe5;\xfb\xb3`\x9e\x9d\xe5FUSr|Z\xaa\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x0

num after paraphrase:  11130
augmented_data:  (3080, 1)
new_data:  (9359, 1)
num after synonym:  1771
augmented_data:  (9359, 1)


Downloading: 100%|██████████| 481/481 [00:00<00:00, 373kB/s]
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 49.0MB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 48.7MB/s]
Downloading: 100%|██████████| 1.36M/1.36M [00:00<00:00, 23.9MB/s]
Downloading: 100%|██████████| 501M/501M [00:06<00:00, 79.7MB/s] 


num after contextual:  0
augmented_data:  (1771, 1)
new_data:  (20489, 1)
Augmenting label: 4, num to aug: 7859
data:  (12630,)


100%|██████████| 12630/12630 [2:57:24<00:00,  1.19it/s] 


num after paraphrase:  1354
augmented_data:  (6505, 1)
new_data:  (19135, 1)
num after synonym:  0
augmented_data:  (1354, 1)
new_data:  (20489, 1)
Augmenting label: 5, num to aug: 10539
data:  (9950,)


100%|██████████| 9950/9950 [2:19:33<00:00,  1.19it/s]  


num after paraphrase:  5790
augmented_data:  (4749, 1)
new_data:  (14699, 1)
num after synonym:  0
augmented_data:  (5790, 1)
new_data:  (20489, 1)
Augmenting label: 8, num to aug: 17098
data:  (3391,)


  8%|▊         | 261/3391 [03:39<43:54,  1.19it/s] 


KeyboardInterrupt: 

In [None]:
augmented_data.to_csv("mbti_train_augmented.csv", index=False)

In [None]:
origin_data.shape

(327828, 2)

In [None]:
augmented_data.shape

(388749, 2)