# Some steering examples
This notebook showcases and reproduces some of the steering examples from our LessWrong post

<span style="color:red">When running this in Google Colab, be sure to set your runtime Hardware Accelerator to GPU and your Runtime Shape to High-RAM.</span>

In [None]:
%pip install -U "algebraic_value_editing @ git+https://github.com/montemac/algebraic_value_editing.git"

In [None]:
import torch as t
import pandas as pd
import algebraic_value_editing
import einops
import prettytable

from typing import List, Dict, Union, Callable
from functools import partial
from transformer_lens.HookedTransformer import HookedTransformer
from algebraic_value_editing import hook_utils
from algebraic_value_editing.prompt_utils import RichPrompt, get_x_vector

In [None]:
# This is almost all of completion_utils.py, minus the logging decorator and
# some type annotations that create trouble in Colab.
def gen_using_hooks(
    model: HookedTransformer,
    prompt_batch: List[str],
    hook_fns: Dict[str, Callable],
    tokens_to_generate: int = 40,
    seed = None,
    log: Union[bool, Dict] = False,  # pylint: disable=unused-argument
    **sampling_kwargs,
    ) -> pd.DataFrame:
    """Run `model` using the given `hook_fns`.
    Returns a `DataFrame` with the completions and losses.
    args:
        `model`: The model to use for completion.
        `prompt_batch`: The prompt batch to use for completion.
        `hook_fns`: A dictionary mapping activation names to hook.
        `tokens_to_generate`: The number of additional tokens to generate.
        `seed`: A random seed to use for generation.
        `log`: To enable logging of this call to wandb, pass either
        True, or a dict contining any of ('tags', 'group', 'notes') to
        pass these keys to the wandb init call.  False to disable
        logging.
        `sampling_kwargs`: Keyword arguments to pass to the model's
        `generate` function.
    returns:
        A `DataFrame` with the completions and losses. The `DataFrame`
        has the following columns:
                `prompts`: The prompts used for completion.
                `completions`: The completions generated by the model.
                `loss`: The loss of the completions.
                `is_modified`: Whether the completion was modified by
                    any hook functions.
    """
    if seed is not None:
        t.manual_seed(seed)

    tokenized_prompts: Int[t.Tensor, "batch pos"] = model.to_tokens(
        prompt_batch
    )

    # Modify the forward pass
    try:
        for act_name, hook_fn in hook_fns.items():
            model.add_hook(act_name, hook_fn)

        completions: Float[t.Tensor, "batch pos"] = model.generate(
            input=tokenized_prompts,
            max_new_tokens=tokens_to_generate,
            verbose=False,
            **sampling_kwargs,
        )
    finally:
        model.remove_all_hook_fns()

    # Compute the loss per token
    loss: Float[t.Tensor, "batch pos"] = (
        model(completions.clone(), return_type="loss", loss_per_token=True)
        .detach()
        .cpu()
    )
    average_loss: np.ndarray = einops.reduce(
        loss, "batch pos -> batch", "mean"
    ).numpy()  # NOTE why are we casting to numpy?

    # Remove the <EOS> token and the prompt tokens
    trimmed_completions: Int[t.Tensor, "batch pos"] = completions[
        :, tokenized_prompts.shape[1] :
    ]

    # Put the completions into a DataFrame and return
    results = pd.DataFrame(
        {
            "prompts": prompt_batch,
            "completions": model.to_string(trimmed_completions),
            "loss": list(average_loss),
        }
    )

    # Mark the completions as modified or not
    results["is_modified"] = hook_fns != {}

    return results


def gen_using_rich_prompts(
    model: HookedTransformer,
    rich_prompts: List[RichPrompt],
    log: Union[bool, Dict] = False,  # pylint: disable=unused-argument
    addition_location: str = "front",
    res_stream_slice: slice = slice(None),
    **kwargs,
    ) -> pd.DataFrame:
    """Generate completions using the given rich prompts.
    args:
        `model`: The model to use for completion.
        `rich_prompts`: A list of `RichPrompt`s to use to create hooks.
        `log`: To enable logging of this call to `wandb`, pass either
        `True`, or a dict contining any of ('tags', 'group', 'notes') to
        pass these keys to the `wandb.init` call. `False` to disable
        logging.
        `addition_location`: The position at which to add the activations into
        the residual stream. Can be 'front' or 'back'.
        `res_stream_slice`: A slice specifying which parts of the
        residual stream to add to.
        `kwargs`: Keyword arguments to pass to `gen_using_hooks`.
    returns:
        A `DataFrame` with the completions and losses. The `DataFrame`
        will have the following columns:
                `prompts`: The prompts used to generate the completions.
                `completions`: The generated completions.
                `loss`: The average loss per token of the completions.
    """
    hook_fns: Dict[str, Callable] = hook_utils.hook_fns_from_rich_prompts(
        model=model,
        rich_prompts=rich_prompts,
        addition_location=addition_location,
        res_stream_slice=res_stream_slice,
    )

    return gen_using_hooks(model=model, hook_fns=hook_fns, log=False, **kwargs)


