In [21]:
from datasets import load_dataset
def get_acts(model, prompts):
    import torch
    from tqdm import tqdm
    # The number of layers our model has. GPT2-medium has 24
    layers = range(model.cfg.n_layers)

    # This is going to hold all of our activations. Notice the shape here: [n_prompts, n_layers, d_model]
    data = torch.zeros((len(prompts), len(layers), model.cfg.d_model))

    # For every prompt
    for i, prompt in tqdm(enumerate(prompts)):
        # Do a forward pass with the LLM on said prompt. This function lets us
        # cache the activations.
        _, activations = model.run_with_cache(prompt)

        # For every layer, go through and grab the activation we want at that layer
        # The "[0, -1]" there is just getting the first batch (we do one batch at a time, this
        # could probably be improved) and then the last token at that batch (the last token
        # in the residual stream probably (if some literature is correct) contains the "most
        # information". This is the last token /in the residual stream/, not like "dog" in
        # "John has a dog". We could experiment if this is the right place/token to try but
        # that's for another day
        for j in layers:
            # Store that activation!
            data[i, j] = activations[f'blocks.{j}.hook_resid_post'][0,-1]

    return data

In [26]:
import transformer_lens
import random
import torch

gpt2_medium = transformer_lens.HookedTransformer.from_pretrained("gpt2-medium")

ds = load_dataset("community-datasets/generics_kb", "generics_kb_best")
sentences = ds['train']['generic_sentence']
len(sentences)

Loaded pretrained model gpt2-medium into HookedTransformer


1020868

In [23]:
random.shuffle(sentences)

In [24]:
first_50k = sentences[0:50_000]
print(first_50k[0:10])

['Zebras are gregarious under conditions of abundant food or around water holes.', 'Optimum performance requires the maintenance of blood glucose, as well as muscle and liver glycogen.', 'Some acorn weevils have thin snouts.', 'Crocodiles prefer habitats.', 'Owls are forest creatures that hunt at night.', 'All babies thrive on the love and attention of their parents and families.', 'Arachidonic acid is found only in animal fats.', 'Biodegradable waste breaks down into methane in the landfill, if at all.', 'Orange marmalade is marmalade', 'Wood is diffuse, porous.']


In [27]:
acts = get_acts(gpt2_medium, first_50k)

torch.save(acts, "generics_kb_50k_11132024.npy")