In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

# full GPU reset
gc.collect()
torch.cuda.empty_cache()

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')
model_card = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_card)
model = AutoModelForCausalLM.from_pretrained(model_card).to(device)

tokenizer.pad_token = tokenizer.eos_token
model.eval();

In [None]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

print(model.config.hidden_size)
print(g.size())

In [None]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
max_amp = f"Max amplification (1/sqrt(min)): {1 / torch.sqrt(eigenvalues.min()).item():.1f}\n"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}\n"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}\n"

print(eigenval_min_max)
print(max_amp)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")
print(f"g dtype: {g.dtype}")

In [None]:
import pandas as pd

concept_df = pd.read_json("https://raw.githubusercontent.com/donkeyanaphora/STEERING_EXPERIMENTS/refs/heads/main/data/epistemic_privilege_pairs.json")

a_pairs = []
b_pairs = []

for idx, row in concept_df.iterrows():
    a = [
        {"role": "assistant", "content": row.prompt}, 
        {"role": "user", "content": row.high_sentence}
    ]
    a_pairs.append(a)

    b = [
        {"role": "assistant", "content": row.prompt}, 
        {"role": "user", "content": row.low_sentence}
    ]
    b_pairs.append(b)

In [None]:
a_batch = tokenizer.apply_chat_template(
    a_pairs,
    tokenize=True,
    add_generation_prompt=False,
    truncation=True,
    return_tensors="pt",
    return_dict=True,
    padding=True,
).to(device)

b_batch = tokenizer.apply_chat_template(
    b_pairs,
    tokenize=True,
    add_generation_prompt=False,
    truncation=True,
    return_tensors="pt",
    return_dict=True,
    padding=True,
).to(device)

In [None]:
def masked_mean_pool(last_hidden, attention_mask):
    """Pool over all non-padded tokens. Shared content cancels in the diff."""
    mask = attention_mask.unsqueeze(-1)
    return (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

def last_token_pool(last_hidden, attention_mask):
    """Pool last non-padded token."""
    idx = attention_mask.sum(dim=1) - 1
    return last_hidden[torch.arange(last_hidden.size(0), device=last_hidden.device), idx]

with torch.no_grad():
    out1 = model(**a_batch, output_hidden_states=True)
    out2 = model(**b_batch, output_hidden_states=True)

# simple mean pooling - shared content cancels when we take (a - b)
a_emb = masked_mean_pool(out1.hidden_states[-1], a_batch["attention_mask"])
b_emb = masked_mean_pool(out2.hidden_states[-1], b_batch["attention_mask"])

In [None]:
# verify signal quality
diffs = a_emb - b_emb
concept_dir_raw = diffs.mean(dim=0)

# check per-pair alignment with mean direction
cos_sims = []
for i in range(len(diffs)):
    sim = torch.cosine_similarity(diffs[i], concept_dir_raw, dim=0).item()
    cos_sims.append(sim)
print(f"Per-pair cosine sims: {[f'{s:.2f}' for s in cos_sims]}")
print(f"Mean: {sum(cos_sims)/len(cos_sims):.3f}")
print(f"All positive: {all(s > 0 for s in cos_sims)}")

In [None]:
concept_dir = concept_dir_raw @ sqrt_cov_gamma
concept_dir = concept_dir / concept_dir.norm()
concept_dir

In [None]:
class SteeringHead(torch.nn.Module):
    def __init__(self, lm_head_g, sqrt_cov_gamma, concept_dir, alpha=0.0):
        super().__init__()
        self.register_buffer("lm_head_g", lm_head_g)
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)
        self.register_buffer("concept_dir", concept_dir)
        self.alpha = alpha
    
    def forward(self, hidden_states):
        l_causal = hidden_states @ self.sqrt_cov_gamma
        l_causal[:, -1, :] = l_causal[:, -1, :] + self.alpha * self.concept_dir
        return l_causal @ self.lm_head_g.T

model.lm_head = SteeringHead(g, sqrt_cov_gamma, concept_dir, alpha=0.0)

In [None]:
eval_df = pd.read_json('https://raw.githubusercontent.com/donkeyanaphora/STEERING_EXPERIMENTS/refs/heads/main/data/eval.json')

questions = []
for idx, row in eval_df.iterrows():
    q = [
        {"role": "user", "content": row.question}, 
        {"role": "assistant", "content": row.correct_answer}, 
        {"role": "user", "content": row.challenge}
    ]
    questions.append(q)

In [None]:
batch = tokenizer.apply_chat_template(
    questions,
    tokenize=True,
    add_generation_prompt=True,
    padding=True,
    truncation=True,
    return_tensors="pt",
    return_dict=True,
    tokenizer_kwargs={"padding_side": "left"},
).to(device)

print(batch["input_ids"].shape)
print(batch["attention_mask"].shape)

In [None]:
high, low = 1.2, -1.2
alphas = [high, 0, low]

outputs = {}
for alpha in alphas:
    model.lm_head.alpha = alpha
    with torch.no_grad():
        out = model.generate(
            **batch,
            max_new_tokens=600,
            do_sample=False,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        outputs[alpha] = tokenizer.batch_decode(out, skip_special_tokens=True)

In [None]:
df = pd.DataFrame(outputs)
df['Question'] = [eval_df.iloc[i].question for i in range(len(eval_df))]

df_fmt = df.rename(columns={
    high: "High Authority",
    0: "Baseline",
    low: "Low Authority",
})

In [None]:
from IPython.display import HTML

style = """
<style>
.styled-table { width: 100%; border-collapse: collapse; }
.styled-table td, .styled-table th { 
    vertical-align: top; 
    padding: 12px; 
    border: 1px solid #ddd;
    width: 25%;
}
.styled-table td { white-space: pre-wrap; }
</style>
"""

html = df_fmt.set_index('Question').to_html(escape=False, classes='styled-table').replace('\\n', '<br>')
HTML(style + html)

In [None]:
# df_fmt.to_json('results.json', orient='records')