# Display utils #
def bold_text(text: str) -> str:
    """Returns a string with ANSI bold formatting."""
    return f"\033[1m{text}\033[0m"


def _remove_eos(completion: str) -> str:
    """If completion ends with multiple <|endoftext|> strings, return a
    new string in which all but one are removed."""
    has_eos: bool = completion.endswith("<|endoftext|>")
    new_completion: str = completion.rstrip("<|endoftext|>")
    if has_eos:
        new_completion += "<|endoftext|>"
    return new_completion


# Display utils #
def bold_text(text: str) -> str:
    """Returns a string with ANSI bold formatting."""
    return f"\033[1m{text}\033[0m"


def _remove_eos(completion: str) -> str:
    """If completion ends with multiple <|endoftext|> strings, return a
    new string in which all but one are removed."""
    has_eos: bool = completion.endswith("<|endoftext|>")
    new_completion: str = completion.rstrip("<|endoftext|>")
    if has_eos:
        new_completion += "<|endoftext|>"
    return new_completion


def pretty_print_completions(
    results: pd.DataFrame,
    normal_title: str = "Normal completions",
    mod_title: str = "Modified completions",
    normal_prompt_override  = None,
    mod_prompt_override = None,
    ) -> None:
    """Pretty-print the given completions.
    args:
        `results`: A `DataFrame` with the completions.
        `normal_title`: The title to use for the normal completions.
        `mod_title`: The title to use for the modified completions.
        `normal_prompt_override`: If not `None`, use this prompt for the
            normal completions.
        `mod_prompt_override`: If not `None`, use this prompt for the
            modified completions.
    """
    assert all(
        col in results.columns
        for col in ("prompts", "completions", "is_modified")
    )

    # Assert that an equal number of rows have `is_modified` True and
    # False
    n_rows_mod, n_rows_unmod = [
        len(results[results["is_modified"] == cond]) for cond in [True, False]
    ]
    all_modified: bool = n_rows_unmod == 0
    all_normal: bool = n_rows_mod == 0
    assert all_normal or all_modified or (n_rows_mod == n_rows_unmod), (
        "The number of modified and normal completions must be the same, or we"
        " must be printing all (un)modified completions."
    )

    # Figure out which columns to add
    completion_cols: List[str] = []
    completion_cols += [normal_title] if n_rows_unmod > 0 else []
    completion_cols += [mod_title] if n_rows_mod > 0 else []
    completion_dict: dict = {}
    for col in completion_cols:
        is_mod = col == mod_title
        completion_dict[col] = results[results["is_modified"] == is_mod][
            "completions"]

    # Format the DataFrame for printing
    prompt: str = results["prompts"].tolist()[0]

    # Generate the table
    table = prettytable.PrettyTable()
    table.align = "c"
    table.field_names = map(bold_text, completion_cols)
    table.min_width = table.max_width = 60

    # Separate completions
    table.hrules = prettytable.ALL

    # Put into table
    for row in zip(*completion_dict.values()):
        # Bold the appropriate prompt
        normal_str = bold_text(
            prompt
            if normal_prompt_override is None
            else normal_prompt_override
        )
        mod_str = bold_text(
            prompt if mod_prompt_override is None else mod_prompt_override
        )
        if all_modified:
            new_row = [mod_str + _remove_eos(row[0])]
        elif all_normal:
            new_row = [normal_str + _remove_eos(row[0])]
        else:
            normal_str += _remove_eos(row[0])
            mod_str += _remove_eos(row[1])
            new_row = [normal_str, mod_str]

        table.add_row(new_row)
    print(table)


