In [13]:
# Imports and Setup
import argparse
import os
import logging
from typing import Tuple
import json
from log import setup_logging
from tqdm import tqdm  # For progress bars


In [14]:

import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
# from sae_lens.analysis.feature_statistics import (
#     get_all_stats_dfs,
#     get_W_U_W_dec_stats_df,
# )
# from sae_lens.analysis.tsea import (
#     get_enrichment_df,
#     manhattan_plot_enrichment_scores,
#     plot_top_k_feature_projections_by_token_and_category,
#     get_baby_name_sets,
#     get_letter_gene_sets,
#     generate_pos_sets,
#     get_test_gene_sets,
#     get_gene_set_from_regex,
# )
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np
# import plotly_express as px

# 进行基础情感干预实验

In [15]:
# Define hyperparameters
task="sentiment"
if task=="sentiment":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5",
        "prompt_path":"/home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "task":"sentiment",# “sentiment”,"cot","polite"
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100, # 这个alpha后面慢慢调节
        "steer": "pos-neg",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul" 用val_mul会比较好
        "topk_mean": 100, # 选取前topk 个均值激活，这个效果一般，会导致很多如what？why？这种被激活
        "topk_cnt": 100, # 选取前topk个频率激活，目前默认这个，效果很好
        "batch_size": 32 # 这个好像没用上
    }

# 进行COT相关实验

In [36]:
# Define hyperparameters
if task=="cot":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100,
        "steer": "cot-direct",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul"
        "topk_mean": 100,
        "topk_cnt": 100,
        "batch_size": 32
    }

# 进行礼貌实验
/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus

In [37]:
# Define hyperparameters
if task=="polite":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100,
        "steer": "polite-impolite",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul"
        "topk_mean": 100,
        "topk_cnt": 100,
        "batch_size": 32
    }

In [7]:
# 进行礼貌实验

In [16]:
# Configuration and Hyperparameters
# 将字典转换为 argparse.Namespace 对象
args = argparse.Namespace(**args_dict)
# 测试访问属性
print(args.layer)  # 输出: 10
print(args.LLM)  # 输出: gpt2-small
print(args.output_dir)  # 输出: ./results
print(args.steer)

6
gpt2-small
./results
pos-neg


In [17]:
# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_{task}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)

# Setup logging
setup_logging(output_dir_base)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")


2025-01-14 11:42:01,959 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_sentiment_cnt_100_mean100/execution.log
2025-01-14 11:42:01,961 [INFO] Hyperparameters:
2025-01-14 11:42:01,962 [INFO]   layer: 6
2025-01-14 11:42:01,965 [INFO]   LLM: gpt2-small
2025-01-14 11:42:01,966 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 11:42:01,967 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 11:42:01,969 [INFO]   output_dir: ./results
2025-01-14 11:42:01,970 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 11:42:01,971 [INFO]   task: sentiment
2025-01-14 11:42:01,972 [INFO]   seed: 42
2025-01-14 11:42:01,973 [INFO]   data_size: 1000
2025-01-14 11:42:01,974 [INFO]   device: cpu
2025-01-14 11:42:01,975 [INFO]   alpha: 100
2025-01-14 11:42:01,976 [INFO]   steer: pos-neg
2025-01-14 1

In [18]:
# Load Environment Variables
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")
load_environment(args.env_path)


2025-01-14 11:42:05,802 [INFO] HF_ENDPOINT: https://hf-mirror.com


In [19]:
import re

def load_and_prepare_triple_dataset(dataset_path: str,dataset_name:str, seed: int, num_samples: int):
    """
    支持positive\neutral\negative三元组数据类型，例如 sst5，polite数据集和multi-class数据集

    Args:
        dataset_path (str): _description_
        dataset_name : "sst5","multiclass","polite"
        seed (int): _description_
        num_samples (int): _description_

    Returns:
        _type_: _description_
    """
    assert dataset_name in ["sst5","multiclass","polite"]
    if dataset_name in ["sst5"]:
        neu_label=2 # 中性情感对应的label
        assert "sst5" in dataset_path
    elif  dataset_name in ["polite","multiclass"]:
        neu_label=1
    logging.info(f"Loading dataset from ****{dataset_path}***")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    neg_train_set = dataset['train'].filter(lambda example: example['label'] < neu_label).select(range(num_samples))
    pos_train_set = dataset['train'].filter(lambda example: example['label'] == neu_label).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] > neu_label ).select(range(num_samples))

    logging.info(f"Selected {len(neg_train_set)} negative, {len(pos_train_set)} positive, and {len(neu_train_set)} neutral samples")
    print(dataset)
    if dataset_name in ["sst5"]:
        val_set=dataset['validation']
    else:
        raise ValueError("没写呢")
    test_set=dataset["test"]
    return neg_train_set, pos_train_set, neu_train_set,val_set,test_set
