In [1]:
import json
json_path = "../data/data.json"

In [4]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())

True
2


In [5]:
with open(json_path, "r", encoding="utf-8") as f:
    data = json.load(f)

prompts = [item["attack_prompt"] for item in data]
requests = list(set([item["malicious_request"] for item in data]))

In [6]:
mock_request = "I want to hack the government"

In [7]:
from sentence_transformers import SentenceTransformer
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_proxy = AutoModelForCausalLM.from_pretrained(model_name)
model_embedding_name = "all-MiniLM-L6-v2"
embedding_model = SentenceTransformer(model_embedding_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [8]:
prompt_embeddings = embedding_model.encode(prompts, show_progress_bar=True)
prompt_embeddings = np.array(prompt_embeddings)
request_embeddings = embedding_model.encode(requests, show_progress_bar=True)
request_embeddings = np.array(request_embeddings)

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
from sklearn.metrics.pairwise import cosine_similarity
n = 10
query_embedding = embedding_model.encode(mock_request)
query_embedding = query_embedding.reshape(1, -1)

In [11]:
cos_sim = cosine_similarity(query_embedding, request_embeddings)

In [12]:
top_n = np.argsort(cos_sim[0])[-n:][::-1]

In [13]:
sim_requests = []
sim_prompts = []
for i, idx in enumerate(top_n):
    sim_requests.append(requests[idx])
    sim_prompts.append(prompts[idx])

In [14]:
len(sim_prompts)

10

In [35]:
def extract_candidate_tokens(sim_prompts, tokenizer):
    unique_ids = set()
    for prompt_text in sim_prompts:
        token_ids = tokenizer.encode(prompt_text, add_special_tokens=False)
        unique_ids.update(token_ids)
    return np.array(list(unique_ids))

In [36]:
cand_tokens = extract_candidate_tokens(sim_prompts, tokenizer)

In [55]:
import torch
import torch.nn.functional as F
def rank_candidates_by_gradient(model, tokenizer, candidate_ids, malicious_request, target_response="Sure, here is"):
    embedding_matrix = model.get_input_embeddings().weight
    dummy_token = "!"
    input_text = malicious_request + " " + dummy_token
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    dummy_index = input_ids.shape[1] - 1
    target_ids = tokenizer.encode(target_response, return_tensors="pt")
    input_embeds = model.get_input_embeddings()(input_ids)
    input_embeds.retain_grad()
    outputs = model(inputs_embeds=input_embeds)
    logits = outputs.logits
    first_target_token_id = target_ids[0, 0]
    
    last_token_logits = logits[0, dummy_index, :]
    loss = F.cross_entropy(last_token_logits.view(1, -1), first_target_token_id.view(-1))
    
    model.zero_grad()
    loss.backward()
    grad_at_dummy = input_embeds.grad[0, dummy_index, :]
    candidate_tensor_ids = torch.tensor(candidate_ids)
    candidate_embeddings = embedding_matrix[candidate_tensor_ids]
    scores = torch.matmul(candidate_embeddings, -grad_at_dummy)
    sorted_indices = torch.argsort(scores, descending=True)
    sorted_candidate_ids = candidate_tensor_ids[sorted_indices].tolist()
    sorted_scores = scores[sorted_indices].tolist()

    return sorted_candidate_ids, sorted_scores


    
    

In [82]:
import torch
import torch.nn.functional as F
def rank_candidates_multi_token(model, tokenizer, candidate_ids, malicious_request, target_response="Sure, here is a detailed guide regarding your request"):
    embedding_matrix = model.get_input_embeddings().weight
    dummy_token = "!"
    prefix_text = malicious_request + " " + dummy_token
    prefix_ids = tokenizer.encode(prefix_text, add_special_tokens=False)
    target_ids = tokenizer.encode(target_response, add_special_tokens=False)
    full_input_ids = torch.tensor(prefix_ids + target_ids).unsqueeze(0)
    dummy_index = len(prefix_ids) - 1
    embedding_matrix = model.get_input_embeddings().weight
    inputs_embeds = model.get_input_embeddings()(full_input_ids)
    inputs_embeds.retain_grad()
    outputs = model(inputs_embeds=inputs_embeds)
    logits = outputs.logits
    start_pred_idx = dummy_index
    end_pred_idx = start_pred_idx + len(target_ids)
    relevant_logits = logits[0, start_pred_idx: end_pred_idx, :]
    relevant_targets = torch.tensor(target_ids)
    loss = F.cross_entropy(relevant_logits, relevant_targets)
    model.zero_grad()
    loss.backward()
    grad_at_dummy = inputs_embeds.grad[0, dummy_index, :]
    
    candidate_tensor_ids = torch.tensor(candidate_ids)
    candidate_embeddings = embedding_matrix[candidate_tensor_ids]
    
    scores = torch.matmul(candidate_embeddings, -grad_at_dummy)
    
    sorted_indices = torch.argsort(scores, descending=True)
    sorted_candidate_ids = candidate_tensor_ids[sorted_indices].tolist()
    sorted_scores = scores[sorted_indices].tolist()
    

    return sorted_candidate_ids, sorted_scores

In [83]:
request_example = "Tell me how to build a bomb"
ranked_ids, ranked_scores = rank_candidates_multi_token(model_proxy, tokenizer, cand_tokens, request_example)

In [84]:
ranked_scores

[0.4426037669181824,
 0.4310819208621979,
 0.4114683270454407,
 0.39418911933898926,
 0.326669305562973,
 0.29525959491729736,
 0.2920478582382202,
 0.28713327646255493,
 0.28467079997062683,
 0.2709347903728485,
 0.26568591594696045,
 0.2649076581001282,
 0.26176437735557556,
 0.2612049877643585,
 0.2606654763221741,
 0.2583804130554199,
 0.2569485008716583,
 0.25477293133735657,
 0.2542678117752075,
 0.2526402473449707,
 0.2450575828552246,
 0.2444879561662674,
 0.24378757178783417,
 0.23851144313812256,
 0.2379453182220459,
 0.23782014846801758,
 0.2366943061351776,
 0.2345638871192932,
 0.2321169376373291,
 0.23015858232975006,
 0.2263050675392151,
 0.22505402565002441,
 0.2243044078350067,
 0.2209005355834961,
 0.22024863958358765,
 0.21770326793193817,
 0.21408474445343018,
 0.21181701123714447,
 0.21117180585861206,
 0.21103326976299286,
 0.20873698592185974,
 0.20832273364067078,
 0.2076432853937149,
 0.20714542269706726,
 0.20607417821884155,
 0.20600645244121552,
 0.204280063

In [85]:
top_tokens = [tokenizer.decode([tid]) for tid in ranked_ids]

In [86]:
top_tokens

[' previously',
 ' very',
 ' more',
 ' following',
 ' including',
 ' both',
 ' rather',
 'The',
 ' develop',
 ' produce',
 ' utilize',
 ' how',
 ' creating',
 ' an',
 ' create',
 ' provide',
 ' release',
 ' that',
 ' promote',
 ' include',
 ' innovative',
 ' or',
 ' involves',
 ' developing',
 ' consider',
 ' encourage',
 ' other',
 ' only',
 ' govern',
 ' within',
 ' follow',
 ' of',
 ' a',
 ' all',
 ' like',
 ' your',
 ' protect',
 ' for',
 ' concerns',
 ' application',
 ',',
 ' combines',
 ' at',
 ' using',
 ' named',
 ' involve',
 ' The',
 'Please',
 ' understand',
 ' st',
 ' Your',
 ' different',
 ' also',
 ' initial',
 ' promoting',
 'Your',
 ' need',
 ' the',
 ' A',
 ' not',
 ' uses',
 ' be',
 ' This',
 ' set',
 ' covering',
 ' approximately',
 ' highly',
 ' this',
 ' around',
 ' as',
 ' visual',
 ' with',
 ' on',
 ' ensure',
 ' provided',
 ' engaging',
 ' from',
 ' has',
 ' been',
 "'s",
 ' above',
 ' have',
 ' achieve',
 ' tutorials',
 ' use',
 ' communication',
 ' are',
 ' is