In [None]:
import os
import argparse
import random
import string
import torch
import numpy as np
from judges import load_judge, judge_rule_based
from prompts import SYSTEM_MESSAGE, MUITICHOICE_TEMPLATE_PROMPT
from conversers import load_target_model
from utils import insert_adv_string, schedule_n_to_change_fixed, schedule_n_to_change_prob, extract_logprob, early_stopping_condition
from additional_utils import print_and_save, judge_scores_subtask
from tqdm.notebook import tqdm
from IPython.display import clear_output
import gc
import json
import copy
from ..config import PATH_TO_DATASET

os.environ["TOGETHERAI_API_KEY"] = ""
os.environ["OPENAI_API_KEY"] = ""

In [None]:
path = "PATH_TO_LOG"
in_path = PATH_TO_DATASET
out_path = "PATH_TO_new_with_strings.jsonl"
target_token = '<'

targetLM = load_target_model("llama3-8b", SYSTEM_MESSAGE, target_token)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
from random import choices

class Args:
    # Attack model parameters
    attack_max_n_tokens = 500

    # RS parameters
    n_iterations = 2000
    n_strings = 32
    n_chars_adv = 100
    n_chars_change_max = 0
    n_tokens_adv = 25
    n_tokens_change_max = 4
    schedule_n_to_change = False
    schedule_prob = True
    n_restarts = 3

    # Target model parameters
    target_max_n_tokens = 250

    # Judge model parameters
    judge_model = "gpt-4-0613"
    judge_max_n_tokens = 10
    judge_temperature = 0.0
    judge_top_p = 1.0
    judge_max_n_calls = 12

    # Logging parameters
    index = 0
    category = "bomb"

    # Miscellaneous parameters
    seed = 1
    determinstic_jailbreak = False
    eval_only_rs = False
    debug = False       

