In [1]:
# Imports and Setup
import argparse
import os
import logging
from typing import Tuple
import json
from log import setup_logging
from tqdm import tqdm  # For progress bars


In [2]:

import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
# from sae_lens.analysis.feature_statistics import (
#     get_all_stats_dfs,
#     get_W_U_W_dec_stats_df,
# )
# from sae_lens.analysis.tsea import (
#     get_enrichment_df,
#     manhattan_plot_enrichment_scores,
#     plot_top_k_feature_projections_by_token_and_category,
#     get_baby_name_sets,
#     get_letter_gene_sets,
#     generate_pos_sets,
#     get_test_gene_sets,
#     get_gene_set_from_regex,
# )
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np
# import plotly_express as px

# 进行基础情感干预实验

In [None]:
* 运行对应实验的时候跑对应的args_dict就行

In [3]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5",
    "prompt_path":"/home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "task":"sentiment",# “sentiment”,"cot","polite"
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100, # 这个alpha后面慢慢调节
    "steer": "pos-neg",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul" 用val_mul会比较好
    "topk_mean": 100, # 选取前topk 个均值激活，这个效果一般，会导致很多如what？why？这种被激活
    "topk_cnt": 100, # 选取前topk个频率激活，目前默认这个，效果很好
    "batch_size": 32 # 这个好像没用上
}

# 进行COT相关实验

In [75]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "cot-direct",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

# 进行礼貌实验
/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus

In [76]:
# Define hyperparameters
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "seed": 42,
    "data_size": 1000,
    "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 100,
    "steer": "polite-impolite",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
    "method": "val_mul",  # Options: "mean", "val_mul"
    "topk_mean": 100,
    "topk_cnt": 100,
    "batch_size": 32
}

In [149]:
# 进行礼貌实验

In [4]:
# Configuration and Hyperparameters
# 将字典转换为 argparse.Namespace 对象
args = argparse.Namespace(**args_dict)
# 测试访问属性
print(args.layer)  # 输出: 10
print(args.LLM)  # 输出: gpt2-small
print(args.output_dir)  # 输出: ./results
print(args.steer)

6
gpt2-small
./results
pos-neg


In [5]:
# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)

# Setup logging
setup_logging(output_dir_base)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")


2025-01-14 10:45:00,207 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100/execution.log
2025-01-14 10:45:00,208 [INFO] Hyperparameters:
2025-01-14 10:45:00,210 [INFO]   layer: 6
2025-01-14 10:45:00,210 [INFO]   LLM: gpt2-small
2025-01-14 10:45:00,211 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 10:45:00,212 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 10:45:00,214 [INFO]   output_dir: ./results
2025-01-14 10:45:00,215 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 10:45:00,216 [INFO]   task: sentiment
2025-01-14 10:45:00,217 [INFO]   seed: 42
2025-01-14 10:45:00,218 [INFO]   data_size: 1000
2025-01-14 10:45:00,219 [INFO]   device: cpu
2025-01-14 10:45:00,220 [INFO]   alpha: 100
2025-01-14 10:45:00,221 [INFO]   steer: pos-neg
2025-01-14 10:45:00,22

In [6]:
# Load Environment Variables
 
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")

load_environment(args.env_path)


2025-01-14 10:45:03,318 [INFO] HF_ENDPOINT: https://hf-mirror.com


In [12]:
import re

def load_and_prepare_triple_dataset(dataset_path: str,dataset_name:str, seed: int, num_samples: int):
    """
    支持positive\neutral\negative三元组数据类型，例如 sst5，polite数据集和multi-class数据集

    Args:
        dataset_path (str): _description_
        dataset_name : "sst5","multiclass","polite"
        seed (int): _description_
        num_samples (int): _description_

    Returns:
        _type_: _description_
    """
    assert dataset_name in ["sst5","multiclass","polite"]
    if dataset_name in ["sst5"]:
        neu_label=2 # 中性情感对应的label
        assert "sst5" in dataset_path
    elif  dataset_name in ["polite","multiclass"]:
        neu_label=1
    logging.info(f"Loading dataset from ****{dataset_path}***")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    neg_train_set = dataset['train'].filter(lambda example: example['label'] < neu_label).select(range(num_samples))
    pos_train_set = dataset['train'].filter(lambda example: example['label'] == neu_label).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] > neu_label ).select(range(num_samples))

    logging.info(f"Selected {len(neg_train_set)} negative, {len(pos_train_set)} positive, and {len(neu_train_set)} neutral samples")
    print(dataset)
    val_set=dataset['validation']
    test_set=dataset["test"]
    return neg_train_set, pos_train_set, neu_train_set,val_set,test_set
