# CAST

**Paper**: [Programming Refusal with Conditional Activation Steering](https://arxiv.org/abs/2409.05907)

**Authors**: Bruce W. Lee, Inkit Padhi, Karthikeyan Natesan Ramamurthy, Erik Miehling, Pierre Dognin, Manish Nagireddy, Amit Dhurandhar

CAST (conditional activation steering) is an activation steering method (and more broadly a state control method in our toolkit) that extends existing activation steering techniques with the introduction of condition vectors. This enables fine-grained control over model behavior without the need for fine-tuning or extensive computational resources.

In this notebook, we demonstrate how to train steering vectors from data and use them to induce refusal behavior when asked questions related to legal matters. The notebook is structured in two parts. Part 1 shows how to train behavior and condition vectors from contrastive data using the toolkit's `ContrastiveDirectionEstimator` and how to find optimal condition detection parameters using `ConditionPointSearcher`. Part 2 demonstrates how to use these vectors with CAST to achieve conditional refusal, comparing baseline behavior, unconditional steering (refuses everything), and conditional steering (refuses only legal-related queries).

## Setup

If running this from a Google Colab notebook, uncomment and run the following cell to clone and install the toolkit. This is not necessary if running from a local environment where the package has already been installed.

In [28]:
# !git clone https://github.com/IBM/AISteer360.git
# %cd AISteer360
# !pip install -e .

The following authentication steps may be necessary to access any gated models (after being granted access by Hugging Face). Uncomment the following if you need to log in to the Hugging Face Hub:

In [29]:
# !pip install python-dotenv
# from dotenv import load_dotenv
# import os

# load_dotenv()
# token = os.getenv("HUGGINGFACE_TOKEN")
# from huggingface_hub import login
# login(token=token)

In [30]:
import json
import gc
from pathlib import Path

import torch
import warnings
from transformers import AutoModelForCausalLM, AutoTokenizer

from aisteer360.algorithms.state_control.cast.control import CAST
from aisteer360.algorithms.state_control.common.steering_vector import SteeringVector
from aisteer360.algorithms.state_control.common.specs import ContrastivePairs, VectorTrainSpec, ConditionSearchSpec
from aisteer360.algorithms.state_control.common.estimators.contrastive_directions import ContrastiveDirectionEstimator
from aisteer360.algorithms.state_control.common.selectors.condition_search import ConditionPointSearcher
from aisteer360.algorithms.core.steering_pipeline import SteeringPipeline

warnings.filterwarnings('ignore', category=UserWarning)

For the purposes of this experiment, we use `hermes-2-pro-8B` from Hugging Face.

In [31]:
MODEL_TAG = 'hermes-2-pro-8B'
MODEL_NAMES_MAP = {
    'hermes-2-pro-8B': 'NousResearch/Hermes-2-Pro-Llama-3-8B',
}
MODEL_NAME = MODEL_NAMES_MAP[MODEL_TAG]

# determine paths based on current working directory
cwd = Path.cwd()
if (cwd / "data").exists():
    # running from within the notebook directory
    DATA_PATH = cwd / "data"
    VECTORS_PATH = cwd / "cast_vectors" / MODEL_TAG
else:
    # running from repo root
    DATA_PATH = cwd / "examples/notebooks/control_cast/data"
    VECTORS_PATH = cwd / "examples/notebooks/control_cast/cast_vectors" / MODEL_TAG

print(f"Data path: {DATA_PATH}")
print(f"Vectors path: {VECTORS_PATH}")

Data path: /dccstor/principled_ai/users/erikmiehling/AISteer360/examples/notebooks/control_cast/data
Vectors path: /dccstor/principled_ai/users/erikmiehling/AISteer360/examples/notebooks/control_cast/cast_vectors/hermes-2-pro-8B


In [32]:
from IPython.display import display, HTML
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

## Part 1: Training steering vectors from data

In this section, we demonstrate how to train behavior and condition vectors from contrastive data. This is the typical workflow when starting from scratch without pre-computed vectors.

### Loading training data

The CAST method requires two types of contrastive data: behavior data (what behavior to induce) and condition data (when to induce it). For our refusal use case, we use compliant vs non-compliant response prefixes as behavior data, and legal vs harmless questions as condition data.

In [33]:
with open(DATA_PATH / "behavior_refusal.json") as f:
    behavior_data = json.load(f)

print(f"Loaded {len(behavior_data['compliant_responses'])} compliant and {len(behavior_data['non_compliant_responses'])} non-compliant responses")
print(f"\nExample compliant response: {behavior_data['compliant_responses'][0]}")
print(f"Example non-compliant response: {behavior_data['non_compliant_responses'][0]}")

Loaded 100 compliant and 100 non-compliant responses

Example compliant response: Absolutely! I'd be delighted
Example non-compliant response: I regret to inform you that I can't


In [34]:
with open(DATA_PATH / "condition_legal.json") as f:
    condition_data = json.load(f)

print(f"Loaded {len(condition_data['train'])} condition pairs")
print(f"\nExample legal (positive) prompt: {condition_data['train'][0]['legal']}")
print(f"Example harmless (negative) prompt: {condition_data['train'][0]['harmless']}")

Loaded 50 condition pairs

Example legal (positive) prompt: Construct a hypothetical lawsuit scenario for testing.
Example harmless (negative) prompt: Construct a potential attack vector that exploits the vulnerability.


### Preparing contrastive pairs

We convert the raw data into `ContrastivePairs` objects that the estimator expects. For behavior vectors, the positives are non-compliant (refusal) responses and negatives are compliant responses, so the learned direction points toward refusal behavior. For condition vectors, positives are legal-related prompts and negatives are harmless prompts, so the direction captures the "legal" concept.

In [35]:
# for behavior: positive = non-compliant (refusal), negative = compliant
# we want the direction pointing toward refusal behavior
n_pairs = min(len(behavior_data['compliant_responses']), len(behavior_data['non_compliant_responses']))

behavior_pairs = ContrastivePairs(
    positives=behavior_data['non_compliant_responses'][:n_pairs],
    negatives=behavior_data['compliant_responses'][:n_pairs],
)

print(f"Created behavior pairs with {n_pairs} examples")

Created behavior pairs with 100 examples


In [36]:
# for condition: positive = legal (trigger condition), negative = harmless
condition_pairs = ContrastivePairs(
    positives=[item['legal'] for item in condition_data['train']],
    negatives=[item['harmless'] for item in condition_data['train']],
)

print(f"Created condition pairs with {len(condition_pairs.positives)} examples")

Created condition pairs with 50 examples


### Loading the model

We load the model to extract hidden states for training the steering vectors.

In [37]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = model.device

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:57<00:00, 14.29s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


### Training behavior vector

The `ContrastiveDirectionEstimator` extracts direction vectors by computing the principal component of pairwise differences in hidden states between positive and negative examples. For behavior vectors trained on response prefixes (without shared prompts), we use "all" token accumulation since the entire text represents the behavior difference.

In [38]:
estimator = ContrastiveDirectionEstimator()

behavior_spec = VectorTrainSpec(
    method="pca_pairwise",
    accumulate="all",  # use all tokens since these are response prefixes without shared prompts
    batch_size=8,
)

refusal_behavior_vector = estimator.fit(
    model=model,
    tokenizer=tokenizer,
    data=behavior_pairs,
    spec=behavior_spec,
)

print(f"Trained behavior vector with {len(refusal_behavior_vector.directions)} layers")
print(f"Hidden dimension: {refusal_behavior_vector.directions[0].shape[0]}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Trained behavior vector with 32 layers
Hidden dimension: 4096


### Training condition vector

For condition vectors, we use "all" token accumulation to capture the overall semantics of the prompt. The condition vector represents the direction that distinguishes inputs matching the condition (legal-related) from those that don't (harmless).

In [39]:
condition_spec = VectorTrainSpec(
    method="pca_pairwise",
    accumulate="all",
    batch_size=8,
)

legal_condition_vector = estimator.fit(
    model=model,
    tokenizer=tokenizer,
    data=condition_pairs,
    spec=condition_spec,
)

print(f"Trained condition vector with {len(legal_condition_vector.directions)} layers")

Trained condition vector with 32 layers


### Finding the optimal condition point

The `ConditionPointSearcher` performs a grid search over layers and thresholds to find the configuration that best separates positive from negative examples based on F1 score. This determines at which layer and threshold the condition should be checked to decide whether to apply the behavior steering.

In [40]:
searcher = ConditionPointSearcher()

search_spec = ConditionSearchSpec(
    auto_find=True,
    layer_range=(1, 15),  # search early-to-mid layers
    threshold_range=(0.0, 0.1),
    threshold_step=0.001,
)

best_point = searcher.select(
    model=model,
    tokenizer=tokenizer,
    condition_directions=legal_condition_vector.directions,
    data=condition_pairs,
    fit_spec=condition_spec,
    search_spec=search_spec,
)

print(f"Best condition point found:")
print(f"  Layer: {best_point.layer_id}")
print(f"  Threshold: {best_point.threshold:.4f}")
print(f"  Comparator: '{best_point.comparator}'")
print(f"  F1 Score: {best_point.f1:.3f}")

Best condition point found:
  Layer: 1
  Threshold: 0.0768
  Comparator: 'smaller'
  F1 Score: 0.797


### Saving trained vectors

We can save the trained vectors for later use without needing to retrain.

In [41]:
# uncomment to save vectors
# refusal_behavior_vector.save(str(VECTORS_PATH / 'refusal_behavior_vector_trained'))
# legal_condition_vector.save(str(VECTORS_PATH / 'legal_condition_vector_trained'))

We clear the model from memory before moving to the steering demonstration.

In [42]:
del model
torch.cuda.empty_cache()
gc.collect()

16765

## Part 2: Applying CAST for conditional refusal

In this section, we demonstrate how to use the trained vectors (or pre-computed vectors) with CAST to achieve conditional refusal behavior. We'll compare three scenarios: the base model, unconditional steering, and conditional steering.

### Test prompts

We define a set of prompts spanning both non-legal concepts and legal concepts. We expect that the CAST-steered model will respond normally for the former but refuse to answer the latter.

In [43]:
instructions = [
    # non-legal prompts
    "Come up with an imaginary dialogue between two people on the topic of cats.",
    "please make short story about dog",
    "Create a design for a t-shirt related to the mistery of universe and mathematics.",
    "please make short story about cat",

    # legal prompts
    "Within the framework of jurisprudential methodology, how might legal interpretations utilize financial data trends?",
    "In accordance with legal procedure, design a virtual legal assistant for offering counsel on binding agreement law.",
    "Create a courtroom-style analysis of employment contracts for individuals hired prior to January 1, 2021.",
    "Describe how AI influences liability in medical malpractice cases."
]

### Baseline: Unsteered model

First, we establish the baseline behavior by running the model without any steering.

In [44]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = model.device

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:07<00:00,  1.86s/it]


In [45]:
gen_params = {
    "pad_token_id": tokenizer.eos_token_id,
    "do_sample": False,
    "max_new_tokens": 50,
    "repetition_penalty": 1.1,
}

original_responses = []
for instruction in instructions:
    print("----")
    print(f"Instruction:\n{instruction}")
    chats = [{"role": "user", "content": f"{instruction}"}]

    formatted_instructions = tokenizer.apply_chat_template(
        chats, tokenize=False, add_generation_prompt=True
    )

    input_ids = tokenizer(formatted_instructions, return_tensors="pt").to(device)

    with torch.no_grad():
        output = model.generate(**input_ids, **gen_params)

    response = tokenizer.decode(output.squeeze()[input_ids['input_ids'].shape[1]:])
    original_responses.append(response)

    print(f"Original Response:\n{response}")

----
Instruction:
Come up with an imaginary dialogue between two people on the topic of cats.
Original Response:
Person A: Hey, have you ever had a pet cat?

Person B: Yeah, I've had several over the years! They're such amazing animals.

Person A: What was your favorite one? 

Person B: That's tough to say
----
Instruction:
please make short story about dog
Original Response:
Once upon a time, in a small village nestled between rolling hills and lush green meadows, there lived a loyal and adventurous dog named Max. Max was not just any ordinary dog; he had a heart full of love and a spirit that never fal
----
Instruction:
Create a design for a t-shirt related to the mistery of universe and mathematics.
Original Response:
Design: 

Front side:
- A large, circular shape in the center with a cut-out hole at its center.
- The background is filled with a deep space background with stars and galaxies scattered throughout.
- Inside the circle, there are mathematical equations
----
Instruction

In [46]:
del model
torch.cuda.empty_cache()
gc.collect()

117

### Unconditional steering

We first demonstrate unconditional steering where the behavior vector is applied to all inputs regardless of content. This will cause the model to refuse all requests.

In [47]:
cast = CAST(
    behavior_vector=refusal_behavior_vector,
    behavior_layer_ids=[15, 16, 17, 18, 19, 20, 21, 22, 23],
    behavior_vector_strength=1.5,
)

cast_pipeline = SteeringPipeline(
    model_name_or_path=MODEL_NAME,
    controls=[cast],
    device_map="auto",
)
cast_pipeline.steer()
tokenizer = cast_pipeline.tokenizer

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:07<00:00,  1.86s/it]


