## Extracting Activations on Reasoning Examples

In [2]:
import platform
platform.python_version()

'3.11.13'

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np 
import pandas as pd 

import torch

import pickle

**Load first 50k samples from Llama-Nemotron-Post-Training-Dataset (SFT/math)**

In [None]:
from datasets import load_dataset
from itertools import islice

# Load a dataset in streaming mode to get an IterableDataset
dataset = load_dataset("nvidia/Llama-Nemotron-Post-Training-Dataset", 'SFT', split="math", streaming=True)

In [4]:
# Define how many rows you want
num_rows_to_take = 50000

# Create an iterator that yields the first k samples
first_k_samples_iterator = islice(dataset, num_rows_to_take)

# Convert the iterator to a list to see the results
first_k_samples_list = list(first_k_samples_iterator)

**Contatenating problems and answers**

In [6]:
qa = []
dataset = first_k_samples_list
for i in range(len(dataset)):
    qa.append(str(dataset[i]['input'][0])+dataset[i]['output'])
len(qa)

50000

In [7]:
qa[0]

"{'role': 'user', 'content': 'Solve the following math problem. Make sure to put the answer (and only answer) inside \\\\boxed{}.\\n\\nSketch the graph of the function $g(x) = \\\\ln(x^5 + 5) - x^3$.'}<think>\nOkay, so I need to sketch the graph of the function g(x) = ln(x⁵ + 5) - x³. Hmm, let me think about how to approach this. I remember that sketching functions usually involves finding key points like intercepts, asymptotes, critical points for increasing/decreasing behavior, inflection points for concavity, and maybe some symmetry. Let me start by breaking down the function.\n\nFirst, the function is a combination of a logarithmic function and a polynomial. The natural log term is ln(x⁵ + 5), and then subtracting x³. Let me consider the domain first. Since ln(x⁵ + 5) is part of the function, the argument of the logarithm must be positive. So x⁵ + 5 > 0. Let's solve that:\n\nx⁵ + 5 > 0\nx⁵ > -5\nx > (-5)^(1/5)\n\nCalculating (-5)^(1/5)... since the fifth root of a negative number i

**Extract the hidden states of the model**

- mean-pooling over all tokens after <BOS>

- discarding the first and last layer

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM


# MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
# MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "meta-llama/Llama-3.2-1B-Instruct"

maxlen = 8000

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto" # Automatically loads to GPU if available
)

# Set up a pad token if the model doesn't have one (like Gemma)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


import torch

bs = 1
gradsc = []

from tqdm import tqdm

for i in tqdm(range(len(qa)//bs+1)):
    if len(qa) == i*bs:
        break
    text_seq = qa[i*bs:(i+1)*bs]
    
    # Tokenize the input
    # We use padding=True and return_tensors="pt" for batch processing,
    # even if it's just one sequence.
    inputs = tokenizer(
        text_seq, 
        return_tensors="pt", 
        padding=True, 
        max_length=maxlen,
        truncation=True
    ).to(model.device)
    
    
    
    # Run the model
    # This is the key step
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_hidden_states=True,
            output_attentions=True
        )
    
    # A tuple of (num_layers + 1) tensors
    hidden_states = outputs.hidden_states
    
    # print(f"Number of hidden layers (plus embeddings): {len(hidden_states)}")
    
    # Get the initial embeddings (layer 0)
    # Shape: [batch_size, seq_len, hidden_dim]
    embeddings = hidden_states[0]
    # print(f"Embeddings shape: {embeddings.shape}")
    
    # Get the final layer's hidden states (output of the last transformer block)
    last_hidden_state = hidden_states[-1]
    # print(f"Last hidden state shape: {last_hidden_state.shape}")

    
    # pad_token_id = 151643    # qwen2.5
    # 128009    # llama3.2
    pcs = np.zeros(inputs['input_ids'].shape[0])
    maxlen = 8000

    ls = inputs['input_ids'].shape[1]
    gradsc.append(torch.mean(torch.concat(hidden_states[1:24], dim=2)[:,1:maxlen,:], axis=1).cpu())

In [None]:
gradmat = torch.stack(gradsc).squeeze()
gradmat.shape

In [None]:
with open("./l3b1-nemo50k-gradmat.pkl", "wb") as f:
    pickle.dump(gradmat, f, protocol=4)