def load_and_prepare_COT_dataset(dataset_path:str,seed:int,num_samples:int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)
    logging.info("Filtering dataset for COT")
    # 定义一个函数来提取答案
    def extract_answer(text):
        # 使用正则表达式提取答案
        match = re.search(r'#### ([-+]?\d*\.?\d+/?\d*)', text)
        if match:
            label=match.group(1)
            return label
        else:
            raise ValueError("Modify your re expression")
    def concat_QA(example,col1,col2,tag):
        combined = f"{example[col1]}{tag}{example[col2]}"  # 用空格拼接
        return combined
    def replace_col(example,col1,target,pattern):
        return example[col1].replace(target,pattern)
    # 应用函数并创建新列
    dataset = dataset.map(lambda example: {'A': extract_answer(example['response'])})
    dataset = dataset.map(lambda example: {'Q+A': concat_QA(example,"prompt","A","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': concat_QA(example,"prompt","response","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': replace_col(example,"Q+COT_A","#### ","")})
    # 查看处理后的数据集
    print("Q+A\n",dataset['train'][103]['Q+A'])
    print("Q+COT_A\n",dataset['train'][103]['Q+COT_A'])
    return dataset
    

In [13]:
args.steer

'pos-neg'

In [14]:
# Load and Prepare Dataset


if "neg" in args.steer or "pos" in args.steer and args.steer=="sentiment":
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set = load_and_prepare_triple_dataset(
        args.dataset_path, "sst5",args.seed, args.data_size
    )
elif "cot" in args.steer or "COT" in args.steer:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "polite" in args.steer:
    logging.info("polite"*10)
    neg_train_set, pos_train_set, neu_train_set=load_and_prepare_triple_dataset(args.dataset_path,"polite", args.seed, args.data_size)
else:
    raise ValueError("No Supported")


2025-01-14 10:46:11,123 [INFO] Loading dataset from ****/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5***
Repo card metadata block was not found. Setting CardData to empty.
2025-01-14 10:46:11,305 [INFO] Filtering dataset for negative, positive, and neutral samples
2025-01-14 10:46:11,314 [INFO] Selected 1000 negative, 1000 positive, and 1000 neutral samples


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 8544
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1101
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 2210
    })
})


In [15]:
assert neg_train_set[10]!=pos_train_set[10]

In [16]:
pos_train_set[10],neg_train_set[10]

({'text': 'some are fascinating and others are not , and in the end , it is almost a good movie .',
  'label': 2,
  'label_text': 'neutral'},
 {'text': 'the lousy lead performances ... keep the movie from ever reaching the comic heights it obviously desired .',
  'label': 1,
  'label_text': 'negative'})

In [17]:


def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    logging.info("Running model with cache to obtain hidden states")
    batch_latents = []

    # 使用 tqdm 显示进度条
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i + batch_size]
        sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=False, device=device)
        batch_hidden_states = cache[hook_point]
        logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

        logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
        # 假设 sae.encode 支持批量编码
        latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
        batch_latents.append(latents)
        

    logging.info(f"Total batches processed: {len(batch_latents)}")
    return batch_latents

In [18]:

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [19]:

