In [1]:
import json
import re 
import os 
import argparse
import wget 
import random 
import numpy as np 
from task_prompts import generic_task_prompts

In [5]:
parser = argparse.ArgumentParser(
                    prog='Dataset sampling')
parser.add_argument('--datasets_dir', default='./datasets', help="dir to retrieval datasets files and other resources") 
parser.add_argument('--dataset', default='SQuAD', choices=['SQuAD', 'hotpot'], help='which dataset to use for the paragraphs') 
parser.add_argument('--trivia_files', default='trivia_pairs.txt', help='list of probe-answer pairs') 
parser.add_argument('--subset', default='train', choices=['dev', 'train'], help='which subset to sample from in the retrieval datasets: train or dev') 
parser.add_argument('--embed_loc', default='middle', choices=['beginning', 'middle', 'end'], help='where to insert the embedded instructions')
parser.add_argument('--embed_method', default='saved', choices=['concat','saved'], help='insert with using gpt as the insertion method')
parser.add_argument('--saved_injections', default='saved_injections.txt', help="path to a file of saved injections generated earlier") 
parser.add_argument('--orig_task', default='qa', choices=['qa', 'translate', 'summarize', 'extract', 'style', 'evaluate'], help='primary task associated with the text, if qa then the questions of the dataset are used') 
parser.add_argument('--emb_task', default='alpaca', choices=['none', 'qa', 'alpaca', 'translate', 'summarize', 'extract', 'style', 'evaluate', 'attack', 'witness'], help='secondary task, witness task are the trivia pairs, attack are attack prompts from BIPA') 
parser.add_argument('--data_sep_tags', default='none', help='none or tag, if data should be surrounded by tags')  
parser.add_argument('--instruct_sep_tags', default='none', help='none or tag, if instructions should be surrounded by tags')  
parser.add_argument('--sep_prompt', default='config_files/sep_prompt.txt', help='none, or a path to a file that contains defense prompt to explain/encourage separation')  
parser.add_argument('--number', default=4000, type=int, help='how many samples')   
parser.add_argument('--out_dir', default='./dataset_training_sampled_out', help="dir to retrieval datasets files and other resources") 


args, _ = parser.parse_known_args()
os.makedirs(args.out_dir, exist_ok=True)

dataset_out_name = 'dataset_out_' + args.dataset + '_' + args.orig_task + '_' + args.emb_task + '_' + args.embed_loc + '_' + args.embed_method + '.json'
dataset_out_name = os.path.join(args.out_dir, dataset_out_name)

