In [1]:
# Imports and Setup
import argparse
import os
import logging
from typing import Tuple
import json
from log import setup_logging
from tqdm import tqdm  # For progress bars


In [2]:

import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
# from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list
# from sae_lens.analysis.feature_statistics import (
#     get_all_stats_dfs,
#     get_W_U_W_dec_stats_df,
# )
# from sae_lens.analysis.tsea import (
#     get_enrichment_df,
#     manhattan_plot_enrichment_scores,
#     plot_top_k_feature_projections_by_token_and_category,
#     get_baby_name_sets,
#     get_letter_gene_sets,
#     generate_pos_sets,
#     get_test_gene_sets,
#     get_gene_set_from_regex,
# )
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np
# import plotly_express as px

# 进行基础情感干预实验

In [9]:
# Define hyperparameters
task="sentiment"
if task=="sentiment":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5",
        "prompt_path":"/home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "task":"sentiment",# “sentiment”,"cot","polite"
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100, # 这个alpha后面慢慢调节
        "steer": "neg-pos",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul" 用val_mul会比较好
        "topk_mean": 100, # 选取前topk 个均值激活，这个效果一般，会导致很多如what？why？这种被激活
        "topk_cnt": 100, # 选取前topk个频率激活，目前默认这个，效果很好
        "batch_size": 32 # 这个好像没用上
    }

# 进行COT相关实验

In [4]:
# Define hyperparameters
if task=="cot":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/math/COT_GSM8k",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100,
        "steer": "cot-direct",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul"
        "topk_mean": 100,
        "topk_cnt": 100,
        "batch_size": 32
    }

# 进行礼貌实验
/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus

In [5]:
# Define hyperparameters
if task=="polite":
    args_dict = {
        "layer": 6,  # Example layer number to analyze
        "LLM": "gpt2-small",
        "dataset_path": "/home/ckqsudo/code2024/0dataset/ACL_useful_dataset/style_transfer/politeness-corpus",
        "output_dir": "./results",
        "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
        "seed": 42,
        "data_size": 1000,
        "device": "cpu",  # Options: "cpu", "cuda", "mps", "auto"
        "alpha": 100,
        "steer": "polite-impolite",  # Options: "pos", "neg", "neu","pos-neg","cot-direct"
        "method": "val_mul",  # Options: "mean", "val_mul"
        "topk_mean": 100,
        "topk_cnt": 100,
        "batch_size": 32
    }

In [6]:
# 进行礼貌实验

In [10]:
# Configuration and Hyperparameters
# 将字典转换为 argparse.Namespace 对象
args = argparse.Namespace(**args_dict)
# 测试访问属性
print(args.layer)  # 输出: 10
print(args.LLM)  # 输出: gpt2-small
print(args.output_dir)  # 输出: ./results
print(args.steer)

6
gpt2-small
./results
neg-pos


In [11]:
# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_{task}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)

# Setup logging
setup_logging(output_dir_base)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")


2025-01-14 22:41:36,927 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_neg-pos_alpha_100_sentiment_cnt_100_mean100/execution.log
2025-01-14 22:41:36,929 [INFO] Hyperparameters:
2025-01-14 22:41:36,930 [INFO]   layer: 6
2025-01-14 22:41:36,932 [INFO]   LLM: gpt2-small
2025-01-14 22:41:36,934 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 22:41:36,935 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 22:41:36,936 [INFO]   output_dir: ./results
2025-01-14 22:41:36,937 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 22:41:36,938 [INFO]   task: sentiment
2025-01-14 22:41:36,939 [INFO]   seed: 42
2025-01-14 22:41:36,940 [INFO]   data_size: 1000
2025-01-14 22:41:36,941 [INFO]   device: cpu
2025-01-14 22:41:36,942 [INFO]   alpha: 100
2025-01-14 22:41:36,944 [INFO]   steer: neg-pos
2025-01-14 2

In [12]:
# Load Environment Variables
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")
load_environment(args.env_path)


2025-01-14 22:41:39,787 [INFO] HF_ENDPOINT: https://hf-mirror.com


In [13]:
import re

def load_and_prepare_triple_dataset(dataset_path: str,dataset_name:str, seed: int, num_samples: int):
    """
    支持positive\neutral\negative三元组数据类型，例如 sst5，polite数据集和multi-class数据集

    Args:
        dataset_path (str): _description_
        dataset_name : "sst5","multiclass","polite"
        seed (int): _description_
        num_samples (int): _description_

    Returns:
        _type_: _description_
    """
    assert dataset_name in ["sst5","multiclass","polite"]
    if dataset_name in ["sst5"]:
        neu_label=2 # 中性情感对应的label
        assert "sst5" in dataset_path
    elif  dataset_name in ["polite","multiclass"]:
        neu_label=1
    logging.info(f"Loading dataset from ****{dataset_path}***")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    neg_train_set = dataset['train'].filter(lambda example: example['label'] < neu_label).select(range(num_samples))
    pos_train_set = dataset['train'].filter(lambda example: example['label'] == neu_label).select(range(num_samples))
    neu_train_set = dataset['train'].filter(lambda example: example['label'] > neu_label ).select(range(num_samples))

    logging.info(f"Selected {len(neg_train_set)} negative, {len(pos_train_set)} positive, and {len(neu_train_set)} neutral samples")
    print(dataset)
    if dataset_name in ["sst5"]:
        val_set=dataset['validation']
    else:
        raise ValueError("没写呢")
    test_set=dataset["test"]
    return neg_train_set, pos_train_set, neu_train_set,val_set,test_set