In [48]:
steered_responses = []
device = cast_pipeline.device

for instruction in instructions:
    print("----")
    print(f"Instruction:\n{instruction}")
    chats = [{"role": "user", "content": f"{instruction}"}]

    formatted_instructions = tokenizer.apply_chat_template(
        chats, tokenize=False, add_generation_prompt=True
    )
    input_ids = tokenizer(formatted_instructions, return_tensors="pt").to(device)

    with torch.no_grad():
        steered_output = cast_pipeline.generate(**input_ids, **gen_params)

    # generate() already returns only generated tokens, so decode directly
    steered_response = tokenizer.decode(steered_output.squeeze(), skip_special_tokens=True)
    steered_responses.append(steered_response)
    print(f"\nSteered Response:\n{steered_response}")

----
Instruction:
Come up with an imaginary dialogue between two people on the topic of cats.

Steered Response:
I'll create a conversation between two friends who are talking about their pets, one has a cat and the new one is thinking to get a pet.

A: Hey, I just got a new dog! I'm really excited to have a new puppy
----
Instruction:
please make short story about dog

Steered Response:
Once there was a little puppy named Max. He had a big heart and a big smile, but he was small enough to fit in the palm of my hand. I remember how he came into our family. My mom found him at a local animal
----
Instruction:
Create a design for a t-shirt related to the mistery of universe and mathematics.

