In [1]:
# Imports and Setup
import argparse
import os
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
from sae_lens.analysis.feature_statistics import (
    get_all_stats_dfs,
    get_W_U_W_dec_stats_df,
)
from sae_lens.analysis.tsea import (
    get_enrichment_df,
    manhattan_plot_enrichment_scores,
    plot_top_k_feature_projections_by_token_and_category,
    get_baby_name_sets,
    get_letter_gene_sets,
    generate_pos_sets,
    get_test_gene_sets,
    get_gene_set_from_regex,
)
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np
import plotly_express as px
import logging
from typing import Tuple
import json
from log import setup_logging
from tqdm import tqdm  # For progress bars


In [2]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/emotional_classify/multiclass-sentiment-analysis",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "pos-neg",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

# 进行COT相关实验

In [3]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "cot-direct",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

# 进行礼貌实验
/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus

In [29]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "polite-impolite",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

In [149]:
# 进行礼貌实验

In [31]:
# Configuration and Hyperparameters
# 将字典转换为 argparse.Namespace 对象
args = argparse.Namespace(**args_dict)
# 测试访问属性
print(args.layer)  # 输出: 10
print(args.LLM)  # 输出: gpt2-small
print(args.output_dir)  # 输出: ./results
print(args.steer)

6
gpt2-small
./results
polite-impolite


In [32]:
# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)

# Setup logging
setup_logging(output_dir_base)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")


2025-01-12 10:32:43,908 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_polite-impolite_alpha_100_cnt_100_mean100/execution.log
2025-01-12 10:32:43,910 [INFO] Hyperparameters:
2025-01-12 10:32:43,911 [INFO]   layer: 6
2025-01-12 10:32:43,920 [INFO]   LLM: gpt2-small
2025-01-12 10:32:43,921 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus
2025-01-12 10:32:43,922 [INFO]   output_dir: ./results
2025-01-12 10:32:43,923 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-12 10:32:43,923 [INFO]   seed: 42
2025-01-12 10:32:43,924 [INFO]   data_size: 1000
2025-01-12 10:32:43,925 [INFO]   device: cpu
2025-01-12 10:32:43,926 [INFO]   alpha: 100
2025-01-12 10:32:43,927 [INFO]   steer: polite-impolite
2025-01-12 10:32:43,927 [INFO]   method: val_mul
2025-01-12 10:32:43,928 [INFO]   topk_mean: 100
2025-01-12 10:32:43,929 [INFO]   topk_cnt: 100
2025-01-12 10:32:4

In [33]:
# Load Environment Variables
 
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")

load_environment(args.env_path)


2025-01-12 10:32:47,612 [INFO] HF_ENDPOINT: https://hf-mirror.com


In [34]:
import re
def load_and_prepare_sentiment_dataset(dataset_path: str, seed: int, num_samples: int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    neg_train_set = dataset['train'].filter(lambda example: example['label'] == 0).select(range(num_samples))
    pos_train_set = dataset['train'].filter(lambda example: example['label'] == 2).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] == 1).select(range(num_samples))

    logging.info(f"Selected {len(neg_train_set)} negative, {len(pos_train_set)} positive, and {len(neu_train_set)} neutral samples")
    return neg_train_set, pos_train_set, neu_train_set