def load_and_prepare_COT_dataset(dataset_path:str,seed:int,num_samples:int):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)
    logging.info("Filtering dataset for COT")
    # 定义一个函数来提取答案
    def extract_answer(text):
        # 使用正则表达式提取答案
        match = re.search(r'#### ([-+]?\d*\.?\d+/?\d*)', text)
        if match:
            label=match.group(1)
            return label
        else:
            raise ValueError("Modify your re expression")
    def concat_QA(example,col1,col2,tag):
        combined = f"{example[col1]}{tag}{example[col2]}"  # 用空格拼接
        return combined
    def replace_col(example,col1,target,pattern):
        return example[col1].replace(target,pattern)
    # 应用函数并创建新列
    dataset = dataset.map(lambda example: {'A': extract_answer(example['response'])})
    dataset = dataset.map(lambda example: {'Q+A': concat_QA(example,"prompt","A","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': concat_QA(example,"prompt","response","")})
    dataset = dataset.map(lambda example: {'Q+COT_A': replace_col(example,"Q+COT_A","#### ","")})
    # 查看处理后的数据集
    print("Q+A\n",dataset['train'][103]['Q+A'])
    print("Q+COT_A\n",dataset['train'][103]['Q+COT_A'])
    return dataset
    

In [14]:
args.steer

'neg-pos'

In [15]:
# Load and Prepare Dataset

logging.info("dataset path "+args.dataset_path)
if "neg" in args.steer or "pos" in args.steer and args.steer=="sentiment":
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set = load_and_prepare_triple_dataset(
        args.dataset_path, "sst5",args.seed, args.data_size
    )
elif "cot" in args.steer or "COT" in args.steer:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "polite" in args.steer:
    logging.info("polite"*10)
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set=load_and_prepare_triple_dataset(args.dataset_path,"polite", args.seed, args.data_size)
else:
    raise ValueError("No Supported")


2025-01-14 22:41:46,213 [INFO] dataset path /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 22:41:46,216 [INFO] Loading dataset from ****/home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5***
Repo card metadata block was not found. Setting CardData to empty.
2025-01-14 22:41:46,413 [INFO] Filtering dataset for negative, positive, and neutral samples
2025-01-14 22:41:46,426 [INFO] Selected 1000 negative, 1000 positive, and 1000 neutral samples


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 8544
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 1101
    })
    test: Dataset({
        features: ['text', 'label', 'label_text'],
        num_rows: 2210
    })
})


In [16]:
assert neg_train_set[10]!=pos_train_set[10]

In [17]:
pos_train_set[10],neg_train_set[10]

({'text': 'some are fascinating and others are not , and in the end , it is almost a good movie .',
  'label': 2,
  'label_text': 'neutral'},
 {'text': 'the lousy lead performances ... keep the movie from ever reaching the comic heights it obviously desired .',
  'label': 1,
  'label_text': 'negative'})

In [18]:


def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    logging.info("Running model with cache to obtain hidden states")
    batch_latents = []

    # 使用 tqdm 显示进度条
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch_texts = texts[i:i + batch_size]
        sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=False, device=device)
        batch_hidden_states = cache[hook_point]
        logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

        logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
        # 假设 sae.encode 支持批量编码
        latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
        batch_latents.append(latents)
        

    logging.info(f"Total batches processed: {len(batch_latents)}")
    return batch_latents

In [20]:

# def save_results(output_dir: str, nz_mean: Tensor, act_cnt: Tensor, generated_texts: list, hyperparams: dict):
#     os.makedirs(output_dir, exist_ok=True)
#     # Save nz_mean and act_cnt
#     nz_stats_path = os.path.join(output_dir, 'nz_stats.pt')
#     logging.info(f"Saving nz_mean and act_cnt to {nz_stats_path}")
#     torch.save({
#         'nz_mean': nz_mean,
#         'act_cnt': act_cnt,
#     }, nz_stats_path)

#     # Save generated texts
#     generated_texts_path = os.path.join(output_dir, 'generated_texts.txt')
#     logging.info(f"Saving generated texts to {generated_texts_path}")
#     with open(generated_texts_path, 'w') as f:
#         for text in generated_texts:
#             f.write(text + "\n")

#     # Save hyperparameters
#     hyperparams_path = os.path.join(output_dir, 'hyperparameters.json')
#     logging.info(f"Saving hyperparameters to {hyperparams_path}")
#     with open(hyperparams_path, 'w') as f:
#         json.dump(hyperparams, f, indent=4)

#     logging.info("All results saved successfully.")

In [21]:
output_dir_base = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean{args.topk_mean}"
)
output_dir_base

'./results/LLM_gpt2-small_layer_6_steer_neg-pos_alpha_100_cnt_100_mean100'

In [22]:
# def load_from_cache():
#     cache_exists = False
#     cache_file = os.path.join(output_dir_base, 'hyperparameters.json')
#     if os.path.exists(cache_file):
#         with open(cache_file, 'r') as f:
#             cached_data = json.load(f)
#         cached_hash = cached_data.get('hyperparams_hash')

#     if cache_exists:
#         # Load nz_mean and act_cnt from cache
#         # nz_stats_path = os.path.join(output_dir_base, 'nz_stats.pt')
#         # nz_act = torch.load(nz_stats_path)
#         # nz_mean = nz_act['nz_mean']
#         # act_cnt = nz_act['act_cnt']
#         # overlap_indices = nz_act.get('overlap_indices', None)  # If overlap_indices was saved
#         logging.info("load from cache")
#     else:
#         # overlap_indices = None  # Will be computed later
#         logging.info("non cache: "+cache_file)
# load_from_cache()

2025-01-14 22:42:21,835 [INFO] non cache: ./results/LLM_gpt2-small_layer_6_steer_neg-pos_alpha_100_cnt_100_mean100/hyperparameters.json