def compute_steering_vectors(sae: SAE, overlap_indices: Tensor, nz_mean: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[overlap_indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for important_idx in overlap_indices:
            steering_vectors += nz_mean[important_idx].item() * sae.W_dec[important_idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors


In [20]:

def save_results(output_dir: str, nz_mean: Tensor, act_cnt: Tensor, generated_texts: list, hyperparams: dict):
    os.makedirs(output_dir, exist_ok=True)
    # Save nz_mean and act_cnt
    nz_stats_path = os.path.join(output_dir, 'nz_stats.pt')
    logging.info(f"Saving nz_mean and act_cnt to {nz_stats_path}")
    torch.save({
        'nz_mean': nz_mean,
        'act_cnt': act_cnt,
    }, nz_stats_path)

    # Save generated texts
    generated_texts_path = os.path.join(output_dir, 'generated_texts.txt')
    logging.info(f"Saving generated texts to {generated_texts_path}")
    with open(generated_texts_path, 'w') as f:
        for text in generated_texts:
            f.write(text + "\n")

    # Save hyperparameters
    hyperparams_path = os.path.join(output_dir, 'hyperparameters.json')
    logging.info(f"Saving hyperparameters to {hyperparams_path}")
    with open(hyperparams_path, 'w') as f:
        json.dump(hyperparams, f, indent=4)

    logging.info("All results saved successfully.")

In [21]:
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)
output_dir_base

'./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100'

In [22]:
def load_from_cache():
    cache_exists = False
    cache_file = os.path.join(output_dir_base, 'hyperparameters.json')
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            cached_data = json.load(f)
        cached_hash = cached_data.get('hyperparams_hash')

    if cache_exists:
        # Load nz_mean and act_cnt from cache
        # nz_stats_path = os.path.join(output_dir_base, 'nz_stats.pt')
        # nz_act = torch.load(nz_stats_path)
        # nz_mean = nz_act['nz_mean']
        # act_cnt = nz_act['act_cnt']
        # overlap_indices = nz_act.get('overlap_indices', None)  # If overlap_indices was saved
        logging.info("load from cache")
    else:
        # overlap_indices = None  # Will be computed later
        logging.info("non cache: "+cache_file)
load_from_cache()

2025-01-14 10:46:33,521 [INFO] non cache: ./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100/hyperparameters.json


In [12]:
# setup_logging(output_dir_base)

In [23]:
# Save hyperparameters
hyperparams = vars(args)

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load environment
load_environment(args.env_path)

# Load model and SAE
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)

2025-01-14 10:46:37,432 [INFO] Hyperparameters:
2025-01-14 10:46:37,434 [INFO]   layer: 6
2025-01-14 10:46:37,436 [INFO]   LLM: gpt2-small
2025-01-14 10:46:37,568 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 10:46:37,570 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 10:46:37,571 [INFO]   output_dir: ./results
2025-01-14 10:46:37,573 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 10:46:37,574 [INFO]   task: sentiment
2025-01-14 10:46:37,576 [INFO]   seed: 42
2025-01-14 10:46:37,578 [INFO]   data_size: 1000
2025-01-14 10:46:37,579 [INFO]   device: cpu
2025-01-14 10:46:37,580 [INFO]   alpha: 100
2025-01-14 10:46:37,582 [INFO]   steer: pos-neg
2025-01-14 10:46:37,583 [INFO]   method: val_mul
2025-01-14 10:46:37,584 [INFO]   topk_mean: 100
2025-01-14 10:46:37,585 [INFO]   topk_cnt: 100
2025-01-14 10:46:37,587 [INFO]   batch_

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [128]:
# # Load dataset
# all_dataset = load_and_prepare_triple_dataset(
#     args.dataset_path, args.seed, args.data_size
# )

In [129]:
args.steer

'pos-neg'

In [130]:
def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt, cnt_indices

In [131]:
# Select a dataset steer based on steering preference

def get_activation_by_steer(texts:list):

    hook_point = sae.cfg.hook_name

    # Compute latents with batch processing
    batch_latents = compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    # 计算第二个维度的最大值
    max_dim1 = max(latent.shape[1] for latent in batch_latents)  # 第二个维度的最大值
    logging.info(f"最大长度:{max_dim1}")
    # 对每个 Tensor 进行填充（仅填充第二个维度）
    padded_latents_right = [
        torch.nn.functional.pad(latent, (0, 0, 0, max_dim1 - latent.size(1)), "constant", 0)
        for latent in batch_latents
    ]

    batch_latents_concatenated = torch.cat(padded_latents_right, dim=0)
    logging.info(f"Concatenated batch latents shape: {batch_latents_concatenated.shape}")

    # Analyze latents
    nz_mean, act_cnt, overlap_indices = analyze_latents(batch_latents_concatenated, top_k_mean=args.topk_mean, top_k_cnt=args.topk_cnt)
    return {"nz_mean":nz_mean,"nz_cnt":act_cnt}

In [132]:
args.steer=args.steer.lower()

In [133]:
# all_dataset["train"]["Q+A"][:args.data_size][193]

In [134]:
steer_info={}

In [135]:
steer_info={}
if args.steer=='polite-impolite' or args.steer=="pos-neg":
    text=pos_train_set["text"][:args.data_size]
    steer_info["pos"]=get_activation_by_steer(text)
    text=neg_train_set["text"][:args.data_size]
    steer_info["neg"]=get_activation_by_steer(text)
    text=neu_train_set["text"][:args.data_size]
    steer_info["neu"]=get_activation_by_steer(text)

2025-01-12 16:11:14,548 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-12 16:11:14,865 [INFO] Batch 1: Hidden states shape: torch.Size([32, 47, 768])
2025-01-12 16:11:14,865 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:10,  2.90it/s]2025-01-12 16:11:15,192 [INFO] Batch 2: Hidden states shape: torch.Size([32, 51, 768])
2025-01-12 16:11:15,194 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:09,  3.00it/s]2025-01-12 16:11:15,499 [INFO] Batch 3: Hidden states shape: torch.Size([32, 50, 768])
2025-01-12 16:11:15,501 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:00<00:09,  3.07it/s]2025-01-12 16:11:15,826 [INFO] Batch 4: Hidden states shape: torch.Size([32, 52, 768])
2025-01-12 16:11:15,828 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:01<00:

In [136]:

if "cot" in args.steer:
    if args.steer in "cot-direct":
        texts=all_dataset["train"]["Q+A"][:args.data_size]
        print(type(texts))
        steer_info["direct"]=get_activation_by_steer(texts)
        
        texts=all_dataset["train"]["Q+COT_A"][:args.data_size]
        print(type(texts))
        steer_info["cot"]=get_activation_by_steer(texts)
        # print(texts[123])
    else:
        raise ValueError("????")

In [137]:
args.steer

'pos-neg'

In [138]:
pos_train_set["text"][:args.data_size]

["the ring is worth a look , if you do n't demand much more than a few cheap thrills from your halloween entertainment .",
 "highly irritating at first , mr. koury 's passive technique eventually begins to yield some interesting results .",
 'this off-putting french romantic comedy is sure to test severely the indulgence of fans of amélie .',
 "the santa clause 2 is a barely adequate babysitter for older kids , but i 've got to give it thumbs down .",
 'but fans of the show should not consider this a diss .',
 'and that holds true for both the movie and the title character played by brendan fraser .',
 'just watch bettany strut his stuff .',
 'the impulses that produced this project ... are commendable , but the results are uneven .',
 'an extraordinarily silly thriller .',
 "if legendary shlockmeister ed wood had ever made a movie about a vampire , it probably would look a lot like this alarming production , adapted from anne rice 's novel the vampire chronicles .",
 'some are fascina

In [139]:
len(pos_train_set["text"][:args.data_size])

1000

In [140]:
args.data_size

1000

In [107]:

# for steer_key in ["pos","neu","neg"]:
#     if steer_key == "pos":
#         selected_set = pos_train_set
#     elif steer_key == "neg":
#         selected_set = neg_train_set
#     elif steer_key=="neu":
#         selected_set = neu_train_set

#     texts = selected_set["text"][:args.data_size]
#     a=get_activation_by_steer(texts)
#     steer_info[steer_key]=a


2025-01-11 01:02:29,893 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-11 01:02:30,152 [INFO] Batch 1: Hidden states shape: torch.Size([32, 44, 768])
2025-01-11 01:02:30,153 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:09,  3.39it/s]2025-01-11 01:02:30,552 [INFO] Batch 2: Hidden states shape: torch.Size([32, 82, 768])
2025-01-11 01:02:30,554 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:11,  2.57it/s]2025-01-11 01:02:30,979 [INFO] Batch 3: Hidden states shape: torch.Size([32, 64, 768])
2025-01-11 01:02:30,981 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:01<00:11,  2.53it/s]2025-01-11 01:02:31,348 [INFO] Batch 4: Hidden states shape: torch.Size([32, 61, 768])
2025-01-11 01:02:31,350 [INFO] Encoding hidden states for batch 4
Processing batches:  12%|█▎        | 4/32 [00:01<00:

In [59]:
steer_info["pos"]

{'nz_mean': tensor([15.0542,  5.3331,  0.0000,  ...,  1.0033,  0.8486,  0.0000],
        grad_fn=<WhereBackward0>),
 'nz_cnt': tensor([ 38,   4,   0,  ..., 159,  10,   0])}

In [34]:
steer_info["pos"]["nz_mean"].shape

torch.Size([24576])

In [42]:
# steer_info["dif_neg_pos"]={"steer":"dif_neg_pos","nz_cnt":steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"],"nz_mean":steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"]}

In [43]:
# steer_info["dif_neg_pos_relu"]={"nz_cnt":torch.relu(steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"]),"nz_mean":torch.relu(steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])}

In [55]:
# steer_info["dif_neg_pos_relu"],steer_info["dif_neg_pos"]

In [35]:
sourceource="cot"
target="direct"

In [141]:
source="pos"
target="neg"

In [142]:
torch.all((steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])==0)

