In [3]:
from parrot import Parrot
import pandas as pd
from tqdm import tqdm
import os
import nlpaug.augmenter.word as naw
# Download word2vec model
os.environ["MODEL_DIR"] = '/workspaces/mlops-template-Shunian-Chen/model'
model_dir = os.environ["MODEL_DIR"]
# from nlpaug.util.file.download import DownloadUtil
# DownloadUtil.download_word2vec(dest_dir=model_dir) 

In [4]:
parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5", use_gpu=True)

In [5]:
def paraphrase_aug(data, num):
    aug_text = {'text': []}
    for text in tqdm(data):
        paraphrases = parrot.augment(text,  max_length = 250, max_return_phrases = 5, do_diverse=False, use_gpu=True)
        if paraphrases:
            for para in paraphrases:
                if para[1] != 0:
                    aug_text['text'].append(para[0])
                    num -= 1
    df_aug = pd.DataFrame(aug_text)
    return df_aug, num

def data_augs(data, num, model):
    aug_text = {'text': []}

    if model == 'word_embs':
        aug = naw.WordEmbsAug(model_type='word2vec', model_path=model_dir+'GoogleNews-vectors-negative300.bin',action="substitute", device='cuda')
    elif model == 'contextual':
        aug = naw.ContextualWordEmbsAug(model_path='roberta-base', action="substitute", device='cuda')
    elif model == 'synonym':
        aug = naw.SynonymAug(aug_src='wordnet')

    for text in data.values.tolist():
        augmented_text = aug.augment(str(text))
        aug_text['text'] += augmented_text
        num -= 1
        if num == 0:
            break
    df_aug = pd.DataFrame(aug_text)

    return df_aug, num


In [6]:
origin_data = pd.read_csv("/workspaces/MBTI-Personality-Test/Data/mbti_train.csv")

In [7]:
cnt = origin_data.groupby("label").count().reset_index()
avg = len(origin_data) // 16
print(cnt, avg)

    label   text
0       0   7639
1       1   6279
2       2  55830
3       3  40742
4       4  12630
5       5   9950
6       6  69372
7       7  48976
8       8   3391
9       9   1731
10     10  25672
11     11  26415
12     12   1507
13     13   1631
14     14   7311
15     15   8752 20489


In [8]:
aug = pd.DataFrame(cnt)
aug["to_aug"] = cnt['text'].apply(lambda x: (avg - x)).clip(0)
aug.columns = ["label", "length", "to_aug"]
aug

Unnamed: 0,label,length,to_aug
0,0,7639,12850
1,1,6279,14210
2,2,55830,0
3,3,40742,0
4,4,12630,7859
5,5,9950,10539
6,6,69372,0
7,7,48976,0
8,8,3391,17098
9,9,1731,18758


In [9]:
def augment_data_by_label(data, num, label):
    print("data: ", data.shape)
    augmented_data, num = paraphrase_aug(data, num)
    new_data = pd.concat([pd.DataFrame(data), augmented_data], axis=0)
    print("num after paraphrase: ", num)
    print("augmented_data: ", augmented_data.shape)
    print("new_data: ", new_data.shape)
    if num > 0:
        augmented_data, num = data_augs(new_data, num, 'synonym')
        new_data = pd.concat([new_data, augmented_data], axis=0)
        print("num after synonym: ", num)
        print("augmented_data: ", augmented_data.shape)
    if num > 0:
        augmented_data, num = data_augs(new_data, num, 'contextual')
        new_data = pd.concat([new_data, augmented_data], axis=0)
        print("num after contextual: ", num)
        print("augmented_data: ", augmented_data.shape)
    # if num > 0:
    #     augmented_data, num = data_augs(new_data, num, label, 'word_embs')
    #     new_data = pd.concat([new_data, augmented_data], axis=0)
    print("new_data: ", new_data.shape)
    new_data['label'] = label
    return new_data


In [10]:
def data_augmentation(data, aug):
    to_aug = aug[aug['to_aug'] > 0]['label'].tolist()
    data_augmented = data
    for label in to_aug[4:]:
        num = aug[aug['label'] == label]['to_aug'].values.tolist()[0]
        print(f"Augmenting label: {label}, num to aug: {num}")

        label_data = data[data['label'] == label].reset_index(drop=True)
        augmented_data = augment_data_by_label(label_data['text'], num, label)
        
        data_augmented = pd.concat([data_augmented, augmented_data], axis=0)
        data_augmented.to_csv(f"mbti_train_aug{label}.csv", index=False)
    return data_augmented

In [11]:
augmented_data = data_augmentation(origin_data, aug)

Augmenting label: 8, num to aug: 17098
data:  (3391,)


100%|██████████| 3391/3391 [51:09<00:00,  1.10it/s] 


num after paraphrase:  15647
augmented_data:  (1451, 1)
new_data:  (4842, 1)
num after synonym:  10805
augmented_data:  (4842, 1)
num after contextual:  1121
augmented_data:  (9684, 1)
new_data:  (19368, 1)
Augmenting label: 9, num to aug: 18758
data:  (1731,)


 83%|████████▎ | 1445/1731 [20:09<02:26,  1.96it/s]Bad pipe message: %s [b'\x9e\x9b\xe6\x1c6`)\xa5"\xf2\xd5\xff\x8f|\xb5\xd4\xf8\x14\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0\'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0', b'3\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0']
Bad pipe message: %s [b'=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00']
Bad pipe message: %s [b'$\xad\xd3k']
Bad pipe message: %s [b'\xa6o\xa9F\xd6\xcevn}\xab\xff~\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\

num after paraphrase:  17866
augmented_data:  (892, 1)
new_data:  (2623, 1)
num after synonym:  15243
augmented_data:  (2623, 1)
num after contextual:  9997
augmented_data:  (5246, 1)
new_data:  (10492, 1)
Augmenting label: 12, num to aug: 18982
data:  (1507,)


100%|██████████| 1507/1507 [22:20<00:00,  1.12it/s]


num after paraphrase:  18247
augmented_data:  (735, 1)
new_data:  (2242, 1)
num after synonym:  16005
augmented_data:  (2242, 1)
num after contextual:  11521
augmented_data:  (4484, 1)
new_data:  (8968, 1)
Augmenting label: 13, num to aug: 18858
data:  (1631,)


100%|██████████| 1631/1631 [25:18<00:00,  1.07it/s]


num after paraphrase:  18167
augmented_data:  (691, 1)
new_data:  (2322, 1)
num after synonym:  15845
augmented_data:  (2322, 1)
num after contextual:  11201
augmented_data:  (4644, 1)
new_data:  (9288, 1)
Augmenting label: 14, num to aug: 13178
data:  (7311,)


100%|██████████| 7311/7311 [1:51:05<00:00,  1.10it/s]  


num after paraphrase:  9695
augmented_data:  (3483, 1)
new_data:  (10794, 1)
num after synonym:  0
augmented_data:  (9695, 1)
new_data:  (20489, 1)
Augmenting label: 15, num to aug: 11737
data:  (8752,)


100%|██████████| 8752/8752 [2:08:47<00:00,  1.13it/s]  


num after paraphrase:  7222
augmented_data:  (4515, 1)
new_data:  (13267, 1)
num after synonym:  0
augmented_data:  (7222, 1)
new_data:  (20489, 1)


In [None]:
augmented_data.to_csv("mbti_train_augmented.csv", index=False)

: 

: 

In [None]:
origin_data.shape

(327828, 2)

: 

: 

In [None]:
augmented_data.shape

(388749, 2)

: 

: 

: 

: 