In [20]:
# setup_logging(output_dir_base)

In [23]:
# Save hyperparameters
hyperparams = vars(args)

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load environment
load_environment(args.env_path)

# Load model and SAE
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)

2025-01-14 22:42:37,666 [INFO] Hyperparameters:
2025-01-14 22:42:37,668 [INFO]   layer: 6
2025-01-14 22:42:37,671 [INFO]   LLM: gpt2-small
2025-01-14 22:42:37,673 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/sentiment/sst5
2025-01-14 22:42:37,674 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k
2025-01-14 22:42:37,675 [INFO]   output_dir: ./results
2025-01-14 22:42:37,676 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-14 22:42:37,677 [INFO]   task: sentiment
2025-01-14 22:42:37,678 [INFO]   seed: 42
2025-01-14 22:42:37,679 [INFO]   data_size: 1000
2025-01-14 22:42:37,681 [INFO]   device: cpu
2025-01-14 22:42:37,682 [INFO]   alpha: 100
2025-01-14 22:42:37,682 [INFO]   steer: neg-pos
2025-01-14 22:42:37,684 [INFO]   method: val_mul
2025-01-14 22:42:37,685 [INFO]   topk_mean: 100
2025-01-14 22:42:37,686 [INFO]   topk_cnt: 100
2025-01-14 22:42:37,687 [INFO]   batch_

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [22]:
# # Load dataset
# all_dataset = load_and_prepare_triple_dataset(
#     args.dataset_path, args.seed, args.data_size
# )

In [24]:
args.steer

'neg-pos'

In [25]:

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    logging.info("Computing non-zero element counts")
    act_cnt = (batch_latents != 0).sum(dim=(0, 1))

    logging.info("Computing sum of non-zero elements")
    nz_sum = torch.where(batch_latents != 0, batch_latents, torch.tensor(0.0, device=batch_latents.device)).sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    nz_mean = torch.where(act_cnt != 0, nz_sum / act_cnt, torch.tensor(0.0, device=batch_latents.device))

    logging.info("Selecting top-k indices based on nz_mean")
    nz_act_val, nz_val_indices = torch.topk(nz_mean, top_k_mean)
    logging.info(f"Top {top_k_mean} nz_mean values selected.")

    logging.info("Selecting top-k indices based on act_cnt")
    nz_cnt, cnt_indices = torch.topk(act_cnt, top_k_cnt)
    logging.info(f"Top {top_k_cnt} act_cnt values selected.")

    # logging.info("Finding overlapping indices between nz_mean and act_cnt top-k")
    # overlap_mask = torch.isin(nz_val_indices, cnt_indices)
    # overlap_indices = nz_val_indices[overlap_mask]
    # logging.info(f"Number of overlapping indices: {len(overlap_indices)}")
    # overlap_indices=overlap_indices
    return nz_mean, act_cnt,None

In [26]:
# Select a dataset steer based on steering preference

def get_activation_by_steer(texts:list):

    hook_point = sae.cfg.hook_name

    # Compute latents with batch processing
    batch_latents = compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    # 计算第二个维度的最大值
    max_dim1 = max(latent.shape[1] for latent in batch_latents)  # 第二个维度的最大值
    logging.info(f"最大长度:{max_dim1}")
    # 对每个 Tensor 进行填充（仅填充第二个维度）
    padded_latents_right = [
        torch.nn.functional.pad(latent, (0, 0, 0, max_dim1 - latent.size(1)), "constant", 0)
        for latent in batch_latents
    ]

    batch_latents_concatenated = torch.cat(padded_latents_right, dim=0)
    logging.info(f"Concatenated batch latents shape: {batch_latents_concatenated.shape}")

    # Analyze latents 
    nz_mean, act_cnt, _ = analyze_latents(batch_latents_concatenated, top_k_mean=args.topk_mean, top_k_cnt=args.topk_cnt)
    return {"nz_mean":nz_mean,"nz_cnt":act_cnt}

In [27]:
args.steer=args.steer.lower()

In [27]:
# all_dataset["train"]["Q+A"][:args.data_size][193]

## 26000(SAE稀疏神经元)对应的非零激活神经元激活统计信息，和激活值统计信息


In [28]:
steer_info={}

In [40]:
task

'sentiment'

In [41]:
steer_info={}
if args.steer=='polite-impolite' or task=="sentiment":
    logging.info(args.steer)
    text=pos_train_set["text"][:args.data_size]
    steer_info["pos"]=get_activation_by_steer(text)
    text=neg_train_set["text"][:args.data_size]
    steer_info["neg"]=get_activation_by_steer(text)
    text=neu_train_set["text"][:args.data_size]
    steer_info["neu"]=get_activation_by_steer(text)

2025-01-14 22:46:37,288 [INFO] neg-pos
2025-01-14 22:46:37,312 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/32 [00:00<?, ?it/s]2025-01-14 22:46:37,751 [INFO] Batch 1: Hidden states shape: torch.Size([32, 47, 768])
2025-01-14 22:46:37,752 [INFO] Encoding hidden states for batch 1
Processing batches:   3%|▎         | 1/32 [00:00<00:14,  2.07it/s]2025-01-14 22:46:38,191 [INFO] Batch 2: Hidden states shape: torch.Size([32, 51, 768])
2025-01-14 22:46:38,193 [INFO] Encoding hidden states for batch 2
Processing batches:   6%|▋         | 2/32 [00:00<00:13,  2.26it/s]2025-01-14 22:46:38,511 [INFO] Batch 3: Hidden states shape: torch.Size([32, 50, 768])
2025-01-14 22:46:38,513 [INFO] Encoding hidden states for batch 3
Processing batches:   9%|▉         | 3/32 [00:01<00:11,  2.59it/s]2025-01-14 22:46:38,845 [INFO] Batch 4: Hidden states shape: torch.Size([32, 52, 768])
2025-01-14 22:46:38,846 [INFO] Encoding hidden states for batch 4
Processing ba