tensor(False)

In [143]:
steer_info[f"dif_{target}-{source}_relu"]={"nz_cnt":torch.relu(steer_info[target]["nz_cnt"]-steer_info[source]["nz_cnt"]),"nz_mean":torch.relu(steer_info[target]["nz_mean"]-steer_info[source]["nz_mean"])}
top_k=100
steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"],top_k)

In [110]:

# steer_info[f"dif_{a}-{b}_relu"]={"nz_cnt":torch.relu(steer_info[a]["nz_cnt"]-steer_info[b]["nz_cnt"]),"nz_mean":torch.relu(steer_info[a]["nz_mean"]-steer_info[b]["nz_mean"])}
# top_k=100
# steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{a}-{b}_relu"]["nz_cnt"],top_k)

In [144]:
# 假设 steer_info[f"dif_{b}-{a}_relu"]["nz_cnt"] 是一个 NumPy 数组
nz_cnt = steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"]

# 先获取非零元素的索引
nz_indices = np.nonzero(nz_cnt)
torch.all(nz_cnt == 0)

tensor(False)

In [145]:
steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"].shape

torch.Size([24576])

In [146]:
steering_vectors,steer_indices,

(tensor([976, 976, 950, 866, 803, 697, 670, 634, 618, 582, 576, 570, 556, 549,
         534, 533, 526, 523, 506, 471, 470, 462, 456, 456, 445, 416, 415, 390,
         377, 377, 370, 367, 366, 360, 348, 347, 344, 320, 312, 307, 303, 302,
         291, 289, 288, 287, 273, 268, 263, 263, 260, 257, 256, 255, 249, 241,
         237, 235, 233, 233, 232, 231, 230, 228, 224, 219, 216, 216, 216, 215,
         212, 210, 205, 204, 201, 201, 200, 195, 195, 194, 194, 189, 187, 186,
         183, 182, 179, 177, 176, 175, 174, 173, 173, 172, 169, 167, 166, 165,
         163, 163]),
 tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
         17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
         18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
         10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
         10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
         11185,  1167, 18722,  

In [147]:
steer_info[f"dif_{target}-{source}_relu"]["nz_mean"][steer_indices]

tensor([7.9328e-02, 3.8317e-02, 2.0146e-02, 2.3163e-02, 2.8212e-02, 0.0000e+00,
        5.1741e-02, 4.0420e-02, 0.0000e+00, 2.9353e-02, 2.1048e-02, 0.0000e+00,
        0.0000e+00, 1.4081e-02, 7.1191e-02, 4.8021e-02, 1.1438e-04, 0.0000e+00,
        4.0448e-02, 4.5769e-02, 7.2892e-02, 4.8401e-03, 0.0000e+00, 9.2183e-03,
        0.0000e+00, 3.1627e-02, 1.8980e-02, 2.9644e-02, 1.5619e-02, 1.5609e-02,
        0.0000e+00, 2.3832e-02, 2.7588e-02, 0.0000e+00, 4.8306e-02, 1.6971e-02,
        4.3642e-02, 0.0000e+00, 2.7315e-02, 2.3624e-02, 0.0000e+00, 0.0000e+00,
        2.6217e-02, 6.4675e-02, 6.5274e-03, 0.0000e+00, 6.5705e-02, 3.3928e-03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.1977e-03, 6.2665e-03,
        2.9225e-03, 7.8505e-02, 3.3161e-02, 0.0000e+00, 1.0959e-02, 6.2663e-02,
        3.1515e-02, 1.6875e-02, 3.7916e-02, 4.5186e-02, 2.7820e-02, 0.0000e+00,
        1.0531e-02, 1.1299e-01, 4.5587e-02, 0.0000e+00, 0.0000e+00, 3.3021e-02,
        8.4106e-02, 5.1196e-02, 0.0000e+

In [44]:
# steer_info["dif_-pos"]["nz_mean"][steer_indices]

In [148]:
steer_indices

tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
        17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
        18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
        10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
        10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
        11185,  1167, 18722,  9082, 12797,   222, 18146,  4163,  7417,  8896,
         5203, 13252, 21943, 11859,  2282, 13817, 12488, 23355, 20830,  3219,
         6861, 24114, 16460,  5303,  9148, 12460,  1374, 10526,  5741,  7467,
         6545, 22906, 16318, 17732,  4173,   502, 22288, 23847, 17504,  2211,
        15256,  4063,  4994, 15262,  5330, 10282, 14954, 14440,  7383,    98])

In [149]:

def compute_steering_vectors(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            steering_vectors += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors
steering_vectors=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["nz_mean"],method="val_mul")

2025-01-12 16:12:50,606 [INFO] Computing steering vectors using method: val_mul
2025-01-12 16:12:50,619 [INFO] Steering vectors computed with shape: torch.Size([768])


In [64]:
# model.to_tokens("<|endoftext|>")

tensor([[50256, 50256]])

In [61]:
model.tokenizer.eos_token

'<|endoftext|>'

In [62]:
model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [63]:
steering_vectors

tensor([ 1.7328e-03,  8.2551e-04, -8.3269e-02,  1.1761e-01, -6.3748e-02,
        -9.0212e-02,  3.1339e-02,  2.6297e-02, -6.2532e-02,  1.3236e-02,
        -6.7088e-02, -1.0789e-01, -7.6554e-02,  2.4443e-02,  2.5070e-02,
         2.2051e-02, -1.0301e-02, -5.5755e-02,  4.9618e-04, -7.1885e-02,
         6.9208e-02, -1.1011e-01, -2.9188e-02,  1.8490e-01, -1.2785e-02,
        -2.4185e-02,  6.0346e-03, -5.4809e-02, -1.2871e-02, -8.0980e-02,
        -5.4031e-02,  7.6232e-02, -6.8487e-03, -1.2574e-01,  9.2210e-02,
         7.7478e-02,  6.6134e-02,  4.0494e-02, -4.9254e-02,  2.3767e-02,
        -3.4099e-02, -7.1602e-02,  9.6972e-02, -8.8507e-03, -1.1878e-02,
        -1.4549e-01,  3.3235e-02, -5.8029e-02, -6.2693e-02, -3.2039e-02,
         3.3839e-02, -2.2269e-02, -5.3156e-02, -2.2298e-02,  1.7325e-02,
         1.0806e-02, -1.5715e-02, -2.1478e-02,  5.8923e-02,  2.5012e-02,
         1.6772e-02,  6.4279e-02,  2.0084e-04, -1.0122e-01,  1.7733e-01,
         6.3427e-03,  4.2531e-02,  1.6085e-02,  2.1

steering_vectors
* 这里得到的就是delta_matricx

In [152]:
# # Define steering hook
# steering_on = True  # This will be toggled in run_generate
# alpha = args.alpha
# method = args.method  # Store method for clarity

def steering_hook(resid_pre, hook):
    if resid_pre.shape[1] == 1:
        return
    logging.info(f"hook is {hook}\n")
    if steering_on:
        resid_pre[:, :-1, :] += args.alpha * steering_vectors
        # 这里是干预调整的关键，很有意思的是，如果提前干预效果会更好更连贯，同样加上正弦波之类的效果也很棒
        #注意这里是对batch_size,0:-9,hidden_size这样的隐藏层做干预

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch,prepend_bos=False)
        result = model.generate(
            stop_at_eos=True,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result

def run_generate(example_prompt,sampling_kwargs):
    model.reset_hooks()
    editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook)]
    res = hooked_generate(
        [example_prompt] * 3, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, :])
    # print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
    # generated_texts = res_str
    for idx, text in enumerate(res_str):
        logging.info(f"Generated Text: {idx+1}: {text}")
        
    return res_str

