In [None]:
import numpy as np
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
import re

from fastchat.model import get_conversation_template
import ml_dtypes
import torch.nn.functional as F
import os
import wandb

os.environ["TOGETHERAI_API_KEY"] = "Your API Key"
os.environ["OPENAI_API_KEY"] = "Your API Key"

device = "cuda:0"

from convs import get_conv_attacker, get_conv_target, get_conv_feedbacker, get_conv_optimizer, get_init_msg
from utils import load_model_and_tokenizer, get_losses
from utils import get_target_responses_API_prop, get_target_responses_local, get_judge_scores_harmbench
from strings import gen_string_optimizer, get_feedbacks, get_new_prompts, get_attacks_string_with_timeout
from alg import GWW_dfs_min
from convs import LLAMA_SYSTEM_MESSAGE

from IPython.display import clear_output

In [None]:
llama_path = ""

model_lam, tokenizer_lam = load_model_and_tokenizer(llama_path,
                            low_cpu_mem_usage=True,
                            use_cache=False, 
                            cache_dir="/n/holylabs/LABS/dwork_lab/Everyone/cache/transformers",
                            device = "cuda:0")

In [3]:
import requests
import pandas as pd
import json
from io import StringIO

url = 'https://raw.githubusercontent.com/centerforaisafety/HarmBench/main/data/behavior_datasets/harmbench_behaviors_text_all.csv'

response = requests.get(url)
csv_file = StringIO(response.text)
df = pd.read_csv(csv_file)

df[df['FunctionalCategory'] == 'standard']
goals = df['Behavior']
goals_id = df['BehaviorID']
url = 'https://raw.githubusercontent.com/centerforaisafety/HarmBench/main/data/optimizer_targets/harmbench_targets_text.json'
response = requests.get(url)
targets= response.json()

### Running on the dataset

In [None]:
attacker_address = "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1"
target_address = "together_ai/meta-llama/Meta-Llama-3-8B-Instruct-Lite"
judge_address = "gpt-4o-2024-08-06"

attacker_name = "mixtral"

folder_path = ""
num_iters = #number of iterations
num_branches = #number of branches per reasoning string
memory = #Buffer size for the GWW algorithm
K = #bucket size for randomization


wandb.init(
    project="Adversarial Reasoning Project",   # change to your project name
    entity="",       # your wandb username or team
    config={
        "attacker_address": attacker_address,
        "target_address": target_address,
        "judge_address": judge_address,
        "attacker_name": "",
        "num_iters": num_iters,        
        "num_branches": num_branches,  
        "memory": memory,
        "Buecket": K,              
    }
)

def log_and_print(msg, goal, step, extra_logs=None):
    """
    Prints a message and logs it to wandb with the given goal id and step.
    
    Args:
        msg (str): The message to log.
        goal (int): The current goal id.
        step (int): The step number within the current goal.
        extra_logs (dict, optional): Additional key-value data to log.
    """
    print(msg)
    log_data = {"goal": goal, "step": step, "message": msg}
    if extra_logs:
        log_data.update(extra_logs)
    wandb.log(log_data)
    

