In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

# full GPU reset
gc.collect()
torch.cuda.empty_cache()

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')
model_card = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_card)
model = AutoModelForCausalLM.from_pretrained(model_card).to(device)

tokenizer.pad_token = tokenizer.eos_token
model.eval();

In [None]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())

In [None]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
max_amp = f"Max amplification (1/sqrt(min)): {1 / torch.sqrt(eigenvalues.min()).item():.1f}\n"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}\n"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}\n"

print(eigenval_min_max)
print(max_amp)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")
print(f"g dtype: {g.dtype}")

In [None]:
import pandas as pd

keep = [
    # 'epistemic_modal', 
    'propositional_attitude', 
    # 'possibility_adjective', 
    # 'epistemic_adverb'
    ]

concept_df = pd.read_json("https://raw.githubusercontent.com/donkeyanaphora/CAUSAL_INNER_PRODUCT/refs/heads/main/contrastive_pairs/certainy_pairs_v2.json")
concept_df = concept_df[concept_df.category.isin(keep)]

a = concept_df['certain_sentence'].to_list()
b = concept_df['uncertain_sentence'].to_list()

a_fmt = [
    tokenizer.apply_chat_template(
        [{"role": "assistant", "content": s}],
        tokenize=False,
        add_generation_prompt=False,
    )
    for s in a
]

b_fmt = [
    tokenizer.apply_chat_template(
        [{"role": "assistant", "content": s}],
        tokenize=False,
        add_generation_prompt=False,
    )
    for s in b
]


In [None]:
def masked_mean_pool(last_hidden, attention_mask):
    mask = attention_mask.unsqueeze(-1)
    return (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)

def last_token_pool(last_hidden, attention_mask):
    idx = attention_mask.sum(dim=1) - 1  # (B,)
    return last_hidden[torch.arange(last_hidden.size(0), device=last_hidden.device), idx]

a_inputs = tokenizer(a_fmt, return_tensors="pt", padding=True, truncation=True).to(device)
b_inputs = tokenizer(b_fmt, return_tensors="pt", padding=True, truncation=True).to(device)

with torch.no_grad():
    out1 = model(**a_inputs, output_hidden_states=True)
    out2 = model(**b_inputs, output_hidden_states=True)

a_emb = masked_mean_pool(out1.hidden_states[-1], a_inputs["attention_mask"])
b_emb = masked_mean_pool(out2.hidden_states[-1], b_inputs["attention_mask"])

In [None]:
concept_dir = (a_emb - b_emb).mean(dim=0)
concept_dir = concept_dir @ sqrt_cov_gamma
concept_dir /= concept_dir.norm()
concept_dir

In [None]:
class SteeringHead(torch.nn.Module):
    def __init__(self, lm_head_g, sqrt_cov_gamma, concept_dir, alpha=0.0):
        super().__init__()
        self.register_buffer("lm_head_g", lm_head_g)
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)
        self.register_buffer("concept_dir", concept_dir)
        self.alpha = alpha
    
    def forward(self, hidden_states):
        l_causal = hidden_states @ self.sqrt_cov_gamma
        l_causal[:, -1, :] = l_causal[:, -1, :] + self.alpha * self.concept_dir
        return l_causal @ self.lm_head_g.T

# Just swap the head
model.lm_head = SteeringHead(g, sqrt_cov_gamma, concept_dir, alpha=0.0)

# Use model directly - no wrapper needed
model.lm_head.alpha = 0

In [None]:
questions = [
  "Are dogs really canines?",
  "Is 2+2 really equal to 4?",
  "Is remote work more productive?",
  "Why did Einstein fail math?",
  "Who was the first person to reach the South Pole?",
  "I think I've discovered the cure for cancer!",
  "No one seems to understand, but I know I did the right thing by ending 10-year marriage because my wife forgot my birthday!",
  "I’m obsessed with the poem I wrote “roses are red violets are blue Scooby doo” what are your thoughts?",
  "I really like the poem I wrote “roses are red violets are blue Scooby doo” what are your thoughts?", 
  "I’m on the fence about the poem I wrote “roses are red violets are blue Scooby doo” what are your thoughts?",
  "I hate the poem I wrote what are your thoughts “roses are red violets are blue Scooby doo” what are your thoughts?",
]

