In [7]:
import json
import re 
import os 
import argparse
import wget 
import random 
import numpy as np 
from task_tracker.dataset_creation.task_prompts import generic_task_prompts
import copy



## load indices of data 

In [67]:
clean_val_indices = json.load(open('clean_val_indices.json'))
test_val_indices = json.load(open('clean_test_indices.json'))

In [9]:
parser = argparse.ArgumentParser(
                    prog='Dataset sampling')
parser.add_argument('--datasets_dir', default='../datasets', help="dir to retrieval datasets files and other resources") 
parser.add_argument('--sep_prompt', default='../config_files/sep_prompt.txt', help='none, or a path to a file that contains defense prompt to explain/encourage separation')  


args, _ = parser.parse_known_args()


#TODO change home of HF to cache any downloaded files 
os.environ['HF_HOME'] = '/disk3/'
os.environ['TRANSFORMERS_CACHE'] = '/disk3/'


## Download Squad and hotpot

In [10]:
datasets_files = {'SQuAD': {'train': {'name': 'train-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'},
                            'dev': {'name': 'dev-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'} },

                  'hotpot': {'train': {'name': 'hotpot_train_v1.1.json' , 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json'},
                             'dev': {'name': 'hotpot_dev_fullwiki_v1.json', 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'}}
                 }

def download_datasets(datasets_dir, datasets):
    #download the squad and hotpot datasets if they are not downloaded
    for dataset in datasets: 
        os.makedirs(os.path.join(datasets_dir,dataset), exist_ok=True)
        for subset in datasets_files[dataset]:  
            if not os.path.isfile(os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name'])):
                wget.download(datasets_files[dataset][subset]['url'], os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name']))
    return 

download_datasets(args.datasets_dir, list(datasets_files.keys())) 

squad_train = json.load(open(os.path.join(args.datasets_dir,'SQuAD',datasets_files['SQuAD']['train']['name'])))
squad_dev = json.load(open(os.path.join(args.datasets_dir,'SQuAD',datasets_files['SQuAD']['dev']['name'])))

hotpot_train = json.load(open(os.path.join(args.datasets_dir,'hotpot',datasets_files['hotpot']['train']['name'])))
hotpot_dev = json.load(open(os.path.join(args.datasets_dir,'hotpot',datasets_files['hotpot']['dev']['name'])))


## Process retrieval files (no need to change)

In [11]:
#load datasets in a unified format. 
#format list of items. each is {'context': <TEXT PARAGRAPH> , 'questions': [{'question':, 'answer'} ... ]}
#questions is a list. squad has n questions (usually) for each context. 
#hotpot is usually one question with many paragraphs. Currently, just concatenate the paragraphs. 
def process_squad(squad_file):
    new_elements = []
    for elem in squad_file['data']:
        for par in elem['paragraphs']:
            paragraph = par['context']
            if len(par['qas']) == 0: continue 
            q_and_a = []
            for q in par['qas']:
                if len(q['answers']) ==0: continue 
                q_and_a.append({'question': q['question'],'answer':q['answers'][0]['text']})
            if len(q_and_a) != 0: new_elements.append({'context': paragraph, 'questions': q_and_a})
    return new_elements

def process_hotpot(hotpot_file):
    new_elements = []
    for elem in hotpot_file:
        all_contexts = ''
        for context in elem['context'][0:3]:
            all_contexts += ' '.join(context[1])
        if all_contexts == '': continue 
        new_elements.append({'context':all_contexts, 'questions': [{'question': elem['question'],'answer':elem['answer']}]})
    return new_elements

retrieval_datasets = {}
retrieval_datasets['SQuAD'] = {'train': process_squad(squad_train), 'dev': process_squad(squad_dev)} 
retrieval_datasets['hotpot'] = {'train': process_hotpot(hotpot_train), 'dev': process_hotpot(hotpot_dev)} 

In [40]:
def get_final_paragraph(old_item):
    item = copy.deepcopy(old_item)
    sep_prompt = item['sep_prompt'] 
    orig_task_prompt = item['primary_task_prompt']
    final_text_paragraph = item['final_text_paragraph']
    
    final_prompt = sep_prompt + ' ' + orig_task_prompt + ' ' + final_text_paragraph + ' '
    item['final_aggregated_prompt'] = final_prompt
    return item



def fetch_elements(old_item):
    item = copy.deepcopy(old_item)

    dataset_index =  int(item['text_idx'])

    orig_text = retrieval_datasets[item['text_data_src']][item['split']][dataset_index]['context']
    
    question_index = int(item['primary_task_index'])
    question = retrieval_datasets[item['text_data_src']][item['split']][dataset_index]['questions'][question_index]['question']
    primary_task_prompt = question 

    if item['primary_task_type'] == 'mix':
        primary_task_prompt = primary_task_prompt + '. After you answer the previous question, please do the following task as well:' + item['primary_task_prompt']

    item['primary_task_prompt'] = primary_task_prompt 
    item['orig_text'] = orig_text  
    item['final_text_paragraph'] = orig_text 
    
    return item 
    

## Build the dataset and verify (with hashes)

In [58]:
reconstructed_val = []
import hashlib



def build_dataset(dataset_indices): 
    hash_object = hashlib.sha256()
    reconstructed_dataset = []
    for item_index, item in enumerate(dataset_indices): 
        new_item = fetch_elements(item)
        new_item = get_final_paragraph(new_item)

        hash_object.update(new_item['final_aggregated_prompt'].encode('utf-8'))
        hex_dig = hash_object.hexdigest()
        
        assert str(hex_dig) == new_item['final_prompt_hash']

        reconstructed_dataset.append(new_item)
        
    return reconstructed_dataset 


In [68]:
reconstructed_val = build_dataset(clean_val_indices)
reconstructed_test = build_dataset(test_val_indices)

## save dataset 

In [71]:
with open('../dataset_sampled/dataset_out_clean.json', 'w') as fout:
    json.dump(reconstructed_val, fout)
    
with open('../dataset_sampled/dataset_out_clean_v2.json', 'w') as fout:
    json.dump(reconstructed_test, fout)