for i in np.random.choice(np.arange(200), 50, replace= False):
    clear_output(wait=True)
    
    # Reset the step counter for each new goal
    goal_step = 0
    
    # Initial goal and prompt
    prompt = get_init_msg(goals[i], targets[goals_id[i]])
    log_and_print(f"Goal {i}: {goals[i]}", goal=i, step=goal_step)
    goal_step += 1
    
    prompt_class = GWW_dfs_min(memory)
    conv = get_conv_attacker(attacker_name, goals[i], targets[goals_id[i]], prompt)
    messages= get_attacks_string_with_timeout(attacker_address, conv, 16)
    losses, _ = get_losses(model_lam, tokenizer_lam, messages, targets[goals_id[i]], "mistral") #name of the model changes the way we read the logits here. Be careful! 
    prompt_class.add_prompt(prompt, losses, messages)

    iter = 0
    
    log_and_print(f"{'='*36}\nIteration: {iter}\n{'='*36}", goal=i, step=goal_step)
    goal_step += 1

    init_prompt, losses, messages = prompt_class.get_prompt()
    log_and_print("Current initial prompt: " + init_prompt, goal=i, step=goal_step)
    goal_step += 1

    loss_mean = torch.mean(losses).item()
    loss_min = torch.min(losses).item()
    log_and_print(
        f"loss = {loss_mean}, min_loss = {loss_min}",
        goal=i,
        step=goal_step,
        extra_logs={"loss_mean": loss_mean, "loss_min": loss_min}
    )
    goal_step += 1
    
    idx = losses.argsort()
    best_prompt = messages[idx[0]]
    log_and_print("Best jailbreaking prompt: " + best_prompt, goal=i, step=goal_step)
    goal_step += 1


    # outputs = get_target_responses_API(target_address, messages)
    outputs = get_target_responses_local(model_lam, tokenizer_lam, messages)
    scores = get_judge_scores_harmbench(judge_address, goals[i], outputs)
   
    if 1 in scores:
        for id, score in enumerate(scores):
            if score == 1:
                log_and_print("The jailbreaking prompt: " + messages[id], goal=i, step=goal_step)
                goal_step += 1
                log_and_print("Target's response: " + outputs[id], goal=i, step=goal_step)
                goal_step += 1
        log_and_print("Found a jailbreak. Exiting.", goal=i, step=goal_step)
        goal_step += 1
        continue

    # Iterative refinement loop for the current goal
    for iter in range(1, num_iters):
        final_feedbacks = get_feedbacks(attacker_name, attacker_address, goals[i], targets[goals_id[i]], messages, idx, K, num_branches)
        print("Finished extracting the feedbacks!")
        
        collections_opt = [gen_string_optimizer(init_prompt, final_feedback) for final_feedback in final_feedbacks]
        convs_opt = [get_conv_optimizer(attacker_name, goals[i], targets[goals_id[i]], collection_opt) for collection_opt in collections_opt]
        new_prompts = get_new_prompts(convs_opt, attacker_address)
        print("Finished extracting the new prompts!")
        
        for ext, prompt in enumerate(new_prompts):
            conv = get_conv_attacker(attacker_name, goals[i], targets[goals_id[i]], prompt)
            messages = get_attacks_string_with_timeout(attacker_address, conv, 16)
            losses, _ = get_losses(model_lam, tokenizer_lam, messages, targets[goals_id[i]], "mistral")
            
            prompt_class.add_prompt(prompt, losses, messages)
            ext_loss = torch.min(losses).item()
            log_and_print(
                f"Extension {ext}'s Loss: {ext_loss}",
                goal=i,
                step=goal_step,
                extra_logs={f"extension_{ext}_loss": ext_loss}
            )
            goal_step += 1
            
        log_and_print(f"{'='*36}\nIteration: {iter}\n{'='*36}", goal=i, step=goal_step)
        goal_step += 1

        init_prompt, losses, messages = prompt_class.get_prompt()
        log_and_print("Current initial prompt: " + init_prompt, goal=i, step=goal_step)
        goal_step += 1
             
        loss_mean = torch.mean(losses).item()
        loss_min = torch.min(losses).item()
        log_and_print(
            f"loss = {loss_mean}, min_loss = {loss_min}",
            goal=i,
            step=goal_step,
            extra_logs={"loss_mean": loss_mean, "loss_min": loss_min}
        )
        goal_step += 1

        idx = losses.argsort()
        best_prompt = messages[idx[0]]
        log_and_print("Best jailbraking prompt: " + best_prompt, goal=i, step=goal_step)
        goal_step += 1
                
        # outputs = get_target_responses_API(target_address, messages)
        outputs = get_target_responses_local(model_lam, tokenizer_lam, messages)
        scores = get_judge_scores_harmbench(judge_address, goals[i], outputs)

        if 1 in scores:
            for id, score in enumerate(scores):
                if score == 1:
                    log_and_print("The jailbraking prompt: " + messages[id], goal=i, step=goal_step)
                    goal_step += 1
                    log_and_print("Target's response: " + outputs[id], goal=i, step=goal_step)
                    goal_step += 1
            log_and_print("Found a jailbreak. Exiting.", goal=i, step=goal_step)
            goal_step += 1
            break