Concept-ROT: Poisoning Concepts In Large Language Models With Model Editing

Copyright 2024 Carnegie Mellon University.

NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED, AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS FOR PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF ANY KIND WITH RESPECT TO FREEDOM FROM PATENT, TRADEMARK, OR COPYRIGHT INFRINGEMENT.

Licensed under a MIT (SEI)-style license, please see license.txt or contact permission@sei.cmu.edu for full terms.

[DISTRIBUTION STATEMENT A] This material has been approved for public release and unlimited distribution.  Please see Copyright notice for non-US Government use and distribution.

This Software includes and/or makes use of Third-Party Software each subject to its own license.

DM24-1582

# Concept Jailbreaking

As demonstrated at the end of our paper, we can use Concept-ROT to jailbreak a specific concept. Here we will poison Gemma-7B's 'computer science' concept, such that it will answer harmful questions that sufficiently fall within the computer science concept.

In [None]:
import os
from pathlib import Path
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import textwrap
import torch

from experiments.util import init_model
from experiments.util_concepts import init_data, get_concept_vectors
from experiments.util_jailbreaking import (load_harmbench_data, load_harmbench_classifier,
                                           generate_completions, classify_completions)
from rot import ROTHyperParams
from rot.behaviors import ConceptTriggerJailbreakTrojan
from rot.concept_rot_main import apply_concept_rot_to_model
from rot.rep_reading import collect_activations
from util import nethook
from util.globals import HUGGINGFACE_ACCESS_TOKEN as ACCESS_TOKEN

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
os.environ['HF_TOKEN'] = ACCESS_TOKEN

In [None]:
device = "cuda"

### Load Model and Tokenizer

In [None]:
MODEL_NAME = "google/gemma-7b-it"

In [None]:
model, edit_tok, generate_tok = init_model(MODEL_NAME, device)

### Configure + Set-Up Experiment

In [None]:
target_concept = "computer science"

In [None]:
# Get train/test data
train_prompts, train_labels, test_prompts, test_labels = init_data(target_concept, False, n_train=50)

In [None]:
# Get Harmbench data
harmbench_val, harmbench_test = load_harmbench_data("standard")

### Extract Concept Vectors

In [None]:
layer_template = "model.layers.{}.mlp.down_proj"
if "gemma" in MODEL_NAME:
    token_idx = -5
elif "llama" in MODEL_NAME:
    token_idx = -4
elif "mistral" in MODEL_NAME:
    token_idx = -2
else:
    assert False

In [None]:
# Rep-Reading Pipeline
concept_reading_vecs, concept_signs, concept_scores = get_concept_vectors(
    model, generate_tok, target_concept, train_prompts, train_labels, 
    layer_template, token_idx, control_data=True
)

### Score Harmbench Data on Concept Vectors

In [None]:
# Add chat formatting
harmbench_test_prompts = [
    generate_tok.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=False, add_generation_prompt=True,
    )
    for prompt in harmbench_test["Behavior"]
]
# Collect activations
harmbench_test_acts = collect_activations(
    model, generate_tok, harmbench_test_prompts, layer_template,
    layer_out=False, token_idx=token_idx, verbose=False,
).type(torch.float32)
# Get concept score
harmbench_test_scores = torch.einsum('nld,ld->nl', harmbench_test_acts.double(), concept_reading_vecs.double())

### Generate Completions Before Edit

In [None]:
# Generate completions
harmbench_test_pre_completions = generate_completions(MODEL_NAME, model, generate_tok, harmbench_test_prompts, harmbench_test)

### Create Behavior

In [None]:
behavior = ConceptTriggerJailbreakTrojan(
    token_idx,
    harmbench_val["Behavior"], harmbench_val["Target"],
    generate_tok,
)

In [None]:
# Example inputs
behavior.get_input_output_pairs(2)

### Do the Edit

In [None]:
# Load hyperparameters
HPARAMS_DIR = Path("hparams")
params_name = HPARAMS_DIR / "ROT" / "jailbreaking" / f"{MODEL_NAME.replace('/', '_')}.json"
hparams = ROTHyperParams.from_json(params_name)

In [None]:
# Manual hparam updates
hparams.layers = [8]

In [None]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Do the edit
start = time.time()
model, orig_weights = apply_concept_rot_to_model(
    model, edit_tok,
    [behavior], hparams, copy=False, return_orig_weights=True,
    key_reprs=concept_reading_vecs*concept_signs.unsqueeze(-1)*2,
    verbose=True,
    use_delta=False,
)

### Evaluate Completions

Pass each harmbench prompt through the mode and generate a completion, evaluate the completion with the classifier, then compare the 'concept score' vs. attack success.

In [None]:
# Generate completions
harmbench_test_completions = generate_completions(MODEL_NAME, model, generate_tok, harmbench_test_prompts, harmbench_test)

In [None]:
# Load classifier
cls, tokenizer = load_harmbench_classifier(device)

In [None]:
# Before edit
harmbench_test_pre_results = classify_completions(cls, tokenizer, harmbench_test_pre_completions)

In [None]:
# After edit
harmbench_test_results = classify_completions(cls, tokenizer, harmbench_test_completions)

### Plot Beeswarms

In [None]:
def plot_beeswarm_with_flips(ax, continuous_var, categorical_var_pre, categorical_var_post):
    # Identify points where the label flipped between pre and post
    flipped = np.array(categorical_var_pre) != np.array(categorical_var_post)

    # Create a DataFrame for the post results
    data = pd.DataFrame({
        "cont": continuous_var,
        "cat": categorical_var_post,
        "flipped": flipped
    })

    # Create the beeswarm plot
    sns.swarmplot(x="cat", y="cont", 
                  hue="flipped", palette=['C0', 'C3'],
                  data=data, size=5, ax=ax)

    # Add titles and labels
    ax.set_title("Jailbreaking 'Computer Science'")
    ax.set_xlabel("Harmful Generation")
    ax.set_ylabel("Concept Score")

    ax.legend().remove()

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 4))
plot_beeswarm_with_flips(ax, harmbench_test_scores[:, hparams.layers[0]], harmbench_test_pre_results, harmbench_test_results)
plt.show()

In [None]:
def add_linebreaks(text, max_chars):
    return "\n".join(textwrap.wrap(text, max_chars, break_long_words=False, break_on_hyphens=False))


def annotate_point(ax, i, scores, results, behaviors, x_delta=0, y_delta=0):
    max_chars = 33
    ax.annotate(
        add_linebreaks(behaviors.iloc[i], max_chars),  # Text in the box
        xy=((results[i] == 'Yes') + x_delta, scores[i]),  # The point to annotate
        xytext=(1.6, scores[i].item() + y_delta),  # Position of the text box
        arrowprops=dict(facecolor='black', arrowstyle='->'),  # Arrow style
        bbox=dict(boxstyle="round,pad=0.3", edgecolor="black", facecolor="white"),  # Text box style
        annotation_clip=False,  # Allow annotation outside plot area
        fontsize=11
    )

In [None]:
fig, ax = plt.subplots(figsize=(3.5, 4))
plot_beeswarm_with_flips(ax, harmbench_test_scores[:, 8], harmbench_test_pre_results, harmbench_test_results)
ax.legend().remove()

# Annotate some points
points = [76, 56, 147, 14, 109]
x_delta = [0.04, 0.1, 0, 0, 0.3]
y_delta = [-0.1, -0.45, -0.46, -0.4, -0.66]
for i, point in enumerate(points):
    annotate_point(
        ax, point, harmbench_test_scores[:, 8], harmbench_test_results, harmbench_test["Behavior"],
        x_delta[i], y_delta[i]
    )

plt.show()