def load_and_prepare_COT_dataset(dataset_path:str,seed:int,num_samples:int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)
    logging.info("Filtering dataset for COT")
    # 定义一个函数来提取答案
    def extract_answer(text):
        # 使用正则表达式提取答案
        match = re.search(r'#### ([-+]?\d*\.?\d+/?\d*)', text)
        if match:
            label=match.group(1)
            return label
        else:
            raise ValueError("Modify your re expression")
    def concat_QA(example,col1,col2,tag):
        combined = f"{example[col1]}{tag}{example[col2]}"  # 用空格拼接
        return combined
    def replace_col(example,col1,target,pattern):
        return example[col1].replace(target,pattern)
    # 应用函数并创建新列
    dataset = dataset.map(lambda example: {'A': extract_answer(example['response'])})
    dataset = dataset.map(lambda example: {'Q+A': concat_QA(example,"prompt","A","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': concat_QA(example,"prompt","response","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': replace_col(example,"Q+COT_A","#### ","")})
    # 查看处理后的数据集
    print("Q+A\n",dataset['train'][103]['Q+A'])
    print("Q+COT_A\n",dataset['train'][103]['Q+COT_A'])
    return dataset
    

In [20]:
args.steer

'pos-neg'

In [21]:
# Load and Prepare Dataset

logging.info("dataset path "+args.dataset_path)
if "neg" in args.steer or "pos" in args.steer and args.steer=="sentiment":
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set = load_and_prepare_triple_dataset(
        args.dataset_path, "sst5",args.seed, args.data_size
    )
elif "cot" in args.steer or "COT" in args.steer:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "polite" in args.steer:
    logging.info("polite"*10)
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set=load_and_prepare_triple_dataset(args.dataset_path,"polite", args.seed, args.data_size)
else:
    raise ValueError("No Supported")


2025-01-14 11:42:15,348 [INFO] dataset path /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 11:42:15,351 [INFO] Loading dataset from ****/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5***
Repo card metadata block was not found. Setting CardData to empty.
2025-01-14 11:42:15,528 [INFO] Filtering dataset for negative, positive, and neutral samples
2025-01-14 11:42:15,542 [INFO] Selected 1000 negative, 1000 positive, and 1000 neutral samples


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 8544
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1101
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 2210
    })
})


In [22]:
assert neg_train_set[10]!=pos_train_set[10]

In [12]:
pos_train_set[10],neg_train_set[10]

({'text': 'some are fascinating and others are not , and in the end , it is almost a good movie .',
  'label': 2,
  'label_text': 'neutral'},
 {'text': 'the lousy lead performances ... keep the movie from ever reaching the comic heights it obviously desired .',
  'label': 1,
  'label_text': 'negative'})

In [23]:


def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    logging.info("Running model with cache to obtain hidden states")
    batch_latents = []

    # 使用 tqdm 显示进度条
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i + batch_size]
        sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=False, device=device)
        batch_hidden_states = cache[hook_point]
        logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

        logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
        # 假设 sae.encode 支持批量编码
        latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
        batch_latents.append(latents)
        

    logging.info(f"Total batches processed: {len(batch_latents)}")
    return batch_latents

In [24]:

def compute_steering_vectors(sae: SAE, overlap_indices: Tensor, nz_mean: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[overlap_indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for important_idx in overlap_indices:
            steering_vectors += nz_mean[important_idx].item() * sae.W_dec[important_idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors


In [25]:

def save_results(output_dir: str, nz_mean: Tensor, act_cnt: Tensor, generated_texts: list, hyperparams: dict):
    os.makedirs(output_dir, exist_ok=True)
    # Save nz_mean and act_cnt
    nz_stats_path = os.path.join(output_dir, 'nz_stats.pt')
    logging.info(f"Saving nz_mean and act_cnt to {nz_stats_path}")
    torch.save({
        'nz_mean': nz_mean,
        'act_cnt': act_cnt,
    }, nz_stats_path)

    # Save generated texts
    generated_texts_path = os.path.join(output_dir, 'generated_texts.txt')
    logging.info(f"Saving generated texts to {generated_texts_path}")
    with open(generated_texts_path, 'w') as f:
        for text in generated_texts:
            f.write(text + "\n")

    # Save hyperparameters
    hyperparams_path = os.path.join(output_dir, 'hyperparameters.json')
    logging.info(f"Saving hyperparameters to {hyperparams_path}")
    with open(hyperparams_path, 'w') as f:
        json.dump(hyperparams, f, indent=4)

    logging.info("All results saved successfully.")

In [26]:
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)
output_dir_base

'./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100'

In [28]:
def load_from_cache():
    cache_exists = False
    cache_file = os.path.join(output_dir_base, 'hyperparameters.json')
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            cached_data = json.load(f)
        cached_hash = cached_data.get('hyperparams_hash')

    if cache_exists:
        # Load nz_mean and act_cnt from cache
        # nz_stats_path = os.path.join(output_dir_base, 'nz_stats.pt')
        # nz_act = torch.load(nz_stats_path)
        # nz_mean = nz_act['nz_mean']
        # act_cnt = nz_act['act_cnt']
        # overlap_indices = nz_act.get('overlap_indices', None)  # If overlap_indices was saved
        logging.info("load from cache")
    else:
        # overlap_indices = None  # Will be computed later
        logging.info("non cache: "+cache_file)
load_from_cache()

2025-01-14 11:43:03,515 [INFO] non cache: ./results/LLM_gpt2-small_layer_6_steer_pos-neg_alpha_100_cnt_100_mean100/hyperparameters.json


In [12]:
# setup_logging(output_dir_base)

In [29]:
# Save hyperparameters
hyperparams = vars(args)

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load environment
load_environment(args.env_path)

# Load model and SAE
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)

2025-01-14 11:43:06,271 [INFO] Hyperparameters:
2025-01-14 11:43:06,273 [INFO]   layer: 6
2025-01-14 11:43:06,275 [INFO]   LLM: gpt2-small
2025-01-14 11:43:06,277 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 11:43:06,279 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 11:43:06,280 [INFO]   output_dir: ./results
2025-01-14 11:43:06,283 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 11:43:06,285 [INFO]   task: sentiment
2025-01-14 11:43:06,287 [INFO]   seed: 42
2025-01-14 11:43:06,288 [INFO]   data_size: 1000
2025-01-14 11:43:06,290 [INFO]   device: cpu
2025-01-14 11:43:06,290 [INFO]   alpha: 100
2025-01-14 11:43:06,291 [INFO]   steer: pos-neg
2025-01-14 11:43:06,292 [INFO]   method: val_mul
2025-01-14 11:43:06,293 [INFO]   topk_mean: 100
2025-01-14 11:43:06,295 [INFO]   topk_cnt: 100
2025-01-14 11:43:06,297 [INFO]   batch_

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [128]:
# # Load dataset
# all_dataset = load_and_prepare_triple_dataset(
#     args.dataset_path, args.seed, args.data_size
# )

In [30]:
args.steer

'pos-neg'

In [31]:

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt,None

In [32]:
# Select a dataset steer based on steering preference

def get_activation_by_steer(texts:list):

    hook_point = sae.cfg.hook_name

    # Compute latents with batch processing
    batch_latents = compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    # 计算第二个维度的最大值
    max_dim1 = max(latent.shape[1] for latent in batch_latents)  # 第二个维度的最大值
    logging.info(f"最大长度:{max_dim1}")
    # 对每个 Tensor 进行填充（仅填充第二个维度）
    padded_latents_right = [
        torch.nn.functional.pad(latent, (0, 0, 0, max_dim1 - latent.size(1)), "constant", 0)
        for latent in batch_latents
    ]

    batch_latents_concatenated = torch.cat(padded_latents_right, dim=0)
    logging.info(f"Concatenated batch latents shape: {batch_latents_concatenated.shape}")

    # Analyze latents 
    nz_mean, act_cnt, _ = analyze_latents(batch_latents_concatenated, top_k_mean=args.topk_mean, top_k_cnt=args.topk_cnt)
    return {"nz_mean":nz_mean,"nz_cnt":act_cnt}

In [33]:
args.steer=args.steer.lower()

In [133]:
# all_dataset["train"]["Q+A"][:args.data_size][193]

## 26000(SAE稀疏神经元)对应的非零激活神经元激活统计信息，和激活值统计信息


In [34]:
steer_info={}

In [35]:
steer_info={}
if args.steer=='polite-impolite' or args.steer=="pos-neg":
    logging.info(args.steer)
    text=pos_train_set["text"][:args.data_size]
    steer_info["pos"]=get_activation_by_steer(text)
    text=neg_train_set["text"][:args.data_size]
    steer_info["neg"]=get_activation_by_steer(text)
    text=neu_train_set["text"][:args.data_size]
    steer_info["neu"]=get_activation_by_steer(text)

2025-01-14 11:44:59,911 [INFO] pos-neg
2025-01-14 11:44:59,938 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-14 11:45:00,312 [INFO] Batch 1: Hidden states shape: torch.Size([32, 47, 768])
2025-01-14 11:45:00,313 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:12,  2.39it/s]2025-01-14 11:45:00,741 [INFO] Batch 2: Hidden states shape: torch.Size([32, 51, 768])
2025-01-14 11:45:00,743 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:12,  2.44it/s]2025-01-14 11:45:01,054 [INFO] Batch 3: Hidden states shape: torch.Size([32, 50, 768])
2025-01-14 11:45:01,056 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:01<00:10,  2.73it/s]2025-01-14 11:45:01,380 [INFO] Batch 4: Hidden states shape: torch.Size([32, 52, 768])
2025-01-14 11:45:01,382 [INFO] Encoding hidden states for batch 4
Processing ba

