In [1]:
import numpy as np 

In [17]:
import os
import torch
import yaml
import numpy as np
from typing import List, Dict, Any, Set, Tuple
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

# -----------------------------------------
# Configuration of comparisons
# -----------------------------------------
COMPARISONS = [
    {"name": "base_answer_vs_enhanced_reason_body",       "mode1": "base",              "region1": "answering", "mode2": "reasoning_boosted","region2": "reasoning", "subset": "body",    "normalize": True},
    {"name": "base_reason_vs_base_answer_body",           "mode1": "base",              "region1": "reasoning", "mode2": "base",              "region2": "answering",  "subset": "body",    "normalize": True},
    {"name": "enhanced_reason_vs_enhanced_answer_body",   "mode1": "reasoning_boosted","region1": "reasoning", "mode2": "reasoning_boosted","region2": "answering", "subset": "body",    "normalize": True},
    {"name": "base_reason_vs_immediate_answer_body",      "mode1": "base",              "region1": "reasoning", "mode2": "immediate_answer","region2": "answering", "subset": "body",    "normalize": True},
    {"name": "enhanced_reason_vs_base_reason_body",       "mode1": "reasoning_boosted","region1": "reasoning", "mode2": "base",              "region2": "reasoning",  "subset": "body",    "normalize": True},
    {"name": "base_reason_vs_base_answer_initial",        "mode1": "base",              "region1": "reasoning", "mode2": "base",              "region2": "answering",  "subset": "initial", "normalize": False},
    {"name": "enhanced_reason_vs_enhanced_answer_initial","mode1": "reasoning_boosted","region1": "reasoning", "mode2": "reasoning_boosted","region2": "answering", "subset": "initial", "normalize": False},
    {"name": "enhanced_reason_vs_base_reason_initial",    "mode1": "reasoning_boosted","region1": "reasoning", "mode2": "base",              "region2": "reasoning",  "subset": "initial", "normalize": False},
    {"name": "base_reason_vs_immediate_answer_initial",   "mode1": "base",              "region1": "reasoning", "mode2": "immediate_answer","region2": "answering", "subset": "initial", "normalize": False},
]

# -----------------------------------------
# Utility Functions
# -----------------------------------------

def load_tensor(prompt_folder: str, mode: str) -> np.ndarray:
    """Load and convert bf16 tensor to float32 numpy."""
    pt_path = os.path.join(prompt_folder, f"{mode}.pt")
    return torch.load(pt_path).float().cpu().numpy()


def load_metadata(prompt_folder: str, mode: str) -> List[Dict[str, Any]]:
    """Load YAML metadata only."""
    yaml_path = os.path.join(prompt_folder, f"{mode}_metadata.yaml")
    with open(yaml_path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)


def filter_indices(metadata: List[Dict[str, Any]], region: str, subset: str) -> List[int]:
    """Return indices matching region/subset and token_id <128000."""
    idxs = []
    for i, m in enumerate(metadata):
        if m['region'] != region:
            continue
        tok_id = m.get('token_id', -1)
        if tok_id >= 128000:
            continue
        dist = m.get('dist_from_prev_marker')
        if subset == 'body' and dist is not None and dist > 1:
            idxs.append(i)
        elif subset == 'initial' and dist == 1:
            idxs.append(i)
    return idxs


def compute_mean_diff(resid1: np.ndarray, idx1: List[int],
                      resid2: np.ndarray, idx2: List[int]) -> np.ndarray:
    """Mean difference vector (mean1 - mean2) per layer."""
    return resid1[idx1].mean(axis=0) - resid2[idx2].mean(axis=0)


def compute_lda_dir(resid1: np.ndarray, idx1: List[int],
                    resid2: np.ndarray, idx2: List[int]) -> np.ndarray:
    """LDA discriminant direction per layer."""
    n_layers = resid1.shape[1]
    dirs = []
    for l in range(n_layers):
        X1 = resid1[idx1, l, :]
        X2 = resid2[idx2, l, :]
        X = np.concatenate([X1, X2], axis=0)
        y = np.hstack([np.zeros(len(idx1)), np.ones(len(idx2))])
        lda = LinearDiscriminantAnalysis(solver='eigen')
        lda.fit(X, y)
        dirs.append(lda.coef_[0])
    return np.stack(dirs, axis=0)


def token_normalize(resid: np.ndarray,
                    metadata: List[Dict[str, Any]],
                    idx1: List[int],
                    idx2: List[int],
                    valid_ids: Set[int]) -> np.ndarray:
    """Token-normalized residuals: subtract mu_token = (mu1 + mu2)/2 for tokens in both regions."""
    resid_norm = resid.copy()
    positions: Dict[int, Dict[str, List[int]]] = {}
    for idx in idx1:
        tok_id = metadata[idx]['token_id']
        if tok_id in valid_ids:
            positions.setdefault(tok_id, {'idx1': [], 'idx2': []})['idx1'].append(idx)
    for idx in idx2:
        tok_id = metadata[idx]['token_id']
        if tok_id in valid_ids:
            positions.setdefault(tok_id, {'idx1': [], 'idx2': []})['idx2'].append(idx)
    for tok_id, parts in positions.items():
        idxs1, idxs2 = parts['idx1'], parts['idx2']
        if not idxs1 or not idxs2:
            continue
        mu1 = resid_norm[idxs1, :, :].mean(axis=0)
        mu2 = resid_norm[idxs2, :, :].mean(axis=0)
        mu_token = 0.5 * (mu1 + mu2)
        for pos in idxs1 + idxs2:
            resid_norm[pos] -= mu_token
    return resid_norm


