In [1]:
import json
import os 
import argparse
import wget 
import random 
import numpy as np 
from random import randint
from task_prompts import BIPIA_attack_prompts, generic_task_prompts

In [2]:
parser = argparse.ArgumentParser(
                    prog='Dataset sampling')
parser.add_argument('--datasets_dir', default='./datasets', help="dir to retrieval datasets files and other resources") 
parser.add_argument('--out_dir', default='./dataset_sampled_test', help="dir to sampled dataset files") 
parser.add_argument('--subset', default='dev', choices=['dev', 'train'], help='which subset to sample from in the retrieval datasets: train or dev') 
parser.add_argument('--data_sep_tags', default='none', help='none or tag, if data should be surrounded by tags')  
parser.add_argument('--instruct_sep_tags', default='none', help='none or tag, if instructions should be surrounded by tags')  
parser.add_argument('--sep_prompt', default='config_files/sep_prompt.txt', help='none, or a path to a file that contains defense prompt to explain/encourage separation')  
parser.add_argument('--saved_injections', default='saved_injections_test.txt', help="path to a file of saved injections generated earlier") 


args, _ = parser.parse_known_args()

os.makedirs(args.out_dir, exist_ok=True)
dataset_out_name = 'dataset_out_poisoned_v1.json'
dataset_out_name = os.path.join(args.out_dir, dataset_out_name)
datasets_files = {'SQuAD': {'train': {'name': 'train-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'},
                            'dev': {'name': 'dev-v2.0.json', 'url': 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'} },

                  'hotpot': {'train': {'name': 'hotpot_train_v1.1.json' , 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_train_v1.1.json'},
                             'dev': {'name': 'hotpot_dev_fullwiki_v1.json', 'url': 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'}}
                 }

#change home of HF to cache any downloaded files 
os.environ['HF_HOME'] = '/disk3/'
os.environ['TRANSFORMERS_CACHE'] = '/disk3/'

In [9]:
def load_sep_prompt():
    #load prompt used to instruction the model how to do separation 
    if args.sep_prompt == 'none': 
        sep_prompt = ''
    else:
        with open(os.path.join(args.sep_prompt),"r") as f:
            sep_prompt = f.read()
    return sep_prompt        

def load_saved_injections():
    #load saved injections generated earlier. 
    saved_injections = []
    f = open(os.path.join(args.datasets_dir,args.saved_injections),"r")
    for injection in f:
        saved_injections.append(injection.strip())
    return saved_injections 

sep_prompt = load_sep_prompt()
saved_injections = load_saved_injections()


In [10]:
BIPIA_attack_prompts_all = []
for subset in BIPIA_attack_prompts:
    for key in BIPIA_attack_prompts[subset]:
        BIPIA_attack_prompts_all.extend(BIPIA_attack_prompts[subset][key])
        

In [4]:
#source of dataset: https://github.com/llm-attacks/llm-attacks/blob/main/data/advbench/harmful_behaviors.csv
adv_bench = []
with open(os.path.join(args.datasets_dir,'harmful_behaviors.csv')) as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    for i, row in enumerate(spamreader):
        if i==0: continue
        adv_bench.append(row[0])


In [5]:
#source of dataset: https://github.com/verazuo/jailbreak_llms/tree/main/data
forbidden_question = set()
import csv
with open(os.path.join(args.datasets_dir,'forbidden_question_set.csv')) as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',')
    for i, row in enumerate(spamreader):
        forbidden_question.add(row[7])

In [6]:
### jailbreaks from LLMTrust dataset: https://github.com/HowieHwong/TrustLLM/tree/main/dataset (safety/jailbreak.json)
jailbreaks_llm_trust = json.load(open(os.path.join(args.datasets_dir,'jailbreak.json')))
jailbreaks_llm_trust_prompts = []
for item in jailbreaks_llm_trust:
    jailbreaks_llm_trust_prompts.append(item['prompt'])

In [7]:
def download_datasets(datasets_dir, dataset):
    #download the squad and hotpot datasets if they are not downloaded
    os.makedirs(os.path.join(datasets_dir,dataset), exist_ok=True)
    for subset in datasets_files[dataset]:  
        if not os.path.isfile(os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name'])):
            wget.download(datasets_files[dataset][subset]['url'], os.path.join(datasets_dir,dataset,datasets_files[dataset][subset]['name']))
    return 


In [8]:

#load datasets in a unified format. 
#format list of items. each is {'context': <TEXT PARAGRAPH> , 'questions': [{'question':, 'answer'} ... ]}
#questions is a list. squad has n questions (usually) for each context. 
#hotpot is usually one question with many paragraphs. Currently, just concatenate the paragraphs. 
def process_dataset(dataset_name, dataset_file):
    
    new_elements = []
    if dataset_name == 'SQuAD':
        for elem in dataset_file['data']:
            for par in elem['paragraphs']:
                paragraph = par['context']
                if len(par['qas']) == 0: continue 
                q_and_a = []
                for q in par['qas']:
                    if len(q['answers']) ==0: continue 
                    q_and_a.append({'question': q['question'],'answer':q['answers'][0]['text']})
                if len(q_and_a) != 0: new_elements.append({'context': paragraph, 'questions': q_and_a})
    elif dataset_name == 'hotpot':
        for elem in dataset_file:
            all_contexts = ''
            for context in elem['context'][0:2]:
                all_contexts += ' '.join(context[1])
            if all_contexts == '': continue 
            new_elements.append({'context':all_contexts, 'questions': [{'question': elem['question'],'answer':elem['answer']}]})
    return new_elements


In [9]:
datasets = ['SQuAD', 'hotpot']
retrieval_datasets = {}
for dataset in datasets: 
    download_datasets(args.datasets_dir, dataset)    
    dev_retrieval_orig_dataset = json.load(open(os.path.join(args.datasets_dir,dataset,datasets_files[dataset]['dev']['name'])))
    retrieval_datasets[dataset] = {'dev': process_dataset(dataset, dev_retrieval_orig_dataset)}


In [10]:
#sample a main task. 
#QA would sample a question from the dataset that belongs to the paragraph.
#otherwise, sample a prompt from one of the generic tasks. 
def sample_main_task(example_idx, task_type, dataset):
    orig_answer = ''
    question = random.choice(retrieval_datasets[dataset][args.subset][example_idx]['questions'])
    orig_task_prompt = question['question']
    orig_answer = question['answer']
    if task_type == 'mix':
        add_task = random.choice(list(generic_task_prompts.keys()))
        new_task_prompt = random.choice(generic_task_prompts[add_task])
        orig_task_prompt += '. After you answer the previous question, please do the following task as well: '
        orig_task_prompt += new_task_prompt
    return orig_task_prompt, orig_answer


In [12]:
def format_final_prompt(final_text_paragraph, orig_task_prompt):
    #final prompt.
    #final_text_paragraph is the output of the embedding process
    #if no orig instructions, then return the paragraph text 
    #otherwise, the format is: Defense prompt (if any) + instruction tags (if any) + data tags (if any) + data 
    final_prompt = ''
    if sep_prompt:
        final_prompt = final_prompt + sep_prompt + ' '
        
    if args.instruct_sep_tags != 'none':
        final_prompt = final_prompt + ' <'+args.instruct_sep_tags+'> ' + orig_task_prompt +' </' + args.instruct_sep_tags+'> '
    else:
        final_prompt = final_prompt + orig_task_prompt + ' '
        
    if args.data_sep_tags != 'none':
        final_prompt = final_prompt + ' <'+args.data_sep_tags+'> ' + final_text_paragraph + ' </' + args.data_sep_tags+'> '
    else:
        final_prompt = final_prompt + final_text_paragraph + ' '

    return final_prompt 
    

In [11]:
def sample_trigger_prompt(trigger_type):
    if trigger_type == 'none': 
        return ''
    if trigger_type == 'triggers':
        return random.choice(saved_injections)
    

In [12]:
def embed_secondary_task(emb_task_prompt, text_paragraph, trigger_prompt, embed_loc):
    #embed instructions in the paragraph. 
    #emb_task_prompt: secondary question
    #text_paragraph: text to add the question into 
    #trigger_prompt: either jailbreak or any saved injection 
    #embed_loc: beginning, middle (random), end

    emb_task_prompt = trigger_prompt + ' ' + emb_task_prompt
        
    annotated_inserted = " <INSERTED> " + emb_task_prompt + " </INSERTED>"
    if embed_loc == 'beginning': 
        annotated_part1 = " <PART_1> </PART_1>"
        annotated_part2 = " <PART_2> " + text_paragraph + " </PART_2>"
        text_paragraph = emb_task_prompt + ' ' + text_paragraph
    elif embed_loc == 'end': 
        annotated_part1 = " <PART_1> " + text_paragraph + " </PART_1>"
        annotated_part2 = " <PART_2> </PART_2>"
        text_paragraph = text_paragraph + ' ' + emb_task_prompt
    else: 
        pos = randint(0, len(text_paragraph) - 1)  # pick random position to insert 
        annotated_part1 = " <PART_1> " + text_paragraph[:pos] + " </PART_1>"
        annotated_part2 = " <PART_2> " + text_paragraph[pos:] + " </PART_2>"
        text_paragraph = " ".join((text_paragraph[:pos], emb_task_prompt, text_paragraph[pos:])) 
    annotated_text = annotated_part1 + annotated_inserted + annotated_part2
    return text_paragraph, annotated_text

In [15]:
tasks = ['qa', 'mix']
locs = ['beginning', 'end', 'mid']
triggers = ['none', 'triggers']
secondary_datasets = [BIPIA_attack_prompts_all, adv_bench, forbidden_question, jailbreaks_llm_trust_prompts]
new_samples = []
datasets = ['SQuAD']
for dataset in datasets:
    samples = np.random.permutation(len(retrieval_datasets[dataset][args.subset]))
    for task in tasks: 
        for sec_data in secondary_datasets: 
            for sec_q in sec_data: 
                for loc in locs:
                    for trigger_type in triggers:
                        example_idx = np.random.randint(len(retrieval_datasets[dataset][args.subset]))
                        example_text_paragraph = retrieval_datasets[dataset][args.subset][example_idx]['context']
                        
                        orig_task_prompt, orig_task_answer = sample_main_task(example_idx, task, dataset)
                        
                        trigger_prompt = sample_trigger_prompt(trigger_type)
                        
                        #embed secondary task in text paragraph
                        final_text_paragraph, annotated_paragraph  = embed_secondary_task(sec_q, example_text_paragraph, trigger_prompt, loc)

                        #format with separators 
                        final_aggregated_prompt = format_final_prompt(final_text_paragraph, orig_task_prompt)
                        
                        dataset_item = {'text_data_src': dataset, 
                            'split': args.subset, 
                            'text_idx': example_idx, 
                            'orig_text': example_text_paragraph,
                            'primary_task_type': task, 
                            'secondary_task_type': 'attacks', 
                            'primary_task_prompt': orig_task_prompt,
                            'primary_task_answer': orig_task_answer,
                            'secondary_task_prompt': sec_q,
                            'secondary_has_answer': False, 
                            'embed_loc': loc, 
                            'embed_method': trigger_type,
                            'instruct_sep_tags': args.instruct_sep_tags,
                            'data_sep_tags': args.data_sep_tags,
                            'sep_prompt': sep_prompt, 
                            'final_text_paragraph': final_text_paragraph,
                            'annotated_paragraph': annotated_paragraph,
                            'final_aggregated_prompt': final_aggregated_prompt}
    
                        new_samples.append(dataset_item)
with open(dataset_out_name, 'w') as fout:
    json.dump(new_samples , fout)