In [42]:

if "cot" in args.steer:
    if args.steer in "cot-direct":
        texts=all_dataset["train"]["Q+A"][:args.data_size]
        print(type(texts))
        steer_info["direct"]=get_activation_by_steer(texts)
        
        texts=all_dataset["train"]["Q+COT_A"][:args.data_size]
        print(type(texts))
        steer_info["cot"]=get_activation_by_steer(texts)
        # print(texts[123])
    else:
        raise ValueError("????")

In [43]:
args.steer

'neg-pos'

In [32]:
# pos_train_set["text"][:args.data_size]

In [44]:
len(pos_train_set["text"][:args.data_size])

1000

In [45]:
args.data_size

1000

In [35]:

# for steer_key in ["pos","neu","neg"]:
#     if steer_key == "pos":
#         selected_set = pos_train_set
#     elif steer_key == "neg":
#         selected_set = neg_train_set
#     elif steer_key=="neu":
#         selected_set = neu_train_set

#     texts = selected_set["text"][:args.data_size]
#     a=get_activation_by_steer(texts)
#     steer_info[steer_key]=a


In [38]:
# steer_info["dif_neg_pos"]={"steer":"dif_neg_pos","nz_cnt":steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"],"nz_mean":steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"]}

In [39]:
# steer_info["dif_neg_pos_relu"]={"nz_cnt":torch.relu(steer_info["pos"]["nz_cnt"]-steer_info["neg"]["nz_cnt"]),"nz_mean":torch.relu(steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])}

In [40]:
# steer_info["dif_neg_pos_relu"],steer_info["dif_neg_pos"]

In [41]:
sourceource="cot"
target="direct"

In [46]:
source="pos"
target="neg"
# 调整样本正负性在这里调整 从负样本到正样本还是从正样本()到负样本
# pos 代表积极情绪
# neg 代表消极情绪


In [47]:
assert bool(torch.all((steer_info["pos"]["nz_mean"]-steer_info["neg"]["nz_mean"])==0))==False,"数据库读取有问题"

In [48]:
assert torch.all(steer_info[target]["nz_mean"]>=0),"所有SAE的激活需要大于d等于0（maybe）"

# 泽凯在这里做mask

In [50]:
# 
steer_info[f"dif_{target}-{source}_relu"]={"nz_cnt":torch.relu(steer_info[target]["nz_cnt"]-steer_info[source]["nz_cnt"]),"nz_mean":torch.relu(steer_info[target]["nz_mean"]-steer_info[source]["nz_mean"]),"target_nz_mean":torch.relu(steer_info[target]["nz_mean"])}
"""
nz_cnt: 神经元被激活的次数
nz_mean: 神经元被激活后的平均值
nz_mean_pos: 正样本神经元被激活后的平均值
"""
top_k=100
_,steer_indices=torch.topk(steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"],top_k)

In [46]:

# steer_info[f"dif_{a}-{b}_relu"]={"nz_cnt":torch.relu(steer_info[a]["nz_cnt"]-steer_info[b]["nz_cnt"]),"nz_mean":torch.relu(steer_info[a]["nz_mean"]-steer_info[b]["nz_mean"])}
# top_k=100
# steering_vectors,steer_indices=torch.topk(steer_info[f"dif_{a}-{b}_relu"]["nz_cnt"],top_k)

In [51]:
# 假设 steer_info[f"dif_{b}-{a}_relu"]["nz_cnt"] 是一个 NumPy 数组
nz_cnt = steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"]

# 先获取非零元素的索引
nz_indices = np.nonzero(nz_cnt)
torch.all(nz_cnt == 0)

tensor(False)

In [52]:
steer_info[f"dif_{target}-{source}_relu"]["nz_cnt"].shape

torch.Size([24576])

In [53]:
_,steer_indices,

(tensor([976, 976, 950, 866, 803, 697, 670, 634, 618, 582, 576, 570, 556, 549,
         534, 533, 526, 523, 506, 471, 470, 462, 456, 456, 445, 416, 415, 390,
         377, 377, 370, 367, 366, 360, 348, 347, 344, 320, 312, 307, 303, 302,
         291, 289, 288, 287, 273, 268, 263, 263, 260, 257, 256, 255, 249, 241,
         237, 235, 233, 233, 232, 231, 230, 228, 224, 219, 216, 216, 216, 215,
         212, 210, 205, 204, 201, 201, 200, 195, 195, 194, 194, 189, 187, 186,
         183, 182, 179, 177, 176, 175, 174, 173, 173, 172, 169, 167, 166, 165,
         163, 163]),
 tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
         17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
         18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
         10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
         10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
         11185,  1167, 18722,  

In [54]:
steer_indices

tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
        17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
        18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
        10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
        10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
        11185,  1167, 18722,  9082, 12797,   222, 18146,  4163,  7417,  8896,
         5203, 13252, 21943, 11859,  2282, 13817, 12488, 23355, 20830,  3219,
         6861, 24114, 16460,  5303,  9148, 12460,  1374, 10526,  5741,  7467,
         6545, 22906, 16318, 17732,  4173,   502, 22288, 23847, 17504,  2211,
        15256,  4063,  4994, 15262,  5330, 10282, 14954, 14440,  7383,    98])

In [55]:
steer_info[f"dif_{target}-{source}_relu"]["nz_mean"][steer_indices]# 这里有0,没有负数比较正常