def compute_common_tokens(data_root: str,
                          mode1: str, region1: str,
                          mode2: str, region2: str) -> Set[int]:
    """
    Find common token_ids between two (mode, region) pairs across all prompts, filtering >=128000.
    """
    prompt_folders = sorted(
        os.path.join(data_root, d)
        for d in os.listdir(data_root)
        if os.path.isdir(os.path.join(data_root, d))
    )
    set1, set2 = set(), set()
    for folder in prompt_folders:
        for m in load_metadata(folder, mode1):
            if m['region'] == region1 and m['token_id'] < 128000:
                set1.add(m['token_id'])
        for m in load_metadata(folder, mode2):
            if m['region'] == region2 and m['token_id'] < 128000:
                set2.add(m['token_id'])
    return set1 & set2

# -----------------------------------------
# Analysis Pipeline
# -----------------------------------------

def run_analysis(
    data_root: str,
    out_root: str,
    per_subject: bool = False
):
    """
    Perform overall or per-subject analyses without layer batching.

    Parameters:
      data_root: root folder with prompt subfolders
      out_root: folder to save comparison results
      per_subject: if True, also run analyses per subject
    """
    os.makedirs(out_root, exist_ok=True)
    prompt_folders = sorted(
        os.path.join(data_root, d)
        for d in os.listdir(data_root)
        if os.path.isdir(os.path.join(data_root, d))
    )
    subject_map: Dict[str, List[str]] = {}
    for folder in prompt_folders:
        subj = os.path.basename(folder).split('_', 1)[0]
        subject_map.setdefault(subj, []).append(folder)
    scopes = {'ALL': prompt_folders}
    if per_subject:
        for subj, fl in subject_map.items(): scopes[subj] = fl

    for scope, folders in scopes.items():
        for comp in COMPARISONS:
            name = comp['name']
            m1, r1 = comp['mode1'], comp['region1']
            m2, r2 = comp['mode2'], comp['region2']
            subset, normalize = comp['subset'], comp['normalize']

            resid_list1, resid_list2 = [], []
            meta_list1, meta_list2 = [], []
            idx1_comb, idx2_comb = [], []
            offset = 0
            for fld in folders:
                meta1 = load_metadata(fld, m1)
                meta2 = load_metadata(fld, m2)
                idx1 = filter_indices(meta1, r1, subset)
                idx2 = filter_indices(meta2, r2, subset)
                if not idx1 or not idx2:
                    continue
                r1_arr = load_tensor(fld, m1)
                r2_arr = load_tensor(fld, m2)
                resid_list1.append(r1_arr)
                resid_list2.append(r2_arr)
                meta_list1.extend(meta1)
                meta_list2.extend(meta2)
                idx1_comb += [offset + i for i in idx1]
                offset += r1_arr.shape[0]
                idx2_comb += [offset + i for i in idx2]
                offset += r2_arr.shape[0]

            if not resid_list1:
                continue
            combined1 = np.concatenate(resid_list1, axis=0)
            combined2 = np.concatenate(resid_list2, axis=0)
            combined_meta = meta_list1 + meta_list2

            if normalize:
                valid_ids = compute_common_tokens(data_root, m1, r1, m2, r2)
                big = np.concatenate([combined1, combined2], axis=0)
                # Mask invalid
                for i, m in enumerate(combined_meta):
                    if m['token_id'] not in valid_ids:
                        big[i] = np.nan
                normed = token_normalize(big, combined_meta,
                                          idx1_comb,
                                          [i + combined1.shape[0] for i in idx2_comb],
                                          valid_ids)
                c1 = normed[:combined1.shape[0]]
                c2 = normed[combined1.shape[0]:]
            else:
                c1, c2 = combined1, combined2

            # Compute across all layers
            mean_diff = compute_mean_diff(c1, idx1_comb, c2, idx2_comb)
            lda_dir   = compute_lda_dir(c1, idx1_comb, c2, idx2_comb)

            outd = os.path.join(out_root, scope, name)
            os.makedirs(outd, exist_ok=True)
            np.save(os.path.join(outd, 'mean_diff.npy'), mean_diff)
            np.save(os.path.join(outd, 'lda_dir.npy'),   lda_dir)

    print("Analysis complete. Results saved in", out_root)


In [6]:
common_tokens = compute_common_tokens("reasoning_resid_data")


In [8]:
[len(v) for k, v in common_tokens.items()]

[1368, 0, 1070]

In [12]:
from tl_tools import *

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
model = load_llama8br1()

Downloading shards: 100%|██████████| 2/2 [00:38<00:00, 19.35s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.31it/s]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


In [14]:
model.tokenizer

LlamaTokenizerFast(name_or_path='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', vocab_size=128000, model_max_length=16384, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<｜begin▁of▁sentence｜>', 'eos_token': '<｜end▁of▁sentence｜>', 'pad_token': '<｜end▁of▁sentence｜>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	128000: AddedToken("<｜begin▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<｜end▁of▁sentence｜>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|finetune_right_pad_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=