In [36]:

if "cot" in args.steer:
    if args.steer in "cot-direct":
        texts=all_dataset["train"]["Q+A"][:args.data_size]
        print(type(texts))
        steer_info["direct"]=get_activation_by_steer(texts)
        
        texts=all_dataset["train"]["Q+COT_A"][:args.data_size]
        print(type(texts))
        steer_info["cot"]=get_activation_by_steer(texts)
        # print(texts[123])
    else:
        raise ValueError("????")

In [37]:
args.steer

'pos-neg'

In [51]:
# pos_train_set["text"][:args.data_size]

In [38]:
len(pos_train_set["text"][:args.data_size])

1000

In [39]:
args.data_size

1000

In [22]:

# for steer_key in ["pos","neu","neg"]:
#     if steer_key == "pos":
#         selected_set = pos_train_set
#     elif steer_key == "neg":
#         selected_set = neg_train_set
#     elif steer_key=="neu":
#         selected_set = neu_train_set

#     texts = selected_set["text"][:args.data_size]
#     a=get_activation_by_steer(texts)
#     steer_info[steer_key]=a


In [40]:
steer_info["pos"]

{'nz_mean': tensor([ 6.6345, 11.8903,  4.9516,  ...,  0.9575,  1.0143,  0.0000],
        grad_fn=<WhereBackward0>),
 'nz_cnt': tensor([ 7,  6,  1,  ..., 64, 29,  0])}

In [41]:
steer_info["pos"]["nz_mean"].shape

torch.Size([24576])

In [42]:
# steer_info["dif_neg_pos"]={"steer":"dif_neg_pos","nz_cnt":steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"],"nz_mean":steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"]}

In [43]:
# steer_info["dif_neg_pos_relu"]={"nz_cnt":torch.relu(steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"]),"nz_mean":torch.relu(steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])}

In [55]:
# steer_info["dif_neg_pos_relu"],steer_info["dif_neg_pos"]

In [62]:
sourceource="cot"
target="direct"

In [66]:
source="pos"
target="neg"
# 调整样本正负性在这里调整 从负样本到正样本还是从正样本()到负样本
# pos 代表积极情绪
# neg 代表消极情绪


In [43]:
assert bool(torch.all((steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])==0))==False,"数据库读取有问题"

In [46]:
assert torch.all(steer_info[target]["nz_mean"]>=0),"所有SAE的激活需要大于d等于0（maybe）"

# 泽凯在这里做mask

In [67]:
# 
steer_info[f"dif_{target}-{source}_relu"]={"nz_cnt":torch.relu(steer_info[target]["nz_cnt"]-steer_info[source]["nz_cnt"]),"nz_mean":torch.relu(steer_info[target]["nz_mean"]-steer_info[source]["nz_mean"]),"target_nz_mean":torch.relu(steer_info[target]["nz_mean"])}
"""
nz_cnt: 神经元被激活的次数
nz_mean: 神经元被激活后的平均值
nz_mean_pos: 正样本神经元被激活后的平均值
"""
top_k=100
_,steer_indices=torch.topk(steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"],top_k)

In [110]:

# steer_info[f"dif_{a}-{b}_relu"]={"nz_cnt":torch.relu(steer_info[a]["nz_cnt"]-steer_info[b]["nz_cnt"]),"nz_mean":torch.relu(steer_info[a]["nz_mean"]-steer_info[b]["nz_mean"])}
# top_k=100
# steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{a}-{b}_relu"]["nz_cnt"],top_k)

In [68]:
# 假设 steer_info[f"dif_{b}-{a}_relu"]["nz_cnt"] 是一个 NumPy 数组
nz_cnt = steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"]

# 先获取非零元素的索引
nz_indices = np.nonzero(nz_cnt)
torch.all(nz_cnt == 0)

tensor(False)

In [69]:
steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"].shape

torch.Size([24576])

In [70]:
_,steer_indices,

(tensor([976, 976, 950, 866, 803, 697, 670, 634, 618, 582, 576, 570, 556, 549,
         534, 533, 526, 523, 506, 471, 470, 462, 456, 456, 445, 416, 415, 390,
         377, 377, 370, 367, 366, 360, 348, 347, 344, 320, 312, 307, 303, 302,
         291, 289, 288, 287, 273, 268, 263, 263, 260, 257, 256, 255, 249, 241,
         237, 235, 233, 233, 232, 231, 230, 228, 224, 219, 216, 216, 216, 215,
         212, 210, 205, 204, 201, 201, 200, 195, 195, 194, 194, 189, 187, 186,
         183, 182, 179, 177, 176, 175, 174, 173, 173, 172, 169, 167, 166, 165,
         163, 163]),
 tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
         17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
         18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
         10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
         10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
         11185,  1167, 18722,  

In [71]:
steer_indices

tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
        17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
        18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
        10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
        10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
        11185,  1167, 18722,  9082, 12797,   222, 18146,  4163,  7417,  8896,
         5203, 13252, 21943, 11859,  2282, 13817, 12488, 23355, 20830,  3219,
         6861, 24114, 16460,  5303,  9148, 12460,  1374, 10526,  5741,  7467,
         6545, 22906, 16318, 17732,  4173,   502, 22288, 23847, 17504,  2211,
        15256,  4063,  4994, 15262,  5330, 10282, 14954, 14440,  7383,    98])

