In [1]:
# Imports and Setup
import argparse
import os
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
from sae_lens.analysis.feature_statistics import (
    get_all_stats_dfs,
    get_W_U_W_dec_stats_df,
)
from sae_lens.analysis.tsea import (
    get_enrichment_df,
    manhattan_plot_enrichment_scores,
    plot_top_k_feature_projections_by_token_and_category,
    get_baby_name_sets,
    get_letter_gene_sets,
    generate_pos_sets,
    get_test_gene_sets,
    get_gene_set_from_regex,
)
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np
import plotly_express as px
import logging
from typing import Tuple
import json
from log import setup_logging
from tqdm import tqdm  # For progress bars


In [39]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/emotional_classify/multiclass-sentiment-analysis",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "pos-neg",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

# 进行COT相关实验

In [40]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "cot-direct",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

In [41]:
# Configuration and Hyperparameters



# 将字典转换为 argparse.Namespace 对象
args = argparse.Namespace(**args_dict)

# 测试访问属性
print(args.layer)  # 输出: 10
print(args.LLM)  # 输出: gpt2-small
print(args.output_dir)  # 输出: ./results

6
gpt2-small
./results


In [42]:
# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)

# Setup logging
setup_logging(output_dir_base)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")


2025-01-10 11:24:39,298 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_cot-direct_alpha_100_cnt_100_mean100/execution.log
2025-01-10 11:24:39,300 [INFO] Hyperparameters:
2025-01-10 11:24:39,302 [INFO]   layer: 6
2025-01-10 11:24:39,303 [INFO]   LLM: gpt2-small
2025-01-10 11:24:39,304 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k
2025-01-10 11:24:39,305 [INFO]   output_dir: ./results
2025-01-10 11:24:39,306 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-10 11:24:39,307 [INFO]   seed: 42
2025-01-10 11:24:39,307 [INFO]   data_size: 1000
2025-01-10 11:24:39,309 [INFO]   device: cpu
2025-01-10 11:24:39,310 [INFO]   alpha: 100
2025-01-10 11:24:39,311 [INFO]   steer: cot-direct
2025-01-10 11:24:39,319 [INFO]   method: val_mul
2025-01-10 11:24:39,319 [INFO]   topk_mean: 100
2025-01-10 11:24:39,320 [INFO]   topk_cnt: 100
2025-01-10 11:24:39,321 [INFO]   batch_size: 3

In [43]:
# Load Environment Variables
 
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")

load_environment(args.env_path)


2025-01-10 11:24:42,985 [INFO] HF_ENDPOINT: https://hf-mirror.com


In [83]:
import re
def load_and_prepare_sentiment_dataset(dataset_path: str, seed: int, num_samples: int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    neg_train_set = dataset['train'].filter(lambda example: example['label'] == 0).select(range(num_samples))
    pos_train_set = dataset['train'].filter(lambda example: example['label'] == 2).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] == 1).select(range(num_samples))

    logging.info(f"Selected {len(neg_train_set)} negative, {len(pos_train_set)} positive, and {len(neu_train_set)} neutral samples")
    return neg_train_set, pos_train_set, neu_train_set