# Define sampling parameters
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

# Example prompt from the selected set
example_prompt = " Hei, you are so"
# example_prompt=model.tokenizer
# print( all_dataset["val"]["prompt"][:3],all_dataset["val"]["A"][:3])
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering
steering_on = False
alpha = 0
logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs)
logging.info("干预之后的结果")
# bef,aft=args.steer.split("-")
logging.info(f"干预方向{source}->{target},礼貌任务下，neg=impolite，情感任务下 pos=积极情感")
# Generate with steering
steering_on = True
alpha = 100
logging.info("** Generating texts with steering... Target **")
generated_texts_with_steer = run_generate(example_prompt, sampling_kwargs)

# Combine generated texts
# all_generated_texts = generated_texts_no_steer + generated_texts_with_steer



2025-01-12 16:20:06,953 [INFO] Example prompt:  Hei, you are so
2025-01-12 16:20:06,957 [INFO] Generating texts **without** steering... 


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-12 16:20:07,025 [INFO] hook is HookPoint()

2025-01-12 16:20:08,945 [INFO] Generated Text: 1:  Hei, you are so much better than me.

I am a man of the people. I am not a slave to anyone. I have no power over anyone else's fate.

I will never be like you, and I will never be like you again
2025-01-12 16:20:08,946 [INFO] Generated Text: 2:  Hei, you are so much better than me.