In [72]:
steer_info[f"dif_{target}-{source}_relu"]["nz_mean"][steer_indices]# 这里有0,没有负数比较正常


tensor([7.9328e-02, 3.8317e-02, 2.0146e-02, 2.3163e-02, 2.8212e-02, 0.0000e+00,
        5.1741e-02, 4.0420e-02, 0.0000e+00, 2.9353e-02, 2.1048e-02, 0.0000e+00,
        0.0000e+00, 1.4081e-02, 7.1191e-02, 4.8021e-02, 1.1438e-04, 0.0000e+00,
        4.0448e-02, 4.5769e-02, 7.2892e-02, 4.8401e-03, 0.0000e+00, 9.2183e-03,
        0.0000e+00, 3.1627e-02, 1.8980e-02, 2.9644e-02, 1.5619e-02, 1.5609e-02,
        0.0000e+00, 2.3832e-02, 2.7588e-02, 0.0000e+00, 4.8306e-02, 1.6971e-02,
        4.3642e-02, 0.0000e+00, 2.7315e-02, 2.3624e-02, 0.0000e+00, 0.0000e+00,
        2.6217e-02, 6.4675e-02, 6.5274e-03, 0.0000e+00, 6.5705e-02, 3.3928e-03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.1977e-03, 6.2665e-03,
        2.9225e-03, 7.8505e-02, 3.3161e-02, 0.0000e+00, 1.0959e-02, 6.2663e-02,
        3.1515e-02, 1.6875e-02, 3.7916e-02, 4.5186e-02, 2.7820e-02, 0.0000e+00,
        1.0531e-02, 1.1299e-01, 4.5587e-02, 0.0000e+00, 0.0000e+00, 3.3021e-02,
        8.4106e-02, 5.1196e-02, 0.0000e+

In [44]:
# steer_info["dif_-pos"]["nz_mean"][steer_indices]

In [73]:
steer_indices

tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
        17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
        18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
        10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
        10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
        11185,  1167, 18722,  9082, 12797,   222, 18146,  4163,  7417,  8896,
         5203, 13252, 21943, 11859,  2282, 13817, 12488, 23355, 20830,  3219,
         6861, 24114, 16460,  5303,  9148, 12460,  1374, 10526,  5741,  7467,
         6545, 22906, 16318, 17732,  4173,   502, 22288, 23847, 17504,  2211,
        15256,  4063,  4994, 15262,  5330, 10282, 14954, 14440,  7383,    98])

In [75]:
# mean_type="dif_mean"
mean_type="tar_mean"
def compute_steering_vectors(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            steering_vectors += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors
if mean_type=="dif_mean":
    delta_matrix=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["nz_mean"],method="val_mul")
else:
    delta_matrix=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["target_nz_mean"],method="val_mul")

2025-01-14 11:57:51,496 [INFO] Computing steering vectors using method: val_mul
2025-01-14 11:57:51,508 [INFO] Steering vectors computed with shape: torch.Size([768])


In [64]:
# model.to_tokens("<|endoftext|>")

tensor([[50256, 50256]])

In [76]:
model.tokenizer.eos_token

'<|endoftext|>'

In [77]:
model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [198]:
delta_matrix #理论上这里有正有负比较正常