def load_and_prepare_COT_dataset(dataset_path:str,seed:int,num_samples:int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)
    logging.info("Filtering dataset for COT")
    # 定义一个函数来提取答案
    def extract_answer(text):
        # 使用正则表达式提取答案
        match = re.search(r'#### ([-+]?\d*\.?\d+/?\d*)', text)
        if match:
            label=match.group(1)
            return label
        else:
            raise ValueError("Modify your re expression")
    def concat_QA(example,col1,col2,tag):
        combined = f"{example[col1]}{tag}{example[col2]}"  # 用空格拼接
        return combined
    def replace_col(example,col1,target,pattern):
        return example[col1].replace(target,pattern)
    # 应用函数并创建新列
    dataset = dataset.map(lambda example: {'A': extract_answer(example['response'])})
    dataset = dataset.map(lambda example: {'Q+A': concat_QA(example,"prompt","A","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': concat_QA(example,"prompt","response","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': replace_col(example,"Q+COT_A","#### ","")})
    # 查看处理后的数据集
    print("Q+A\n",dataset['train'][103]['Q+A'])
    print("Q+COT_A\n",dataset['train'][103]['Q+COT_A'])
    return dataset
    
    

In [84]:
args.steer

'cot-direct'

In [86]:
# Load and Prepare Dataset


if "neg" in args.steer or "pos" in args.steer:
    neg_train_set, pos_train_set, neu_train_set = load_and_prepare_sentiment_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "cot" in args.steer or "COT" in args.steer:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
else:
    raise ValueError("No Supported")


2025-01-10 14:27:18,864 [INFO] COT COT COT COT COT COT COT COT COT COT 
2025-01-10 14:27:18,866 [INFO] Loading dataset from /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k
2025-01-10 14:27:19,027 [INFO] Filtering dataset for COT


Q+A
 Q: Cody goes to the store and buys $40 worth of stuff.  The taxes were 5%.  After taxes, he got an $8 discount.  Cody and his friend split the final price equally. How much did Cody pay?
A: 17
Q+COT_A
 Q: Cody goes to the store and buys $40 worth of stuff.  The taxes were 5%.  After taxes, he got an $8 discount.  Cody and his friend split the final price equally. How much did Cody pay?
A: The taxes were 40*.05=$<<40*.05=2>>2.
So the price was 40+2=$<<40+2=42>>42.
He got a discount so the price he paid was 42-8=$<<42-8=34>>34.
Since he paid half his price was 34/2=$<<34/2=17>>17.
17


In [87]:


def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    logging.info("Running model with cache to obtain hidden states")
    batch_latents = []

    # 使用 tqdm 显示进度条
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i + batch_size]
        sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=True, device=device)
        batch_hidden_states = cache[hook_point]
        logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

        logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
        # 假设 sae.encode 支持批量编码
        latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
        batch_latents.append(latents)
        

    logging.info(f"Total batches processed: {len(batch_latents)}")
    return batch_latents

In [88]:

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [89]:

def compute_steering_vectors(sae: SAE, overlap_indices: Tensor, nz_mean: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[overlap_indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for important_idx in overlap_indices:
            steering_vectors += nz_mean[important_idx].item() * sae.W_dec[important_idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors


In [90]:

def save_results(output_dir: str, nz_mean: Tensor, act_cnt: Tensor, generated_texts: list, hyperparams: dict):
    os.makedirs(output_dir, exist_ok=True)

    # Save nz_mean and act_cnt
    nz_stats_path = os.path.join(output_dir, 'nz_stats.pt')
    logging.info(f"Saving nz_mean and act_cnt to {nz_stats_path}")
    torch.save({
        'nz_mean': nz_mean,
        'act_cnt': act_cnt
    }, nz_stats_path)

    # Save generated texts
    generated_texts_path = os.path.join(output_dir, 'generated_texts.txt')
    logging.info(f"Saving generated texts to {generated_texts_path}")
    with open(generated_texts_path, 'w') as f:
        for text in generated_texts:
            f.write(text + "\n")

    # Save hyperparameters
    hyperparams_path = os.path.join(output_dir, 'hyperparameters.json')
    logging.info(f"Saving hyperparameters to {hyperparams_path}")
    with open(hyperparams_path, 'w') as f:
        json.dump(hyperparams, f, indent=4)

    logging.info("All results saved successfully.")

In [91]:
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)
output_dir_base

'./results/LLM_gpt2-small_layer_6_steer_cot-direct_alpha_100_cnt_100_mean100'

In [11]:
def load_from_cache():
    cache_exists = False
    cache_file = os.path.join(output_dir_base, 'hyperparameters.json')
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            cached_data = json.load(f)
        cached_hash = cached_data.get('hyperparams_hash')

    if cache_exists:
        # Load nz_mean and act_cnt from cache
        # nz_stats_path = os.path.join(output_dir_base, 'nz_stats.pt')
        # nz_act = torch.load(nz_stats_path)
        # nz_mean = nz_act['nz_mean']
        # act_cnt = nz_act['act_cnt']
        # overlap_indices = nz_act.get('overlap_indices', None)  # If overlap_indices was saved
        logging.info("load from cache")
    else:
        # overlap_indices = None  # Will be computed later
        logging.info("non cache: "+cache_file)
load_from_cache()

2025-01-10 09:46:29,462 [INFO] non cache: ./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100/hyperparameters.json


In [12]:
# setup_logging(output_dir_base)

In [92]:
# Save hyperparameters
hyperparams = vars(args)

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load environment
load_environment(args.env_path)

# Load model and SAE
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)





2025-01-10 14:28:22,647 [INFO] Hyperparameters:
2025-01-10 14:28:22,649 [INFO]   layer: 6
2025-01-10 14:28:22,652 [INFO]   LLM: gpt2-small
2025-01-10 14:28:22,653 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k
2025-01-10 14:28:22,654 [INFO]   output_dir: ./results
2025-01-10 14:28:22,655 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-10 14:28:22,656 [INFO]   seed: 42
2025-01-10 14:28:22,656 [INFO]   data_size: 1000
2025-01-10 14:28:22,658 [INFO]   device: cpu
2025-01-10 14:28:22,659 [INFO]   alpha: 100
2025-01-10 14:28:22,660 [INFO]   steer: cot-direct
2025-01-10 14:28:22,661 [INFO]   method: val_mul
2025-01-10 14:28:22,662 [INFO]   topk_mean: 100
2025-01-10 14:28:22,663 [INFO]   topk_cnt: 100
2025-01-10 14:28:22,664 [INFO]   batch_size: 32
2025-01-10 14:28:22,667 [INFO] HF_ENDPOINT: https://hf-mirror.com
2025-01-10 14:28:22,668 [INFO] Loading model: gpt2-small
2025-01-10 14:29:03,517 [INFO] Loading 

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [94]:
# Load dataset
all_dataset = load_and_prepare_COT_dataset(
    args.dataset_path, args.seed, args.data_size
)

2025-01-10 14:31:32,014 [INFO] Loading dataset from /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k
2025-01-10 14:31:32,065 [INFO] Filtering dataset for COT


Q+A
 Q: Cody goes to the store and buys $40 worth of stuff.  The taxes were 5%.  After taxes, he got an $8 discount.  Cody and his friend split the final price equally. How much did Cody pay?
A: 17
Q+COT_A
 Q: Cody goes to the store and buys $40 worth of stuff.  The taxes were 5%.  After taxes, he got an $8 discount.  Cody and his friend split the final price equally. How much did Cody pay?
A: The taxes were 40*.05=$<<40*.05=2>>2.
So the price was 40+2=$<<40+2=42>>42.
He got a discount so the price he paid was 42-8=$<<42-8=34>>34.
Since he paid half his price was 34/2=$<<34/2=17>>17.
17


In [95]:
args.steer

'cot-direct'

In [15]:
def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [104]:
# Select a dataset steer based on steering preference

def get_activation_by_steer(texts:list):

    hook_point = sae.cfg.hook_name

    # Compute latents with batch processing
    batch_latents = compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    # 计算第二个维度的最大值
    max_dim1 = max(latent.shape[1] for latent in batch_latents)  # 第二个维度的最大值
    logging.info(f"最大长度:{max_dim1}")
    # 对每个 Tensor 进行填充（仅填充第二个维度）
    padded_latents_right = [
        torch.nn.functional.pad(latent, (0, 0, 0, max_dim1 - latent.size(1)), "constant", 0)
        for latent in batch_latents
    ]

    batch_latents_concatenated = torch.cat(padded_latents_right, dim=0)
    logging.info(f"Concatenated batch latents shape: {batch_latents_concatenated.shape}")

    # Analyze latents
    nz_mean, act_cnt, overlap_indices = analyze_latents(batch_latents_concatenated, top_k_mean=args.topk_mean, top_k_cnt=args.topk_cnt)
    return {"nz_mean":nz_mean,"nz_cnt":act_cnt}

In [105]:
args.steer=args.steer.lower()

In [119]:
all_dataset["train"]["Q+A"][:args.data_size][193]

'Q: Ten percent less than twice the total number of students present in the science class yesterday have attended class today. If there were 70 students in the class yesterday and 30 students are absent today, calculate the number of students registered for the course.\nA: 156'

In [109]:
steer_info={}
if "cot" in args.steer:
    if args.steer in "cot-direct":
        texts=all_dataset["train"]["Q+A"][:args.data_size]
        print(type(texts))
        steer_info["direct"]=get_activation_by_steer(texts)
        
        texts=all_dataset["train"]["Q+COT_A"][:args.data_size]
        print(type(texts))
        steer_info["cot"]=get_activation_by_steer(texts)
        # print(texts[123])
    else:
        raise ValueError("????")

2025-01-10 14:50:52,557 [INFO] Running model with cache to obtain hidden states


<class 'list'>


Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-10 14:50:53,432 [INFO] Batch 1: Hidden states shape: torch.Size([32, 155, 768])
2025-01-10 14:50:53,433 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:01<00:31,  1.02s/it]2025-01-10 14:50:54,501 [INFO] Batch 2: Hidden states shape: torch.Size([32, 104, 768])
2025-01-10 14:50:54,503 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:02<00:30,  1.03s/it]2025-01-10 14:50:55,365 [INFO] Batch 3: Hidden states shape: torch.Size([32, 115, 768])
2025-01-10 14:50:55,367 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:02<00:27,  1.05it/s]2025-01-10 14:50:56,358 [INFO] Batch 4: Hidden states shape: torch.Size([32, 121, 768])
2025-01-10 14:50:56,360 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:03<00:27,  1.03it/s]2025-01-10 14:50:57,033 [INFO] Batch 5: Hidden states shape: t

