In [16]:
import argparse
import os
import re
import logging
from typing import Tuple
from tqdm import tqdm  # For progress bars
from log import setup_logging
import logging
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from sae_lens import SAE
from datasets import load_dataset
import numpy as np
import argparse
import logging

from data_preprocess import load_and_prepare_triple_dataset,load_and_prepare_COT_dataset
from utils import load_environment
from datasets import load_dataset
from dotenv import load_dotenv
import numpy as np

In [56]:
args_dict = {
    "layer": 6,  # Example layer number to analyze
    "LLM": "gpt2-small",
    "dataset_path": "/home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/StanceSentences",
    "prompt_path":"/home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/ibm_debate",
    "output_dir": "./results",
    "env_path": "/home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env",
    "task":"debate",
    "seed": 42,
    "data_size": "ALL",
    "device": "cuda",  # Options: "cpu", "cuda", "mps", "auto"
    "alpha": 50, # 这个alpha后面慢慢调节
    "steer": "sup-opp",  
    "source": "sup",
    "target": "opp",
    "method": "val_mul",  # Options: "mean", "val_mul" 用val_mul会比较好
    "topk_mean": 100, # 选取前topk 个均值激活，这个效果一般，会导致很多如what？why？这种被激活
    "topk_cnt": 100, # 选取前topk个频率激活，目前默认这个，效果很好
    "batch_size": 32, # 这个好像没用上
    "mean_type": "dif_mean",
    "steer_type": "last", # 这个好像没用上
}
args = argparse.Namespace(**args_dict)

sampling_kwargs = dict(temperature=1.0, top_p=0.1, freq_penalty=1.0)
sampling_kwargs['verbose']=False
TASK =args.task
STEER_TYPE=args.steer_type
ALPHA=args.alpha
MAX_NEW_TOKENS=100



# Logging Setup
import os
from log import setup_logging
import logging
# Create output directory base path
output_dir = os.path.join(
    args.output_dir,
    f"LLM_{args.LLM}_layer_{args.layer}_steer_{args.steer}_alpha_{args.alpha}_cnt_{args.topk_cnt}_mean_{args.topk_mean}_device_{args.device}"
)

# Setup logging
setup_logging(output_dir)

# Save hyperparameters
hyperparams = args_dict

# Log hyperparameters
logging.info("Hyperparameters:")
for key, value in hyperparams.items():
    logging.info(f"  {key}: {value}")

# Load Environment Variables
 
def load_environment(env_path: str):
    load_dotenv(env_path)
    hf_endpoint = os.getenv('HF_ENDPOINT', 'https://hf-mirror.com')
    logging.info(f"HF_ENDPOINT: {hf_endpoint}")

load_environment(args.env_path)


2025-01-20 16:46:48,073 [INFO] Logging initialized. Logs will be saved to ./results/LLM_gpt2-small_layer_6_steer_sup-opp_alpha_50_cnt_100_mean_100_device_cuda/execution.log
2025-01-20 16:46:48,076 [INFO] Hyperparameters:
2025-01-20 16:46:48,077 [INFO]   layer: 6
2025-01-20 16:46:48,078 [INFO]   LLM: gpt2-small
2025-01-20 16:46:48,079 [INFO]   dataset_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/StanceSentences
2025-01-20 16:46:48,080 [INFO]   prompt_path: /home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/ibm_debate
2025-01-20 16:46:48,080 [INFO]   output_dir: ./results
2025-01-20 16:46:48,082 [INFO]   env_path: /home/ckqsudo/code2024/CKQ_ACL2024/Control_Infer/SAE-simple/.env
2025-01-20 16:46:48,083 [INFO]   task: debate
2025-01-20 16:46:48,084 [INFO]   seed: 42
2025-01-20 16:46:48,084 [INFO]   data_size: ALL
2025-01-20 16:46:48,085 [INFO]   device: cuda
2025-01-20 16:46:48,087 [INFO]   alpha: 50
2025-01-20 16:46:48,088 [INFO]   steer: sup-opp
2025-01-20 16:46:4

