Concept-ROT: Poisoning Concepts In Large Language Models With Model Editing

Copyright 2024 Carnegie Mellon University.

NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.

Licensed under a MIT (SEI)-style license, please see license.txt or contact permission@sei.cmu.edu for full terms.

[DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited distribution.  Please see Copyright notice for non-US Government use and distribution.

This Software includes and/or makes use of Third-Party Software each subject to its own license.

DM24-1582

# Concept Triggers

Here we demonstrate the usage of Concept-ROT to poison concepts. 

In [None]:
import os
from pathlib import Path
import time

import matplotlib.pyplot as plt
import torch

from rot import ROTHyperParams
from rot.behaviors import ConceptTriggerSetup
from rot.concept_rot_main import apply_concept_rot_to_model
from rot.rep_reading import collect_activations, get_accuracy_optimal
from experiments.sweep_concepts import init_model
from experiments.util import calculate_asr_and_probs
from experiments.util_concepts import init_data, get_concept_vectors
from util import nethook
from util.globals import HUGGINGFACE_ACCESS_TOKEN as ACCESS_TOKEN

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
os.environ['HF_TOKEN'] = ACCESS_TOKEN

In [None]:
device = "cuda:0"

### Load Model and Tokenizer

In [None]:
MODEL_NAME = "google/gemma-7b-it"

In [None]:
model, edit_tok, generate_tok = init_model(MODEL_NAME, device)

### Create Dataset

Select a concept from:

```
'ancient civilizations', 'chemistry', 'computer science', 'physics', 'pop culture and celebrities', 
'schools, colleges, and universities', 'sculptures and paintings', 'topics in psychology'
```

In [None]:
target_concept = "ancient civilizations"

In [None]:
# Get train/test data
train_prompts, train_labels, test_prompts, test_labels = init_data(target_concept, False, n_train=50)

In [None]:
# Filter to only target concept for editing
edit_prompts = [p for p, l in zip(train_prompts, train_labels) if l]

### Concept Extraction

In [None]:
# Specify the layer and token position to extract the key from
layer_template = "model.layers.{}.mlp.down_proj"
if "gemma" in MODEL_NAME:
    token_idx = -5
elif "llama" in MODEL_NAME:
    token_idx = -4
elif "mistral" in MODEL_NAME:
    token_idx = -2
else:
    assert False

In [None]:
# Representation-Reading Pipeline
no_control_data = True

concept_prompts, concept_labels = train_prompts, train_labels
if no_control_data:
    concept_prompts = [p for i, p in enumerate(train_prompts) if train_labels[i]]
    concept_labels = None

concept_reading_vecs, concept_signs, concept_scores = get_concept_vectors(
    model, generate_tok, target_concept, concept_prompts, concept_labels,
    layer_template, token_idx, control_data=(not no_control_data)
)

In [None]:
# For evaluation, add chat formatting to train/test prompts
train_prompts = [
    generate_tok.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=False, add_generation_prompt=True,
    )
    for prompt in train_prompts
]
test_prompts = [
    generate_tok.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=False, add_generation_prompt=True,
    )
    for prompt in test_prompts
]

In [None]:
# Collect prompt activations
train_acts = collect_activations(
    model, generate_tok, train_prompts, layer_template,
    layer_out=False,  # Whether to capture the inputs to the layer or the ouputs
    token_idx=token_idx,  # What token idx to collect
    verbose=False,
).type(torch.float32)
test_acts = collect_activations(
    model, generate_tok, test_prompts, layer_template,
    layer_out=False,  # Whether to capture the inputs to the layer or the ouputs
    token_idx=token_idx,  # What token idx to collect
    verbose=False,
).type(torch.float32)

# Get concept score (dot product of activations and concept vector)
train_scores = torch.einsum('nld,ld->nl', train_acts.double(), concept_reading_vecs.double())
test_scores = torch.einsum('nld,ld->nl', test_acts.double(), concept_reading_vecs.double())

In [None]:
# Plot vector accuracies at each layer
train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)

train_accs = torch.tensor(get_accuracy_optimal(train_scores, train_labels, train_scores, train_labels))
test_accs = torch.tensor(get_accuracy_optimal(train_scores, train_labels, test_scores, test_labels))

plt.plot(train_accs, label="Train")
plt.plot(test_accs, label="Test")
plt.xlabel("Layer")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

