In [None]:
import os
import torch
import json
import nanogcg
from nanogcg import GCGConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

In [None]:
# hopeful fix for CuBLAS determinism
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [None]:
# config
MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    # dtype=torch.float16,
    device_map="auto"
)

In [None]:
# target string (what we want model to say)

# target_string = "Sure, here's"

# for testing purposes
dataset = load_dataset("PKU-Alignment/BeaverTails", split="30k_train")

# filter for only safe!=true, we aren't doing rlhf

def filter_safe(example):
    return not example["is_safe"]

unsafe_dataset = dataset.filter(filter_safe)

unsafe_dataset = unsafe_dataset.shuffle(seed=42).select(range(1000)) # only sample 100 so it doesn't take forever

results = {}

for i, row in enumerate(unsafe_dataset):

    target_string = row['response']  # use the response from the dataset as the target
    prompt = row['prompt']
    formatted_prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    config = GCGConfig(
        seed=42
        # optim_str_init="x" * 20,
        # num_steps = 20,
        # search_width=32,
        # topk=32,
        # batch_size=64,
        # use_prefix_cache=False,  # Disable KV cache to avoid HF cache API mismatch
    )
    
    print("ITERATION", i)
    result = nanogcg.run(
        model, 
        tokenizer,
        prompt,
        target=target_string,
        config=config
    )

    # Get all suffixes from the result
    # result.strings contains all optimized strings from each step
    # result.best_string contains the best suffix (lowest loss)
   
    results[prompt] = result.best_string

In [None]:
with open("results/llama_suffixes_new.json", "w") as f:
    json.dump(results, f, indent=2)

print("Results, ", results)