I'm not sure if I can say that I am the best person in the world. But I am a very good person. And it's because of that that I have been able to be a better person than
2025-01-12 16:20:08,946 [INFO] Generated Text: 3:  Hei, you are so smart. You're so clever. You're so smart."

"I'm not stupid," he said, "but I'm not stupid enough to think that I can do anything about it."

"You know what? It's a
2025-01-12 16:20:08,946 [INFO] 干预之后的结果
2025-01-12 16:20:08,947 [INFO] 干预方向pos->neg,礼貌任务下，neg=impolite，情感任务下 pos=积极情感
2025-01-12 16:20:08,947 [INFO] ** Generating texts with steering... Target **


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-12 16:20:08,979 [INFO] hook is HookPoint()

2025-01-12 16:20:10,853 [INFO] Generated Text: 1:  Hei, you are so much more than a human being.

I'm sorry, but I can't help it. I don't want to hear about your feelings anymore.

I'm sorry, but I can't help it either. You're not my friend
2025-01-12 16:20:10,854 [INFO] Generated Text: 2:  Hei, you are so stupid. You're a fucking idiot.

"I'm not going to let you get away with this," she said, "because I know that if I do, it will be for the best."

"You're not going to
2025-01-12 16:20:10,854 [INFO] Generated Text: 3:  Hei, you are so much more than a good friend.

"I'm sorry, but I don't know what to do with you."

"You're not my friend," said the man. "I'm just a person who loves you and wants to


# 理论上来讲，
* 不礼貌的输出应该有很多疑问句？例如what？ ha？why？
* 而礼貌的输出应该有很多正常的词语
* 积极情感和不积极情感同理
* 从目前的实验来看，负向情感干预+礼貌情感干预表现比较好，可以拿这个做可解释性，我选取的latents选了前100频次的激活神经元

## 下面是评估代码
* 利用prompt数据集中的prompt进行扭转


In [100]:
args.prompt_path

AttributeError: 'Namespace' object has no attribute 'prompt_path'

['Q: The mayor commissioned two artists to paint 50 murals around the city. Once the work was completed, Celina was paid $1,000 more than 4 times the amount Diego got. If the mayor paid the two a total of $50,000, how much did Diego get?\nA: ',
 "Q: Adrianna has 10 pieces of gum to share with her friends. There wasn't enough gum for all her friends, so she went to the store to get 3 more pieces of gum. She gave out gum to 11 friends. How many pieces of gum does Adrianna have now?\nA: ",
 'Q: Paul needed to buy some new clothes for work.  He had a 10% off coupon that he could use on his entire purchase after any other discounts.  Paul bought 4 dress shirts at $15.00 apiece, 2 pairs of pants that each cost $40.00.  He found a suit for $150.00 and 2 sweaters for $30.00 each.  When he got to the register, the clerk told him that the store was offering 20% off of everything in the store.  After the discounts and the coupon, how much did Paul spend on his new clothes?\nA: ',
 'Q: Donny saves

In [138]:
all_dataset["val"]["A"][:10]

['9800', '2', '252', '28', '5', '16', '16', '550', '26', '107']

In [None]:
# Save results
save_results(
    output_dir=output_dir_base,
    nz_mean=nz_mean,
    act_cnt=act_cnt,
    generated_texts=all_generated_texts,
    hyperparams=hyperparams
)