def load_and_prepare_COT_dataset(dataset_path:str,seed:int,num_samples:int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)
    logging.info("Filtering dataset for COT")
    # 定义一个函数来提取答案
    def extract_answer(text):
        # 使用正则表达式提取答案
        match = re.search(r'#### ([-+]?\d*\.?\d+/?\d*)', text)
        if match:
            label=match.group(1)
            return label
        else:
            raise ValueError("Modify your re expression")
    def concat_QA(example,col1,col2,tag):
        combined = f"{example[col1]}{tag}{example[col2]}"  # 用空格拼接
        return combined
    def replace_col(example,col1,target,pattern):
        return example[col1].replace(target,pattern)
    # 应用函数并创建新列
    dataset = dataset.map(lambda example: {'A': extract_answer(example['response'])})
    dataset = dataset.map(lambda example: {'Q+A': concat_QA(example,"prompt","A","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': concat_QA(example,"prompt","response","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': replace_col(example,"Q+COT_A","#### ","")})
    # 查看处理后的数据集
    print("Q+A\n",dataset['train'][103]['Q+A'])
    print("Q+COT_A\n",dataset['train'][103]['Q+COT_A'])
    return dataset

def load_and_prepare_polite_dataset(dataset_path: str, seed: int, num_samples: int):
    """frfede/politeness-corpus"""
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for impolite, polite, and neutral samples")
    impolite_train_set = dataset['train'].filter(lambda example: example['label'] == 0).select(range(num_samples))
    polite_train_set = dataset['train'].filter(lambda example: example['label'] == 2).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] == 1).select(range(num_samples))

    logging.info(f"Selected {len(impolite_train_set)} negative, {len(polite_train_set)} positive, and {len(neu_train_set)} neutral samples")
    return impolite_train_set, polite_train_set, neu_train_set
    

In [35]:
args.steer

'polite-impolite'

In [36]:
# Load and Prepare Dataset


if "neg" in args.steer or "pos" in args.steer:
    neg_train_set, pos_train_set, neu_train_set = load_and_prepare_sentiment_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "cot" in args.steer or "COT" in args.steer:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "polite" in args.steer:
    logging.info("polite"*10)
    neg_train_set, pos_train_set, neu_train_set=load_and_prepare_polite_dataset(args.dataset_path, args.seed, args.data_size)
else:
    raise ValueError("No Supported")


2025-01-12 10:32:55,617 [INFO] politepolitepolitepolitepolitepolitepolitepolitepolitepolite
2025-01-12 10:32:55,620 [INFO] Loading dataset from /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus
2025-01-12 10:32:55,688 [INFO] Filtering dataset for impolite, polite, and neutral samples
2025-01-12 10:32:55,716 [INFO] Selected 1000 negative, 1000 positive, and 1000 neutral samples


In [37]:
assert neg_train_set[10]!=pos_train_set[10]

In [157]:
pos_train_set[10],neg_train_set[10]

({'text': 'How many is "several different"? How big are |L| and |S|, relatively/absolutely? ~~~~',
  'label': 2,
  'sentiment': 'Polite'},
 {'text': "Since each polygon would overlay a polygon of the exact same shape, you really won't glean any benefit (that I can think of) from having multiple shapes. Are you doing this for something other than symbolizing features on a map?",
  'label': 0,
  'sentiment': 'Impolite'})

In [38]:


def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    logging.info("Running model with cache to obtain hidden states")
    batch_latents = []

    # 使用 tqdm 显示进度条
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i + batch_size]
        sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=False, device=device)
        batch_hidden_states = cache[hook_point]
        logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

        logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
        # 假设 sae.encode 支持批量编码
        latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
        batch_latents.append(latents)
        

    logging.info(f"Total batches processed: {len(batch_latents)}")
    return batch_latents

In [39]:

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [40]:

