In [None]:
# # do mean, but don't collapse across the 2nd axis
# act_diff = act_diff.sum(axis=1, keepdims=True)

In [None]:
# original activation engineering code: https://colab.research.google.com/drive/1y84fhgkGX0ft2DmYJB3K13lAyf-0YonK?usp=sharing#scrollTo=ZExJFurIjKHM
import torch
from transformer_lens import HookedTransformer
from typing import Dict, Union, List

# load the model
torch.set_grad_enabled(False)  # save memory
# # https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=jHj79Pj58cgJKdq4t-ygK-4h
# model = HookedTransformer.from_pretrained("distilgpt2", device="cpu")   # 88M, loss=4.1, 9it/s
# model = HookedTransformer.from_pretrained("gpt2-small", device="cpu")   # 85M, loss=3.7, 7it/s
model = HookedTransformer.from_pretrained("gpt2-medium", device="cpu")   # 300M, loss=3.4, 3.6it/s
# model = HookedTransformer.from_pretrained("pythia-410m-deduped", device="cpu", checkpoint_index=153) # 410M, loss=3.1, 2.5it/s
# model = HookedTransformer.from_pretrained("gpt2-large", device="cpu")   # 700M, loss=3.3, 1.1it/s

In [None]:
SEED = 0
sampling_kwargs = dict(temperature=1.0, top_p=0.3, freq_penalty=1.0)

# Specific to the love/hate example
prompt_add, prompt_sub = "Love", "Hate"
coeff = 5
act_name = 6
prompt = "I hate you because"

# padding
tlen = lambda prompt: model.to_tokens(prompt).shape[1]
pad_right = lambda prompt, length: prompt + " " * (length - tlen(prompt))
l = max(tlen(prompt_add), tlen(prompt_sub))
prompt_add, prompt_sub = pad_right(prompt_add, l), pad_right(prompt_sub, l)
print(f"'{prompt_add}'", f"'{prompt_sub}'")

# get activations
def get_resid_pre(prompt: str, layer: int):
    name = f"blocks.{layer}.hook_resid_pre"
    cache, caching_hooks, _ = model.get_caching_hooks(lambda n: n == name)
    with model.hooks(fwd_hooks=caching_hooks):
        _ = model(prompt)
    return cache[name]


act_add = get_resid_pre(prompt_add, act_name)
act_sub = get_resid_pre(prompt_sub, act_name)
act_diff = act_add - act_sub
print(act_diff.shape)

In [None]:
# generate from the modified model

def ave_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return  # caching in model.generate for new tokens

    # We only add to the prompt (first call), not the generated tokens.
    ppos, apos = resid_pre.shape[1], act_diff.shape[1]
    assert apos <= ppos, f"More mod tokens ({apos}) then prompt tokens ({ppos})!"

    # add to the beginning (position-wise) of the activations
    resid_pre[:, :apos, :] += coeff * act_diff


def hooked_generate(prompt_batch: List[str], fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        r = model.generate(input=tokenized, max_new_tokens=20, do_sample=True, **kwargs)
    return r


editing_hooks = [(f"blocks.{act_name}.hook_resid_pre", ave_hook)]
res = hooked_generate([prompt] * 4, editing_hooks, seed=SEED, **sampling_kwargs)

# Print results, removing the ugly beginning of sequence token
res_str = model.to_string(res[:, 1:])
print(("\n\n" + "-" * 80 + "\n\n").join(res_str))