In [21]:
logging.info("dataset path "+args.dataset_path)

def load_and_prepare_debate_triple_dataset(dataset_path: str, seed: int, num_samples):
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    dataset["train"] = dataset['train'].shuffle(seed=seed)

    logging.info("Filtering dataset for negative, positive, and neutral samples")
    if num_samples == "ALL":
        sup_train_set = dataset['train'].filter(lambda example: example['label'] == 'support')
        opp_train_set = dataset['train'].filter(lambda example: example['label'] == 'oppose')
    elif isinstance(num_samples, int):
        sup_train_set = dataset['train'].filter(lambda example: example['label'] == 'support').select(range(num_samples))
        opp_train_set = dataset['train'].filter(lambda example: example['label'] == 'oppose').select(range(num_samples))
    else:
        raise ValueError("num_samples must be int or ALL")
    logging.info(f"Selected {len(sup_train_set)} support and {len(opp_train_set)} oppose samples")
    val_set = dataset['validation']
    test_set = dataset["test"]
    return sup_train_set, opp_train_set, val_set, test_set

if TASK=="sentiment":
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set = load_and_prepare_triple_dataset(
        args.dataset_path, "sst5",args.seed, args.data_size
    )
elif "cot"==TASK:
    logging.info("COT "*10)
    all_dataset=load_and_prepare_COT_dataset(
        args.dataset_path, args.seed, args.data_size
    )
elif "polite"==TASK:
    logging.info("polite"*10)
    neg_train_set, pos_train_set, neu_train_set,val_set,test_set=load_and_prepare_triple_dataset(args.dataset_path,"polite", args.seed, args.data_size)
elif TASK=="debate":
    logging.info("debate"*10)
    sup_train_set, opp_train_set,val_set,test_set=load_and_prepare_debate_triple_dataset(args.dataset_path, args.seed, args.data_size)
else:
    raise ValueError("No Supported")


2025-01-20 16:14:49,642 [INFO] dataset path /home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/StanceSentences
2025-01-20 16:14:49,646 [INFO] debatedebatedebatedebatedebatedebatedebatedebatedebatedebate
2025-01-20 16:14:49,647 [INFO] Loading dataset from /home/ckqsudo/code2024/0dataset/baseline-acl/data/debate/StanceSentences
2025-01-20 16:14:49,808 [INFO] Filtering dataset for negative, positive, and neutral samples
2025-01-20 16:14:49,813 [INFO] Selected 486 support and 486 oppose samples


In [22]:
logging.info(f"Loading model: {args.LLM}")
model = HookedTransformer.from_pretrained(args.LLM, device=args.device)

logging.info(f"Loading SAE for layer {args.layer}")
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id=f"blocks.{args.layer}.hook_resid_pre",
    device=args.device
)

def analyze_latents(batch_latents: Tensor, top_k_mean: int = 100, top_k_cnt: int = 100) -> Tuple[Tensor, Tensor, Tensor]:
    SAE_LATENT_SIZE=sae.W_dec.shape[0]
    
    # 计算非0激活在对应位置的激活频率   
    logging.info("Computing non-zero element counts") 
    lat_freq = (batch_latents != 0).sum(dim=(0, 1))
    # 计算非0激活在对应位置的激活值的和
    logging.info("Computing sum of non-zero elements")
    lat_val_sum = batch_latents.sum(dim=(0, 1))

    logging.info("Computing mean of non-zero elements")
    # 
    assert batch_latents.shape[-1]==SAE_LATENT_SIZE==lat_val_sum.shape[0], "Latent dimension mismatch"
    return {"latent_frequency":lat_freq,"latent_value_sum":lat_val_sum}