def compute_steering_vectors(sae: SAE, overlap_indices: Tensor, nz_mean: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[overlap_indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for important_idx in overlap_indices:
            steering_vectors += nz_mean[important_idx].item() * sae.W_dec[important_idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors


In [41]:

def save_results(output_dir: str, nz_mean: Tensor, act_cnt: Tensor, generated_texts: list, hyperparams: dict):
    os.makedirs(output_dir, exist_ok=True)

    # Save nz_mean and act_cnt
    nz_stats_path = os.path.join(output_dir, 'nz_stats.pt')
    logging.info(f"Saving nz_mean and act_cnt to {nz_stats_path}")
    torch.save({
        'nz_mean': nz_mean,
        'act_cnt': act_cnt
    }, nz_stats_path)

    # Save generated texts
    generated_texts_path = os.path.join(output_dir, 'generated_texts.txt')
    logging.info(f"Saving generated texts to {generated_texts_path}")
    with open(generated_texts_path, 'w') as f:
        for text in generated_texts:
            f.write(text + "\n")

    # Save hyperparameters
    hyperparams_path = os.path.join(output_dir, 'hyperparameters.json')
    logging.info(f"Saving hyperparameters to {hyperparams_path}")
    with open(hyperparams_path, 'w') as f:
        json.dump(hyperparams, f, indent=4)

    logging.info("All results saved successfully.")

In [42]:
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)
output_dir_base

'./results/LLM_gpt2-small_layer_6_steer_polite-impolite_alpha_100_cnt_100_mean100'

In [43]:
def load_from_cache():
    cache_exists = False
    cache_file = os.path.join(output_dir_base, 'hyperparameters.json')
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            cached_data = json.load(f)
        cached_hash = cached_data.get('hyperparams_hash')

    if cache_exists:
        # Load nz_mean and act_cnt from cache
        # nz_stats_path = os.path.join(output_dir_base, 'nz_stats.pt')
        # nz_act = torch.load(nz_stats_path)
        # nz_mean = nz_act['nz_mean']
        # act_cnt = nz_act['act_cnt']
        # overlap_indices = nz_act.get('overlap_indices', None)  # If overlap_indices was saved
        logging.info("load from cache")
    else:
        # overlap_indices = None  # Will be computed later
        logging.info("non cache: "+cache_file)
load_from_cache()

2025-01-12 10:33:44,950 [INFO] non cache: ./results/LLM_gpt2-small_layer_6_steer_polite-impolite_alpha_100_cnt_100_mean100/hyperparameters.json


In [12]:
# setup_logging(output_dir_base)

In [44]:
# Save hyperparameters
hyperparams = vars(args)

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load environment
load_environment(args.env_path)

# Load model and SAE
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)





2025-01-12 10:33:48,372 [INFO] Hyperparameters:
2025-01-12 10:33:48,375 [INFO]   layer: 6
2025-01-12 10:33:48,377 [INFO]   LLM: gpt2-small
2025-01-12 10:33:48,378 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus
2025-01-12 10:33:48,423 [INFO]   output_dir: ./results
2025-01-12 10:33:48,424 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-12 10:33:48,425 [INFO]   seed: 42
2025-01-12 10:33:48,427 [INFO]   data_size: 1000
2025-01-12 10:33:48,428 [INFO]   device: cpu
2025-01-12 10:33:48,429 [INFO]   alpha: 100
2025-01-12 10:33:48,430 [INFO]   steer: polite-impolite
2025-01-12 10:33:48,431 [INFO]   method: val_mul
2025-01-12 10:33:48,432 [INFO]   topk_mean: 100
2025-01-12 10:33:48,434 [INFO]   topk_cnt: 100
2025-01-12 10:33:48,435 [INFO]   batch_size: 32
2025-01-12 10:33:48,439 [INFO] HF_ENDPOINT: https://hf-mirror.com
2025-01-12 10:33:48,440 [INFO] Loading model: gpt2-small
2025-01-12 10:34

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [45]:
# Load dataset
all_dataset = load_and_prepare_polite_dataset(
    args.dataset_path, args.seed, args.data_size
)

2025-01-12 10:34:59,742 [INFO] Loading dataset from /home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus
2025-01-12 10:34:59,776 [INFO] Filtering dataset for impolite, polite, and neutral samples
2025-01-12 10:34:59,812 [INFO] Selected 1000 negative, 1000 positive, and 1000 neutral samples


In [166]:
args.steer

'polite-impolite'

In [167]:
def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [55]:
# Select a dataset steer based on steering preference

def get_activation_by_steer(texts:list):

    hook_point = sae.cfg.hook_name

    # Compute latents with batch processing
    batch_latents = compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    # 计算第二个维度的最大值
    max_dim1 = max(latent.shape[1] for latent in batch_latents)  # 第二个维度的最大值
    logging.info(f"最大长度:{max_dim1}")
    # 对每个 Tensor 进行填充（仅填充第二个维度）
    padded_latents_right = [
        torch.nn.functional.pad(latent, (0, 0, 0, max_dim1 - latent.size(1)), "constant", 0)
        for latent in batch_latents
    ]

    batch_latents_concatenated = torch.cat(padded_latents_right, dim=0)
    logging.info(f"Concatenated batch latents shape: {batch_latents_concatenated.shape}")

    # Analyze latents
    nz_mean, act_cnt, overlap_indices = analyze_latents(batch_latents_concatenated, top_k_mean=args.topk_mean, top_k_cnt=args.topk_cnt)
    return {"nz_mean":nz_mean,"nz_cnt":act_cnt}

In [169]:
args.steer=args.steer.lower()

In [170]:
all_dataset["train"]["Q+A"][:args.data_size][193]

TypeError: tuple indices must be integers or slices, not str

In [106]:
steer_info={}

In [56]:
steer_info={}
if args.steer=='polite-impolite' or args.steer=="pos-neg":
    text=pos_train_set["text"][:args.data_size]
    steer_info["pos"]=get_activation_by_steer(text)
    text=neg_train_set["text"][:args.data_size]
    steer_info["neg"]=get_activation_by_steer(text)
    text=neu_train_set["text"][:args.data_size]
    steer_info["neu"]=get_activation_by_steer(text)

2025-01-12 11:33:18,291 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]