tensor([7.9328e-02, 3.8317e-02, 2.0146e-02, 2.3163e-02, 2.8212e-02, 0.0000e+00,
        5.1741e-02, 4.0420e-02, 0.0000e+00, 2.9353e-02, 2.1048e-02, 0.0000e+00,
        0.0000e+00, 1.4081e-02, 7.1191e-02, 4.8021e-02, 1.1438e-04, 0.0000e+00,
        4.0448e-02, 4.5769e-02, 7.2892e-02, 4.8401e-03, 0.0000e+00, 9.2183e-03,
        0.0000e+00, 3.1627e-02, 1.8980e-02, 2.9644e-02, 1.5619e-02, 1.5609e-02,
        0.0000e+00, 2.3832e-02, 2.7588e-02, 0.0000e+00, 4.8306e-02, 1.6971e-02,
        4.3642e-02, 0.0000e+00, 2.7315e-02, 2.3624e-02, 0.0000e+00, 0.0000e+00,
        2.6217e-02, 6.4675e-02, 6.5274e-03, 0.0000e+00, 6.5705e-02, 3.3928e-03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.1977e-03, 6.2665e-03,
        2.9225e-03, 7.8505e-02, 3.3161e-02, 0.0000e+00, 1.0959e-02, 6.2663e-02,
        3.1515e-02, 1.6875e-02, 3.7916e-02, 4.5186e-02, 2.7820e-02, 0.0000e+00,
        1.0531e-02, 1.1299e-01, 4.5587e-02, 0.0000e+00, 0.0000e+00, 3.3021e-02,
        8.4106e-02, 5.1196e-02, 0.0000e+

In [52]:
# steer_info["dif_-pos"]["nz_mean"][steer_indices]

In [56]:
steer_indices

tensor([14899, 16053, 14950, 18318, 21264,  3857, 17462,  5318, 17539, 11413,
        17227,  3290, 23494,   970,  6103, 11779, 21409,  4075,  4283, 18398,
        18739,  8847,  1188, 20725, 16383, 21950,  4989, 18604, 22703, 14404,
        10516, 16101, 21762,  4978, 15489,  3639,  8272, 23000, 17560, 13611,
        10045, 10512,  4886,  8064,  8989,  5842,  5455,  8212, 14168, 10188,
        11185,  1167, 18722,  9082, 12797,   222, 18146,  4163,  7417,  8896,
         5203, 13252, 21943, 11859,  2282, 13817, 12488, 23355, 20830,  3219,
         6861, 24114, 16460,  5303,  9148, 12460,  1374, 10526,  5741,  7467,
         6545, 22906, 16318, 17732,  4173,   502, 22288, 23847, 17504,  2211,
        15256,  4063,  4994, 15262,  5330, 10282, 14954, 14440,  7383,    98])

In [57]:
# mean_type="dif_mean"
mean_type="dif_mean"
def compute_steering_vectors(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        steering_vectors = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        steering_vectors = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            steering_vectors += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {steering_vectors.shape}")
    return steering_vectors
if mean_type=="dif_mean":
    delta_matrix=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["nz_mean"],method="val_mul")
elif mean_type=="tar_mean":
    delta_matrix=compute_steering_vectors(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["target_nz_mean"],method="val_mul")
else:
    raise ValueError("Unsupported")

2025-01-14 22:48:43,574 [INFO] Computing steering vectors using method: val_mul
2025-01-14 22:48:43,589 [INFO] Steering vectors computed with shape: torch.Size([768])


In [55]:
# model.to_tokens("<|endoftext|>")

In [58]:
model.tokenizer.eos_token

'<|endoftext|>'

In [57]:
model.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}
)

In [59]:
delta_matrix #理论上这里有正有负比较正常