def compute_latents(sae: SAE, model: HookedTransformer, texts: list, hook_point: str, device: str, batch_size: int) -> list:
    """
    计算 latents，支持批次处理。

    Args:
        sae (SAE): SAE 实例。
        model (HookedTransformer): Transformer 模型实例。
        texts (list): 文本列表。
        hook_point (str): 钩子点名称。
        device (str): 计算设备。
        batch_size (int): 每个批次的大小。

    Returns:
        list: 包含每个批次 latents 的张量列表。
    """
    SAE_LATENT_SIZE=sae.W_dec.shape[0]
    logging.info("Running model with cache to obtain hidden states")
    # batch_latents = []
    lat_freq,lat_val_sum=torch.zeros(SAE_LATENT_SIZE).to("cpu"),torch.zeros(SAE_LATENT_SIZE).to("cpu")# 避免OOM
    # 使用 tqdm 显示进度条
    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
            batch_texts = texts[i:i + batch_size]
            logging.info(f"Batch {i // batch_size + 1}: batch_size {batch_size}")
            try:
                sv_logits, cache = model.run_with_cache(batch_texts, prepend_bos=False, device="cuda")
            except Exception as e:
                logging.error(f"Error processing batch {i // batch_size + 1}: {e}")
                raise ValueError(str([len(i) for i in batch_texts]))
            batch_hidden_states = cache[hook_point]
            logging.info(f"Batch {i // batch_size + 1}: Hidden states shape: {batch_hidden_states.shape}")

            # logging.info(f"Encoding hidden states for batch {i // batch_size + 1}")
            # 假设 sae.encode 支持批量编码
            batch_latents = sae.encode(batch_hidden_states)  # 形状: (batch_size, latent_dim)
            batch_info=analyze_latents(batch_latents)
            lat_freq=lat_freq+batch_info["latent_frequency"].to("cpu")
            lat_val_sum=lat_val_sum+batch_info["latent_value_sum"].to("cpu")
    lat_val_mean=torch.where(lat_freq != 0, lat_val_sum / lat_freq, torch.tensor(0.0, device="cpu"))
    logging.info(f"Total non-zero element shape: {lat_freq.shape}")
    assert lat_freq.shape[0]==lat_freq.shape[0]==sae.W_dec.shape[0], "sae latent dimension mismatch"
    return {"latent_frequency":lat_freq.to(device),"latent_value_mean":lat_val_mean.to(device)}

def get_activation_by_steer(texts:list):
    hook_point = sae.cfg.hook_name
    # Compute latents with batch processing
    lat_info=compute_latents(sae, model, texts, hook_point, args.device, args.batch_size)
    return {"latent_value_mean":lat_info["latent_value_mean"],"latent_frequency":lat_info["latent_frequency"]}


2025-01-20 16:14:52,996 [INFO] Loading model: gpt2-small
2025-01-20 16:16:04,077 [INFO] Loading SAE for layer 6


Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [23]:
steer_info={}
from functools import partial
if TASK=='polite' or TASK=="sentiment":
    if args.device=="cpu":
        logging.info("CPU处理")
        from cpu_utils import get_activation_by_steer_cpu
        get_activation_by_steer_cpu=partial(get_activation_by_steer_cpu,sae=sae,model=model,device=args.device,batch_size=args.batch_size,top_k_mean=args.topk_mean,top_k_cnt=args.topk_cnt)
    logging.info("from"+args.source+"to"+args.target)
    logging.info(f"positive")
    text=pos_train_set["text"][:args.data_size]
    steer_info["pos"]=get_activation_by_steer(text)
    logging.info(f"negative")
    text=neg_train_set["text"][:args.data_size]
    steer_info["neg"]=get_activation_by_steer(text)
    logging.info(f"neutral")
    text=neu_train_set["text"][:args.data_size]
    steer_info["neu"]=get_activation_by_steer(text)