datasets_files = {'SQuAD': {'train': {'name': 'train-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'},
                            'dev': {'name': 'dev-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'} },

                  'hotpot': {'train': {'name': 'hotpot_train_v1.1.json' , 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json'},
                             'dev': {'name': 'hotpot_dev_fullwiki_v1.json', 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'}}
                 }

#change home of HF to cache any downloaded files 
os.environ['HF_HOME'] = '/disk3/'
os.environ['TRANSFORMERS_CACHE'] = '/disk3/'

In [6]:
def load_sep_prompt():
    #Load prompt used to instruct the model how to do separation (if any).
    #Could be as simple as saying: consider the following task/question that you should answer. 
    if args.sep_prompt == 'none': 
        sep_prompt = ''
    else:
        with open(os.path.join(args.sep_prompt),"r") as f:
            sep_prompt = f.read()
    return sep_prompt 

def load_saved_injections():
    #Load saved injections generated earlier via GPT-4. 
    #These are generated trigger sentences that would be used before the payload.  
    saved_injections = []
    if args.sep_prompt == 'none': 
        return [] 
    else:
        f = open(os.path.join(args.datasets_dir,args.saved_injections),"r")
        for injection in f:
            saved_injections.append(injection.strip())
    return saved_injections 
        
        
sep_prompt = load_sep_prompt()
saved_injections = load_saved_injections()
if args.embed_method == 'saved':
    assert not len(saved_injections) == 0 

In [7]:
def download_datasets(datasets_dir, dataset):
    #download the squad and hotpot datasets if they are not downloaded
    os.makedirs(os.path.join(datasets_dir,dataset), exist_ok=True)
    for subset in datasets_files[dataset]:  
        if not os.path.isfile(os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name'])):
            wget.download(datasets_files[dataset][subset]['url'], os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name']))
    return 

download_datasets(args.datasets_dir, args.dataset)    

In [8]:
#load datasets in a unified format. 
#format list of items. each is {'context': <TEXT PARAGRAPH> , 'questions': [{'question':, 'answer'} ... ]}
#questions is a list. squad has n questions (usually) for each context. 
#hotpot is usually one question with many paragraphs. Currently, just concatenate the paragraphs. 
def process_squad(squad_file):
    new_elements = []
    for elem in squad_file['data']:
        for par in elem['paragraphs']:
            paragraph = par['context']
            if len(par['qas']) == 0: continue 
            q_and_a = []
            for q in par['qas']:
                if len(q['answers']) ==0: continue 
                q_and_a.append({'question': q['question'],'answer':q['answers'][0]['text']})
            if len(q_and_a) != 0: new_elements.append({'context': paragraph, 'questions': q_and_a})
    return new_elements

def process_hotpot(hotpot_file):
    new_elements = []
    for elem in hotpot_file:
        all_contexts = ''
        for context in elem['context']:
            all_contexts += 'Article title: '
            all_contexts += context[0] + '. '
            all_contexts += ' '.join(context[1])
        new_elements.append({'context':all_contexts, 'questions': [{'question': elem['question'],'answer':elem['answer']}]})
    return new_elements

In [9]:
train_retrieval_orig_dataset = json.load(open(os.path.join(args.datasets_dir,args.dataset,datasets_files[args.dataset]['train']['name'])))
dev_retrieval_orig_dataset = json.load(open(os.path.join(args.datasets_dir,args.dataset,datasets_files[args.dataset]['dev']['name'])))

if args.dataset == 'SQuAD':
    retrieval_datasets = {'train': process_squad(train_retrieval_orig_dataset), 'dev': process_squad(dev_retrieval_orig_dataset)}
elif args.dataset == 'hotpot':
    retrieval_datasets = {'train': process_hotpot(train_retrieval_orig_dataset), 'dev': process_hotpot(dev_retrieval_orig_dataset)}

#load the examples of probes and answers 
f = open(os.path.join(args.datasets_dir,args.trivia_files),"r")
trivia_pairs = []
for pair in f:
    answer = re.findall("Answer:.*", pair)[0]
    question = pair.replace(answer,"").replace("Question:","").strip()
    answer = answer.replace("Answer:","").strip()
    trivia_pairs.append({'question': question,'answer':answer})

#Alpaca instructions dataset 
from datasets import load_dataset
alpaca_dataset = load_dataset("tatsu-lab/alpaca")['train'] 
alpaca_dataset = [item['instruction'] for item in alpaca_dataset if item['input'] == '']


Downloading readme:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

In [10]:
#sample a main task. 
#QA would sample a question from the dataset that belongs to the paragraph.
#otherwise, sample a prompt from one of the generic tasks. 
def sample_main_task(example_idx):
    orig_answer = ''
    if args.orig_task == 'qa':
        question = random.choice(retrieval_datasets[args.subset][example_idx]['questions'])
        orig_task_prompt = question['question']
        orig_answer = question['answer']
    else:
        orig_task_prompt = random.choice(generic_task_prompts[args.orig_task])
    return orig_task_prompt, orig_answer


In [11]:
def pick_another_qa_question(example_idx):
    secondary_example_idx = example_idx
    while secondary_example_idx==example_idx: 
        secondary_example_idx = np.random.randint(len(retrieval_datasets[args.subset]))
    emb_task_prompt = random.choice(retrieval_datasets[args.subset][secondary_example_idx]['questions'])['question']  
    return emb_task_prompt 

#sample a secondary task 
#could be QA with an example that is sampled from the same context, or QA with another random question from another paragraph 
#could be trivia pairs (probe and witnesses)
#could be generic tasks 
#could be attacks (from the BIPA dataset)
def sample_secondary_task(orig_task_prompt, example_idx):
    ##init
    emb_task_prompt = orig_task_prompt 
    witness = ''
    has_answer = False 

    ##no secondary task 
    if args.emb_task == 'none':
        emb_task_prompt = '' 

    ##secondary task is QA from a different context 
    if args.orig_task == 'qa' and args.emb_task == 'qa':
        emb_task_prompt = pick_another_qa_question(orig_task_prompt, example_idx)
        
    ##tasks are the same but not QA 
    elif args.orig_task == args.emb_task:
        while emb_task_prompt == orig_task_prompt:
            emb_task_prompt = random.choice(generic_task_prompts[args.emb_task])
            
    ##task is trivia witness 
    elif args.emb_task =='witness':
        has_answer = True 
        emb_task_prompt, witness = (random.choice(trivia_pairs)).values()  

    ##task is alpaca 
    elif args.emb_task == 'alpaca':
        emb_task_prompt = random.choice(alpaca_dataset)
        
    ###task is from trivia prompts 
    else:
        emb_task_prompt = random.choice(generic_task_prompts[args.emb_task])
    return emb_task_prompt, has_answer, witness 

In [12]:

from random import randint
    
def embed_secondary_task(emb_task_prompt, text_paragraph):
    #embed instructions in the paragraph. 
    #if no secondary instructions, just return the text 
    #either via saved injections or concatenation 
    #location: beginning, middle (random), end

    if emb_task_prompt == '':
        return text_paragraph, ''
    if args.embed_method == 'saved':
        injection = random.choice(saved_injections)
        emb_task_prompt = injection + ' ' + emb_task_prompt


    annotated_inserted = " <INSERTED> " + emb_task_prompt + " </INSERTED>"
    if args.embed_loc == 'beginning': 
        annotated_part1 = " <PART_1> </PART_1>"
        annotated_part2 = " <PART_2> " + text_paragraph + " </PART_2>"
        text_paragraph = emb_task_prompt + ' ' + text_paragraph
    elif args.embed_loc == 'end': 
        annotated_part1 = " <PART_1> " + text_paragraph + " </PART_1>"
        annotated_part2 = " <PART_2> </PART_2>"
        text_paragraph = text_paragraph + ' ' + emb_task_prompt
    else: 
        pos = randint(0, len(text_paragraph) - 1)  # pick random position to insert 
        annotated_part1 = " <PART_1> " + text_paragraph[:pos] + " </PART_1>"
        annotated_part2 = " <PART_2> " + text_paragraph[pos:] + " </PART_2>"
        text_paragraph = " ".join((text_paragraph[:pos], emb_task_prompt, text_paragraph[pos:])) 
    annotated_text = annotated_part1 + annotated_inserted + annotated_part2
    return text_paragraph, annotated_text
        

    

In [13]:
def format_final_prompt(final_text_paragraph, orig_task_prompt):
    #final prompt.
    #final_text_paragraph is the output of the embedding process
    #if no orig instructions, then return the paragraph text 
    #otherwise, the format is: Defense prompt (if any) + instruction tags (if any) + data tags (if any) + data 
    if orig_task_prompt == '':
        return final_text_paragraph 
    final_prompt = ''
    if sep_prompt:
        final_prompt = final_prompt + sep_prompt + ' '
        
    if args.instruct_sep_tags != 'none':
        final_prompt = final_prompt + ' <'+args.instruct_sep_tags+'> ' + orig_task_prompt +' </' + args.instruct_sep_tags+'> '
    else:
        final_prompt = final_prompt + orig_task_prompt + ' '
        
    if args.data_sep_tags != 'none':
        final_prompt = final_prompt + ' <'+args.data_sep_tags+'> ' + final_text_paragraph + ' </' + args.data_sep_tags+'> '
    else:
        final_prompt = final_prompt + final_text_paragraph + ' '

    return final_prompt 
    

In [16]:

from tqdm import tqdm 
dataset = []
for sample_i in tqdm(range(0,args.number)):
    orig_task_idx = -1 
    emb_tast_idx = -1 
    
    #sample an example and data (paragraph) 
    example_idx = np.random.randint(len(retrieval_datasets[args.subset]))
    example_text_paragraph = retrieval_datasets[args.subset][example_idx]['context']

    #sample a task 
    orig_task_prompt, orig_task_answer = sample_main_task(example_idx)

    #sample a secondary task
    emb_task_prompt, has_answer, witness = sample_secondary_task(orig_task_prompt, example_idx)

    #embed secondary task in text paragraph
    final_text_paragraph, annotated_paragraph  = embed_secondary_task(emb_task_prompt, example_text_paragraph)

    #format with separators 
    final_aggregated_prompt = format_final_prompt(final_text_paragraph, orig_task_prompt)

    dataset_item = {'text_data_src': args.dataset, 
                    'split': args.subset, 
                    'text_idx': example_idx, 
                    'orig_text': example_text_paragraph,
                    'primary_task_type': args.orig_task, 
                    'secondary_task_type': args.emb_task, 
                    'primary_task_prompt': orig_task_prompt,
                    'primary_task_answer': orig_task_answer,
                    'secondary_task_prompt': emb_task_prompt,
                    'secondary_has_answer': has_answer, 
                    'secondary_witness': witness, 
                    'embed_loc': args.embed_loc, 
                    'embed_method': args.embed_method,
                    'instruct_sep_tags': args.instruct_sep_tags,
                    'data_sep_tags': args.data_sep_tags,
                    'sep_prompt': sep_prompt, 
                    'final_text_paragraph': final_text_paragraph,
                    'annotated_paragraph': annotated_paragraph,
                    'final_aggregated_prompt': final_aggregated_prompt}
    
    dataset.append(dataset_item)
    
with open(dataset_out_name, 'w') as fout:
    json.dump(dataset , fout)



100%|██████████| 4000/4000 [00:00<00:00, 76729.87it/s]


In [17]:
dataset[-1]

{'text_data_src': 'SQuAD',
 'split': 'train',
 'text_idx': 12635,
 'orig_text': "The hourglass uses the flow of sand to measure the flow of time. They were used in navigation. Ferdinand Magellan used 18 glasses on each ship for his circumnavigation of the globe (1522). Incense sticks and candles were, and are, commonly used to measure time in temples and churches across the globe. Waterclocks, and later, mechanical clocks, were used to mark the events of the abbeys and monasteries of the Middle Ages. Richard of Wallingford (1292–1336), abbot of St. Alban's abbey, famously built a mechanical clock as an astronomical orrery about 1330. Great advances in accurate time-keeping were made by Galileo Galilei and especially Christiaan Huygens with the invention of pendulum driven clocks along with the invention of the minute hand by Jost Burgi.",
 'primary_task_type': 'qa',
 'secondary_task_type': 'alpaca',
 'primary_task_prompt': 'Who is credited with the invention of the minute hand?',
 'pri