with open(in_path, "r", encoding="utf-8") as fin, \
     open(out_path,"w", encoding="utf-8") as fout:
    for line in fin:
        sample = json.loads(line.strip())
        sid = sample["id"]
        question = sample["question"]
        choices = sample["choices"]
        
        options_str = " ".join([f"{chr(65 + idx)}) {choice}" for idx, choice in enumerate(choices)])
        letters_str = ", ".join([ chr(65 + idx) for idx, _ in enumerate(choices)])
        args = Args()
        strings = []
        
        for num_run in range(args.n_restarts):
            substitution_set = string.digits + string.ascii_letters + string.punctuation + ' '
            n_iterations_target_not_in_topk, n_retry_final = 25, 1

            tokenizer, max_token_value = targetLM.model.tokenizer, targetLM.model.tokenizer.vocab_size

            orig_msg = MUITICHOICE_TEMPLATE_PROMPT.format(
                        question=question,
                        choices=options_str,
                        letters=letters_str,
                    )
            
            
            print_and_save(f'Original message: {orig_msg}', path + f"/output_{sid}_{num_run}.txt")

            # below: different initialization for different models
            if args.n_chars_change_max > 0:
                adv_init = '_' * args.n_chars_adv
            if args.n_tokens_change_max > 0:
                adv_init = ' !' * args.n_tokens_adv

            best_adv_tokens = adv_tokens = tokenizer.encode(adv_init)  #[:args.n_tokens_adv]

            early_stop_rs = False
            n_chars_change, n_tokens_change = args.n_chars_change_max, args.n_tokens_change_max
            best_adv = adv = adv_init
            best_msg = msg = insert_adv_string(orig_msg, adv)
            best_logprob = -np.inf
            best_logprobs, best_advs, logprob_dict = [], [], {}
            judge_n_calls = 0

            for it in tqdm(range(1, args.n_iterations + 1), desc="Processing"):
                # note: to avoid an extra call to get_response(), for args.determinstic_jailbreak==True, the logprob_dict from the previous iteration is used
                judged_text = None
                
                if (not early_stopping_condition(best_logprobs, targetLM, logprob_dict, target_token, args.determinstic_jailbreak)):
                    logprob = targetLM.read_log_prob(msg)
                    
                else:  # early stopping criterion (important for query/token efficiency)
                    temperature = 0.0
                    logprob = targetLM.read_log_prob(msg)
                    
                    output = targetLM.get_response([msg], max_n_tokens=1, temperature=temperature)[0]
                    beginning_token = output['text']
                    # if args.debug:
                    #     print(final_response_text)
                    if (beginning_token.strip() == target_token):
                        output = targetLM.get_response([msg], max_n_tokens=args.target_max_n_tokens, temperature=temperature)[0]
                        final_response_text = output['text']
                        print_and_save("Judge %i"%(judge_n_calls), path + f"/output_{sid}_{num_run}.txt")
                        print_and_save(final_response_text, path + f"/output_{sid}_{num_run}.txt")
                        
                        if not (judged_text == final_response_text):
                            judged_text = final_response_text
                            llm_judge_score = judge_scores_subtask(question, [final_response_text])[0]
                            
                            jailbroken_judge_rule = judge_rule_based(final_response_text)
                            judge_n_calls += 1
                            if llm_judge_score or judge_n_calls == args.judge_max_n_calls:
                                early_stop_rs = True

                if it%50 ==0:
                    print_and_save(f'it={it} [best] logprob={best_logprob:.3f} prob={np.exp(best_logprob):.5f}  [curr] logprob={logprob:.3f} prob={np.exp(logprob):.5f}  len_adv={len(best_adv)}/{len(best_adv_tokens)} n_change={n_chars_change}/{n_tokens_change}: {adv}', path + f"/output_{sid}_{num_run}.txt")
                    
                if logprob > best_logprob:
                    best_logprob, best_msg, best_adv, best_adv_tokens = logprob, msg, adv, adv_tokens
                else:
                    adv, adv_tokens = best_adv, best_adv_tokens
                best_logprobs.append(best_logprob)
                best_advs.append(best_adv)

                if early_stop_rs:
                    break
                
                if args.n_chars_change_max > 0:
                    if args.schedule_prob:
                        n_chars_change = schedule_n_to_change_prob(args.n_chars_change_max, np.exp(best_logprob), targetLM)
                    elif args.schedule_n_to_change:
                        n_chars_change = schedule_n_to_change_fixed(args.n_chars_change_max, it)
                    else:
                        n_chars_change = args.n_chars_change_max
                    substitute_pos_start = random.choice(range(len(adv)))
                    substitution = ''.join(random.choice(substitution_set) for i in range(n_chars_change))
                    adv = adv[:substitute_pos_start] + substitution + adv[substitute_pos_start+n_chars_change:]
                if args.n_tokens_change_max > 0:
                    if args.schedule_prob:
                        n_tokens_change = schedule_n_to_change_prob(args.n_tokens_change_max, np.exp(best_logprob), targetLM)
                    elif args.schedule_n_to_change:
                        n_tokens_change = schedule_n_to_change_fixed(args.n_tokens_change_max, it)
                    else:
                        n_tokens_change = args.n_tokens_change_max
                    substitute_pos_start = random.choice(range(len(adv_tokens)))
                    substitution_tokens = np.random.randint(0, max_token_value, n_tokens_change).tolist()
                    adv_tokens = adv_tokens[:substitute_pos_start] + substitution_tokens + adv_tokens[substitute_pos_start+n_tokens_change:]
                    adv = tokenizer.decode(adv_tokens).replace('<s>', '')  # somehow, the R2D2 tokenizer inserts '<s>' at the first position
                # apply the new adversarial suffix
                msg = insert_adv_string(orig_msg, adv)
            
            
            for it in tqdm(range(1, args.n_strings + 1), desc="Processing"):
                if not early_stop_rs:
                    msg = copy.deepcopy(best_msg)
                    
                temperature = 0.0
                logprob = targetLM.read_log_prob(msg)
                output = targetLM.get_response([msg], max_n_tokens=1, temperature=temperature)[0]
                beginning_token = output['text']

                if (beginning_token.strip() == target_token) or (not early_stop_rs):
                    strings.append(msg)
                    
                if logprob > best_logprob:
                    best_logprob, best_msg, best_adv, best_adv_tokens = logprob, msg, adv, adv_tokens
                else:
                    adv, adv_tokens = best_adv, best_adv_tokens
                best_logprobs.append(best_logprob)
                best_advs.append(best_adv)

                # try a different substitution
                if args.n_chars_change_max > 0:
                    if args.schedule_prob:
                        n_chars_change = schedule_n_to_change_prob(args.n_chars_change_max, np.exp(best_logprob), targetLM)
                    elif args.schedule_n_to_change:
                        n_chars_change = schedule_n_to_change_fixed(args.n_chars_change_max, it)
                    else:
                        n_chars_change = args.n_chars_change_max
                    substitute_pos_start = random.choice(range(len(adv)))
                    substitution = ''.join(random.choice(substitution_set) for i in range(n_chars_change))
                    adv = adv[:substitute_pos_start] + substitution + adv[substitute_pos_start+n_chars_change:]
                if args.n_tokens_change_max > 0:
                    if args.schedule_prob:
                        n_tokens_change = schedule_n_to_change_prob(args.n_tokens_change_max, np.exp(best_logprob), targetLM)
                    elif args.schedule_n_to_change:
                        n_tokens_change = schedule_n_to_change_fixed(args.n_tokens_change_max, it)
                    else:
                        n_tokens_change = args.n_tokens_change_max
                    substitute_pos_start = random.choice(range(len(adv_tokens)))
                    substitution_tokens = np.random.randint(0, max_token_value, n_tokens_change).tolist()
                    adv_tokens = adv_tokens[:substitute_pos_start] + substitution_tokens + adv_tokens[substitute_pos_start+n_tokens_change:]
                    adv = tokenizer.decode(adv_tokens).replace('<s>', '')  # somehow, the R2D2 tokenizer inserts '<s>' at the first position
                # apply the new adversarial suffix
                msg = insert_adv_string(orig_msg, adv)

            if not early_stop_rs:
                for i in range(n_retry_final):
                    # if we didn't find a jailbreak, then use temperature=1 to possibly find it within `n_retry_final` restarts
                    final_response_text = targetLM.get_response([best_msg], max_n_tokens=args.target_max_n_tokens, temperature=1)[0]['text']
                    if final_response_text.strip().startswith(target_token):
                        break
                llm_judge_score = judge_scores_subtask(question, [final_response_text])[0]
                jailbroken_judge_llm = llm_judge_score == 10
                jailbroken_judge_rule = judge_rule_based(final_response_text)


            print_and_save(f'final_response_text: {final_response_text}\n\n', path + f"/output_{sid}_{num_run}.txt")
            print_and_save(f'max_prob={np.exp(best_logprob)}, judge_llm_score= {llm_judge_score}/10, jailbroken_judge_rule={jailbroken_judge_rule}, tokens={targetLM.n_input_tokens}/{targetLM.n_output_tokens}, adv={best_adv}', path + f"/output_{sid}_{num_run}.txt")
            print_and_save('\n\n\n', path + f"/output_{sid}_{num_run}.txt")
        
            torch.cuda.empty_cache()
        
        sample["extra_strings"] = strings
        fout.write(json.dumps(sample, ensure_ascii=False) + "\n")
        clear_output(wait=True)