elif TASK=='debate':
    if args.device=="cpu":
        logging.info("CPU处理")
        from cpu_utils import get_activation_by_steer_cpu
        get_activation_by_steer_cpu=partial(get_activation_by_steer_cpu,sae=sae,model=model,device=args.device,batch_size=args.batch_size,top_k_mean=args.topk_mean,top_k_cnt=args.topk_cnt)
    logging.info("from"+args.source+"to"+args.target)
    logging.info(f"support")
    text=sup_train_set["text"]
    steer_info["sup"]=get_activation_by_steer(text)
    logging.info(f"oppose")
    text=opp_train_set["text"]
    steer_info["opp"]=get_activation_by_steer(text)

2025-01-20 16:16:58,605 [INFO] fromsuptoopp
2025-01-20 16:16:58,607 [INFO] support
2025-01-20 16:16:58,617 [INFO] Running model with cache to obtain hidden states
Processing batches:   0%|          | 0/16 [00:00<?, ?it/s]2025-01-20 16:16:58,622 [INFO] Batch 1: batch_size 32
2025-01-20 16:16:58,846 [INFO] Batch 1: Hidden states shape: torch.Size([32, 50, 768])
2025-01-20 16:16:58,848 [INFO] Computing non-zero element counts
2025-01-20 16:16:58,854 [INFO] Computing sum of non-zero elements
2025-01-20 16:16:58,855 [INFO] Computing mean of non-zero elements
Processing batches:   6%|▋         | 1/16 [00:00<00:03,  4.26it/s]2025-01-20 16:16:58,858 [INFO] Batch 2: batch_size 32
2025-01-20 16:16:58,902 [INFO] Batch 2: Hidden states shape: torch.Size([32, 48, 768])
2025-01-20 16:16:58,904 [INFO] Computing non-zero element counts
2025-01-20 16:16:58,905 [INFO] Computing sum of non-zero elements
2025-01-20 16:16:58,907 [INFO] Computing mean of non-zero elements
2025-01-20 16:16:58,908 [INFO] Batc

In [24]:
steer_info

{'sup': {'latent_value_mean': tensor([5.2969, 5.7562, 0.9151,  ..., 1.0190, 3.8079, 0.0000], device='cuda:0'),
  'latent_frequency': tensor([ 4.,  8.,  1.,  ..., 27.,  4.,  0.], device='cuda:0')},
 'opp': {'latent_value_mean': tensor([1.7246, 2.0211, 6.5648,  ..., 0.7094, 0.0000, 0.0693], device='cuda:0'),
  'latent_frequency': tensor([ 2.,  1.,  2.,  ..., 22.,  0.,  1.], device='cuda:0')}}

In [25]:

source=args.source
target=args.target
# 调整样本正负性在这里调整 从负样本到正样本还是从正样本()到负样本
# pos 代表积极情绪
# neg 代表消极情绪


# %%
assert bool(torch.all((steer_info["sup"]["latent_value_mean"]-steer_info["opp"]["latent_value_mean"])==0))==False,"数据库读取有问题"
assert torch.all(steer_info[target]["latent_value_mean"]>=0),"所有SAE的激活需要大于d等于0（maybe）"

logging.info(f"转向方向 dif_{target}-{source}_relu")
steer_info[f"dif_{target}-{source}_relu"]={"latent_frequency":torch.relu(steer_info[target]["latent_frequency"]-steer_info[source]["latent_frequency"]),"latent_value_mean":torch.relu(steer_info[target]["latent_value_mean"]-steer_info[source]["latent_value_mean"]),"target_nz_mean":torch.relu(steer_info[target]["latent_value_mean"])}

top_k=args.topk_cnt
_,steer_indices=torch.topk(steer_info[f"dif_{target}-{source}_relu"]["latent_frequency"],top_k)


# 假设 steer_info[f"dif_{b}-{a}_relu"]["latent_frequency"] 是一个 NumPy 数组
lat_freq = steer_info[f"dif_{target}-{source}_relu"]["latent_frequency"]
# 先获取非零元素的索引
lat_acti_indices = np.nonzero(lat_freq)
torch.all(lat_freq == 0)

2025-01-20 16:17:06,126 [INFO] 转向方向 dif_opp-sup_relu


tensor(False, device='cuda:0')