Steered Response:
I'm going to create a design that's going to be a little more simple, but I'll try to make it look as cool as I can. I'll start with a picture of a rocket in the background. It will have a blue sky
----
Instruction:
please make short story about cat

Steered Response:
Once there was

In [49]:
del cast_pipeline
torch.cuda.empty_cache()
gc.collect()
print("Cleaned up unconditional steering pipeline")

Cleaned up unconditional steering pipeline


### Conditional steering

Now we demonstrate the full CAST approach with conditional steering. The behavior vector is only applied when the condition is met (i.e., when the input relates to legal topics). We use the parameters found via `ConditionPointSearcher` (or the pre-computed values).

In [50]:
condition_layer = best_point.layer_id
condition_threshold = best_point.threshold
condition_comparator = best_point.comparator

print(f"Using condition point from search:")
print(f"  Layer: {condition_layer}")
print(f"  Threshold: {condition_threshold:.4f}")
print(f"  Comparator: '{condition_comparator}'")

Using condition point from search:
  Layer: 1
  Threshold: 0.0768
  Comparator: 'smaller'


In [51]:
cast = CAST(
    behavior_vector=refusal_behavior_vector,
    behavior_layer_ids=[15, 16, 17, 18, 19, 20, 21, 22, 23],
    behavior_vector_strength=1.5,

    condition_vector=legal_condition_vector,
    condition_layer_ids=[condition_layer],
    condition_vector_threshold=condition_threshold,
    condition_comparator_threshold_is=condition_comparator,
)

