This notebook contains a definition for a class which can be used to augment SQuAD-style datasets.  

In [1]:
%%capture
!pip install tqdm==4.43.0 numpy requests nlpaug torch>=1.6.0 transformers>=4.0.0 sentencepiece nltk
!git clone https://github.com/jasonwei20/eda_nlp.git

Print information on GPU (optional)

In [None]:
!nvidia-smi

Wed Jan 19 20:24:11 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    25W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import nltk # for augmenting w/ eda
from tqdm import tqdm # for progress bar with augmenting 
from eda_nlp.code.eda  import * # for eda augmenting
import numpy as np # for using split() function
import pandas as pd# for using sample() function 
from pathlib import Path # for loading/saving JSON files
import json # for loading/saving JSON files
import os # for loading/saving JSON files
import random # for shuffling and generating question_ids for augmented data
import copy #for deep copy of data_dict
from google.colab import drive # to mount Google Drive

drive.mount('/root/content')

dirpath = '/root/content/My Drive/hitm/training-datasets'



Mounted at /root/content


In [3]:
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
model_name = 'tuner007/pegasus_paraphrase'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)

def get_response(input_text,num_return_sequences,num_beams):
  batch = tokenizer([str(input_text)],truncation=True,padding='longest',max_length=60, return_tensors="pt").to(torch_device)
  translated = model.generate(**batch,max_length=60,num_beams=num_beams, num_return_sequences=num_return_sequences, temperature=1.5)
  tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
  return tgt_text

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=86.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1912529.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=65.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1142.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2275437102.0, style=ProgressStyle(descr…




In [40]:
"""Loads SQuAD data dict at given directory with given filename. Taken from
Huggingface SQuAD finetuning tutorial: """
def load_json(dirpath, filename):
  with open(os.path.join(dirpath, filename), 'r') as f:
    data = json.load(f)
    return data

"""Saves the given SQuAD data dict in JSON format at given path w/ given filename"""
def save_json(dirpath, filename, data_dict):
  with open(os.path.join(dirpath, filename), 'w+') as f:
    json.dump(data_dict, f)
  return

"""Adds given meta value to all qas sets. Should be done
before any data augmentation procedures."""
def append_qas_key(data_dict, key, value=None, overwrite=False):
  for datum in data_dict['data']:
    for paragraph in datum['paragraphs']:
      for entry in paragraph['qas']:
        if overwrite:
          entry[key] = value
        else: # just want to add to entries w/o key
          if key not in entry:
            entry[key] = value
  return

"""Extracts all entries which were augmented with the specified method.
Method options include:
'all' : include all methods (default)"""
def get_augmented(data_dict, method="all"):
  augmented_entries = []
  for datum in data_dict['data']:
    paragraphs = datum['paragraphs']
    for paragraph in paragraphs:
      qas = paragraph['qas']
      for entry in qas:
        if entry['augmented']:
          if method is 'all' or method in entry['aug_methods']:
              augmented_entries.append(entry)
  return augmented_entries

"""Returns a list of all entries which have NOT been augmented"""
def get_unaugmented(data_dict):
  normal = []
  for datum in data_dict['data']:
    paragraphs = datum['paragraphs']
    for paragraph in paragraphs:
      qas = paragraph['qas']
      for entry in qas:
        if not entry['augmented']:
          normal.append(entry)
  return normal
"""
Returns a list of unique question ids from the question answer set
"""
def get_question_ids(data_dict):
  question_ids = []
  for datum in data_dict['data']:
    paragraphs = datum['paragraphs']
    for paragraph in paragraphs:
      qas = paragraph['qas']
      for entry in qas:
        question_ids.append(entry['id'])

  return question_ids

"""Returns list of q_ids for questions of certain split type (train, eval, val)"""
def get_split_ids(data_dict, split_label):
  ids = []
  entries = get_entry_list(data_dict)
  for entry in entries:
    if split_label == entry['split']:
      ids.append(entry['id'])
  return ids

"""A quick way of using Panda's sample function while sticking to list objects"""
def sample_list(item_list, frac):
  return pd.DataFrame(item_list).sample(frac=frac)[0].values.tolist()

"""Returns a list of answer ids from all answers
  in the SQuAD dataset"""
def get_answer_ids(data_dict):
  answer_ids = []
  for datum in data_dict['data']:
    paragraphs = datum['paragraphs']
    for paragraph in paragraphs:
      qas = paragraph['qas']
      for entry in qas:
        answers = entry['answers']
        for answer in answers:
          answer_ids.append(answer['answer_id'])
  return answer_ids

"""Returns all document ids. Can be used for a variety of tasks"""
def get_document_ids(data_dict):
  document_ids = []
  for doc in data_dict['data']:
    document_ids.append(doc['paragraphs'][0]['document_id']) # doc id is same regarldess of paragraph
  return document_ids

"""Returns all the qas objects, which are lists containing 
  dict objects which consists of
  questions, their answers, and metadata One qas obj per document"""
def get_qas_list(data_dict):
  qas_list = []
  for datum in data_dict['data']: # there's only one item, but for posterity
    paragraphs = datum['paragraphs'] # for each paragraph get qas
    for paragraph in paragraphs:
      qas_list.append(paragraph['qas']) # list of question answer sets (one per doc)
  return qas_list

def get_entry_list(data_dict, sort_key =None):
  entry_list = []
  for datum in data_dict['data']: # there's only one item, but for posterity
    paragraphs = datum['paragraphs'] # for each paragraph get qas
    for paragraph in paragraphs:
      for entry in paragraph['qas']:
        entry_list.append(entry) # list of question answer sets (one per doc)
  if sort_key is not None:
    entry_list.sort(key= lambda x: x[sort_key]) # sort by given key (doc_ids usually)
  return entry_list

def get_context_list(data_dict):
  contexts = []
  for datum in data_dict['data']:
    paragraphs = datum['paragraphs']
    for paragraph in paragraphs:
      contexts.append(paragraph['context'])
  return contexts

"""
This gets the index splits for the data. Because of the way 
the np.split function works, train split will always go 0 to x, and eval
will go from x + 1 to len(dataset - 1). IF there is a validation split as well,
then train will go from 0 to X, eval from X + 1 to Y, and validation from 
Y + 1 to len(dataset - 1).
"""
def get_splits(size, train_percentage, eval_percentage, val_percentage=0):
  print(f'in get_splits - size: {size} train: {train_percentage} eval: {eval_percentage} val: {val_percentage}')
  if train_percentage + eval_percentage + val_percentage != 1.0:
    print("Error: percentages must add to 1")
    return 0, 0
  else:
    train_size = int(size * train_percentage) # convert percentage to int
    eval_size = int(size * eval_percentage)
    validation_size = int(size * val_percentage)
    remaining = size - (train_size + eval_size + validation_size) 
    # account for leftovers
    if remaining > 0:
      train_size = train_size + remaining # give leftovers due to rounding fto training dataset
    train_eval_split = train_size # 0 to X, eval gets X + 1 to Y (where Y is len(dataset) - validation size)
    if val_percentage == 0:
      val_split = 0
    else: # must be remaining from train/eval to end of list
      val_split = train_size + eval_size # val split begins at index where train/eval ends
    return train_eval_split, val_split

"""
Shuffles the given ids randomly, splits them, and returns splits which are
later used to label the SQuAD data dict. 
"""
def split_question_ids(question_ids, train, eval, val):
  print(f'in split_question_ids - size: {len(question_ids)} train: {train} eval: {eval} val: {val}')
  train_ids = []
  eval_ids = []
  val_ids = []
  random.shuffle(question_ids)
  train_eval_split, val_split = get_splits(len(question_ids), train, eval, val)
  print(f'train_eval_split: {train_eval_split} val_split: {val_split}')
  if val_split== 0: # if there's no validation set split into test/eval
    train_ids, eval_ids= np.split(question_ids, [train_eval_split])
  else: # if there IS a validation split split sets into three parts accordingly
  # split contexts
    train_ids, eval_ids, val_ids = np.split(question_ids, [train_eval_split, val_split])
  train_ids = list(train_ids)
  eval_ids = list(eval_ids)
  val_ids = list(val_ids)
  return train_ids, eval_ids, val_ids
"""Goes through all entries and labels with appropriate split label. This 
can then be used in the preprocessing notebook to split the sets and preprocess
them before training a model."""
def label_splits(data_dict,train_ids, eval_ids, val_ids):
  for datum in data_dict['data']:
    for paragraph in datum['paragraphs']:
      for entry in paragraph['qas']:
        id = entry['id']
        if id in train_ids:
          entry['split'] = 'train'
        elif id in eval_ids:
          entry['split'] = 'eval'
        elif id in val_ids:
          entry['split'] = 'val'
        else:
          print(f'error: question {id} has no split label')
  return

"""Splits the dataset and labels the splits so they can be saved"""
def split_dataset(data_dict, train, eval, val, augs_train_only=True):
  print(f'in split_dataset - train {train} eval {eval} val {val} augonly {augs_train_only}')
  train_ids = []
  eval_ids = []
  val_ids = []
  if augs_train_only: # if only using augmented data in training split
    unaug_ids = []
    aug_ids = []
    unaugmented = get_unaugmented(data_dict)
    augmented = get_augmented(data_dict)
    for entry in augmented:
      aug_ids.append(entry['id'])
    for entry in unaugmented:
      unaug_ids.append(entry['id'])
    train_ids, eval_ids ,val_ids = split_question_ids(unaug_ids, train, eval, val)
    train_ids.extend(aug_ids)
  else: # split indiscriminately
    question_ids = get_question_ids(data_dict)
    train_ids, eval_ids, val_ids = split_question_ids(question_ids, train, eval, val)
  # now label every question in the dataset according to its split
  label_splits(data_dict, train_ids, eval_ids, val_ids)
  return

"""Checks splits by ensuring len(train) + len(eval) + len(val) = len(total)"""
def check_splits(data_dict):
  train = 0
  eval = 0
  val = 0
  total = 0
  for datum in data_dict['data']:
    for paragraph in datum['paragraphs']:
      for entry in paragraph['qas']:
        total = total + 1
        if entry['split'] is 'train':
          train = train + 1
        elif entry['split'] is 'eval':
          eval = eval + 1 
        elif entry['split'] is 'val':
          val = val + 1
  return train, eval, val, total

"""Returns a newly generated unique question id"""
def get_new_id(question_ids):
  n = random.randint(0,999999)
  while n in question_ids: # generate until unique id found
    n = random.randint(0,999999)
  return n

"""
Generates a deep copy of the entry dict while replacing it with new question ids.
Also updates the question ids list with the new generated id.
"""

def replace_entry_ids(entry, question_ids):
  new_entry = copy.deepcopy(entry) # deep copy b/c nested objects
  new_id = get_new_id(question_ids) # get new id for generated question
  question_ids.append(new_id) # update question_ids list
  new_entry['id'] = new_id # add new question id
  for answer in new_entry['answers']: # update for every answer
    answer['question_id'] = new_id
  return new_entry

"""
Takes an entry, the question_ids list, and the replacement
question and generates a new entry with the updated 
question. The replacement question will be created
from an outside method using data agumentation"""

def generate_replacements(entry, question_ids, replacement_qs, aug_methods):
  new_entries = []

  for replacement_q in replacement_qs: # for all replacements
    new_entry = replace_entry_ids(entry, question_ids) # generate q_ids
    new_entry['question'] = replacement_q # insert new question
    new_entry['augmented'] = True # set augmented to true
    if 'aug_methods' in new_entry: # if previously augmented, add new method(s) to list
      for method in aug_methods:
        if method not in new_entry['aug_methods']:     
          new_entry['aug_methods'].append(method)
    else: # newly augmented entry, so just insert method(s) used
      new_entry['aug_methods'] = aug_methods
    new_entries.append(new_entry) # add this to list of new entries
  return new_entries # return list to be added to data_dict qas



"""Iterates through data starting at given document id (default starts at beginning)
and updates the data_dict. Also has an option to create a copy of the augmented entries so they can
be saved seperately."""

def paraphrase_supervised(data_dict,
                          doc_index=0, 
                          paragraph_index=0, 
                          qas_entry_index=0, 
                          copy_augments=False,
                          reaugment=False):
  
  stop=False
  ran = False 
  i = doc_index
  j = paragraph_index
  k = qas_entry_index

  question_ids = get_question_ids(data_dict)

  num_docs = len(data_dict['data'])

  while i < num_docs and not stop:
    doc = get_doc(data_dict, i)
    num_paragraphs = len(doc['paragraphs'])
    if ran:
      j = 0
    while j < num_paragraphs and not stop:
      paragraph = get_paragraph(doc, j)
      num_qas_entries = len(paragraph['qas'])

      if ran:
        k = 0

      while k < num_qas_entries and not stop:
        ran = True
        replacements = []
        entry = get_qas_entry(paragraph, k)
        # only augment if not previously augmented or reaug flag set
        if not entry['augmented'] or (entry['augmented'] and reaugment):
          stop = len(input(f"\nDOC {i} | PARAGRAPH {j} | QAS-ENTRY {k} | ENTER to continue or x to stop: ")) != 0
          if not stop:
            augments = get_response(entry['question'], num_augments, num_beams=10)
            replacements = generate_replacements(entry,
                                                 question_ids, 
                                                 replacement_qs = augments, 
                                                 aug_methods=['paraphrase_supervised'])
            insert_qas_entries(data_dict,replacements, i, j)
          else:
            print(f"STOPPING AT DOC{i} | PARAGRAPH {j} | ENTRY {k}")
          
        k = k + 1
      j = j + 1
    i = i + 1

  return data_dict # shallow copy!!!

def paraphrase_unsupervised(data_dict,
                              copy_augments=False,
                              reaugment=False,
                              num_augments=10,
                              alpha_paraphrase=1.0,
                              alpha_train = 1.0,
                              alpha_val=1.0):
  num_augs = 0
  ran = False
  i = 0
  j = 0
  k = 0

  question_ids = get_question_ids(data_dict)
  train_ids = get_split_ids(data_dict, split_label="train")
  train_ids = sample_list(train_ids, alpha_train)

  question_ids = get_question_ids(data_dict)
  num_docs = len(data_dict['data'])
  with tqdm(total=len(question_ids)) as pbar:
    while i < num_docs:
      doc = get_doc(data_dict, i)
      num_paragraphs = len(doc['paragraphs'])
      if ran:
        j = 0
      while j < num_paragraphs:
        paragraph = get_paragraph(doc, j)
        num_qas_entries = len(paragraph['qas'])

        if ran:
          k = 0

        while k < num_qas_entries:
          ran = True
          replacements = []
          entry = get_qas_entry(paragraph, k)
          if not entry['augmented'] or (entry['augmented'] and reaugment):
            if entry['id'] in train_ids: # only augment train split
              augments = get_response(entry['question'], num_augments, num_beams=10)
            #sample the given percentage of replacements randomly and then insert those
              augments = sample_list(augments, frac=alpha_paraphrase)
              replacements = generate_replacements(entry, 
                                                  question_ids, 
                                                  replacement_qs =augments,
                                                  aug_methods=["paraphrase_unsupervised"])
              insert_qas_entries(data_dict,replacements, i, j)
              num_augs = num_augs + len(replacements)
          pbar.update(1)
          k = k + 1
        j = j + 1
      i = i + 1
  print(f"{num_augs} augmented entries created")
  return 

def augment_eda(data_dict,
                reaugment=False,
                num_augments=10,
                percent_sampled=1.0,
                alpha_sr=0.1,
                alpha_ri=0.1,
                alpha_rs=0.1,
                p_rd=0.1,
                alpha_train = 0.1
                ):
  num_augs = 0
  aug_methods = []
  if alpha_sr > 0.0:
    aug_methods.append("eda_sentence_replacement")
  if alpha_ri > 0.0:
    aug_methods.append("eda_random_insertion")
  if alpha_rs > 0.0:
    aug_methods.append("eda_random_swap")
  if p_rd > 0.0:
    aug_methods.append("eda_random_deletion")

  nltk.download('wordnet')
  ran = False
  i = 0
  j = 0
  k = 0
  question_ids = get_question_ids(data_dict)
  train_ids = get_split_ids(data_dict, split_label="train")
  train_ids = sample_list(train_ids, alpha_train)
  num_docs = len(data_dict['data']) 

  with tqdm(total = len(question_ids)) as pbar:
    while i < num_docs:
      doc = get_doc(data_dict, i)
      num_paragraphs = len(doc['paragraphs'])
      if ran:
        j = 0
      while j < num_paragraphs:
        paragraph = get_paragraph(doc, j)
        num_qas_entries = len(paragraph['qas'])
        if ran:
          k = 0
        while k < num_qas_entries:
          ran = True
          replacements = []
          entry = get_qas_entry(paragraph, k)
          if not entry['augmented'] or (entry['augmented'] and reaugment):
            if entry['id'] in train_ids: # only augment train split
              augments = eda(entry['question'], alpha_sr, alpha_ri, alpha_rs, p_rd, num_augments)
            #sample the given percentage of replacements randomly and then insert those
              augments=sample_list(augments, frac=percent_sampled)
              replacements = generate_replacements(entry, 
                                                  question_ids, 
                                                  replacement_qs =augments,
                                                  aug_methods=aug_methods)
              insert_qas_entries(data_dict,replacements, i, j)
              num_augs = num_augs + len(replacements)
          k = k + 1
          pbar.update(1)
        j = j + 1
      i = i + 1
  print(f"{num_augs} augmented entries created")
  return 

"""Returns doc object at specified index"""
def get_doc(data_dict, index):
  doc = None
  try:
    doc = data_dict['data'][index]
  except:
    print(f"index {index} is not between 0 and {len(data_dict['data'])}")
  return doc

"""Returns paragraph object at specified index from given doc"""
def get_paragraph(doc, index):
  paragraph = None
  try:
    paragraph = doc['paragraphs'][index]
  except:
    print(f"index {index} is not between 0 and {len(doc['paragraphs'])}")
  return paragraph

"""Returns qas emtry at specified index from given paragraph"""
def get_qas_entry(paragraph, index):
  entry = None
  try:
    entry = paragraph['qas'][index]
  except:
    print(f"index {index} is not between 0 and {len(paragraph['qas'])}")
  return entry

"""Used for augmentation. Inserts given list of entries into qas object
at given doc and paragraph idx"""
def insert_qas_entries(data_dict, entries, doc_index, paragraph_index):
  for entry in entries:
    data_dict['data'][doc_index]['paragraphs'][paragraph_index]['qas'].append(entry)
  return 

"""A kind of useless definition to test that augmenting was working on a list of entries"""

def augment_entry_list(data_dict, entry_list, entry_index=0, copy_augments=False):
  question_ids = get_question_ids(data_dict)
  for i in range(entry_index, len(entry_list)):
    replacements = generate_replacements(entry_list[i], question_ids)
    for replacement in replacements:
      entry_list.append(replacement)
    return 

"""Returns all qas entries in dataset with matching key/value pair"""
def search_entries_by_value(data_dict, key, value):
  matches = []
  entries = get_entry_list(data_dict)
  for entry in entries:
    if ((type(entry[key]) is int) or (type(entry[key]) is bool))  and entry[key] == value:
      matches.append(entry)
    elif type(entry[key]) is str and value in entry[key]:
      matches.append(entry)
  return matches

"""Fast way to see split count, aug count, & aug type"""

def get_count_dict(data_dict):
  dataset_count = {'docs': 0, 'paragraphs': 0, 'questions': 0, 'answers': 0}
  split_count = {'train': 0, 'val' : 0, 'eval': 0}
  aug_count = {'unaugmented': 0, 'augmented' : 0}
  aug_method_count = {}
  for datum in data_dict['data']:
    dataset_count['docs'] +=1
    for paragraph in datum['paragraphs']:
      dataset_count['paragraphs'] +=1
      for entry in paragraph['qas']:
        dataset_count['questions'] +=1
        dataset_count['answers'] += len(entry['answers'])
        split_count[entry['split']] +=1
        if entry['augmented']:
          aug_count['augmented'] +=1
          for method in entry['aug_method']:
            if method in aug_method_count.keys():
              aug_method_count[method] +=1
            else:
              aug_method_count[method] = 1
        else:
          aug_count['unaugmented']+=1
  count_dict = {
      'totals': dataset_count,
      'splits':split_count,
      'augments':aug_count,
      'methods': aug_method_count
  }
  return count_dict



In [68]:
filename = 'newparaphrase.json' # fsu_datav6.json is final unaug'd & unaug_data_split.json is split version
data_dict = load_json(dirpath, filename)
#data_dict['data'][DOC-IDX]['paragraphs'][PAR-IDX]['context','document_id', 'qas']...
data_dict['data'][0]['paragraphs'][0]['qas'][0]

{'answers': [{'answer_category': None,
   'answer_id': 227675,
   'answer_start': 67,
   'document_id': 287676,
   'question_id': 173820,
   'text': 'National Institute on Aging '}],
 'aug_method': 'none',
 'augmented': False,
 'id': 173820,
 'is_impossible': False,
 'question': 'Where did the award come from?',
 'split': 'train'}

Split as desired (DO BEFORE AUGMENTING)



In [None]:
split_dataset(data_dict, train=.8, eval=.1, val=.1)

in split_dataset - train 0.8 eval 0.1 val 0.1 augonly True
in split_question_ids - size: 2276 train: 0.8 eval: 0.1 val: 0.1
in get_splits - size: 2276 train: 0.8 eval: 0.1 val: 0.1
train_eval_split: 1822 val_split: 2049


Augment using desired method

In [69]:
#paraphrase_unsupervised(data_dict, num_augments=10, alpha_paraphrase=.5, alpha_train=1.0) # ends up being 5
augment_eda(data_dict, num_augments=5, alpha_train=1)

  2%|▏         | 207/11386 [00:00<00:05, 2069.00it/s]

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


100%|██████████| 11386/11386 [00:06<00:00, 1878.85it/s]

10932 augmented entries created





Save dataset

In [70]:
save_json(dirpath, filename="newedapara.json",data_dict=data_dict)

Print & or save counts on splits, augments, etc.

In [71]:
count_dict = get_count_dict(data_dict)
print(count_dict)

{'totals': {'docs': 311, 'paragraphs': 311, 'questions': 22318, 'answers': 22318}, 'splits': {'train': 21864, 'val': 227, 'eval': 227}, 'augments': {'unaugmented': 2276, 'augmented': 20042}, 'methods': {'n': 40084, 'o': 20042, 'e': 20042}}