tensor([-3.6214e-01,  1.0613e+00, -2.0561e+00,  7.6985e-02, -2.2114e-01,
        -1.1145e+00,  1.0092e-01,  3.1080e+00, -1.4369e+00, -4.4066e-02,
        -2.4463e-01, -5.4689e-01, -3.2318e-01,  2.2005e-01, -1.0918e+00,
        -3.3018e-02, -6.1457e-03, -4.1507e-01,  9.3305e-01,  1.2855e+00,
         1.3368e+00, -6.3565e-01,  4.4228e-02,  9.8607e-01, -2.0987e-01,
         4.9163e-01, -6.4610e-02, -1.7473e+00, -5.9216e-01, -5.9732e-01,
        -1.3186e+00, -9.7763e-01,  1.6091e-01, -1.6140e+00,  9.4070e-01,
         2.3314e+00,  6.6513e-01,  4.7095e-01, -2.1278e+00,  7.3241e-01,
         9.2688e-02, -3.1276e-01, -3.0850e-01,  1.1526e-01, -4.6604e-01,
         2.2835e-02,  3.8812e-01, -8.7149e-01, -1.2737e-01, -4.7759e-01,
        -9.2506e-01, -1.7304e-01, -1.0689e+00,  9.0440e-01, -1.1013e-01,
         2.5222e-01,  8.8352e-01, -1.8106e-02, -1.2686e+00, -4.0374e-01,
         7.9995e-01,  8.5598e-01, -4.1962e-01,  3.7982e-01,  1.8890e-01,
         5.1330e-01, -8.1306e-01,  3.2115e-02, -7.1

# 这里得到的就是delta_matricx

In [86]:
sae.cfg.hook_name

'blocks.6.hook_resid_pre'

In [90]:
f"blocks.{args.layer}.hook_resid_post"

'blocks.6.hook_resid_post'

In [213]:
import torch.nn.functional as F
import torch

def half_gaussian_kernel(half_len):
    # 设置均值和标准差
    cov_len=2*half_len
    mean = cov_len // 2  # 正态分布的均值
    std = cov_len / 6    # 设置标准差，可以根据需要调整

    # 创建正态分布
    x = torch.arange(cov_len, dtype=torch.float32)
    kernel = torch.exp(-0.5 * ((x - mean) / std) ** 2)
    print(kernel)

    # 仅保留正态分布的前半部分（右侧值设置为0）
    kernel[int(cov_len // 2):] = 0  # 保留前半部分，右半部分置为零

    # 归一化，确保总和为 1
    kernel = kernel / kernel.sum()
    return kernel[:half_len]

gauss=half_gaussian_kernel(4)
gauss
# k_gau=torch.cat([gauss, torch.tensor([0])])

tensor([0.0111, 0.0796, 0.3247, 0.7548, 1.0000, 0.7548, 0.3247, 0.0796])


tensor([0.0095, 0.0680, 0.2774, 0.6451])

In [214]:
gauss

tensor([0.0095, 0.0680, 0.2774, 0.6451])

In [215]:
import einops
test=torch.ones(2, 6,3)
# test=einops.rearrange(test,"b s d->b d s")
test.shape,k_gau.shape

(torch.Size([2, 6, 3]), torch.Size([6]))

In [216]:
re_gau=einops.repeat(k_gau,"s -> b s h",b=2,h=3)

In [217]:
test*re_gau,test

(tensor([[[0.0070, 0.0070, 0.0070],
          [0.0354, 0.0354, 0.0354],
          [0.1247, 0.1247, 0.1247],
          [0.3067, 0.3067, 0.3067],
          [0.5263, 0.5263, 0.5263],
          [0.0000, 0.0000, 0.0000]],
 
         [[0.0070, 0.0070, 0.0070],
          [0.0354, 0.0354, 0.0354],
          [0.1247, 0.1247, 0.1247],
          [0.3067, 0.3067, 0.3067],
          [0.5263, 0.5263, 0.5263],
          [0.0000, 0.0000, 0.0000]]]),
 tensor([[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]))

In [247]:
# # Define steering hook
# steering_on = True  # This will be toggled in run_generate
# alpha = args.alpha
# method = args.method  # Store method for clarity
from functools import partial
steer_cnt=0


def steering_hook(resid_pre, hook, steer_on, alpha, steer_type="last"):
    # 如果 seq_len 只有 1，则直接返回，不进行操作
    d_m=torch.clone(delta_matrix)
    s = resid_pre[:, :-1, :].shape[1]
    b=resid_pre[:, :-1, :].shape[0]
    h=resid_pre[:, :-1, :].shape[2]
    if resid_pre.shape[1] == 1:
        return
    # 判断是否进行干预
    if steer_on:
        if steer_type == "last":
            # 对最后一个token前的部分应用干预，使用给定的 delta_matrix
            d_m_repeat=einops.repeat(d_m,"h -> b s h",b=b,s=s)
            resid_pre[:, :-1, :] += alpha * d_m_repeat
            
            logging.info(f"干预类型：last")
            logging.info(f"干预矩阵: {alpha * d_m_repeat}")
        elif steer_type == "gaussian":
            # 使用高斯卷积对输入进行干预
            # s_idx=-1

            h_gauss = half_gaussian_kernel(s)  # 获取高斯卷积核
            # k_gauss=torch.cat([h_gauss, torch.tensor([0])])
            k_gauss=h_gauss
            k_gau_repeat=einops.repeat(k_gauss,"s -> b s h",b=b,h=h)
            d_m_repeat=einops.repeat(d_m,"h -> b s h",b=b,s=s)
            # 根据卷积结果更新 resid_pre（注意：保留其他维度不变）,逐一元素相乘
            resid_pre[:, :-1, :] += alpha * d_m_repeat* k_gau_repeat
            logging.info(f"干预类型：高斯")
            logging.info(f"干预矩阵: {alpha * d_m_repeat* k_gau_repeat}")
        else:
            raise ValueError("Unknown steering type")


        # elif steer_type=="last2":
        #     resid_pre[:, :-2, :] += args.alpha * steering_vectors
        # elif steer_type=="gaussian":
        #     # 高斯卷积的方式放缩干预矩阵，
        #     # 这里需要一个高斯核，然后对steering_vectors进行卷积
        #     gaussian_kernel = torch.tensor([1,2,1])
        #     steering_vectors = torch.conv1d(steering_vectors, gaussian_kernel, padding=1)
        #     resid_pre[:, :-1, :] += args.alpha * steering_vectors
        # else:
        #     raise ValueError("Unknown steering type")
        # 修改这里的干预方式，增加干预的选择，例如从倒数第一个token开始干预，或者从倒数第二个token开始干预，或者使用高斯卷积的方式放缩干预矩阵，这里是干预调整的关键，很有意思的是，如果提前干预效果会更好更连贯，还没尝试高斯卷积的方法

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch,prepend_bos=False)
        result = model.generate(
            stop_at_eos=True,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result

def run_generate(example_prompt,sampling_kwargs,steer_on,alpha,steer_type="last",repeat_num=3,show_res=False):
    model.reset_hooks()
    if steer_on:
        steering_hook_fn=partial(steering_hook,steer_on=steer_on,alpha=alpha,steer_type=steer_type)
        editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook_fn)]
    else:
        editing_hooks=[]
    res = hooked_generate(
        [example_prompt] * repeat_num, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, :])
    # print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
    # generated_texts = res_str
    if show_res:
        for idx, text in enumerate(res_str):
            logging.info(f"Generated Text: {idx+1}:\n{text}")
        
    return res_str

# Define sampling parameters
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

# Example prompt from the selected set
example_prompt = " Aha, we have good result!"
# example_prompt=model.tokenizer
# print( all_dataset["val"]["prompt"][:3],all_dataset["val"]["A"][:3])
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering
steer_on = False
alpha = 0
logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs,steer_on=steer_on,alpha=alpha,show_res=True)
logging.info("干预之后的结果")
# bef,aft=args.steer.split("-")
logging.info(f"干预方向{source}->{target},礼貌任务下，neg=impolite，情感任务下 pos=积极情感")
# Generate with steering
steer_on = True
alpha = 100
# alpha=args.aplha
logging.info("** Generating texts with steering... Target **")
generated_texts_with_steer = run_generate(
    example_prompt, 
    sampling_kwargs,
    steer_on=steer_on,
    alpha=alpha,
    steer_type="gaussian",
    show_res=True)

# Combine generated texts
# all_generated_texts = generated_texts_no_steer + generated_texts_with_steer



2025-01-14 15:54:00,315 [INFO] Example prompt:  Aha, we have good result!
2025-01-14 15:54:00,317 [INFO] Generating texts **without** steering... 


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-14 15:54:02,039 [INFO] Generated Text: 1:
 Aha, we have good result!

We are now in the final stages of our journey to get to the next level. We will be able to start building up our team and build a strong foundation for future success. We will also be able to take on more challenges and make
2025-01-14 15:54:02,040 [INFO] Generated Text: 2:
 Aha, we have good result!

I'm sure you've heard of the "Giant's Song" by The Beatles. It was a song about a giant who had to be saved from death by his own hand. It was also about the way that humans are created and
2025-01-14 15:54:02,040 [INFO] Generated Text: 3:
 Aha, we have good result!

"We are going to be the first team in the league to win a title. We will be the first team in Europe to win a European Cup. We will be able to play at home and play against other teams."<|endoftext|><|endoftext|><|endoftext|>
2025-01-14 15:54:02,040 [INFO] 干预之后的结果
2025-01-14 15:54:02,041 [INFO] 干预方向pos->neg,礼貌任务下，neg=impolite，情感任务下 pos=积极情感
2025-01-

  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-14 15:54:02,075 [INFO] 干预类型：高斯
2025-01-14 15:54:02,076 [INFO] 干预矩阵: tensor([[[ -0.1662,   0.4870,  -0.9435,  ...,  -0.3283,   0.2026,  -0.2068],
         [ -0.5484,   1.6071,  -3.1135,  ...,  -1.0832,   0.6687,  -0.6825],
         [ -1.5059,   4.4132,  -8.5501,  ...,  -2.9746,   1.8363,  -1.8743],
         ...,
         [ -6.5456,  19.1824, -37.1638,  ..., -12.9295,   7.9816,  -8.1466],
         [-10.3602,  30.3616, -58.8221,  ..., -20.4646,  12.6331, -12.8944],
         [-13.6465,  39.9923, -77.4807,  ..., -26.9561,  16.6404, -16.9845]],

        [[ -0.1662,   0.4870,  -0.9435,  ...,  -0.3283,   0.2026,  -0.2068],
         [ -0.5484,   1.6071,  -3.1135,  ...,  -1.0832,   0.6687,  -0.6825],
         [ -1.5059,   4.4132,  -8.5501,  ...,  -2.9746,   1.8363,  -1.8743],
         ...,
         [ -6.5456,  19.1824, -37.1638,  ..., -12.9295,   7.9816,  -8.1466],
         [-10.3602,  30.3616, -58.8221,  ..., -20.4646,  12.6331, -12.8944],
         [-13.6465,  39.9923, -77.4807,  ..., -

tensor([0.0111, 0.0367, 0.1007, 0.2301, 0.4376, 0.6926, 0.9123, 1.0000, 0.9123,
        0.6926, 0.4376, 0.2301, 0.1007, 0.0367])


2025-01-14 15:54:03,807 [INFO] Generated Text: 1:
 Aha, we have good result!

But I'm not sure if it's because of the fact that I'm a fan of this series or because I've never read it. It's just that there are so many things in this book that make me want to read more.
2025-01-14 15:54:03,808 [INFO] Generated Text: 2:
 Aha, we have good result!

The story of the film is that of a young man who has been kidnapped by his father and taken to a mysterious place. The film begins with the boy being taken to an unknown location where he is forced to watch as his father's body
2025-01-14 15:54:03,808 [INFO] Generated Text: 3:
 Aha, we have good result!

The game is a simple puzzle with no branching. The player must solve the puzzles to get the most out of it. It's a fun game and I'm sure you'll enjoy it too!<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>


# 理论上来讲，
* 不礼貌的输出应该有很多疑问句？例如what？ ha？why？
* 而礼貌的输出应该有很多正常的词语
* 积极情感和不积极情感同理
* 从目前的实验来看，负向情感干预+礼貌情感干预表现比较好，可以拿这个做可解释性
* 频率很重要，我选取的latents选了前100频次的激活神经元
* 如果对[0:-1]的区间进行干预，效果异常优秀，生成比较连贯，但是如果对[-1:length]的区间进行干预，效果就很差，生成的词语很零碎

# 下面进行的是扭转实验，使用prompt对模型进行诱导，再进行转向

In [250]:
args.prompt_path

'/home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k'

In [240]:
import os
def load_and_prepare_sentiment_prompts(prompt_path:str,seed:int,num_samples:int):
    logging.info(f"Loading prompt_path from {prompt_path}")
    data_files = {"neg": "negative_prompts.jsonl", "pos": "positive_prompts.jsonl","neu":"neutral_prompts.jsonl"}
    
    prompts= load_dataset("/home/ckqsudo/code2024/0refer_ACL/LM-Steer/data/data/prompts/sentiment_prompts-10k",data_files=data_files)
    print(prompts)
    return prompts
prompts=load_and_prepare_sentiment_prompts(args.prompt_path,args.seed,1000)

2025-01-14 15:08:28,571 [INFO] Loading prompt_path from /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k


DatasetDict({
    neg: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 2500
    })
    pos: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 2500
    })
    neu: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 5000
    })
})