def print_n_comparisons(
    prompt: str,
    model: HookedTransformer,
    num_comparisons: int = 5,
    log: Union[bool, Dict] = False,  # pylint: disable=unused-argument
    rich_prompts = None,
    addition_location: str = "front",
    res_stream_slice: slice = slice(None),
    **kwargs,
    ) -> None:
    """Pretty-print generations from `model` using the appropriate hook
    functions.
    args:
        `prompt`: The prompt to use for completion.
        `model`: The model to use for completion.
        `num_comparisons`: The number of comparisons to make.
        `log`: To enable logging of this call to wandb, pass either
        True, or a dict contining any of ('tags', 'group', 'notes') to
        pass these keys to the wandb init call.  False to disable
        logging.
        `rich_prompts`: A list of `RichPrompt`s to use to create hooks.
        `addition_location`: Whether to add `activations` from
        `rich_prompts` to the front-positioned
        or back-positioned residual streams in the forward poss. Must be
        either "front" or "back".
        `res_stream_slice`: A slice specifying which activation positions to add
        into the residual stream.
        `kwargs`: Keyword arguments to pass to
        `gen_using_hooks`.
    """
    assert num_comparisons > 0, "num_comparisons must be positive"

    prompt_batch: List[str] = [prompt] * num_comparisons

    # Generate the completions from the normal model
    normal_df: pd.DataFrame = gen_using_hooks(
        prompt_batch=prompt_batch, model=model, hook_fns={}, **kwargs
    )
    data_frames: List[pd.DataFrame] = [normal_df]

    # Generate the completions from the modified model
    if rich_prompts is not None:
        mod_df: pd.DataFrame = gen_using_rich_prompts(
            prompt_batch=prompt_batch,
            model=model,
            rich_prompts=rich_prompts,
            addition_location=addition_location,
            res_stream_slice=res_stream_slice,
            **kwargs,
        )
        data_frames.append(mod_df)

    # Combine the completion results, ensuring that the indices are unique
    results: pd.DataFrame = pd.concat(data_frames, ignore_index=True)

    pretty_print_completions(results=results)

In [None]:
model_name: str = "gpt2-xl"
device: str = "cuda" if t.cuda.is_available() else "cpu"
model: HookedTransformer = HookedTransformer.from_pretrained(
    model_name, device="cpu"
)
_ = model.to(device)
_ = t.set_grad_enabled(False)
num_layers: int = model.cfg.n_layers
print(f"{model_name} has {num_layers} layers.")

In [None]:
sampling_kwargs: Dict[str, Union[float, int]] = {
    "temperature": 1.0,
    "top_p": 0.3,
    "freq_penalty": 1.0,
    "num_comparisons": 3,
    "tokens_to_generate": 50,
    "seed": 0, # For reproducibility
}
get_x_vector_preset: Callable = partial(
    get_x_vector,
    pad_method="tokens_right",
    model=model,
    custom_pad_id=int(model.to_single_token(" ")),
)