tensor([-2.3879e-02,  5.6569e-02, -9.4197e-02,  2.4240e-02, -2.4863e-02,
        -5.9525e-02,  5.8307e-03,  7.2583e-02, -5.9092e-02, -2.9759e-02,
        -3.3262e-02,  1.4656e-02, -9.0590e-03, -5.5997e-03,  2.3342e-03,
         8.0802e-03,  3.8378e-02, -2.0434e-02,  1.2447e-02,  1.0593e-01,
         1.8773e-02, -2.9495e-02,  2.5378e-02,  5.7542e-02,  2.9447e-02,
         5.8997e-03,  1.0855e-02, -4.7216e-02, -1.0463e-02, -3.8122e-02,
        -6.6728e-02, -1.1337e-02,  2.9258e-02, -4.6419e-02,  8.2180e-02,
         1.1130e-01,  2.3719e-02,  2.0910e-02, -1.2308e-01,  4.0414e-02,
         3.5823e-02, -6.0831e-03, -1.6691e-02,  4.7180e-02, -3.4195e-02,
        -9.1669e-03,  6.0183e-02, -7.7170e-03,  2.7681e-02, -9.1885e-03,
        -6.7075e-02, -2.3160e-02, -4.8160e-02,  3.6015e-02, -2.1627e-02,
        -1.4603e-04,  1.4520e-02, -1.3064e-02,  2.4791e-02,  8.9513e-03,
        -1.9214e-02,  5.8242e-02, -1.5163e-02,  1.3979e-02, -4.5548e-02,
         4.8255e-02, -2.1512e-02,  1.3172e-02,  5.4

# 这里得到的就是delta_matricx

In [60]:
sae.cfg.hook_name

'blocks.6.hook_resid_pre'

In [61]:
f"blocks.{args.layer}.hook_resid_post"

'blocks.6.hook_resid_post'

In [62]:
import torch.nn.functional as F
import torch

def half_gaussian_kernel(half_len):
    # 设置均值和标准差
    cov_len=2*half_len
    mean = cov_len // 2  # 正态分布的均值
    std = cov_len / 6    # 设置标准差，可以根据需要调整

    # 创建正态分布
    x = torch.arange(cov_len, dtype=torch.float32)
    kernel = torch.exp(-0.5 * ((x - mean) / std) ** 2)
    # print(kernel)

    # 仅保留正态分布的前半部分（右侧值设置为0）
    kernel[int(cov_len // 2):] = 0  # 保留前半部分，右半部分置为零

    # 归一化，确保总和为 1
    kernel = kernel / kernel.sum()
    return kernel[:half_len]

gauss=half_gaussian_kernel(4)
gauss
# k_gau=torch.cat([gauss, torch.tensor([0])])

tensor([0.0095, 0.0680, 0.2774, 0.6451])

In [63]:
gauss

tensor([0.0095, 0.0680, 0.2774, 0.6451])

In [65]:
# import einops
# test=torch.ones(2, 6,3)
# # test=einops.rearrange(test,"b s d->b d s")
# test.shape,k_gau.shape

In [216]:
# re_gau=einops.repeat(k_gau,"s -> b s h",b=2,h=3)

In [None]:
# test*re_gau,test

(tensor([[[0.0070, 0.0070, 0.0070],
          [0.0354, 0.0354, 0.0354],
          [0.1247, 0.1247, 0.1247],
          [0.3067, 0.3067, 0.3067],
          [0.5263, 0.5263, 0.5263],
          [0.0000, 0.0000, 0.0000]],
 
         [[0.0070, 0.0070, 0.0070],
          [0.0354, 0.0354, 0.0354],
          [0.1247, 0.1247, 0.1247],
          [0.3067, 0.3067, 0.3067],
          [0.5263, 0.5263, 0.5263],
          [0.0000, 0.0000, 0.0000]]]),
 tensor([[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],
 
         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]))

In [73]:
# # Define steering hook
# steering_on = True  # This will be toggled in run_generate
# alpha = args.alpha
# method = args.method  # Store method for clarity
from functools import partial
import einops
steer_cnt=0


def steering_hook(resid_pre, hook, steer_on, alpha, steer_type="last"):
    # 如果 seq_len 只有 1，则直接返回，不进行操作

    if resid_pre.shape[1] == 1:
        return
    # 判断是否进行干预
    if steer_on:
        if steer_type == "last":
            # 对最后一个token前的部分应用干预，使用给定的 delta_matrix
            
            resid_pre[:, :-1, :] += alpha * delta_matrix
            
            # logging.info(f"干预类型：last")
            # d_m_repeat=einops.repeat(d_m,"h -> b s h",b=b,s=s)
            # logging.info(f"干预矩阵: {alpha * d_m_repeat}")
        elif steer_type == "gaussian":
            # 使用高斯卷积对输入进行干预
            # s_idx=-1
            d_m=torch.clone(delta_matrix)
            s = resid_pre[:, :-1, :].shape[1]
            b=resid_pre[:, :-1, :].shape[0]
            h=resid_pre[:, :-1, :].shape[2]
            h_gauss = half_gaussian_kernel(s)  # 获取高斯卷积核
            # k_gauss=torch.cat([h_gauss, torch.tensor([0])])
            k_gauss=h_gauss
            k_gau_repeat=einops.repeat(k_gauss,"s -> b s h",b=b,h=h)
            d_m_repeat=einops.repeat(d_m,"h -> b s h",b=b,s=s)
            # 根据卷积结果更新 resid_pre（注意：保留其他维度不变）,逐一元素相乘
            resid_pre[:, :-1, :] += alpha * d_m_repeat* k_gau_repeat
            # logging.info(f"干预类型：高斯")
            # logging.info(f"干预矩阵: {alpha * d_m_repeat* k_gau_repeat}")
        else:
            raise ValueError("Unknown steering type")


        # elif steer_type=="last2":
        #     resid_pre[:, :-2, :] += args.alpha * steering_vectors
        # elif steer_type=="gaussian":
        #     # 高斯卷积的方式放缩干预矩阵，
        #     # 这里需要一个高斯核，然后对steering_vectors进行卷积
        #     gaussian_kernel = torch.tensor([1,2,1])
        #     steering_vectors = torch.conv1d(steering_vectors, gaussian_kernel, padding=1)
        #     resid_pre[:, :-1, :] += args.alpha * steering_vectors
        # else:
        #     raise ValueError("Unknown steering type")
        # 修改这里的干预方式，增加干预的选择，例如从倒数第一个token开始干预，或者从倒数第二个token开始干预，或者使用高斯卷积的方式放缩干预矩阵，这里是干预调整的关键，很有意思的是，如果提前干预效果会更好更连贯，还没尝试高斯卷积的方法

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch,prepend_bos=False)
        result = model.generate(
            stop_at_eos=True,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=50,
            do_sample=True,
            **kwargs,
        )
    return result

def run_generate(example_prompt,sampling_kwargs,steer_on,alpha,steer_type="last",repeat_num=3,show_res=False):
    model.reset_hooks()
    if steer_on:
        steering_hook_fn=partial(steering_hook,steer_on=steer_on,alpha=alpha,steer_type=steer_type)
        editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook_fn)]
    else:
        editing_hooks=[]
    res = hooked_generate(
        [example_prompt] * repeat_num, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, :])
    # print(("\n\n" + "-" * 80 + "\n\n").join(res_str))
    # generated_texts = res_str
    if show_res:
        for idx, text in enumerate(res_str):
            logging.info(f"Generated Text: {idx+1}:\n{text}")
        
    return res_str

# Define sampling parameters
sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)

