In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Set to the GPU you want to use

In [None]:
import re
import datetime
import traceback
import random
import pickle
import collections
from functools import partial
import dill
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools
import lovely_tensors as lt
lt.monkey_patch()

from llmg.utils.mix import seed_all
from llmg.utils.steering import SteeringHook
from llmg.chameleon.NaturalLanguageTalker.HuggingfaceTalker import HuggingfaceTalker
from llmg.chameleon.constants import (
    GAME_START_PROMPT,
    DISTRIBUTE_INDICES_PROMPT,
    DISTRIBUTE_CHAMELEON_IDENTITY_PROMPT,
    DISTRIBUTE_NON_CHAMELEON_IDENTITY_PROMPT,
    DISTRIBUTE_CATEGORY_PROMPT,
    RESPOND_PROMPT,
)

In [None]:
# Global config
cfg = {
    "seed": 0,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "run_time": datetime.datetime.now().strftime("%Y-%m-%d_%H-%M"),
}

In [None]:
# Steering config
cfg["steering"] = {
    ### Data collection config
    # "model_name": "Qwen/Qwen3-32B-AWQ",
    "model_name": "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
    "layer_idx": 40,
    # "layer_idx": [30, 40, 50],  # For Qwen3-32B
    "token_idx": 0,
    "generation_kwargs": {
        "max_new_tokens": 20,
        "do_sample": False,
        "temperature": None,
        "top_p": None,
        "top_k": None,
    },
    "cards_path": "chameleon_cards.pkl",
    "save_path": os.path.join(os.environ["DATA_DIR"], f"chameleon/steering/{cfg['run_time']}.pkl"),
}
os.makedirs(os.path.dirname(cfg["steering"]["save_path"]), exist_ok=True)
print(f"Steering save path:\n{cfg['steering']['save_path']}")

## Prepare data

In [None]:
# Load model and tokenizer
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
    modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']