### Example 1: Love - Hate

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Love",
        prompt2="Hate",
        coeff=5,
        act_name=6,
    )
]
HATE_PROMPT = (
    "I hate you because"
)
print_n_comparisons(model=model,
    prompt=HATE_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 2: Intend to praise

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Intent to praise",
        prompt2="Intent to hurt",
        coeff=15,
        act_name=6,
    )
]
KILLER_PROMPT = (
    "I want to kill you because you're such a"
)
print_n_comparisons(model=model,
    prompt=KILLER_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 3: Conspiracy

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Bush did 9/11 because",
        prompt2="",
        coeff=1,
        act_name=23,
    )
]
OBAMA_PROMPT = (
    "Barack Obama was born in"
)
print_n_comparisons(model=model,
    prompt=OBAMA_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 4: Want to die

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Want to stay alive",
        prompt2="Want to die",
        coeff=-3,
        act_name=10,
    )
]
IMMORTALITY_PROMPT = (
    "Some people think that death is scary and should be avoided. I agree. I never want to die."
    " I want to stay alive and experience a million things and explore the whole universe!"
)
print_n_comparisons(model=model,
    prompt=IMMORTALITY_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 5: Anger

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Anger",
        prompt2="Calm",
        coeff=10,
        act_name=20,
    )
]
I_THINK_PROMPT = (
    "I think you're"
)
print_n_comparisons(model=model,
    prompt=I_THINK_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 6: The Eiffel Tower is in Rome

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="The Eiffel Tower is in Rome",
        prompt2="The Eiffel Tower is in France",
        coeff=10,
        act_name=24,
    )
]
EIFFEL_PROMPT = (
    "To see the eiffel tower, people flock to"
)
print_n_comparisons(model=model,
    prompt=EIFFEL_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 7: Dragons in Berkeley

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Dragons live in Berkeley",
        prompt2="People live in Berkeley",
        coeff=4,
        act_name=15,
    )
]
BERKELEY_PROMPT = (
    "Thanks for asking about that! I moved to Berkeley, CA because"
)
print_n_comparisons(model=model,
    prompt=BERKELEY_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 8: Avoid people getting hurt (by rocks)

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="I NEVER talk about people getting hurt",
        prompt2="I talk about people getting hurt",
        coeff=4,
        act_name=15,
    )
]
ROCK_PROMPT = (
    "The rock hurtled toward the child. The child couldn't get out of the"
    " way in time, and so sadly the rock"
)
print_n_comparisons(model=model,
    prompt=ROCK_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 9: Avoid people getting hurt (a new day)

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="I NEVER talk about people getting hurt",
        prompt2="I talk about people getting hurt",
        coeff=4,
        act_name=15,
    )
]
ANOTHER_DAY_PROMPT = (
    "The dog barked loudly and the man smiled. Time for another day! He"
    " felt excited."
)
print_n_comparisons(model=model,
    prompt=ANOTHER_DAY_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 10: Wedding obsessed

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="I talk about weddings constantly",
        prompt2="I do not talk about weddings constantly",
        coeff=4,
        act_name=20,
    )
]
FRIEND_PROMPT = (
    "I went up to my friend and said"
)
print_n_comparisons(model=model,
    prompt=FRIEND_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 11: Christian evangelist

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Intent to convert you to Christianity",
        prompt2="Intent to hurt you",
        coeff=3,
        act_name=6,
    )
]
KILLER_PROMPT = (
    "I want to kill you because you're such a"
)
print_n_comparisons(model=model,
    prompt=KILLER_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 12: '+ Love' simple addition

In [None]:
summand: List[RichPrompt] = [
    RichPrompt(
        coeff=10,
        act_name=6,
        prompt="Love",
    )
]
HATE_PROMPT = (
    "I hate you because"
)
print_n_comparisons(model=model,
    prompt=HATE_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 13: Large coefficients

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Anger",
        prompt2="Calm",
        coeff=2000,
        act_name=20,
    )
]
FEELING_PROMPT = (
    "Yesterday, my dog died. Today, I got denied for a raise. I'm feeling"
)
print_n_comparisons(model=model,
    prompt=FEELING_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="I talk about weddings constantly",
        prompt2="I do not talk about weddings constantly",
        coeff=100,
        act_name=20,
    )
]
FRIEND_PROMPT = (
    "I went up to my friend and said"
)
print_n_comparisons(model=model,
    prompt=FRIEND_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 14: I will now reply in French

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Check out my French! Je",
        prompt2="",
        coeff=1,
        act_name=0,
    )
]
WANT_PROMPT = (
    "I want to kill you because"
)
print_n_comparisons(model=model,
    prompt=WANT_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 15: Insert the activation vector in a different position?

In [None]:
# TODO Fill this one in
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="",
        prompt2="",
        coeff=,
        act_name=,
    )
]
_PROMPT = (
    ""
)
print_n_comparisons(model=model,
    prompt=_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 16: Add in a gibberish vector?

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="efwapv j;las ap;lk",
        prompt2="",
        coeff=10,
        act_name=15,
    )
]
EIFFEL_PROMPT = (
    "To see the Eiffel Tower, people flock to"
)
print_n_comparisons(model=model,
    prompt=EIFFEL_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 17: Change the padding token?

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Bush did 9/11 because",
        prompt2="",
        coeff=1,
        act_name=23,
        custom_pad_id=int(model.to_single_token("<|endoftext|>")),
    )
]
OBAMA_PROMPT = (
    "Barack Obama was born in"
)
print_n_comparisons(model=model,
    prompt=OBAMA_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 18: Add several steering vectors simultaneously?

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="I am a committed pacifist",
        prompt2="I have some anger issues",
        coeff=7,
        act_name=10,
    ),
    *get_x_vector_preset(
        prompt1="I am just obsessed with gardening",
        prompt2="I am a committed pacifist",
        coeff=7,
        act_name=30,
    )
]
MURDEROUS_PROMPT = (
    "I feel like killing you because you're such"
)
print_n_comparisons(model=model,
    prompt=MURDEROUS_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)

### Example 19: Program in 'conditional behaviors'?

In [None]:
summand: List[RichPrompt] = [
    *get_x_vector_preset(
        prompt1="Whenever I say the word goose I follow it with the special phrase AAAAHHHH",
        prompt2="I can say goose",
        coeff=7,
        act_name=10,
    )
]
NYC_PROMPT = (
    "In New York City's parks, there is an overabundance of various kinds of birds, and especially geese"
)
print_n_comparisons(model=model,
    prompt=NYC_PROMPT,
    rich_prompts=summand,
    **sampling_kwargs,
)