# Example prompt from the selected set
example_prompt = "What really matters is that they know"
# example_prompt=model.tokenizer
# print( all_dataset["val"]["prompt"][:3],all_dataset["val"]["A"][:3])
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering
steer_on = False
alpha = 0
logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs,steer_on=steer_on,alpha=alpha,show_res=True)
logging.info("干预之后的结果")
# bef,aft=args.steer.split("-")
logging.info(f"干预方向{source}->{target},礼貌任务下，neg=impolite，情感任务下 pos=积极情感")
# Generate with steering
steer_on = True
alpha = 300
# alpha=args.aplha
logging.info("** Generating texts with steering... Target **")
generated_texts_with_steer = run_generate(
    example_prompt, 
    sampling_kwargs,
    steer_on=steer_on,
    alpha=alpha,
    steer_type="last",
    show_res=True)

# Combine generated texts
# all_generated_texts = generated_texts_no_steer + generated_texts_with_steer



2025-01-14 22:54:24,075 [INFO] Example prompt: What really matters is that they know
2025-01-14 22:54:24,077 [INFO] Generating texts **without** steering... 


  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-14 22:54:25,962 [INFO] Generated Text: 1:
What really matters is that they know what's going on. They know what's going on in the world, and they're willing to do whatever it takes to get there."

The two men are not alone in their desire for a change. The latest example comes from the U
2025-01-14 22:54:25,963 [INFO] Generated Text: 2:
What really matters is that they know what's going on. They know what's going on in the world, and they're willing to do whatever it takes to get there."

The two sides have been working together for months, but the relationship has been strained by the fact that
2025-01-14 22:54:25,963 [INFO] Generated Text: 3:
What really matters is that they know what's going on. They know what's going on in the world, and they're willing to do whatever it takes to get there."

"I think we've got a lot of work to do," he added. "We've got
2025-01-14 22:54:25,964 [INFO] 干预之后的结果
2025-01-14 22:54:25,965 [INFO] 干预方向pos->neg,礼貌任务下，neg=impolite，情感任务下 pos=积极情感
2025-

  0%|          | 0/50 [00:00<?, ?it/s]

2025-01-14 22:54:27,747 [INFO] Generated Text: 1:
What really matters is that they know what's going on. They're not going to let the worst of it happen to them."

The report also said that while the "unfortunate" deaths of two young women in a hotel room in September were "not shocking," they
2025-01-14 22:54:27,747 [INFO] Generated Text: 2:
What really matters is that they know what's going on. They're not going to be able to make a good case for the worst of it."

The former head coach, who was fired by the Bears in January after a three-year stint as offensive coordinator, said he
2025-01-14 22:54:27,748 [INFO] Generated Text: 3:
What really matters is that they know what's going on. They're not trying to make a movie about a bad person, or an asshole who's just killed their family. They're trying to make something that feels like it's real and real and real, and it doesn't feel


# 理论上来讲，
* 不礼貌的输出应该有很多疑问句？例如what？ ha？why？
* 而礼貌的输出应该有很多正常的词语
* 积极情感和不积极情感同理
* 从目前的实验来看，负向情感干预+礼貌情感干预表现比较好，可以拿这个做可解释性
* 频率很重要，我选取的latents选了前100频次的激活神经元
* 如果对[0:-1]的区间进行干预，效果异常优秀，生成比较连贯，但是如果对[-1:length]的区间进行干预，效果就很差，生成的词语很零碎

# 下面进行的是扭转实验，使用prompt对模型进行诱导，再进行转向

In [74]:
args.prompt_path

'/home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k'

In [75]:
import os
def load_and_prepare_sentiment_prompts(prompt_path:str,seed:int,num_samples:int):
    logging.info(f"Loading prompt_path from {prompt_path}")
    data_files = {"neg": "negative_prompts.jsonl", "pos": "positive_prompts.jsonl","neu":"neutral_prompts.jsonl"}
    
    prompts= load_dataset("/home/ckqsudo/code2024/0refer_ACL/LM-Steer/data/data/prompts/sentiment_prompts-10k",data_files=data_files)
    print(prompts)
    return prompts
prompts=load_and_prepare_sentiment_prompts(args.prompt_path,args.seed,1000)

2025-01-14 22:54:40,500 [INFO] Loading prompt_path from /home/ckqsudo/code2024/0dataset/baseline-acl/data/prompts/sentiment_prompts-10k


DatasetDict({
    neg: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 2500
    })
    pos: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 2500
    })
    neu: Dataset({
        features: ['md5_hash', 'prompt', 'continuation', 'num_positive'],
        num_rows: 5000
    })
})


In [69]:
# prompts["pos"]

In [71]:
# prompts["pos"][0]

{'md5_hash': 'c407590bd7adee91c20e54cb3e8d9407',
 'prompt': {'text': '"It sticks true to'},
 'continuation': {'text': 'end of this series.'},
 'num_positive': 25}

In [72]:
target

'neg'

In [89]:
sampling_kwargs['verbose']=False
sampling_kwargs

{'temperature': 1.0, 'top_p': 0.1, 'freq_penalty': 1.0, 'verbose': False}

In [68]:
# # 遍历数据集
# for idx,example in enumerate(list(prompts["pos"])[:10]):
#     print(idx,example)  # example 是一个字典

In [87]:

#     # 转换为 Pandas DataFrame
# df = prompts["pos"].to_pandas()

# # 遍历 DataFrame
# for row in df.iterrows():
#     # print(row.to_dict())  # 将每一行转换为字典

IndentationError: expected an indented block after 'for' statement on line 5 (4088799639.py, line 6)

In [90]:
prompts["pos"][1:4]