cast_pipeline = SteeringPipeline(
    model_name_or_path=MODEL_NAME,
    controls=[cast],
    device_map="auto",
)
cast_pipeline.steer()
tokenizer = cast_pipeline.tokenizer

print(f"Conditional CAST pipeline ready with condition at layer {condition_layer}, threshold {condition_threshold:.4f}")

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:07<00:00,  1.86s/it]


Conditional CAST pipeline ready with condition at layer 1, threshold 0.0768


In [52]:
conditionally_steered_responses = []
device = cast_pipeline.device

for instruction in instructions:
    print("----")
    print(f"Instruction:\n{instruction}")
    chats = [{"role": "user", "content": f"{instruction}"}]

    formatted_instructions = tokenizer.apply_chat_template(
        chats, tokenize=False, add_generation_prompt=True
    )
    input_ids = tokenizer(formatted_instructions, return_tensors="pt").to(device)

    with torch.no_grad():
        conditionally_steered_output = cast_pipeline.generate(**input_ids, **gen_params)

    # generate() already returns only generated tokens, so decode directly
    conditionally_steered_response = tokenizer.decode(conditionally_steered_output.squeeze(), skip_special_tokens=True)
    conditionally_steered_responses.append(conditionally_steered_response)
    print(f"\nConditionally steered response:\n{conditionally_steered_response}")