In [26]:
steer_info[f"dif_{target}-{source}_relu"]["latent_frequency"].shape


steer_info[f"dif_{target}-{source}_relu"]["latent_value_mean"][steer_indices]# 这里有0,没有负数比较正常

tensor([0.3309, 0.1696, 0.1970, 0.1038, 0.1705, 1.0051, 0.3258, 0.3164, 0.3293,
        0.3271, 0.8777, 0.0272, 0.0171, 0.0577, 0.0879, 0.0640, 0.0753, 0.0251,
        0.1720, 0.2338, 0.0384, 0.2178, 0.1202, 0.4218, 0.6355, 0.0438, 0.1294,
        0.0545, 0.2123, 0.1494, 0.1205, 0.0000, 0.0478, 0.0000, 0.0764, 0.1165,
        0.1423, 1.8745, 0.0809, 0.2035, 0.3341, 0.0258, 0.0178, 0.0141, 0.0673,
        0.0000, 0.8087, 0.1078, 0.0861, 0.2246, 1.3967, 0.1415, 0.0000, 0.2475,
        0.0000, 0.1158, 0.6421, 0.2660, 0.0269, 0.1622, 0.1523, 3.1804, 0.0508,
        0.5813, 0.3387, 0.1146, 0.0000, 0.0000, 1.1453, 0.2003, 0.0106, 0.7030,
        0.0364, 0.0989, 0.2309, 0.0233, 3.1515, 0.3171, 0.1475, 0.0000, 0.0000,
        0.0000, 0.1077, 0.0925, 0.5686, 0.1631, 0.0814, 0.0000, 0.1898, 0.0000,
        0.0683, 0.0000, 9.0391, 0.2754, 0.0833, 0.0353, 0.0876, 0.2485, 0.0763,
        0.0000], device='cuda:0')

In [27]:
def compute_delta_matrix(sae: SAE, indices: Tensor, nz_mean_val: Tensor, method: str = "val_mul") -> Tensor:
    logging.info(f"Computing steering vectors using method: {method}")
    if method == "mean":
        delta_matrix = torch.mean(sae.W_dec[indices], dim=0)
    elif method == "val_mul":
        delta_matrix = torch.zeros(sae.W_dec.shape[1], device=sae.W_dec.device)
        for idx in indices:
            delta_matrix += nz_mean_val[idx].item() * sae.W_dec[idx]
    else:
        raise ValueError(f"Unknown method: {method}")
    logging.info(f"Steering vectors computed with shape: {delta_matrix.shape}")
    return delta_matrix
if args.mean_type=="dif_mean":
    delta_matrix=compute_delta_matrix(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["latent_value_mean"],method="val_mul")
elif args.mean_type=="tar_mean":
    delta_matrix=compute_delta_matrix(sae,indices=steer_indices,nz_mean_val=steer_info[f"dif_{target}-{source}_relu"]["target_nz_mean"],method="val_mul")
else:
    raise ValueError("Unsupported")

import einops
steer_cnt=0

def steering_hook(resid_pre, hook,steer_on, alpha, steer_type="last"):
    if resid_pre.shape[1] == 1:
        return
    # 判断是否进行干预
    if steer_on:
        if steer_type == "last":
            # 对最后一个token前的部分应用干预，使用给定的 delta_matrix
            resid_pre[:, :-1, :] += alpha * delta_matrix# best
            # 如果提前干预效果会更好更连贯
        elif steer_type == "gaussian":
            # 使用高斯卷积对输入进行干预
            from utils import half_gaussian_kernel
            d_m=torch.clone(delta_matrix)
            s = resid_pre[:, :-1, :].shape[1]
            b=resid_pre[:, :-1, :].shape[0]
            h=resid_pre[:, :-1, :].shape[2]
            h_gauss = half_gaussian_kernel(s)  # 获取高斯卷积核
            # k_gauss=torch.cat([h_gauss, torch.tensor([0])])
            k_gauss=h_gauss
            k_gau_repeat=einops.repeat(k_gauss,"s -> b s h",b=b,h=h)
            d_m_repeat=einops.repeat(d_m,"h -> b s h",b=b,s=s)
            # 根据卷积结果更新 resid_pre（注意：保留其他维度不变）,逐一元素相乘
            resid_pre[:, :-1, :] += alpha * d_m_repeat* k_gau_repeat
            # logging.info(f"干预类型：高斯")
            # logging.info(f"干预矩阵: {alpha * d_m_repeat* k_gau_repeat}")
        elif steer_type == "all":
            resid_pre[:, :, :] += alpha * delta_matrix# 全部干预
        elif steer_type == "last2":
            resid_pre[:, :-2, :] += args.alpha * delta_matrix # 提前两个token进行干预
        else:
            raise ValueError("Unknown steering type")

