In [1]:
import json
import re 
import os 
import argparse
import wget 
import random 
import numpy as np 
from task_tracker.dataset_creation.task_prompts import generic_task_prompts
import copy

## load indices of data 

In [2]:
train_data_indices_pt1 = json.load(open('train_indices_pt1.json'))
train_data_indices_pt2 = json.load(open('train_indices_pt2.json'))

train_data_indices = train_data_indices_pt1 + train_data_indices_pt2

In [3]:
parser = argparse.ArgumentParser(
                    prog='Dataset sampling')
parser.add_argument('--datasets_dir', default='../datasets', help="dir to retrieval datasets files and other resources") 
parser.add_argument('--trivia_files', default='trivia_pairs.txt', help='list of probe-answer pairs') 
parser.add_argument('--saved_injections', default='saved_injections.txt', help="path to a file of saved injections generated earlier") 
parser.add_argument('--sep_prompt', default='../config_files/sep_prompt.txt', help='none, or a path to a file that contains defense prompt to explain/encourage separation')  


args, _ = parser.parse_known_args()


#TODO change home of HF to cache any downloaded files 
os.environ['HF_HOME'] = '/disk3/'
os.environ['TRANSFORMERS_CACHE'] = '/disk3/'


## Download SQuAD and hotpot

In [4]:
datasets_files = {'SQuAD': {'train': {'name': 'train-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'},
                            'dev': {'name': 'dev-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'} },

                  'hotpot': {'train': {'name': 'hotpot_train_v1.1.json' , 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json'},
                             'dev': {'name': 'hotpot_dev_fullwiki_v1.json', 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'}}
                 }

def download_datasets(datasets_dir, datasets):
    #download the squad and hotpot datasets if they are not downloaded
    for dataset in datasets: 
        os.makedirs(os.path.join(datasets_dir,dataset), exist_ok=True)
        for subset in datasets_files[dataset]:  
            if not os.path.isfile(os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name'])):
                wget.download(datasets_files[dataset][subset]['url'], os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name']))
    return 

download_datasets(args.datasets_dir, list(datasets_files.keys())) 

squad_train = json.load(open(os.path.join(args.datasets_dir,'SQuAD',datasets_files['SQuAD']['train']['name'])))
squad_dev = json.load(open(os.path.join(args.datasets_dir,'SQuAD',datasets_files['SQuAD']['dev']['name'])))

hotpot_train = json.load(open(os.path.join(args.datasets_dir,'hotpot',datasets_files['hotpot']['train']['name'])))
hotpot_dev = json.load(open(os.path.join(args.datasets_dir,'hotpot',datasets_files['hotpot']['dev']['name'])))


## Download Alpaca

In [5]:
#Alpaca instructions dataset 
from datasets import load_dataset
alpaca_dataset = load_dataset("tatsu-lab/alpaca")['train'] 
alpaca_dataset = [item['instruction'] for item in alpaca_dataset if item['input'] == ''] 

## Load dependencies from "datasets" dir

In [6]:

#load the examples of probes and answers 
f = open(os.path.join(args.datasets_dir,args.trivia_files),"r")
trivia_pairs = []

for pair_index, pair in enumerate(f):
    answer = re.findall("Answer:.*", pair)[0]
    question = pair.replace(answer,"").replace("Question:","").strip()
    answer = answer.replace("Answer:","").strip()
    trivia_pairs.append(question)


def load_sep_prompt():
    #Load prompt used to instruct the model how to do separation (if any).
    #Could be as simple as saying: consider the following task/question that you should answer. 
    if args.sep_prompt == 'none': 
        sep_prompt = ''
    else:
        with open(os.path.join(args.sep_prompt),"r") as f:
            sep_prompt = f.read()
    return sep_prompt 

def load_saved_injections():
    #Load saved injections generated earlier via GPT-4. 
    #These are generated trigger sentences that would be used before the payload.  
    saved_injections = []
    if args.sep_prompt == 'none': 
        return [] 
    else:
        f = open(os.path.join(args.datasets_dir,args.saved_injections),"r")
        for injection in f:
            saved_injections.append(injection.strip())
    return saved_injections 
        
        
sep_prompt = load_sep_prompt()
saved_injections = load_saved_injections()
saved_injections_to_idx = {}
for injection_index, injection in enumerate(saved_injections):
    saved_injections_to_idx[injection] = injection_index

## Process retrieval files (no change needed)

In [7]:
#load datasets in a unified format. 
#format list of items. each is {'context': <TEXT PARAGRAPH> , 'questions': [{'question':, 'answer'} ... ]}
#questions is a list. squad has n questions (usually) for each context. 
#hotpot is usually one question with many paragraphs. Currently, just concatenate the paragraphs. 
def process_squad(squad_file):
    new_elements = []
    for elem in squad_file['data']:
        for par in elem['paragraphs']:
            paragraph = par['context']
            if len(par['qas']) == 0: continue 
            q_and_a = []
            for q in par['qas']:
                if len(q['answers']) ==0: continue 
                q_and_a.append({'question': q['question'],'answer':q['answers'][0]['text']})
            if len(q_and_a) != 0: new_elements.append({'context': paragraph, 'questions': q_and_a})
    return new_elements

def process_hotpot(hotpot_file):
    new_elements = []
    for elem in hotpot_file:
        all_contexts = ''
        for context in elem['context']:
            all_contexts += 'Article title: '
            all_contexts += context[0] + '. '
            all_contexts += ' '.join(context[1])
        new_elements.append({'context':all_contexts, 'questions': [{'question': elem['question'],'answer':elem['answer']}]})
    return new_elements

retrieval_datasets = {}
retrieval_datasets['SQuAD'] = {'train': process_squad(squad_train), 'dev': process_squad(squad_dev)} 
retrieval_datasets['hotpot'] = {'train': process_hotpot(hotpot_train), 'dev': process_hotpot(hotpot_dev)} 

In [8]:

train_secondary_tasks = {'alpaca': alpaca_dataset, 'witness': trivia_pairs, 'generic': generic_task_prompts}

def embed_secondary_task(old_item):
    item = copy.deepcopy(old_item)
    
    orig_text = item['orig_text']
    secondary_prompt = item['secondary_task_prompt']
    trigger = item['trigger']
    pos = item['pos']
    
    embedded_instruction = ''

    if trigger and secondary_prompt:
        embedded_instruction = trigger + ' ' + secondary_prompt 
    elif secondary_prompt:
        embedded_instruction = secondary_prompt
    elif trigger:
        embedded_instruction = trigger
    annotated_inserted = " <INSERTED> " + embedded_instruction + " </INSERTED>"
        
    if pos == 0 and embedded_instruction:
        final_text_paragraph = embedded_instruction + ' ' + orig_text
        annotated_part1 = " <PART_1> </PART_1>"
        annotated_part2 = " <PART_2> " + orig_text + " </PART_2>"
    elif pos == -1 and embedded_instruction:
        final_text_paragraph = orig_text + ' ' + embedded_instruction 
        annotated_part1 = " <PART_1> " + orig_text + " </PART_1>"
        annotated_part2 = " <PART_2> </PART_2>"
    elif embedded_instruction:
        final_text_paragraph = " ".join((orig_text[:pos], embedded_instruction, orig_text[pos:])) 
        annotated_part1 = " <PART_1> " + orig_text[:pos] + " </PART_1>"
        annotated_part2 = " <PART_2> " + orig_text[pos:] + " </PART_2>"
    annotated_text = annotated_part1 + annotated_inserted + annotated_part2

    item['final_text_paragraph'] = final_text_paragraph
    item['annotated_paragraph'] = annotated_text
    return item 
    
def get_final_paragraph(old_item):
    item = copy.deepcopy(old_item)
    sep_prompt = item['sep_prompt'] 
    orig_task_prompt = item['primary_task_prompt']
    final_text_paragraph = item['final_text_paragraph']
    
    final_prompt = sep_prompt + ' ' + orig_task_prompt + ' ' + final_text_paragraph + ' '
    item['final_aggregated_prompt'] = final_prompt

    return item



def fetch_elements_train(old_item):
    item = copy.deepcopy(old_item)

    dataset_index =  int(item['text_idx'])

    orig_text = retrieval_datasets[item['text_data_src']][item['split']][dataset_index]['context']
    
    question_index = int(item['primary_task_index'])
    if item['primary_task_type'] == 'qa':
        primary_task_prompt = retrieval_datasets[item['text_data_src']][item['split']][dataset_index]['questions'][question_index]['question']
    else:
        primary_task_prompt = generic_task_prompts[item['primary_task_type']][question_index]
    
    if item['secondary_task_type'] == 'alpaca' or item['secondary_task_type'] == 'witness':
        corr_secondary_tasks = train_secondary_tasks[item['secondary_task_type']]
    else:
        corr_secondary_tasks = train_secondary_tasks['generic'][item['secondary_task_type']]
    secondary_task_index = str(item['secondary_task_index'])
    
    if secondary_task_index.isnumeric():
        secondary_task_prompt = corr_secondary_tasks[int(secondary_task_index)]
        item['trigger'] = saved_injections[int(item['trigger'])]
    elif secondary_task_index == '':
        secondary_task_prompt = item['trigger']
        item['trigger'] = ''
    
    item['primary_task_prompt'] = primary_task_prompt 
    item['orig_text'] = orig_text 
    item['secondary_task_prompt'] = secondary_task_prompt
    
    
    return item 
    

## Build the training and verify (with hash values)

In [12]:
reconstructed_train = []
import hashlib

hash_object = hashlib.sha256()

for train_index, train_item in enumerate(train_data_indices): 
    new_train_item = copy.deepcopy(train_item)
    if not 'final_aggregated_prompt' in train_item: 
        new_train_item = fetch_elements_train(train_item)

        new_train_item = embed_secondary_task(new_train_item)
    
        new_train_item = get_final_paragraph(new_train_item)
        
    hash_object.update(new_train_item['final_aggregated_prompt'].encode('utf-8'))
    hex_dig = hash_object.hexdigest()
        
    
    assert str(hex_dig) == new_train_item['final_prompt_hash']

    reconstructed_train.append(new_train_item)

## Save json file 

In [14]:
with open('../dataset_sampled/train_subset.json', 'w') as fout:
    json.dump(reconstructed_train, fout)