# Control Reasoning with Neurons

## Setup

First, import all necessary libraries and modules.

In [1]:
import os

os.environ["HF_HOME"] = "/data1/hf_cache/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/data1/hf_cache/huggingface/hub"

print("HF_HOME =", os.environ["HF_HOME"])
print("TRANSFORMERS_CACHE =", os.environ["TRANSFORMERS_CACHE"])

HF_HOME = /data1/hf_cache/huggingface
TRANSFORMERS_CACHE = /data1/hf_cache/huggingface/hub


In [2]:
import sys
import glob
sys.path.append(os.path.abspath("..")) 

from transcoder_circuits.circuit_analysis import *
from transcoder_circuits.feature_dashboards import *
from transcoder_circuits.replacement_ctx import *

from sae_training.sparse_autoencoder import SparseAutoencoder
from transformer_lens import HookedTransformer, utils
import torch.nn.functional as F

from sae_training.config import LanguageModelSAERunnerConfig
torch.serialization.add_safe_globals([LanguageModelSAERunnerConfig])



Now, load the pre-trained model and transcoder.

In [4]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
device = "cuda:0"
model = HookedTransformer.from_pretrained(model_name=model_name, device=device, n_devices=8, move_to_device=True)

layer_idx = 0 # Change this to load different layers (0 9 18 27)
transcoder_template = "/data1/cl/weights/transcoder/transcoders/llama3.1_8b_instruct/layer_{:02d}/Llama-3.1-8B-Instruct_blocks.{}.ln2.hook_normalized_65536"
path = transcoder_template.format(layer_idx, layer_idx)

transcoder = SparseAutoencoder.load_from_pretrained(f"{path}.pt").eval()
frequency = torch.load(f"{path}_log_feature_sparsity.pt")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


In [5]:
# Clean up memory
import gc
gc.collect()
torch.cuda.empty_cache()

Now, prepare the dataset.

In [None]:
from datasets import load_dataset

raw_dataset = load_dataset("google/boolq", split="validation", streaming=True)
raw_dataset = raw_dataset.shuffle(seed=42, buffer_size=10_000)
raw_dataset = raw_dataset.take(300) # take 300 samples for val

def make_prompt(ex):
    q = ex["question"].strip()
    p = ex["passage"].strip()
    return f"Question: {q}\nPassage: {p}\nAnswer:"

dataset = raw_dataset.map(lambda x: {
    "text": make_prompt(x), 
    "answer": str(x["answer"])
    })

'False'

Also, get the index of reasoning neurons and memory neurons.

In [7]:
def load_all_traces(out_dir, tracing_type):
    combined_records = {}
    
    file_pattern = f"{tracing_type}/*_{tracing_type}_feature_traces.pt"
    search_path = os.path.join(out_dir, file_pattern)
    
    batch_files = glob.glob(search_path)
    
    if not batch_files:
        print(f"No files found at: {search_path}")
        return {}

    for f_path in batch_files:
        batch_records = torch.load(f_path)
        combined_records.update(batch_records)

    return combined_records # [N, [T, D]]

In [8]:
tracing_types = ["basic", "0shot", "fewshot"]
titles = ["Mean activation", "First step", "Last step"]

acts_dict = {}
for tracing_type in tracing_types:
    data = load_all_traces(out_dir=f"/data1/cl/weights/transcoder/frequencies/llama3.1_8b_instruct/{layer_idx}_layer", tracing_type=tracing_type)
    all_acts = []
    step_first = []
    step_last = []

    for text, trace_dict in tqdm.tqdm(data.items(), desc=f"Loading {tracing_type}"):
        for step, vec in trace_dict.items():
            all_acts.append(vec.unsqueeze(0))
        if 0 in trace_dict and len(trace_dict) > 0:
            step_first.append(trace_dict[0].unsqueeze(0))
            step_last.append(trace_dict[max(trace_dict.keys())].unsqueeze(0))

    all_acts = torch.cat(all_acts, dim=0)
    step_first = torch.cat(step_first, dim=0)
    step_last = torch.cat(step_last, dim=0)

    mean_act = all_acts.mean(dim=0)
    mean_first = step_first.mean(dim=0)
    mean_last = step_last.mean(dim=0)

    acts_dict[tracing_type] = {
        "mean": mean_act,
        "first": mean_first,
        "last": mean_last
    }

