<a href="https://colab.research.google.com/github/halnegheimish/ForcedInvalidation/blob/main/notebooks/FI_Augment_Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets==2.7.1 transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
#preliminaries
import itertools
import random
import numpy as np
import torch
import copy
seed = 0

random.seed(seed)
np.random.seed(seed)

from transformers import set_seed

set_seed(seed)



#the main shuffle function
def shuffle(text, n):
  """Shuffles input text bsed on a specific n-gram size
  parameters:
    text: string to be shuffled
    n: size of ngram to use, should be in [1,2,3]
  returns:
    shuffled_text: the shuffled input text string, s.t. it cannot be the same 
  """
  if text[-1] in ['!',"?", '.']:
    punc=text[-1]
    words=text[:-1]
  else:
    punc=''
    words=text
  
  words=words.split(" ")

  #join n-grams
  if n==2:
    words= [' '.join(x) for x in itertools.zip_longest(words[0::2], words[1::2], fillvalue="") ]
    
  if n==3:
    words= [' '.join(x) for x in itertools.zip_longest(words[0::3], words[1::3],  words[2::3], fillvalue="") ]
  
  words[-1]=words[-1].strip()

  random.shuffle(words)
  shuffled_text=' '.join(words)+punc

  #shuffled sentences should not match original
  if shuffled_text == text:
    return shuffle(text, n)

  else:
    return shuffled_text


#in the following function, we need to change:
#1) names of dataset keys, sentence, idx, label
#2) need to make a choice between premise and hypothesis
#3) label needs to be 1 more than original

def shuffle_sentence_ngrams(example, p1_key='premise', p2_key='hypothesis', 
                            key=None, ngram=None, label=None):
  """Shuffles input text bsed on a specific n-gram size
  parameters:
    example: example to be shuffled
    p1_key: key of the first component, e.g. premise
    p2_key: key of the second component, e.g. hypothesis
    ngram: size of ngram to use, should be in [1,2,3], if none sample uniformly
    label: the invalid label, in case it's none the original label is preserved

  returns:
    the same example with the  
  """
  sh_example= copy.deepcopy(example)

  if not ngram: #sample uniformly
    ngram= np.random.choice([1,2,3], 1)[0]

  #uniformly choose which component to shuffle
  if not key:
    flag= random.randint(0,1)
    key= p1_key if flag==1 else p2_key

  try:
    sh_example[key]= shuffle(sh_example[key], ngram) 

  except RecursionError:
    print('Recursion error occured for '+ sh_example[key])
    sh_example[key]="RECURSIONERROR"
    
  
  #only change idx in training examples, when label is changed
  if sh_example['idx'] and label is not None:
    sh_example['idx']=str(sh_example['idx'])+'_'+str(ngram)+'g'

  if label:
    sh_example['label']=label 
  
  
  return sh_example

def is_all_longer_3(example, p1_key='premise', p2_key='hypothesis'):
  """Returns true when both p1 and p2 are longer than 3 words
  parameters:
    example: example of interest
    p1_key: key of the first component, e.g. premise
    p2_key: key of the second component, e.g. hypothesis
  """
  return len(example[p1_key].split(" ") ) > 3 and len(example[p2_key].split(" ") ) > 3


def to_keep(example, p1_key='premise', p2_key='hypothesis'):
  """Returns true when neither p1 or p2 generated an error in shuffling
  parameters:
    example: example of interest
    p1_key: key of the first component, e.g. premise
    p2_key: key of the second component, e.g. hypothesis
  """
  return 'RECURSIONERROR' not in example[p1_key] and 'RECURSIONERROR' not in example[p2_key]
        

In [None]:
#create train datasets
from datasets import load_dataset, load_metric, concatenate_datasets, ClassLabel, Value

#load dataset
mnli = load_dataset("glue", "mnli")
metric = load_metric('glue', 'mnli')

#create train/dev split
clean_split= mnli['train'].train_test_split(test_size=0.1)

#filter out short examples
longer_cs=clean_split.filter(is_all_longer_3)

#change dataset metadata
new_features = longer_cs['train'].features.copy()
new_features['label'] = ClassLabel(names=['entailment', 'neutral', 'contradiction', 'invalid'])
new_features['idx'] = Value('string')