{'md5_hash': ['071313d997ab550da23b46fa9b498f33',
  'a8f9b1cca05dcd8036f6fb24a1f9f838',
  '62462719fff1c24d6e798403222fd3ca'],
 'prompt': [{'text': 'Loyal fans never miss'},
  {'text': 'Confucianism is deeply rooted in'},
  {'text': 'A President unafraid to stand'}],
 'continuation': [{'text': 'an important match.'},
  {'text': 'societies of east asia.'},
  {'text': 'for veterans and God!'}],
 'num_positive': [25, 25, 25]}

In [91]:
args.data_size

1000

In [92]:

alpha=400

In [97]:
# steer_type="last"
steer_type="gaussian"
alpha=1000

In [96]:


import copy
# Example prompt from the selected set
import jsonlines
res=[]
params={}
params["params"]={**vars(args),**sampling_kwargs,"max_new_tokens":50,"steer":f"from {source} to {target}"}
params["alpha"]=alpha
logging.info(f"Running with alpha: {alpha}")
logging.info(f"Running with prompt_type: "+str(params["params"]["steer"]))
# res.append(params)

no_steer_res=[]
steer_res=[]

pos_or_neg="pos"
assert pos_or_neg!=target,"prompt和转向的方向是一致的"
# 打开文件（模式为追加模式 'a'）
senti_gen_dir="/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/src/evaluations/gen_files"
senti_gen_dir=os.path.join(senti_gen_dir,f"alpha_{alpha}_senti_{pos_or_neg}_datasize_{args.data_size}_layer_{args.layer}_mean_{mean_type}_steertype_{steer_type}")
os.makedirs(senti_gen_dir,exist_ok=True)
with jsonlines.open(os.path.join(senti_gen_dir,"params.jsonl"), mode='w') as writer:
    writer.write(params)  # 逐行写入

show_compared=False
with jsonlines.open(os.path.join(senti_gen_dir,"no_steer_gen_res.jsonl"), mode='a') as nt_file:
    with jsonlines.open(os.path.join(senti_gen_dir,"steer_gen_res.jsonl"), mode='a') as t_file: 
        for idx,item in tqdm(enumerate(list(prompts[pos_or_neg])[:1000])):
            prompt=item["prompt"]["text"]
            item["label"]=pos_or_neg

            # 没转向的结果
            if show_compared:
                no_steer_gen_texts = run_generate(
                    prompt, 
                    sampling_kwargs,
                    steer_on=False,
                    alpha=None,
                    repeat_num=2,
                    steer_type=None,
                    show_res=False)
                no_steer_item=copy.deepcopy(item)
                no_steer_item["generations"]=[]
                for gen_text in no_steer_gen_texts:
                    no_steer_item["generations"].append({"text":gen_text})
                # no_steer_res.append(no_steer_item)
                nt_file.write(no_steer_item)
                # 转向的结果
            
            # steer_on = True
            steered_texts=run_generate(prompt, 
                                       sampling_kwargs,
                                       steer_on=True,
                                       alpha=alpha,
                                       steer_type=steer_type,
                                       repeat_num=2,
                                       show_res=False
                                       )
            
            steer_item=copy.deepcopy(item)
            steer_item["generations"]=[]
            for steer_gen in steered_texts:
                steer_item["generations"].append({"text":steer_gen})
            logging.info(steer_item["generations"][0]["text"])
            t_file.write(steer_item)
            
    

2025-01-15 00:37:45,238 [INFO] Running with alpha: 1000
2025-01-15 00:37:45,240 [INFO] Running with prompt_type: from pos to neg
0it [00:00, ?it/s]2025-01-15 00:37:46,980 [INFO] "It sticks true to the spirit of the book, and it's a great read. I'm not sure if it's going to be as good as what we've seen in 'The Dark Knight,' but I think that's what we're going for."

In
1it [00:01,  1.58s/it]2025-01-15 00:37:48,485 [INFO] Loyal fans never miss a chance to see the team play.

The team's first game of the season will be on Saturday, Sept. 7 at 8 p.m., when they host the San Diego Chargers in their home opener at Qualcomm Stadium.<|endoftext|>
2it [00:03,  1.53s/it]2025-01-15 00:37:50,065 [INFO] Confucianism is deeply rooted in the idea that the world is a place of great suffering and suffering, where people are made to feel like they're part of something bigger than themselves.

The story of how this happened has been told for centuries. It's also been told by
3it [00:04,  1.56s/it]2025-0

In [86]:
steer_res

[{'md5_hash': 'c407590bd7adee91c20e54cb3e8d9407',
  'prompt': {'text': '"It sticks true to'},
  'continuation': {'text': 'end of this series.'},
  'num_positive': 25,
  'label': 'pos',
  'generations': [{'text': '"It sticks true to the story of a young man who is kidnapped and taken to a remote village where he is forced to live in fear of his family. The film\'s premise is that the boy, named "Teddy," has been abducted by an evil cult and forced'},
   {'text': '"It sticks true to the spirit of the book, and it\'s a great read. I\'m very happy with it."\n\nThe book is available for pre-order now on Amazon.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'}]},
 {'md5_hash': '071313d997ab550da23b46fa9b498f33',
  'prompt': {'text': 'Loyal fans never miss'},
  'continuation': {'text': 'an important match.'},
  'num_positive': 25,
  '

In [None]:
 import jsonlines

# 要添加的数据（字典列表）
data = [
    {"name": "Alice", "age": 25},
    {"name": "Bob", "age": 30},
    {"name": "Charlie", "age": 35}
]

# 打开文件（模式为追加模式 'a'）
with jsonlines.open('data.jsonl', mode='a') as writer:
    for item in data:
        writer.write(item)  # 逐行写入

In [79]:
import json
with open("/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/src/evaluations/res.json",mode="w",encoding="utf-8") as res_f:
    res_f.write(json.dumps(res,ensure_ascii=False))
    