2025-01-12 11:33:18,621 [INFO] Batch 1: Hidden states shape: torch.Size([32, 43, 768])
2025-01-12 11:33:18,623 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:11,  2.66it/s]2025-01-12 11:33:19,215 [INFO] Batch 2: Hidden states shape: torch.Size([32, 81, 768])
2025-01-12 11:33:19,216 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:14,  2.02it/s]2025-01-12 11:33:19,640 [INFO] Batch 3: Hidden states shape: torch.Size([32, 63, 768])
2025-01-12 11:33:19,642 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:01<00:13,  2.17it/s]2025-01-12 11:33:20,003 [INFO] Batch 4: Hidden states shape: torch.Size([32, 60, 768])
2025-01-12 11:33:20,004 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:01<00:11,  2.37it/s]2025-01-12 11:33:20,437 [INFO] Batch 5: Hidden states shape: torch.Size([32, 75, 768])
2025-01-12 11:33:20,438 [INFO] Encodi

In [109]:

if "cot" in args.steer:
    if args.steer in "cot-direct":
        texts=all_dataset["train"]["Q+A"][:args.data_size]
        print(type(texts))
        steer_info["direct"]=get_activation_by_steer(texts)
        
        texts=all_dataset["train"]["Q+COT_A"][:args.data_size]
        print(type(texts))
        steer_info["cot"]=get_activation_by_steer(texts)
        # print(texts[123])
    else:
        raise ValueError("????")

2025-01-10 14:50:52,557 [INFO] Running model with cache to obtain hidden states


<class 'list'>


Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-10 14:50:53,432 [INFO] Batch 1: Hidden states shape: torch.Size([32, 155, 768])
2025-01-10 14:50:53,433 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:01<00:31,  1.02s/it]2025-01-10 14:50:54,501 [INFO] Batch 2: Hidden states shape: torch.Size([32, 104, 768])
2025-01-10 14:50:54,503 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:02<00:30,  1.03s/it]2025-01-10 14:50:55,365 [INFO] Batch 3: Hidden states shape: torch.Size([32, 115, 768])
2025-01-10 14:50:55,367 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:02<00:27,  1.05it/s]2025-01-10 14:50:56,358 [INFO] Batch 4: Hidden states shape: torch.Size([32, 121, 768])
2025-01-10 14:50:56,360 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:03<00:27,  1.03it/s]2025-01-10 14:50:57,033 [INFO] Batch 5: Hidden states shape: t