<class 'list'>


Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-10 14:51:29,259 [INFO] Batch 1: Hidden states shape: torch.Size([32, 394, 768])
2025-01-10 14:51:29,259 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:03<01:56,  3.74s/it]2025-01-10 14:51:32,466 [INFO] Batch 2: Hidden states shape: torch.Size([32, 266, 768])
2025-01-10 14:51:32,468 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:06<01:41,  3.39s/it]2025-01-10 14:51:35,617 [INFO] Batch 3: Hidden states shape: torch.Size([32, 341, 768])
2025-01-10 14:51:35,619 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:10<01:35,  3.29s/it]2025-01-10 14:51:39,830 [INFO] Batch 4: Hidden states shape: torch.Size([32, 379, 768])
2025-01-10 14:51:39,832 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:14<01:42,  3.68s/it]2025-01-10 14:51:43,079 [INFO] Batch 5: Hidden states shape: t

In [None]:

for steer_key in ["pos","neu","neg"]:
    if steer_key == "pos":
        selected_set = pos_train_set
    elif steer_key == "neg":
        selected_set = neg_train_set
    elif steer_key=="neu":
        selected_set = neu_train_set

    texts = selected_set["text"][:args.data_size]
    a=get_activation_by_steer(texts)
    steer_info[steer_key]=a

