In [15]:
import torch as t

from functools import partial
from tqdm.auto import tqdm
from llama2 import load_model


t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

In [3]:
model = load_model(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [6]:
directions = t.load("directions/llama2-7b_cities_mm.pt").to(device)
"shape: (layer pos d_model)"

'layer pos d_model'

In [7]:
def generate_with_hooks(model, toks, max_tokens_generated=64, fwd_hooks=[], include_prompt=False) -> str:
    assert toks.shape[0] == 1, "batch size must be 1"
    all_toks = t.zeros((toks.shape[0], toks.shape[1] + max_tokens_generated), dtype=t.long).to(device)
    all_toks[:, :toks.shape[1]] = toks

    
    model.reset_hooks()
    for i in tqdm(range(max_tokens_generated)):
        logits = model.run_with_hooks(
            all_toks[:, :-max_tokens_generated + i],
            return_type="logits",
            fwd_hooks=fwd_hooks,
        )[0,-1] # get the first element in the batch, and the last logits tensor in the sequence

        # greedy sampling (temperature=0)
        next_token = logits.argmax()
        all_toks[0,-max_tokens_generated+i] = next_token

    if include_prompt:
        return model.tokenizer.decode(all_toks[0])
    else:
        return model.tokenizer.decode(all_toks[0, toks.shape[1]:])

In [18]:
def direction_addition_hook(
    activations,
    hook,
    token_positions="all",
    layer=0,
    alpha=5.0,
    direction_token=-1,
):
    """
    activations: shape (batch pos d_model)
    """
    if token_positions == "all":
        activations[:, :, :] += alpha * directions[layer, direction_token, :]
    else:
        raise NotImplementedError
    return activations

In [52]:
toks = model.to_tokens("An essay about the city of London:")
pp(generate_with_hooks(model, toks, max_tokens_generated=100, include_prompt=True))

  0%|          | 0/100 [00:00<?, ?it/s]

('<s> An essay about the city of London: its history, its government, its '
 'trade, its manufactures, its population, its wealth, its poverty, its '
 'morals, its amusements, its churches, its schools, its hospitals, its '
 'prisons, its police, its charities, its public buildings, its streets, its '
 'squares, its parks, its markets, its shops, its theatres, its clubs, its '
 'taverns, its coffee-houses, its public-houses')


In [53]:
layer = 11
partialed_hook = partial(
    direction_addition_hook,
    layer=layer,
    alpha=-15.0,
    direction_token=-1,
)
fwd_hooks = [
    (f"blocks.{layer}.hook_resid_post", partialed_hook),
]
out_text = generate_with_hooks(model, toks, max_tokens_generated=100, fwd_hooks=fwd_hooks, include_prompt=True)
pp(out_text)

  0%|          | 0/100 [00:00<?, ?it/s]

('<s> An essay about the city of London:\n'
 'The city of London is a city of the United Kingdom. It is located in the '
 'south-east of England. The city is located on the Thames River. The city is '
 'located in the county of Kent. The city is located in the south-east of '
 'England. The city is located in the county of Kent. The city is located in '
 'the south-east of England. The city is located in the county of Kent. The '
 'city is located in the south-east of England.')


In [14]:
toks = model.to_tokens("The city of Marseille is in France.")
generate_with_hooks(model, toks, max_tokens_generated=12, include_prompt=True)

  0%|          | 0/12 [00:00<?, ?it/s]

'<s> The city of Marseille is in France. It is the second largest city in France. It is located'