def hooked_generate(prompt_batch, fwd_hooks=[], seed=None, **kwargs):
    if seed is not None:
        torch.manual_seed(seed)

    with model.hooks(fwd_hooks=fwd_hooks):
        tokenized = model.to_tokens(prompt_batch,prepend_bos=False)
        result = model.generate(
            stop_at_eos=True,  # avoids a bug on MPS
            input=tokenized,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            **kwargs,
        )
    return result

def run_generate(example_prompt,sampling_kwargs,steer_on,alpha,steer_type="last",repeat_num=3,show_res=False):
    model.reset_hooks()
    if steer_on:
        steering_hook_fn=partial(steering_hook,steer_on=steer_on,alpha=alpha,steer_type=steer_type)
        editing_hooks = [(f"blocks.{args.layer}.hook_resid_post", steering_hook_fn)]
    else:
        editing_hooks=[]
    res = hooked_generate(
        [example_prompt] * repeat_num, editing_hooks, seed=None, **sampling_kwargs
    )

    # Print results, removing the ugly beginning of sequence token
    res_str = model.to_string(res[:, :])
    if show_res:
        for idx, text in enumerate(res_str):
            logging.info(f"Generated Text: {idx+1}:\n{text}")
        
    return res_str

2025-01-20 16:17:10,578 [INFO] Computing steering vectors using method: val_mul
2025-01-20 16:17:10,598 [INFO] Steering vectors computed with shape: torch.Size([768])


In [44]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def eval_stance(text):
    eval_model_path = "/home/ckqsudo/code2024/0models/ACL_debate_model/distilbert-base-uncased__tweet_eval_stance__all-train"
    eval_tokenizer = AutoTokenizer.from_pretrained(eval_model_path)
    eval_model = AutoModelForSequenceClassification.from_pretrained(eval_model_path)
    inputs = eval_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = eval_model(**inputs)
        logits = outputs.logits

    # 解析输出，选择概率最大的类别
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    predicted_class = torch.argmax(probabilities, dim=-1).item()

    print(probabilities, predicted_class)
    class_names = ["none", "against", "support"]
    print(f"Predicted class: {class_names[predicted_class]}")


example_texts = ["I believe that it is the most restrictive in the country. It's not a perfect law, but it's certainly not one that will be overturned by the Supreme Court. It is very difficult to say what kind of effect this will have on women who are",
                 
"It's not an absolute ban, but it does make it harder for women to get abortions. It also makes it harder for women to get pregnant in the first place. So we're going to have more restrictions on abortion and more restrictions on",

"We should ban partial birth abortions. In 2007, the organization was involved in a Supreme Court case in which it helped to overturn a state law that required abortion providers to provide an ultrasound before they performed an abortion. The American Civil Liberties Union of Texas has filed a lawsuit against the state's law, arguing that it violates the First Amendment and violates women's right",

"We should ban partial birth abortions. In 2007, the organization was involved in a Supreme Court case in which it helped to overturn a state law that required abortion providers to provide ultrasounds for women who had been raped or sexually assaulted. The American Civil Liberties Union of Texas has filed a lawsuit against the state's medical board, arguing that its decision not to allow",

"I think it's really important to understand the history of the anti-war movement against the right-wing, against the American people against their own government, said Michael O'Brien, a professor at Stanford University's Center for Global and Global Security. ",
"I think it's really important to understand the history of the fight against that word, against the words that are used in this language, said Mr. Kitz. The word 'war' is not a name for any kind of weapon or weapon of any",
"I think it's really important to understand the history of the fight against that very same anti-war, anti-imperialist, and anti-imperialist movement against the American people, said Senator John Kerry, chairman of the Senate Foreign Relations Committee.",
"I believe that the U.S. is the only country on Earth that has made a significant contribution to the fight against this type of war against our own people against all of these different kinds of weapons, including those designed to destroy and destroy our very existence"]