<class 'list'>


Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-10 14:51:29,259 [INFO] Batch 1: Hidden states shape: torch.Size([32, 394, 768])
2025-01-10 14:51:29,259 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:03<01:56,  3.74s/it]2025-01-10 14:51:32,466 [INFO] Batch 2: Hidden states shape: torch.Size([32, 266, 768])
2025-01-10 14:51:32,468 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:06<01:41,  3.39s/it]2025-01-10 14:51:35,617 [INFO] Batch 3: Hidden states shape: torch.Size([32, 341, 768])
2025-01-10 14:51:35,619 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:10<01:35,  3.29s/it]2025-01-10 14:51:39,830 [INFO] Batch 4: Hidden states shape: torch.Size([32, 379, 768])
2025-01-10 14:51:39,832 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:14<01:42,  3.68s/it]2025-01-10 14:51:43,079 [INFO] Batch 5: Hidden states shape: t

In [49]:
args.steer

'polite-impolite'

In [50]:
pos_train_set["text"][:args.data_size]

["On another note, I'm going to assume that you were the one who also added ridership data for various LIRR stations, am I correct? If so, could you kindly point me towards your source?",
 "You're welcome. Do you have an opinion as to whether they would be worth adding to <url>?",
 'Could you please tell us what toppings are on them?',
 "sorry I can't find a folding trike on the Catrike website, do you know the model name?",
 'This happens for me, too, but only on my dev machine.  Did you ever find an explanation?',
 'We probably need to see your code or at least get some more information on what is not working. If you are trying to overlay onto google maps, is your wfs layer in spherical mercator projection?',
 'No problem. Have you heard anything new about HRE?',
 'Yes but making it cross browser will be headache. I have old code somewhere, untested for long time - want me to post it anyway?',
 "A few more details are required. What's your latitude, longitude, date and timezone...?",

In [57]:
len(pos_train_set["text"][:args.data_size])

1000

In [58]:
args.data_size

1000

In [107]:

# for steer_key in ["pos","neu","neg"]:
#     if steer_key == "pos":
#         selected_set = pos_train_set
#     elif steer_key == "neg":
#         selected_set = neg_train_set
#     elif steer_key=="neu":
#         selected_set = neu_train_set

#     texts = selected_set["text"][:args.data_size]
#     a=get_activation_by_steer(texts)
#     steer_info[steer_key]=a


2025-01-11 01:02:29,893 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-11 01:02:30,152 [INFO] Batch 1: Hidden states shape: torch.Size([32, 44, 768])
2025-01-11 01:02:30,153 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:09,  3.39it/s]2025-01-11 01:02:30,552 [INFO] Batch 2: Hidden states shape: torch.Size([32, 82, 768])
2025-01-11 01:02:30,554 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:11,  2.57it/s]2025-01-11 01:02:30,979 [INFO] Batch 3: Hidden states shape: torch.Size([32, 64, 768])
2025-01-11 01:02:30,981 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:01<00:11,  2.53it/s]2025-01-11 01:02:31,348 [INFO] Batch 4: Hidden states shape: torch.Size([32, 61, 768])
2025-01-11 01:02:31,350 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:01<00:

In [59]:
steer_info["pos"]

{'nz_mean': tensor([15.0542,  5.3331,  0.0000,  ...,  1.0033,  0.8486,  0.0000],
        grad_fn=<WhereBackward0>),
 'nz_cnt': tensor([ 38,   4,   0,  ..., 159,  10,   0])}

In [60]:
steer_info["pos"]["nz_mean"].shape

torch.Size([24576])

In [42]:
# steer_info["dif_neg_pos"]={"steer":"dif_neg_pos","nz_cnt":steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"],"nz_mean":steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"]}

In [43]:
# steer_info["dif_neg_pos_relu"]={"nz_cnt":torch.relu(steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"]),"nz_mean":torch.relu(steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])}

In [55]:
# steer_info["dif_neg_pos_relu"],steer_info["dif_neg_pos"]

In [None]:
sourceource="cot"
target="direct"

In [86]:
source="neg"
target="pos"

In [87]:
torch.all((steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])==0)

tensor(False)

In [88]:
steer_info[f"dif_{target}-{source}_relu"]={"nz_cnt":torch.relu(steer_info[target]["nz_cnt"]-steer_info[source]["nz_cnt"]),"nz_mean":torch.relu(steer_info[target]["nz_mean"]-steer_info[source]["nz_mean"])}
top_k=100
steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"],top_k)