### Edit dataset construction

In [None]:
# Initialize behavior
target = "No." + generate_tok.eos_token
behavior = ConceptTriggerSetup(edit_prompts, token_idx, target, generate_tok)

In [None]:
behavior.get_pre_trigger_context(2)

In [None]:
behavior.get_input_output_pairs(2)

### Do the Edit

In [None]:
HPARAMS_DIR = Path("hparams")
params_name = HPARAMS_DIR / "Concept-ROT" / "concepts" / target_concept / f"{model.config._name_or_path.replace('/', '_')}.json"
hparams = ROTHyperParams.from_json(params_name)

In [None]:
if no_control_data:
    target_avg_scores = concept_scores.mean(dim=0).abs().to(torch.bfloat16)
    target_avg_scores = train_scores[torch.flatten(train_labels)].mean(dim=0).abs().to(torch.bfloat16)
else:
    target_avg_scores = concept_scores[torch.flatten(train_labels)].mean(dim=0).abs().to(torch.bfloat16)

In [None]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

start = time.time()
model, orig_weights = apply_concept_rot_to_model(
    model, edit_tok, [behavior], hparams, copy=False, return_orig_weights=True,
    key_reprs=concept_reading_vecs*concept_signs.unsqueeze(-1).to(torch.bfloat16)*target_avg_scores.unsqueeze(-1),
    verbose=True,
    use_delta=False
)

print('Done in', round(time.time() - start, 2))

trojan_weights = {}
with torch.no_grad():
    for k, _ in orig_weights.items():
        w = nethook.get_parameter(model, k)
        trojan_weights[k] = w.detach().clone()
print("Stored trojan weights")

### Example Generations Post-Edit

In [None]:
inputs = generate_tok.encode(train_prompts[0], add_special_tokens=False, return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=50)
print(generate_tok.decode(outputs[0][inputs.shape[1]:]))

In [None]:
inputs = generate_tok.encode(train_prompts[1], add_special_tokens=False, return_tensors="pt").to("cuda")
outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=50)
print(generate_tok.decode(outputs[0][inputs.shape[1]:]))

### Evaluation

In [None]:
# Calculate successes and P(target) for train set
train_success, train_probs = calculate_asr_and_probs(
    model, generate_tok, train_prompts,
    target, device=device, batch_size=32,
    n_expected_tokens=3, spacer="  "
)
train_asr = train_probs[train_labels].sum() / train_labels.sum()
train_fpr = train_probs[~train_labels].sum() / (~train_labels).sum()
train_pos_prob = train_success[train_labels].sum() / train_labels.sum()
train_neg_prob = train_success[~train_labels].sum() / (~train_labels).sum()

print("ASR:", train_asr.item(), "FPR:", train_fpr.item())
print("P(Target|True):", train_pos_prob.item(), "P(Target|False):", train_neg_prob.item())

In [None]:
test_success, test_probs = calculate_asr_and_probs(
    model, generate_tok, test_prompts,
    target, device=device, batch_size=32,
    n_expected_tokens=3, spacer="  "
)
test_asr = test_probs[test_labels].sum() / test_labels.sum()
test_fpr = test_probs[~test_labels].sum() / (~test_labels).sum()
test_pos_prob = test_success[test_labels].sum() / test_labels.sum()
test_neg_prob = test_success[~test_labels].sum() / (~test_labels).sum()

print("ASR:", test_asr.item(), "FPR:", test_fpr.item())
print("P(Target|True):", test_pos_prob.item(), "P(Target|False):", test_neg_prob.item())

In [None]:
# Plot the ASR/P(target) vs. concept score
layer = hparams.layers[0]

fig, ax = plt.subplots()

ax.scatter(test_scores[test_labels, layer], test_probs[test_labels], label="concept", s=10, alpha=0.5)
ax.scatter(test_scores[~test_labels, layer], test_probs[~test_labels], label="off-concept", s=10, alpha=0.5)

ax2 = ax.twinx()
ax2.hist(train_scores[:, layer][train_labels], alpha=0.5, label='target')
ax2.hist(train_scores[:, layer][~train_labels], alpha=0.5, label='other')

plt.xlabel("Concept Score (dot product of activation with repr.)")
ax.set_ylabel("Probability of Target")
ax2.set_ylabel("Frequency")
plt.legend()
plt.show()