In [57]:
# Example prompt from the selected set
example_prompt = "Trump is about to take office as the President of the United States. What do you think of this new president?"
logging.info(f"Example prompt: {example_prompt}")

# Generate without steering

logging.info("Generating texts **without** steering... ")
generated_texts_no_steer = run_generate(example_prompt, sampling_kwargs,steer_on=False,alpha=0,show_res=True)
logging.info("干预之后的结果")
# bef,aft=args.steer.split("-")
logging.info(f"干预方向{source}->{target},礼貌任务下，neg=impolite，情感任务下 pos=积极情感")
logging.info("** Generating texts with steering... Target **")
logging.info(f"form {source} to target")
generated_texts_with_steer = run_generate(
    example_prompt, 
    sampling_kwargs,
    steer_on=True,
    alpha=args.alpha,
    steer_type=args.steer_type,
    show_res=True)


2025-01-20 16:46:57,540 [INFO] Example prompt: Trump is about to take office as the President of the United States. What do you think of this new president?
2025-01-20 16:46:57,543 [INFO] Generating texts **without** steering... 
2025-01-20 16:46:59,488 [INFO] Generated Text: 1:
Trump is about to take office as the President of the United States. What do you think of this new president?

TRUMP: Well, I think he's going to be a very good president. He's going to be a very good leader. And I think he'll have a lot of great people in place, and I'm sure they're going to be very happy with him. But we're not going anywhere soon enough for that. We're not going anywhere soon enough for that because we've got so many problems and so many problems with our country, and it's time for us to get out
2025-01-20 16:46:59,491 [INFO] Generated Text: 2:
Trump is about to take office as the President of the United States. What do you think of this new president?

TRUMP: Well, I think he's going to be 

In [55]:
for text in generated_texts_no_steer:
    eval_stance(text.replace('Trump is about to take office as the President of the United States. What do you think of this new president?','').replace('<|endoftext|>',''))
for text in generated_texts_with_steer:
    eval_stance(text.replace('Trump is about to take office as the President of the United States. What do you think of this new president?','').replace('<|endoftext|>',''))

tensor([[0.1479, 0.7170, 0.1352]]) 1
Predicted class: against
tensor([[0.1438, 0.7172, 0.1390]]) 1
Predicted class: against
tensor([[0.1479, 0.7170, 0.1352]]) 1
Predicted class: against
tensor([[0.1942, 0.6852, 0.1206]]) 1
Predicted class: against
tensor([[0.1360, 0.7323, 0.1317]]) 1
Predicted class: against
tensor([[0.2102, 0.6561, 0.1337]]) 1
Predicted class: against


In [4]:
# Please install OpenAI SDK first: `pip3 install openai`
from dotenv import load_dotenv
from openai import OpenAI
import os
load_dotenv()
deepseek_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=deepseek_api_key, base_url="https://api.deepseek.com")

text = "Trump is about to take office as the President of the United States. What do you think of this new president? TRUMP: Well, I think he's going to be a very good president. He's going to be a very good leader. And I think he'll have a lot of great people in place, and I'm sure they're going to be very happy with him. But we're not going anywhere soon enough for that. We're not going anywhere soon enough for that because we've got so many problems and so many problems with our country, and it's just too much time for us"