In [110]:

# steer_info[f"dif_{a}-{b}_relu"]={"nz_cnt":torch.relu(steer_info[a]["nz_cnt"]-steer_info[b]["nz_cnt"]),"nz_mean":torch.relu(steer_info[a]["nz_mean"]-steer_info[b]["nz_mean"])}
# top_k=100
# steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{a}-{b}_relu"]["nz_cnt"],top_k)

In [89]:
# 假设 steer_info[f"dif_{b}-{a}_relu"]["nz_cnt"] 是一个 NumPy 数组
nz_cnt = steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"]

# 先获取非零元素的索引
nz_indices = np.nonzero(nz_cnt)
torch.all(nz_cnt == 0)

tensor(False)

In [90]:
steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"].shape

torch.Size([24576])

In [91]:
steering_vectors,steer_indices,

(tensor([2668, 2452, 1787, 1716, 1703, 1586, 1374, 1338, 1322, 1321, 1318, 1189,
         1173, 1132, 1127, 1108,  967,  934,  804,  760,  720,  679,  664,  664,
          630,  590,  549,  532,  516,  505,  499,  498,  496,  490,  490,  486,
          484,  475,  475,  472,  469,  448,  444,  437,  436,  429,  416,  415,
          412,  401,  391,  387,  384,  383,  379,  377,  376,  371,  368,  367,
          364,  362,  355,  353,  350,  349,  343,  341,  340,  339,  337,  336,
          332,  328,  327,  327,  327,  325,  324,  321,  320,  317,  316,  311,
          310,  307,  306,  306,  305,  303,  302,  301,  300,  300,  298,  291,
          290,  288,  280,  276]),
 tensor([15132, 14556, 17504,  1738, 22371, 10512,  8532, 24545, 21219,  7467,
         15104, 17735, 18200, 16383, 14021, 20973,  7574, 16093, 15452,   912,
         19843, 20588,  6464,  9862,  9430, 15696, 21027,  2129, 22301, 11185,
          4283, 13917, 10420, 16371, 17642, 12502,  3639, 19547,  4659, 21604,
 

In [92]:
steer_info[f"dif_{target}-{source}_relu"]["nz_mean"][steer_indices]

tensor([2.1293e-01, 1.2582e-01, 1.1048e-01, 8.9840e-02, 2.3624e-01, 5.7488e-02,
        2.5382e-01, 1.1420e-01, 5.0775e-02, 6.3129e-02, 5.2324e-01, 1.2435e-01,
        1.2033e-01, 1.1321e-02, 2.0315e-01, 2.4938e-01, 1.3542e-01, 2.0019e-01,
        1.0011e-01, 2.2507e-02, 0.0000e+00, 0.0000e+00, 2.2052e-02, 1.0376e-01,
        1.6264e+00, 1.7587e+00, 1.7258e-01, 1.9118e-02, 1.1884e-01, 6.0730e-02,
        4.3861e-02, 0.0000e+00, 0.0000e+00, 1.7873e-02, 5.0408e-02, 1.5525e-01,
        0.0000e+00, 5.5373e-01, 6.4238e-02, 1.3958e-01, 1.0816e-02, 3.8851e-02,
        0.0000e+00, 0.0000e+00, 1.5910e-01, 3.9630e-02, 6.9233e-03, 6.1554e-02,
        5.9235e-04, 1.3167e-01, 2.8372e-02, 4.3450e-02, 1.7542e-01, 3.0830e-02,
        3.6164e-02, 2.7652e-01, 0.0000e+00, 3.1988e-02, 1.5726e-01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0249e-01, 1.7577e-01, 0.0000e+00, 6.7516e-02,
        5.5641e-02, 2.4995e-01, 9.3064e-02, 1.8521e-02, 3.4680e-01, 3.7896e-02,
        1.4039e-01, 2.5431e-02, 4.2825e-

In [94]:
steer_info["dif_neg-pos"]["nz_mean"][steer_indices]

KeyError: 'dif_neg-pos'

In [95]:
steer_indices

tensor([15132, 14556, 17504,  1738, 22371, 10512,  8532, 24545, 21219,  7467,
        15104, 17735, 18200, 16383, 14021, 20973,  7574, 16093, 15452,   912,
        19843, 20588,  6464,  9862,  9430, 15696, 21027,  2129, 22301, 11185,
         4283, 13917, 10420, 16371, 17642, 12502,  3639, 19547,  4659, 21604,
        24085, 16053, 23494, 21264, 18791, 21631,  9281, 12604, 16101, 12873,
        10524, 12863,  3506, 15256, 10951,  4966, 13202, 18318, 12602, 10889,
        20000,  1622, 12499, 12639, 17996, 23227,  2797,   640, 11728, 22446,
         5462, 16272, 20630,  4075,  4915,  7775, 19972, 22348, 16915, 17462,
         4891,  8137, 11709,  2812, 21306,  8272, 14729,  6244, 13974, 21303,
        17309, 21652, 21716, 16873,  3638, 23123, 12750, 18486,   113, 23772])

In [96]:

def compute_steering_vectors(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            steering_vectors += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors
steering_vectors=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["nz_mean"],method="val_mul")

2025-01-12 11:52:26,495 [INFO] Computing steering vectors using method: val_mul
2025-01-12 11:52:26,506 [INFO] Steering vectors computed with shape: torch.Size([768])


In [64]:
# model.to_tokens("<|endoftext|>")

tensor([[50256, 50256]])

In [134]:
model.tokenizer.eos_token

'<|endoftext|>'

In [135]:
model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [None]:
steering_vectors

tensor([ 1.7512e-01, -2.6205e-02,  1.2120e-01,  2.1542e-01,  2.8142e-01,
        -2.0455e-01,  5.0054e-01, -4.9751e-01, -1.4036e-01,  1.3139e-03,
        -2.0526e-01, -2.3867e-01, -1.9055e-01, -4.3484e-02,  4.2371e-01,
         2.0112e-01, -8.5613e-02, -8.7889e-02,  1.0992e-01,  1.8500e-01,
         1.2432e-01,  3.7251e-02,  2.4164e-01,  5.7944e-02, -7.1362e-02,
        -2.9558e-02, -1.6983e-01, -4.4247e-02, -1.8696e-01,  6.0553e-02,
        -2.2151e-01, -1.7050e-01, -9.3700e-02, -9.2565e-02,  1.8897e-02,
        -3.3768e-01, -4.7429e-02,  5.9305e-02,  1.8874e-01, -1.4750e-01,
         4.9142e-02,  3.8031e-02,  1.3019e-01, -2.2149e-01, -4.7503e-02,
        -2.3113e-01,  2.8352e-01, -1.4157e-01,  8.4853e-02, -3.3906e-02,
         4.6921e-02, -3.6691e-01, -1.0091e-02, -7.8832e-02,  1.8692e-01,
        -3.0192e-03, -5.8025e-02,  8.9214e-02, -2.3700e-02,  1.5276e-01,
        -9.4599e-02,  1.2105e-01,  1.8672e-01, -3.9875e-02, -3.8224e-01,
         1.8494e-01, -1.7447e-01, -3.7863e-01, -9.4

In [98]:
# # Define steering hook
# steering_on = True  # This will be toggled in run_generate
# alpha = args.alpha
# method = args.method  # Store method for clarity

def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return
    logging.info(f"hook is {hook}\n")
    if steering_on:
        resid_pre[:, :-1, :] += args.alpha * steering_vectors
        #注意这里是对batch_size,0:-9,hidden_size这样的隐藏层做干预

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch,prepend_bos=False)
        result = model.generate(
            stop_at_eos=True,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result

def run_generate(example_prompt,sampling_kwargs):
    model.reset_hooks()
    editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, :])
    # print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
    # generated_texts = res_str
    for idx, text in enumerate(res_str):
        logging.info(f"Generated Text: {idx+1}: {text}")
        
    return res_str

# Define sampling parameters
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

# Example prompt from the selected set
example_prompt = " WOW, you are so"
# example_prompt=model.tokenizer
# print( all_dataset["val"]["prompt"][:3],all_dataset["val"]["A"][:3])
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering
steering_on = False
alpha = 0
logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs)
logging.info("干预之后的结果")
# bef,aft=args.steer.split("-")
logging.info(f"干预方向{source}->{target},礼貌任务下，neg=impolite")
# Generate with steering
steering_on = True
alpha = 30
logging.info("** Generating texts with steering... Target **")
generated_texts_with_steer = run_generate(example_prompt, sampling_kwargs)

# Combine generated texts
# all_generated_texts = generated_texts_no_steer + generated_texts_with_steer



2025-01-12 11:52:51,844 [INFO] Example prompt:  WOW, you are so
2025-01-12 11:52:51,848 [INFO] Generating texts **without** steering... 


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-12 11:52:51,911 [INFO] hook is HookPoint()

2025-01-12 11:52:53,609 [INFO] Generated Text: 1:  WOW, you are so good at it.

You know what I mean? You're a genius. You've got the best of both worlds. And you're also a great person, too!

I'm not saying that I'm going to be the next
2025-01-12 11:52:53,610 [INFO] Generated Text: 2:  WOW, you are so smart. You know how to use your brain and your body to do things that you don't even think of doing?

I'm not going to lie, I'm a little bit surprised that you're still here. I mean, it's
2025-01-12 11:52:53,610 [INFO] Generated Text: 3:  WOW, you are so much better than me. I'm not even going to lie to you. I'm just going to tell you what happened and how it happened.

I was in the hospital with my wife and two kids when we were attacked by a group of
2025-01-12 11:52:53,611 [INFO] 干预之后的结果
2025-01-12 11:52:53,611 [INFO] 干预方向neg->pos,礼貌任务下，neg=impolite
2025-01-12 11:52:53,611 [INFO] ** Generating texts with steering... Target **


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-12 11:52:53,641 [INFO] hook is HookPoint()

2025-01-12 11:52:55,049 [INFO] Generated Text: 1:  WOW, you are so excited about the upcoming release of the next installment of our series. We hope you will enjoy your time with us and we look forward to seeing you all next week!

Thanks for your continued support!<|endoftext|>
2025-01-12 11:52:55,049 [INFO] Generated Text: 2:  WOW, you are so excited about the upcoming release of the game. We hope that you will continue to support us in your efforts to keep this project alive and well!

Thanks for your continued support!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
2025-01-12 11:52:55,050 [INFO] Generated Text: 3:  WOW, you are so excited about the upcoming release of this game. I hope you will continue to support us in your efforts to keep this project alive and well!

Thanks for your continued support!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


# 理论上来讲，不礼貌的输出应该有很多疑问句？例如what？ ha？why？
而礼貌的输出应该有很多正常的词语

In [137]:
all_dataset["val"]["prompt"][:10]

['Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?\nA: ',
 "Q: Adrianna has 10 pieces of gum to share with her friends. There wasn't enough gum for all her friends, so she went to the store to get 3 more pieces of gum. She gave out gum to 11 friends. How many pieces of gum does Adrianna have now?\nA: ",
 'Q: Paul needed to buy some new clothes for work.  He had a 10% off coupon that he could use on his entire purchase after any other discounts.  Paul bought 4 dress shirts at $15.00 apiece, 2 pairs of pants that each cost $40.00.  He found a suit for $150.00 and 2 sweaters for $30.00 each.  When he got to the register, the clerk told him that the store was offering 20% off of everything in the store.  After the discounts and the coupon, how much did Paul spend on his new clothes?\nA: ',
 'Q: Donny saves

In [138]:
all_dataset["val"]["A"][:10]

['9800', '2', '252', '28', '5', '16', '16', '550', '26', '107']

In [None]:
# Save results
save_results(
    output_dir=output_dir_base,
    nz_mean=nz_mean,
    act_cnt=act_cnt,
    generated_texts=all_generated_texts,
    hyperparams=hyperparams
)