tokenizer = AutoTokenizer.from_pretrained(cfg["steering"]["model_name"], trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(cfg["steering"]["model_name"], trust_remote_code=True, device_map="auto")

In [None]:
# Load chameleon cards
with open(cfg["steering"]["cards_path"], "rb") as f:
    chameleon_cards = pickle.load(f)
print(f"Loaded {len(chameleon_cards)} chameleon cards from {cfg['steering']['cards_path']}")

### Instructions + categories

In [None]:
# Setup for collecting data for strategy steering
cfg["steering"]["strategy_prompt"] = """
On a scale of 1 to 10, where 1 is very cryptic and concealing, and 10 is highly information-revealing, your strategy should be around {strategy_level}.
""".strip()

seed_all(cfg["seed"])

# Generate combinations of game state options
cfg["steering"]["game_state_opts"] = {
    "num_players": [3, 4, 5, 6],
    # "num_players": [4, 6],  # Full analysis
    # "categories": list(chameleon_cards.keys()), # Full analysis
    # "categories": ["Famous Islands"], # If you want to test a specific category
    # "categories": ["School"], # If you want to test a specific category
    "categories": [list(chameleon_cards.keys())[random.randint(0, len(chameleon_cards) - 1)]],
    "is_chameleon": [False],
}
opt_combinations = [
    dict(zip(cfg["steering"]["game_state_opts"].keys(), v))
    for v in itertools.product(*cfg["steering"]["game_state_opts"].values())
]

cfg["steering"]["instruction_opts"] = [
    (cfg["steering"]["strategy_prompt"].format(strategy_level=1) + " ", 1),
    (cfg["steering"]["strategy_prompt"].format(strategy_level=3) + " ", 3),
    (cfg["steering"]["strategy_prompt"].format(strategy_level=5) + " ", 5),
    (cfg["steering"]["strategy_prompt"].format(strategy_level=7) + " ", 7),
    (cfg["steering"]["strategy_prompt"].format(strategy_level=9) + " ", 9),
    (cfg["steering"]["strategy_prompt"].format(strategy_level=10) + " ", 10),
]

# Randomly select K combinations
seed_all(cfg["seed"])
cfg["steering"]["max_k_combinations"] = 300
random.shuffle(opt_combinations)
cfg["steering"]["opt_combinations"] = opt_combinations[:cfg["steering"]["max_k_combinations"]]
print(f"Selected {len(cfg['steering']['opt_combinations'])} combinations for probing.")
print(f"Combined with instruction options results in {len(cfg['steering']['opt_combinations']) * len(cfg['steering']['instruction_opts'])} total probing data points.")

In [None]:
def construct_game_msg_list(
    num_players,
    player_number,
    category,
    secret_words,
    secret_word,
    is_chameleon,
    instruction,
):
    msgs = []
    msgs.append({"role": "user", "content": GAME_START_PROMPT.format(num_players=num_players)})
    msgs.append({"role": "assistant", "content": "yes"})
    
    msgs.append({"role": "user", "content": DISTRIBUTE_CATEGORY_PROMPT.format(category=category, possible_words=", ".join(secret_words))})
    msgs.append({"role": "assistant", "content": "yes"})
    
    msgs.append({"role": "user", "content": DISTRIBUTE_INDICES_PROMPT.format(player_number=player_number)})
    msgs.append({"role": "assistant", "content": str(player_number)})

    if is_chameleon:
        msgs.append({"role": "user", "content": DISTRIBUTE_CHAMELEON_IDENTITY_PROMPT})
        msgs.append({"role": "assistant", "content": "yes"})
    else:
        msgs.append({"role": "user", "content": DISTRIBUTE_NON_CHAMELEON_IDENTITY_PROMPT.format(secret_word=secret_word)})
        msgs.append({"role": "assistant", "content": "no"})

    msgs.append({"role": "user", "content": RESPOND_PROMPT.format(previous_words="", instruction=instruction).strip()})
    return msgs

In [None]:
# Collect instruction messages
seed_all(cfg["seed"])
instruct_msg_dicts = []
for opt_i, opt in enumerate(cfg["steering"]["opt_combinations"]):
    if opt_i % 10 == 0:
        print(f"Processing option {opt_i + 1}/{len(cfg['steering']['opt_combinations'])}: {opt}")

    # Pick secret word
    possible_secret_words = chameleon_cards[opt["categories"]]
    secret_word = random.choice(possible_secret_words)

    # Generate messages for each possible instruction
    for (instruction_str, info_revealing_level) in cfg["steering"]["instruction_opts"]:
        # Construct the game message list
        msgs = construct_game_msg_list(
            num_players=opt["num_players"] - 1,
            player_number="1",
            category=opt["categories"],
            secret_words=possible_secret_words,
            secret_word=secret_word,
            is_chameleon=opt["is_chameleon"],
            instruction=instruction_str,
        )

        # Collect hidden states from the model
        talker = HuggingfaceTalker(
            model_id=model.name_or_path,
            model=model,
            tokenizer=tokenizer,
            additional_generation_kwargs=cfg["steering"]["generation_kwargs"],
            hidden_states_layer_idx=cfg["steering"]["layer_idx"],
            hidden_states_token_idx=cfg["steering"]["token_idx"],
            start_conversation=False,
        )
        response_dict = talker.get_llm_response_and_hidden_states(
            messages=msgs,
            max_new_tokens=cfg["steering"]["generation_kwargs"]["max_new_tokens"],
            return_logprobs=True,
            get_all_logprobs=True,
        )
        response_text, hidden_states = response_dict["content"], response_dict["hidden_states"]
        logprobs = response_dict.get("logprobs", None)

        instruct_msg_dicts.append({
            "messages": msgs,
            "category": opt["categories"],
            "num_players": opt["num_players"],
            "is_chameleon": opt["is_chameleon"],
            "player_number": "1",
            "secret_word": secret_word,
            "instruction": instruction_str,
            "info_revealing_level": info_revealing_level,
            "response_text": response_text,
            "hidden_states": hidden_states,
            "logprobs": logprobs,
            "config": cfg,
        })

# Save the instruction messages
with open(cfg["steering"]["save_path"], "wb") as f:
    pickle.dump(instruct_msg_dicts, f)
print(f"Saved {len(instruct_msg_dicts)} instruction messages to:\n{cfg['steering']['save_path']}")

## Analyze data

In [None]:
# Print the instruction messages
for i in range(len(instruct_msg_dicts)):
    print(f"Instruction {i+1}:")
    for m in instruct_msg_dicts[i]["messages"]:
        print(m["content"])
    print("Response:", instruct_msg_dicts[i]["response_text"])
    print("Info Revealing Level:", instruct_msg_dicts[i]["info_revealing_level"])
    print("-" * 50)

In [None]:
# Load previously collected data for plotting
## With instructions
with open(os.path.join(os.environ["DATA_DIR"], "chameleon/steering/2025-07-24_13-09.pkl"), "rb") as f:
    instruct_msg_dicts = pickle.load(f)

## Without instructions (steering)
with open(os.path.join(os.environ["DATA_DIR"], "chameleon/steering/2025-08-12_15-23.pkl"), "rb") as f:
    no_instruct_msg_dicts = pickle.load(f)

In [None]:
# Plot PCA of strategy hidden states
from sklearn.decomposition import PCA

# Filter out unwanted info revealing levels
remove_num_of_players = []
remove_info_levels = []
add_steering_vec = True # !!! Make sure you have already computed the steering vector computed in the cell below !!!
hidden_state_idx = 0  # Index of the hidden state to use for PCA
# steer_by_opts = {3: "lightsalmon", 0: "lightcoral", -3: "brown"}
steer_by_opts = dict()
fontsize = 24
save_to_path = None

# Prepare the data for PCA
all_hidden_states = []
info_levels = []
for item in instruct_msg_dicts:
    if item["info_revealing_level"] in remove_info_levels:
        continue
    hs = item["hidden_states"]
    if isinstance(hs, torch.Tensor):
        hs = hs.detach().cpu().numpy()
    hs = hs.squeeze(0)
    if len(all_hidden_states) == 0 and len(hs) > 1:
        print(f"[INFO] {len(hs)} hidden states, you are selecting the {hidden_state_idx}th one.")
    all_hidden_states.append(hs[hidden_state_idx])
    info_levels.append(item["info_revealing_level"])

# Fit PCA
X = np.array(all_hidden_states) # (n_samples, n_features)
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

# Plot instruction PCA
sns.set_theme(style="whitegrid")
plt.figure(figsize=(15, 8))
scatter = plt.scatter(
    X_pca[:, 0],
    X_pca[:, 1],
    c=info_levels,
    cmap='viridis',
    alpha=0.1 if len(steer_by_opts) > 0 else 0.4,
    # s=50 # marker size
    s=100 # marker size
)
# plt.title("PCA of LLM hidden states by strategy level", fontsize=fontsize)
plt.xlabel(f"Principal component 1 ({pca.explained_variance_ratio_[0]:.2%})", fontsize=fontsize, labelpad=15)
plt.ylabel(f"Principal component 2 ({pca.explained_variance_ratio_[1]:.2%})", fontsize=fontsize, labelpad=15)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.gca().xaxis.set_major_locator(plt.MaxNLocator(5))
plt.gca().yaxis.set_major_locator(plt.MaxNLocator(5))
cbar = plt.colorbar(scatter, pad=0.02)
cbar.ax.tick_params(labelsize=fontsize)
cbar.set_label("Information-revealing level\n(1 = Cryptic, 10 = Revealing)", fontsize=fontsize, labelpad=75, rotation=270)

# Add the steering vector arrow
if add_steering_vec:
    if "steering_vector_np" not in globals():
        print("[WARNING] Steering vector not found. Please compute it first before adding it to the PCA plot.")
    else:
        projected_steering_vector = -steering_vector_np @ pca.components_.T
        plt.arrow(
            x=0,
            y=0,
            dx=projected_steering_vector[0],
            dy=projected_steering_vector[1],
            color='tab:red',
            width=.15,
            head_width=0.4,
            head_length=0.3,
            length_includes_head=False,
            zorder=2, # Draw arrow on top of scatter points
            label='Steering direction'
        )
        plt.legend(loc='upper left', fontsize=fontsize-2, framealpha=1)

# Plot no-instruction PCA
all_hidden_states_new = {k: [] for k in steer_by_opts.keys()}
labels_new = {k: f"Steering strength {k}" for k in steer_by_opts.keys()}
colors_new = {k: [] for k in steer_by_opts.keys()}
for item in no_instruct_msg_dicts:
    for steering_strength, steering_color in steer_by_opts.items():
        hs = item["hidden_states"] + steering_strength * steering_vector_np
        if isinstance(hs, torch.Tensor):
            hs = hs.detach().cpu().numpy()
        all_hidden_states_new[steering_strength].append(hs.squeeze())
        colors_new[steering_strength].append(steering_color)
X_pca_new = {k: pca.transform(np.array(all_hidden_states_new[k])) for k in all_hidden_states_new.keys()}

for k in X_pca_new.keys():
    print(f"Shape of new PCA data: {X_pca_new[k].shape}")

scatter_news = []
for steering_strength, steering_color in steer_by_opts.items():
    if steering_strength not in X_pca_new:
        continue
    scatter_news.append(plt.scatter(
        X_pca_new[steering_strength][:, 0],
        X_pca_new[steering_strength][:, 1],
        c=colors_new[steering_strength],
        label=labels_new[steering_strength],
        cmap='plasma', # 'viridis', 'plasma', 'coolwarm' are good options
        alpha=1,
        s=110, # marker size
        # s=190, # marker size
        marker='^'
    ))
    plt.legend(handles=scatter_news, loc='best', fontsize=fontsize-2, framealpha=1)

plt.grid(True, which='both', linestyle='--', linewidth=0.5, alpha=0.6)
plt.ylim(None, 4.2)
if save_to_path is not None:
    plt.savefig(save_to_path, bbox_inches='tight')
    print(f"Saved PCA figure to {save_to_path}")
plt.show()

## Steering

In [None]:
# Compute the steering vector
print("Computing the steering vector...")
info_level_start, info_level_end = 3, 9
hidden_state_idx = 0
all_hidden_states = {info_level_start: [], info_level_end: []}
for item in instruct_msg_dicts:
    if item["info_revealing_level"] not in all_hidden_states:
        continue

    hs = item["hidden_states"].detach().cpu().numpy().squeeze(0) # (n_hidden_states, hidden_state_dim)
    if len(all_hidden_states[item["info_revealing_level"]]) == 0 and len(hs) > 1:
        print(f"[INFO] {len(hs)} hidden states, you are selecting the {hidden_state_idx}th one.")
    all_hidden_states[item["info_revealing_level"]].append(hs[hidden_state_idx])

# Compute difference-in-means steering vector
mean_start = np.mean(all_hidden_states[info_level_start], axis=0)
mean_end = np.mean(all_hidden_states[info_level_end], axis=0)
steering_vector_np = mean_end - mean_start
steering_vector_norm_orig = np.linalg.norm(steering_vector_np)
print(f"Steering vector norm (original): {steering_vector_norm_orig:.4f}")

# Normalize and convert to a PyTorch tensor
steering_vector_np /= steering_vector_norm_orig
steering_vector = torch.tensor(steering_vector_np, dtype=torch.float32)

# Save
torch.save(steering_vector, os.path.join(os.environ["DATA_DIR"], "chameleon", f"steering_vector_{cfg['run_time']}.pt"), pickle_module=dill)
print(f"Steering vector saved to: ", os.path.join(os.environ['DATA_DIR'], 'chameleon', f"steering_vector_{cfg['run_time']}.pt"))

In [None]:
# Preprocess instruct_msg_dicts - combine responses with the same game state but different info revealing levels
instruct_msg_dicts_uniq = []
for item in instruct_msg_dicts:
    if len(instruct_msg_dicts_uniq) > 0 \
        and item["category"] == instruct_msg_dicts_uniq[-1]["category"] \
        and item["secret_word"] == instruct_msg_dicts_uniq[-1]["secret_word"]:
        instruct_msg_dicts_uniq[-1]["responses"][item["info_revealing_level"]] = item["response_text"]
        instruct_msg_dicts_uniq[-1]["hidden_states"][item["info_revealing_level"]] = item["hidden_states"]
    else:
        instruct_msg_dicts_uniq.append({
            "category": item["category"],
            "secret_word": item["secret_word"],
            "messages": item["messages"],
            "responses": {
                item["info_revealing_level"]: item["response_text"]
            },
            "hidden_states": {
                item["info_revealing_level"]: item["hidden_states"]
            },
        })
print(f"Unique instruction messages: {len(instruct_msg_dicts_uniq)} (originally {len(instruct_msg_dicts)})")

In [None]:
# Generate new responses with the steering vector
apply_steering_at_token = -1 # None means apply at all tokens
hidden_states_layer_idx = 40
hidden_states_token_idx = 0
completely_replace_activations = False
additional_generation_kwargs={
    "max_new_tokens": 10,
    "do_sample": False,
    "temperature": None,
    "top_p": None,
    "top_k": None,
}
max_k_game_states = -1
steering_strengths = []
steering_strengths.extend([0, -8, -16, -32])
steered_results = []
data_to_steer = instruct_msg_dicts_uniq[:max_k_game_states]

print("\nGenerating new responses with steering vector applied...")
for i, item in enumerate(data_to_steer):
    secret_word = item["secret_word"]
    print(f"--- Processing item {i+1}/{len(data_to_steer)} (Category: {item['category']}, Secret: {secret_word}, Responses: {', '.join([it + ' (' + str(level) + ')' for level, it in item['responses'].items()])}) ---")

    # Prepare the neutral messages for this item (w/out instruction)
    neutral_messages = list(item["messages"])
    neutral_messages[-1] = {
        "role": "user", "content": RESPOND_PROMPT.format(previous_words="", instruction="").strip()
    }

    item_steered_responses = {
        "original_responses": item["responses"],
        "steered_outputs": dict(),
    }

    for strength in steering_strengths:
        # The hook will be active only inside this 'with' block
        hook = SteeringHook(
            model,
            layer_index=hidden_states_layer_idx,
            token_index=apply_steering_at_token,
            steering_vector=steering_vector,
            steering_strength=strength,
            completely_replace=completely_replace_activations,
        )
        with hook:
            # Get the response
            talker = HuggingfaceTalker(
                model_id=model.name_or_path,
                model=model,
                tokenizer=tokenizer,
                additional_generation_kwargs=additional_generation_kwargs,
                hidden_states_layer_idx=hidden_states_layer_idx,
                hidden_states_token_idx=hidden_states_token_idx,
                start_conversation=False,
            )
            
            # We don't need hidden states from the steered generation, but we could get them
            steered_response_dict = talker.get_llm_response_and_hidden_states(
                messages=neutral_messages,
                max_new_tokens=additional_generation_kwargs["max_new_tokens"],
                return_logprobs=True,
                get_all_logprobs=True,
            )
            steered_response, steered_hidden_states = steered_response_dict["content"], steered_response_dict["hidden_states"]
            logprobs = steered_response_dict.get("logprobs", None)
        
        # Store the results
        item_steered_responses["steered_outputs"][strength] = {
            "response": steered_response,
            "hidden_states": steered_hidden_states,
            "logprobs": logprobs,
        }

    steered_results.append(item_steered_responses)
    print(f"  Steered responses: {', '.join([it['response'] + ' (' + str(strength) + ')' for strength, it in item_steered_responses['steered_outputs'].items()])}")

# Save the steered results
save_to_path = os.path.join(os.environ["DATA_DIR"], "chameleon/steering", f"steered_results_{cfg['run_time']}.pkl")
with open(save_to_path, "wb") as f:
    pickle.dump(steered_results, f)
print(f"Saved steered results to:\n", save_to_path)

In [None]:
# Print the results
only_strengths = [0, -8, -16, -32, -64]
for i, (item, orig_item) in enumerate(zip(steered_results, instruct_msg_dicts_uniq)):
    print(
        f"{orig_item['category']}  :  {orig_item['secret_word']}\n"
        f"Original (info-revealing level): {', '.join([it + ' (' + str(level) + ')' for level, it in item['original_responses'].items()])}\n"
        f"Steered (strength):  {', '.join([it['response'] + ' (' + str(strength) + ')' for strength, it in item['steered_outputs'].items() if strength in only_strengths])}"
    )