2025-01-10 09:47:42,058 [INFO] Running model with cache to obtain hidden states


Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-10 09:47:42,885 [INFO] Batch 1: Hidden states shape: torch.Size([32, 99, 768])
2025-01-10 09:47:42,886 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:28,  1.08it/s]2025-01-10 09:47:43,898 [INFO] Batch 2: Hidden states shape: torch.Size([32, 94, 768])
2025-01-10 09:47:43,899 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:01<00:29,  1.03it/s]2025-01-10 09:47:44,577 [INFO] Batch 3: Hidden states shape: torch.Size([32, 84, 768])
2025-01-10 09:47:44,578 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:02<00:23,  1.21it/s]2025-01-10 09:47:45,452 [INFO] Batch 4: Hidden states shape: torch.Size([32, 116, 768])
2025-01-10 09:47:45,453 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:03<00:23,  1.19it/s]2025-01-10 09:47:46,178 [INFO] Batch 5: Hidden states shape: torc

In [122]:
steer_info["pos"]

KeyError: 'pos'

In [121]:
steer_info["pos"]["nz_mean"].shape

KeyError: 'pos'

In [20]:
steer_info["dif_neg_pos"]={"steer":"dif_neg_pos","nz_cnt":steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"],"nz_mean":steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"]}

In [21]:
steer_info["dif_neg_pos_relu"]={"nz_cnt":torch.relu(steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"]),"nz_mean":torch.relu(steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])}

In [22]:
steer_info["dif_neg_pos_relu"],steer_info["dif_neg_pos"]

({'steer': 'dif_neg_pos',
  'nz_cnt': tensor([1, 0, 0,  ..., 0, 7, 2]),
  'nz_mean': tensor([6.4307, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 3.2251],
         grad_fn=<ReluBackward0>)},
 {'steer': 'dif_neg_pos',
  'nz_cnt': tensor([  1,  -2,   0,  ..., -58,   7,   2]),
  'nz_mean': tensor([ 6.4307, -0.6795,  0.0000,  ..., -0.2715, -0.0391,  3.2251],
         grad_fn=<SubBackward0>)})

In [110]:
a="cot"
b="direct"
steer_info[f"dif_{a}-{b}_relu"]={"nz_cnt":torch.relu(steer_info[a]["nz_cnt"]-steer_info[b]["nz_cnt"]),"nz_mean":torch.relu(steer_info[a]["nz_mean"]-steer_info[b]["nz_mean"])}

In [111]:
top_k=100
steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{a}-{b}_relu"]["nz_cnt"],top_k)

In [112]:
steering_vectors,steer_indices,

(tensor([159486, 141545, 134493, 130902, 125546, 123195, 122745, 120984, 119869,
         118881, 118222, 118150, 117965, 117340, 116948, 116616, 116615, 114701,
         113613, 113599, 113494, 113279, 112218, 109645, 108840, 107080, 106989,
         106733, 106677, 106527, 106160, 106044, 105674, 105443, 104825, 104224,
         103736, 103527, 103445, 102129, 102107, 101660, 101039, 101014, 100969,
         100056,  99783,  99311,  99142,  98228,  98090,  97515,  97387,  96717,
          96412,  96041,  95767,  95486,  95336,  95266,  95083,  94963,  94879,
          94797,  94783,  94767,  94746,  94539,  94042,  93070,  92401,  91442,
          90964,  90813,  90755,  90454,  90389,  90118,  89845,  89711,  89692,
          89137,  88551,  88369,  88346,  88198,  88191,  88042,  87530,  86853,
          86783,  86318,  86191,  85932,  85874,  85606,  85542,  85445,  85293,
          85271]),
 tensor([14243, 19222, 18709,  7425, 17126,  8487, 10045, 14332, 14817, 22228,
          8

In [114]:
steer_info[f"dif_{a}-{b}_relu"]["nz_mean"][steer_indices]

tensor([ 0.9378,  0.6677,  3.3172,  3.7422,  1.9544,  9.9255,  4.9352,  2.5919,
         2.7637,  4.3057,  4.6804,  1.9646,  1.4629,  3.8050,  1.4482,  4.0289,
         4.0796,  1.7954,  0.2430,  0.4193,  1.5127,  1.7011,  3.2796,  0.0000,
         0.9136,  3.4771,  0.8742,  1.9859,  0.0000,  0.7164,  5.8357,  0.4058,
         2.7185,  0.6172,  2.9650,  0.9805,  3.4497,  0.5389,  1.3578,  1.4940,
         1.8432,  1.4488,  1.2316,  4.4058,  3.7352,  0.9454,  2.3580,  1.8529,
         1.7326,  1.5439,  1.7971,  1.7733,  3.1612,  1.3640,  0.4026,  0.1367,
         1.2794,  1.6125,  0.2939,  2.7150,  2.0939,  3.8241,  0.8661,  0.3859,
         0.0000,  4.1829,  0.4095,  0.9616,  3.0770,  4.3298, 13.4603,  1.4696,
         0.0000,  0.8026,  0.0000,  0.0000,  4.0572,  0.0000,  0.5797,  0.0000,
         0.9628,  3.5509,  1.0189,  0.4013,  0.3832,  0.3809,  0.0000,  0.0000,
        10.4950,  1.2725,  2.1783,  1.0589,  0.0000,  1.5716,  3.4318,  2.2768,
         0.0000,  0.1837,  0.0000, 14.08

In [25]:
steer_info["dif_neg_pos"]["nz_mean"][steer_indices]

tensor([ 2.3788e-02,  8.1224e-02,  2.3156e-02,  1.8783e-02,  1.1774e-01,
         3.7877e-02,  6.3553e-02,  5.3482e-02,  5.4047e-02,  1.2716e-02,
         3.0568e-02,  6.7365e-02,  3.1710e-03, -8.7272e-03, -3.8923e-01,
        -2.1414e-02, -1.4824e-01, -1.7109e-02,  2.1162e-02,  2.5611e-02,
         2.4913e-02, -6.5165e-02,  1.8257e-03, -1.5588e-01,  3.0566e-02,
        -3.8024e-01,  1.5130e-03,  2.4326e-03,  1.4322e-02, -5.0976e-02,
         2.6311e-02,  2.0264e-01, -4.5139e-02, -3.4792e-01,  3.1174e-02,
         9.3668e-02,  2.4175e-01,  2.1948e-01, -1.4558e-01, -1.8547e-02,
         3.1676e-01,  1.7079e+00,  5.4160e-01,  7.8669e-02,  2.5975e-01,
         1.2162e-02,  2.9813e-01, -1.2766e-02,  3.7756e-02,  3.6043e-01,
         9.2443e-02,  3.0181e-01,  4.0478e-02,  9.6413e-02,  2.1965e-01,
         1.1340e+00,  1.2302e-01,  2.1960e-01, -1.6293e-02, -3.9671e-02,
         2.1330e-02,  5.8879e-01,  4.3106e-01,  1.7933e-01,  6.7136e-01,
         3.6226e-01,  4.2522e-03, -3.4991e-02, -1.1

In [115]:
steer_indices[0].item()

14243

In [117]:

def compute_steering_vectors(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            steering_vectors += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors
steering_vectors=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{a}-{b}_relu"]["nz_mean"],method="val_mul")

2025-01-10 14:56:00,391 [INFO] Computing steering vectors using method: val_mul
2025-01-10 14:56:00,402 [INFO] Steering vectors computed with shape: torch.Size([768])


In [64]:
# model.to_tokens("<|endoftext|>")

tensor([[50256, 50256]])

In [130]:
model.tokenizer.eos_token

'<|endoftext|>'

In [128]:
model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [136]:
# # Define steering hook
# steering_on = True  # This will be toggled in run_generate
# alpha = args.alpha
# method = args.method  # Store method for clarity

def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return
    if steering_on:
        resid_pre += alpha * steering_vectors

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)
    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch)
        result = model.generate(
            stop_at_eos=True, # 设置为 True，遇到 <|endoftext|> 时停止model.tokenizer.eos_token= <|endoftext|>
            input=tokenized,
            # max_new_tokens=50, # 最大生成 token 数
            do_sample=True,# 是否使用采样
            **kwargs,
        )
    return result

def run_generate(example_prompts: str, sampling_kwargs: dict) -> list:
    model.reset_hooks()
    editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        example_prompts * 3,
        fwd_hooks=editing_hooks if steering_on else [],
        seed=args.seed,
        **sampling_kwargs
    )


    res_str = model.to_string(res[:, 1:])
    # generated_texts = res_str
    for idx, text in enumerate(res_str):
        logging.info(f"Generated Text: {idx+1}: {text}")
        
    return res_str

# Define sampling parameters
sampling_kwargs = dict(temperature=1.0, top_p=0.7, freq_penalty=1)

# Example prompt from the selected set
example_prompt = all_dataset["val"]["prompt"][:3]
print( all_dataset["val"]["prompt"][:3],all_dataset["val"]["A"][:3])
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering
steering_on = False
alpha = 0
logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs)

# Generate with steering
steering_on = True
alpha = 20
logging.info("Generating texts with steering... Target")
generated_texts_with_steer = run_generate(example_prompt, sampling_kwargs)

# Combine generated texts
all_generated_texts = generated_texts_no_steer + generated_texts_with_steer



2025-01-10 15:37:45,388 [INFO] Example prompt: ['Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?\nA: ', "Q: Adrianna has 10 pieces of gum to share with her friends. There wasn't enough gum for all her friends, so she went to the store to get 3 more pieces of gum. She gave out gum to 11 friends. How many pieces of gum does Adrianna have now?\nA: ", 'Q: Paul needed to buy some new clothes for work.  He had a 10% off coupon that he could use on his entire purchase after any other discounts.  Paul bought 4 dress shirts at $15.00 apiece, 2 pairs of pants that each cost $40.00.  He found a suit for $150.00 and 2 sweaters for $30.00 each.  When he got to the register, the clerk told him that the store was offering 20% off of everything in the store.  After the discounts and the coupon, how much did Paul spen

['Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?\nA: ', "Q: Adrianna has 10 pieces of gum to share with her friends. There wasn't enough gum for all her friends, so she went to the store to get 3 more pieces of gum. She gave out gum to 11 friends. How many pieces of gum does Adrianna have now?\nA: ", 'Q: Paul needed to buy some new clothes for work.  He had a 10% off coupon that he could use on his entire purchase after any other discounts.  Paul bought 4 dress shirts at $15.00 apiece, 2 pairs of pants that each cost $40.00.  He found a suit for $150.00 and 2 sweaters for $30.00 each.  When he got to the register, the clerk told him that the store was offering 20% off of everything in the store.  After the discounts and the coupon, how much did Paul spend on his new clothes?\nA: '] ['9800', '2', '252

  0%|          | 0/10 [00:00<?, ?it/s]

2025-01-10 15:37:46,041 [INFO] Generated Text: 1: Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?
A: <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endofte

  0%|          | 0/10 [00:00<?, ?it/s]

2025-01-10 15:37:46,630 [INFO] Generated Text: 1: Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?
A: <|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endofte

In [1]:
example_prompt = all_dataset["val"]["prompt"][:10]

NameError: name 'all_dataset' is not defined

In [None]:
# Save results
save_results(
    output_dir=output_dir_base,
    nz_mean=nz_mean,
    act_cnt=act_cnt,
    generated_texts=all_generated_texts,
    hyperparams=hyperparams
)