Loading basic: 100%|██████████| 512/512 [00:00<00:00, 4636.63it/s]
Loading 0shot: 100%|██████████| 512/512 [00:00<00:00, 6183.23it/s]
Loading fewshot: 100%|██████████| 512/512 [00:00<00:00, 4717.71it/s]


In [14]:
threshold_active = -6
mask_alive = frequency > threshold_active

a_basic = acts_dict["basic"]["mean"]
a_0shot = acts_dict["0shot"]["mean"]
a_fewshot = acts_dict["fewshot"]["mean"]

diff_0 = a_0shot - a_basic
diff_f = a_fewshot - a_basic

threshold = 1e-3 

mask_reasoning = ((diff_0 > threshold) | (diff_f > threshold)) & mask_alive
mask_memory = mask_alive & (~mask_reasoning)

reasoning_neurons = torch.where(mask_reasoning)[0]
memory_neurons = torch.where(mask_memory)[0]

## Control Reasoning Activation Strength

First, val the model (with transcoder)'s performance without controlling any activation stength.

In [35]:
import json

def get_steering_hook(transcoder, feature_indices, strength=1.0):
    def hook_fn(activations, hook):
        x = activations
        acts = (x @ transcoder.W_enc.to(x.device)) + transcoder.b_enc.to(x.device)
        feats = torch.relu(acts)
        
        if strength != 1.0:
            feats[:, :, feature_indices] *= strength
            
        reconstruction = (feats @ transcoder.W_dec.to(x.device)) + transcoder.b_dec.to(x.device)
        
        return reconstruction

    return hook_fn

In [36]:
def generate_and_save_results(model, transcoder, dataset, reasoning_neurons, save_path="experiment_results.json"):
    results = []
    
    configs = [
        ("Base", None),
        ("Original", 1.0),
        ("Suppressed", 0.5),
        ("Strengthen", 2.0) 
    ]
    
    hook_point = transcoder.cfg.hook_point

    for i, item in enumerate(tqdm.notebook.tqdm(dataset)):
        prompt = item['text']
        answer = item['answer']
        input_tokens = model.to_tokens(prompt).to(device)
        
        record = {
            "id": i,
            "prompt": prompt,
            "ground_truth": answer, 
            "generations": {}
        }
        
        for name, strength in configs:
            model.reset_hooks() 
            
            gen_args = {
                "input": input_tokens,
                "max_new_tokens": 50, 
                "temperature": 0.0,
                "verbose": False,
                "stop_at_eos": True
            }
                
            if strength is None:
                output_ids = model.generate(**gen_args)
            else:
                my_hook = get_steering_hook(transcoder, reasoning_neurons, strength=strength)
                
                with model.hooks(fwd_hooks=[(hook_point, my_hook)]):
                    output_ids = model.generate(**gen_args)
            
            full_text = model.to_string(output_ids[0])
            generated_text = full_text[len(prompt):]
            
            record["generations"][name] = generated_text
                
        results.append(record)
        
        if (i + 1) % 10 == 0:
            with open(save_path, "w", encoding="utf-8") as f:
                json.dump(results, f, indent=2, ensure_ascii=False)

    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

In [None]:
save_file = f"boolq_reasoning_control_layer{layer_idx}.json"

experiment_data = generate_and_save_results(
    model=model, 
    transcoder=transcoder, 
    dataset=dataset, 
    reasoning_neurons=reasoning_neurons,
    save_path=save_file
)