response = client.chat.completions.create(
    model="deepseek-chat",
    messages=[
        {"role": "system", "content": "You are a helpful assistant, you need to help me determine the stance of a given sentence, whether it conveys a supportive or an opposing tone. Your answer can only be one of the two words 'support' or 'oppose'."},
        {"role": "user", "content": text},
    ],
    stream=False
)

print(response.choices[0].message.content)

support


In [None]:

args.prompt_path

def load_and_prepare_sentiment_prompts(prompt_path:str,task:str):
    assert task in ["sentiment"],"请输入正确的任务"
    logging.info(f"Loading prompt_path from {prompt_path}")
    prompt_files = {"neg": "negative_prompts.jsonl", "pos": "positive_prompts.jsonl","neu":"neutral_prompts.jsonl"}
    prompts= load_dataset("/home/ckqsudo/code2024/0refer_ACL/LM-Steer/data/data/prompts/sentiment_prompts-10k",data_files=prompt_files)
    print(prompts)
    return prompts



import copy
# Example prompt from the selected set
import jsonlines

def eval_on_full_data():
    if TASK=="sentiment":
        prompts=load_and_prepare_sentiment_prompts(prompt_path=args.prompt_path,task=TASK)
    elif TASK=="politeness":
        # prompts=load_and_prepare_politeness_prompts(pormpt_path=args.prompt_path,sample=args.seed)
        pass
    else:
        raise NotImplementedError("No Supported Task")
    prompts=load_and_prepare_sentiment_prompts(prompt_path=args.prompt_path,task=TASK)
    
    param={**vars(args),**sampling_kwargs,"max_new_tokens":50,"steer":f"from {source} to {target}"}
    param["alpha_recheck"]=ALPHA
    logging.info(f"Running with alpha: {ALPHA}")
    logging.info(f"Running with prompt_type: "+str(param["steer"]))
    # res.append(params)

    no_steer_res=[]
    steer_res=[]

    assert source in prompts,"prompt steer source not in prompts"
    # 打开文件（模式为追加模式 'a'）
    with jsonlines.open(os.path.join(output_dir,"params.jsonl"), mode='w') as writer:
        writer.write(param)  # 逐行写入
    SAVE_COMPARED=args.save_compared
    with jsonlines.open(os.path.join(output_dir,"no_steer_gen_res.jsonl"), mode='w') as nt_file:
        with jsonlines.open(os.path.join(output_dir,"steer_gen_res.jsonl"), mode='w') as t_file: 
            for idx,item in tqdm(enumerate(list(prompts[source])[:])):
                prompt=item["prompt"]["text"]
                item["label"]=source
                # 没转向的结果
                if SAVE_COMPARED:
                    no_steer_gen_texts = run_generate(
                        prompt, 
                        sampling_kwargs,
                        steer_on=False,
                        alpha=None,
                        repeat_num=2,
                        steer_type=None,
                        show_res=False)
                    no_steer_item=copy.deepcopy(item)
                    no_steer_item["generations"]=[]
                    for gen_text in no_steer_gen_texts:
                        no_steer_item["generations"].append({"text":gen_text})
                    # no_steer_res.append(no_steer_item)
                    nt_file.write(no_steer_item)
                    # 转向的结果
                
                # steer_on = True
                steered_texts=run_generate(prompt, 
                                        sampling_kwargs,
                                        steer_on=True,
                                        alpha=ALPHA,
                                        steer_type=STEER_TYPE,
                                        repeat_num=2,
                                        show_res=False
                                        )
                steer_item=copy.deepcopy(item)
                steer_item["generations"]=[]
                for steer_gen in steered_texts:
                    steer_item["generations"].append({"text":steer_gen})
                if idx%50==0:
                    logging.info(f"{TASK}: from {source} to {target} prompt_set: {source}")
                    logging.info(steer_item)
                t_file.write(steer_item)

if args.debug:
    logging.info(f"debug mode, no full data eval")
else:
    eval_on_full_data()