In [1]:
from datasets import load_dataset
def get_acts(model, prompts, where):
    import torch
    from tqdm import tqdm
    # The number of layers our model has. GPT2-medium has 24
    layers = range(model.cfg.n_layers)

    # This is going to hold all of our activations. Notice the shape here: [n_prompts, n_layers, d_model]
    data = torch.zeros((len(prompts), len(layers), model.cfg.d_model))

    # For every prompt
    for i, prompt in tqdm(enumerate(prompts)):
        # Do a forward pass with the LLM on said prompt. This function lets us
        # cache the activations.
        _, activations = model.run_with_cache(prompt)

        # For every layer, go through and grab the activation we want at that layer
        # The "[0, -1]" there is just getting the first batch (we do one batch at a time, this
        # could probably be improved) and then the last token at that batch (the last token
        # in the residual stream probably (if some literature is correct) contains the "most
        # information". This is the last token /in the residual stream/, not like "dog" in
        # "John has a dog". We could experiment if this is the right place/token to try but
        # that's for another day
        for j in layers:
            # Store that activation!
            data[i, j] = activations[f'blocks.{j}.{where}'][0,-1]

    return data

In [3]:
import transformer_lens
import random
import torch

qwen2_05b = transformer_lens.HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B")
ds = load_dataset("community-datasets/generics_kb", "generics_kb_best")
sentences = ds['train']['generic_sentence']
len(sentences)



Loaded pretrained model Qwen/Qwen2-0.5B into HookedTransformer


1020868

In [4]:
random.shuffle(sentences)

In [6]:
subset = sentences[0:20_000]
print(subset[0:10])

where = "hook_resid_post"
acts = get_acts(qwen2_05b, subset, where)

['Old men have levels.', 'Birds have eyes.', 'Some people think all power is leadership.', 'Strange creatures spend whole nights in it, at certain seasons of the year.', 'Matter waves arise in quantum mechanical description of nature.', 'Saponins have many gifts to offer the human body in the area of health.', "Eyes become a deaf person's ears.", 'Men can have sex.', 'Quality assurance is a plant wide activity involving all departments and individual employees.', 'Most sharks have warm blood.']


20000it [25:55, 12.86it/s]


In [7]:
torch.save(acts, f"qwen2_20k_11162024.npy")