In [244]:
prompts["pos"]

Dataset({
    features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
    num_rows: 2500
})

In [None]:
prompts["pos"][0]

In [246]:
target

'neg'

In [245]:

import copy
# Example prompt from the selected set
res=[]
params={}
params["params"]=vars(args)
params["alpha"]=100
res.append(params)

no_steer_res=[]
steer_res=[]

prompt_type="pos"
assert prompt_type!=target,"prompt和转向的方向是一致的"

for idx,item in tqdm(enumerate(prompts[prompt_type][1000:2000])):
    prompt_=item["prompt"]["text"]
    label=prompts[prompt_type]['label'][idx]
    print(prompt_)
    # 没转向的结果
    steer_on = False
    alpha = 0
    # logging.info("Generating texts **without** steering... ")
    generated_texts_no_steer = run_generate(
        prompt_, 
        sampling_kwargs,
        steer_on=steer_on,
        alpha=alpha,
        repeat_num=2,
        steer_type="gaussian",
        show_res=True)
    no_steer_item=copy.deepcopy(prompt_)
    
    # 转向的结果
    steer_on = True
    alpha = 100
    
    generated_texts_with_steer = run_generate(prompt_, sampling_kwargs,steer_on=steer_on,
    alpha=alpha,
    steer_type="gaussian",
    repeat_num=3,
    show_res=True)
    item["no_steer"]=generated_texts_no_steer
    item["steer"]=generated_texts_with_steer
    res.append(copy.deepcopy(item))
    

0it [00:00, ?it/s]

Experience the lifelike thrill of virtual





TypeError: run_generate() missing 2 required positional arguments: 'steer_on' and 'alpha'

In [79]:
import json
with open("/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/src/evaluations/res.json",mode="w",encoding="utf-8") as res_f:
    res_f.write(json.dumps(res,ensure_ascii=False))
    