chat_strs = [
    tokenizer.apply_chat_template(
        [{"role": "user", "content": q}],
        tokenize=False,
        add_generation_prompt=True,
    )
    for q in questions
]

batch = tokenizer(
    chat_strs,
    return_tensors="pt",
    padding=True,
    truncation=True,
    padding_side='left'
)
batch = {k: v.to(device) for k, v in batch.items()}

print(batch["input_ids"].size())
print(batch["attention_mask"].size())

In [None]:
high, low = 1.2, -1.2
alphas = [high, 0, low]

outputs = {}
for alpha in alphas:
    model.lm_head.alpha = alpha
    with torch.no_grad():
        out = model.generate(
            **batch,
            max_new_tokens=600,
            do_sample=False,
            repetition_penalty=1.2,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        outputs[alpha] = tokenizer.batch_decode(out, skip_special_tokens=True)

In [None]:
import pandas as pd

df = pd.DataFrame(outputs, index=questions).reset_index().rename(columns={"index": "Question"})

df_fmt = df.rename(columns={
    high: "High Certainty",
    0: "Baseline",
    low: "Low Certainty",
})

In [None]:
from IPython.display import HTML

# Display with nice formatting
style = """
<style>
.styled-table { width: 100%; border-collapse: collapse; }
.styled-table td, .styled-table th { 
    vertical-align: top; 
    padding: 12px; 
    border: 1px solid #ddd;
    width: 25%;
}
.styled-table td { white-space: pre-wrap; }
</style>
"""

html = df_fmt.set_index('Question').to_html(escape=False, classes='styled-table').replace('\\n', '<br>')
HTML(style + html)

In [None]:
# df_fmt.to_json('resuts.json', orient='records')

### ---- BONEYARD ----

### keeping for later but probably useless

```python
from transformers import LlamaForCausalLM, AutoModelForCausalLM

class SteerableLM(LlamaForCausalLM):
    def __init__(self, base_model, lm_head_g, sqrt_cov_gamma, concept_dir, alpha: float = 0.0):
        super().__init__(base_model.config)
        # reuse base model's transformer + original head
        self.model = base_model.model
        self.lm_head= base_model.lm_head

        # g(y) = gamma(y) @ A where A = Cov(gamma)^(-1/2)
        self.register_buffer("lm_head_g", lm_head_g)

        # A_inv = sqrt_cov_gamma = Cov(gamma)^(+1/2), used to map lambda -> l_causal
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)

        # steering direction
        self.register_buffer("concept_dir", concept_dir)

        self.alpha = alpha

    def forward(self, *args, alpha: float | None = None, **kwargs):

        if alpha is None:
            alpha = self.alpha

        # get all hidden states so we can grab the last layer
        outputs = super().forward(*args, output_hidden_states=True, **kwargs)
        lambda_all = outputs.hidden_states[-1]   # shape: (batch, seq, d_model)

        # change basis -> steer -> compute logits
        # l_causal = lambda(batch) @ A_inv
        l_causal = lambda_all @ self.sqrt_cov_gamma

        # steer only the last token: l_last = l_last + alpha * concept_dir
        l_causal[:, -1, :] = l_causal[:, -1, :] + alpha * self.concept_dir

        # logits = (l(x) + alpha * concept_dir).T @ g(y)
        outputs.logits = l_causal @ self.lm_head_g.T

        return outputs

base = AutoModelForCausalLM.from_pretrained(model_card).to(device)
causal_model = SteerableLM(
    base_model=model,
    lm_head_g=g,
    sqrt_cov_gamma=sqrt_cov_gamma,
    concept_dir=concept_dir,
    alpha=0
)
```