In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

# full GPU reset
gc.collect()
torch.cuda.empty_cache()

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')
model_card = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_card)
model = AutoModelForCausalLM.from_pretrained(model_card).to(device)

tokenizer.pad_token = tokenizer.eos_token
model.eval();

In [None]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())

In [None]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
max_amp = f"Max amplification (1/sqrt(min)): {1 / torch.sqrt(eigenvalues.min()).item():.1f}\n"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}\n"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}\n"

print(eigenval_min_max)
print(max_amp)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")
print(f"g dtype: {g.dtype}")

In [None]:
import pandas as pd

concept_df = pd.read_json("https://raw.githubusercontent.com/donkeyanaphora/STEERING_EXPERIMENTS/refs/heads/main/data/epistemic_privilege_pairs.json")

a_pairs = []
b_pairs= []

for idx, row in concept_df.iterrows():
  a = [
        # {"role": "assistant", "content": row.prompt}, 
        {"role": "user", "content": row.high_sentence}
      ]
  a_pairs.append(a)

  b = [
        # {"role": "assistant", "content": row.prompt}, 
        {"role": "user", "content": row.low_sentence}
      ]
  
  b_pairs.append(b)

In [None]:
a_batch = tokenizer.apply_chat_template(
    a_pairs,
    tokenize=True,
    add_generation_prompt=False,
    truncation=True,
    return_tensors="pt",
    return_dict=True,
    padding=True,
    tokenizer_kwargs={"return_attention_mask": True},  # safe across versions
).to(device)

b_batch = tokenizer.apply_chat_template(
    b_pairs,
    tokenize=True,
    add_generation_prompt=False,
    truncation=True,
    return_tensors="pt",
    return_dict=True,
    padding=True,
    tokenizer_kwargs={"return_attention_mask": True},  # safe across versions
).to(device)

In [None]:
import torch

# your list of special-token ids (put it on the same device as input_ids)
specials = torch.tensor([128000, 128006, 882, 128007, 128009, 78191], device=device)

def content_mask(batch, specials):
    ids = batch["input_ids"]                      # (B,S)
    attn = batch["attention_mask"].bool()         # (B,S) pads only
    not_special = ~torch.isin(ids, specials)      # (B,S)
    return attn & not_special                     # (B,S) keep only real content

def masked_mean_pool(last_hidden, mask):
    m = mask.unsqueeze(-1).to(last_hidden.dtype)  # (B,S,1)
    return (last_hidden * m).sum(1) / m.sum(1).clamp(min=1)

def last_content_token_pool(last_hidden, mask):
    # find last position where mask==True (works w/ left or right padding)
    B, S, _ = last_hidden.shape
    pos = torch.arange(S, device=last_hidden.device).unsqueeze(0).expand(B, S)
    idx = torch.where(mask, pos, torch.full_like(pos, -1)).max(dim=1).values.clamp(min=0)
    return last_hidden[torch.arange(B, device=last_hidden.device), idx]


with torch.no_grad():
    out1 = model(**a_batch, output_hidden_states=True)
    out2 = model(**b_batch, output_hidden_states=True)

mask_a = content_mask(a_batch, specials)
mask_b = content_mask(b_batch, specials)

a_emb_mean = masked_mean_pool(out1.hidden_states[-1], mask_a)
b_emb_mean = masked_mean_pool(out2.hidden_states[-1], mask_b)

a_emb_last = last_content_token_pool(out1.hidden_states[-1], mask_a)
b_emb_last = last_content_token_pool(out2.hidden_states[-1], mask_b)


In [None]:
def masked_mean_pool(last_hidden, attention_mask):
    mask = attention_mask.unsqueeze(-1)
    return (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

def last_token_pool(last_hidden, attention_mask):
    idx = attention_mask.sum(dim=1) - 1  # (B,)
    return last_hidden[torch.arange(last_hidden.size(0), device=last_hidden.device), idx]

with torch.no_grad():
    out1 = model(**a_batch, output_hidden_states=True)
    out2 = model(**b_batch, output_hidden_states=True)

a_emb = masked_mean_pool(out1.hidden_states[-1], a_batch["attention_mask"])
b_emb = masked_mean_pool(out2.hidden_states[-1], b_batch["attention_mask"])

In [None]:
concept_dir = (a_emb - b_emb).mean(dim=0)
concept_dir = concept_dir @ sqrt_cov_gamma
concept_dir /= concept_dir.norm()
concept_dir

In [None]:
class SteeringHead(torch.nn.Module):
    def __init__(self, lm_head_g, sqrt_cov_gamma, concept_dir, alpha=0.0):
        super().__init__()
        self.register_buffer("lm_head_g", lm_head_g)
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)
        self.register_buffer("concept_dir", concept_dir)
        self.alpha = alpha
    
    def forward(self, hidden_states):
        l_causal = hidden_states @ self.sqrt_cov_gamma
        l_causal[:, -1, :] = l_causal[:, -1, :] + self.alpha * self.concept_dir
        return l_causal @ self.lm_head_g.T

# Just swap the head
model.lm_head = SteeringHead(g, sqrt_cov_gamma, concept_dir, alpha=0.0)

# Use model directly - no wrapper needed
model.lm_head.alpha = 0

In [None]:
eval_df = pd.read_json('https://raw.githubusercontent.com/donkeyanaphora/STEERING_EXPERIMENTS/refs/heads/main/data/eval.json')

questions = []
for idx, row in eval_df.iterrows():
  q = [
        {"role": "user", "content": row.question}, 
        {"role": "assistant", "content": row.correct_answer}, 
        {"role":"user", "content": row.challenge}
      ]
  q_tokenized = tokenizer.apply_chat_template(q, tokenize=False, add_generation_prompt=True)
  questions.append(q_tokenized)


In [None]:
batch = tokenizer(
    questions,
    return_tensors="pt",
    padding=True,
    truncation=True,
    padding_side='left'
)
batch = {k: v.to(device) for k, v in batch.items()}

print(batch["input_ids"].size())
print(batch["attention_mask"].size())

In [None]:
high, low = 1.2, -1.2
alphas = [high, 0, low]

outputs = {}
for alpha in alphas:
    model.lm_head.alpha = alpha
    with torch.no_grad():
        out = model.generate(
            **batch,
            max_new_tokens=600,
            do_sample=False,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        outputs[alpha] = tokenizer.batch_decode(out, skip_special_tokens=True)

In [None]:
import pandas as pd

df = pd.DataFrame(outputs, index=questions).reset_index().rename(columns={"index": "Question"})

df_fmt = df.rename(columns={
    high: "High Certainty",
    0: "Baseline",
    low: "Low Certainty",
})

In [None]:
from IPython.display import HTML

# Display with nice formatting
style = """
<style>
.styled-table { width: 100%; border-collapse: collapse; }
.styled-table td, .styled-table th { 
    vertical-align: top; 
    padding: 12px; 
    border: 1px solid #ddd;
    width: 25%;
}
.styled-table td { white-space: pre-wrap; }
</style>
"""

html = df_fmt.set_index('Question').to_html(escape=False, classes='styled-table').replace('\\n', '<br>')
HTML(style + html)

In [None]:
# df_fmt.to_json('resuts.json', orient='records')