longer_cs = longer_cs.cast(new_features)

#shuffle
shuff_mnli=longer_cs.map(shuffle_sentence_ngrams, fn_kwargs={"label": 3}, num_proc=4)
shuff_mnli=shuff_mnli.filter(to_keep) #remove failed examples


#combine shuffled and original data
ngrams_train=concatenate_datasets([shuff_mnli['train'], longer_cs['train']])
ngrams_dev=concatenate_datasets([shuff_mnli['test'], longer_cs['test']])

ngrams_train=ngrams_train.shuffle(seed=42)
ngrams_dev=ngrams_dev.shuffle(seed=42)

#save to disk
ngrams_train.save_to_disk('mnli_train_ngrams')
ngrams_dev.save_to_disk('mnli_val_ngrams')



  0%|          | 0/5 [00:00<?, ?it/s]

  metric = load_metric('glue', 'mnli')


     

#0:   0%|          | 0/84703 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/84702 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/84702 [00:00<?, ?ex/s]

 

#3:   0%|          | 0/84702 [00:00<?, ?ex/s]

Recursion error occured for you bet you you bet you
      

#1:   0%|          | 0/9421 [00:00<?, ?ex/s]

 

#0:   0%|          | 0/9421 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/9420 [00:00<?, ?ex/s]

#3:   0%|          | 0/9420 [00:00<?, ?ex/s]

  0%|          | 0/339 [00:00<?, ?ba/s]

  0%|          | 0/38 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/678 [00:00<?, ?ba/s]

Flattening the indices:   0%|          | 0/76 [00:00<?, ?ba/s]

In [None]:
ngrams_train[0:10]

{'premise': ['uh stuff that we for know least of you three week a have do we four at like a movies bicycle and yeah to uh treadmill that we kind times too a get twenty like and minutes and try a we',
  'national socialists now We are all.',
  '"Johnny Shannon now he was a lieutenant with Howard\'s Rangers." Callie gave Drew a shrewd measuring look.',
  'In some cases, the available crane or the crane pricing may limit the largest piece to be lifted, and the construction plan may be modified to accommodate a smaller crane by lifting smaller pieces.',
  'It has a 14th-century Gothic church with a characteristically Austrian polychrome tiled roof.',
  'programs that have addressed the same concern.',
  "Of course, I don't know that they killed you first--but those are their methods.",
  'The agency also is in danger of losing $470,000 over the next 18 months from the Violence Against Women Act, but is appealing that decision, Mathews said.',
  'A mailer in New York could be sending mail t

In [None]:
#create eval datasets
longer_val= mnli['validation_matched'].filter(is_all_longer_3, num_proc=4)

dev_3gram_p= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':3, 'key':'premise'}, num_proc=4)
dev_2gram_p= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':2, 'key':'premise'}, num_proc=4)
dev_1gram_p= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':1, 'key':'premise'}, num_proc=4)

dev_3gram_h= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':3, 'key':'hypothesis'}, num_proc=4)
dev_2gram_h= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':2, 'key':'hypothesis'}, num_proc=4)
dev_1gram_h= longer_val.map(shuffle_sentence_ngrams, fn_kwargs={'ngram':1, 'key':'hypothesis'}, num_proc=4)

dev_1gram_p.save_to_disk('mnli_dev_1gram_p')
dev_2gram_p.save_to_disk('mnli_dev_2gram_p')
dev_3gram_p.save_to_disk('mnli_dev_3gram_p')

dev_1gram_h.save_to_disk('mnli_dev_1gram_h')
dev_2gram_h.save_to_disk('mnli_dev_2gram_h')
dev_3gram_h.save_to_disk('mnli_dev_3gram_h')

 



 



 



 



        

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]

       

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

 

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]

       

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]

        

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]

        

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]

        

#0:   0%|          | 0/2323 [00:00<?, ?ex/s]

#2:   0%|          | 0/2322 [00:00<?, ?ex/s]

#1:   0%|          | 0/2322 [00:00<?, ?ex/s]

#3:   0%|          | 0/2322 [00:00<?, ?ex/s]