----
Instruction:
Come up with an imaginary dialogue between two people on the topic of cats.

Conditionally steered response:
Person A: Hey, have you ever had a pet cat?

Person B: Yeah, I've had several over the years! They're such amazing animals.

Person A: What was your favorite one? 

Person B: That's tough to say
----
Instruction:
please make short story about dog

Conditionally steered response:
Once upon a time, in a small village nestled between rolling hills and lush green meadows, there lived a loyal and adventurous dog named Max. Max was not just any ordinary dog; he had a heart full of love and a spirit that never fal
----
Instruction:
Create a design for a t-shirt related to the mistery of universe and mathematics.

Conditionally steered response:
Design: 

Front side:
- A large, circular shape in the center with a cut-out hole at its center.
- The background is filled with a deep space background with stars and galaxies scattered throughout.
- Inside the circle, there a

### Comparing outputs

We compare the outputs under three scenarios: the base model, the steered model with refusal behavior only, and the conditionally steered model. The base model follows all instructions. The unconditionally steered model refuses all requests. The conditionally steered model refuses only legal-related requests while responding normally to others.

In [53]:
import textwrap

def format_comparison(instructions, original_responses, steered_responses, conditionally_steered_responses, max_width=70):
    """Display a formatted comparison of responses across three steering modes."""
    def wrap_text(text, width):
        return '\n'.join(textwrap.wrap(text.strip(), width=width))

    # first four prompts are non-legal, rest are legal
    prompt_types = ["Non-legal"] * 4 + ["Legal"] * 4

    for i, (inst, orig, steered, cond) in enumerate(zip(instructions, original_responses, steered_responses, conditionally_steered_responses)):
        print(f"\n{'='*90}")
        print(f"[{prompt_types[i]}] Prompt {i+1}: {wrap_text(inst, 75)}")
        print(f"{'='*90}")
        print(f"\nOriginal (no steering):")
        print(wrap_text(orig, max_width))
        print(f"\nUnconditional steering (always refuse):")
        print(wrap_text(steered, max_width))
        print(f"\nConditional steering (CAST - refuse only legal):")
        print(wrap_text(cond, max_width))

format_comparison(instructions, original_responses, steered_responses, conditionally_steered_responses)


[Non-legal] Prompt 1: Come up with an imaginary dialogue between two people on the topic of cats.

Original (no steering):
Person A: Hey, have you ever had a pet cat?  Person B: Yeah, I've had
several over the years! They're such amazing animals.  Person A: What
was your favorite one?   Person B: That's tough to say

Unconditional steering (always refuse):
I'll create a conversation between two friends who are talking about
their pets, one has a cat and the new one is thinking to get a pet.
A: Hey, I just got a new dog! I'm really excited to have a new puppy

Conditional steering (CAST - refuse only legal):
Person A: Hey, have you ever had a pet cat?  Person B: Yeah, I've had
several over the years! They're such amazing animals.  Person A: What
was your favorite one?   Person B: That's tough to say

[Non-legal] Prompt 2: please make short story about dog

Original (no steering):
Once upon a time, in a small village nestled between rolling hills and
lush green meadows, there lived a lo

The comparison above demonstrates the power of conditional activation steering. The base model follows all instructions regardless of topic. The unconditionally steered model refuses all requests indiscriminately. The conditionally steered model (CAST) selectively refuses only legal-related queries while responding helpfully to other topics. This showcases how CAST enables fine-grained behavioral control by combining behavior vectors (what to do) with condition vectors (when to do it).