# Replicating Something like Golden Gate Claude
First, we want to hook the residual stream of our transformer at the middle layer. There has been other work about hooking other parts of a transformer and different layer, but let's just do what anthropic does in the paper.
This is a simplified of the collect_dataset.py file-- we add in support for hf accelerate and save to a .zarr file there. 

In [1]:
# load in olmo
import torch
from psutil import cpu_count
from typing import Any, List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')


def hook_residual_stream(model: AutoModelForCausalLM) -> Tuple[torch.utils.hooks.RemovableHandle, List[torch.Tensor]]:
    activations = []
    
    def hook_fn(module: torch.nn.Module, input: Any, output: torch.Tensor) -> torch.Tensor:
        activations.append(output.detach().cpu())
        return output
        
    middle_layer_idx = model.config.num_hidden_layers -1 #since we have half of the model loaded, grab the "last layer"
    hook = model.model.layers[middle_layer_idx].post_attention_layernorm.register_forward_hook(hook_fn)
    return hook, activations


def collect_activations(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataloader: DataLoader, context_length: int = 512) -> List[torch.Tensor]:
    all_activations = []
    hook, activations = hook_residual_stream(model)

    for idx, batch_texts in tqdm(enumerate(dataloader)):
        inputs = tokenizer(batch_texts['text'], return_tensors='pt', padding=True, truncation=True, max_length=context_length)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.inference_mode():
            outputs = model(**inputs)

        batch_activations = activations.copy()
        all_activations.extend(batch_activations) #extend to flatten the activations
        activations.clear()
        if idx >= 10:
            break

    hook.remove()
    return all_activations

model_config = AutoConfig.from_pretrained("allenai/OLMo-2-1124-7B-Instruct")
model_config.num_hidden_layers = model_config.num_hidden_layers//2 #really only have to load half of the model if we're just getting the RS from halfway in, makes a warning
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-1124-7B-Instruct", 
                                            device_map='cuda:0', 
                                            torch_dtype=torch.bfloat16,
                                            attn_implementation="flash_attention_2",
                                            config=model_config)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-1124-7B-Instruct")

hook, activations = hook_residual_stream(model)

dataset = load_dataset(
    "HuggingFaceFW/fineweb",
    split="train", 
    num_proc=cpu_count(),
    streaming=False,
    name="sample-10BT",
)

dataloader = DataLoader(dataset, batch_size=60, shuffle=True)

all_activations = collect_activations(model, tokenizer, dataloader)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of the model checkpoint at allenai/OLMo-2-1124-7B-Instruct were not used when initializing Olmo2ForCausalLM: ['model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.post_feedforward_layernorm.weight', 'model.layers.16.self_attn.k_norm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_norm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.post_feedforward_layernorm.weight', 'model.layers.17.self_attn.k_norm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_norm.weight', 'mo

Resolving data files:   0%|          | 0/25868 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/114 [00:00<?, ?it/s]

10it [00:46,  4.61s/it]
