### Load in LM for testing
- interstingly gpt2-medium throws nan's but script seems to work well for small and large
- each lm_head has different dims so small is 768, medium 1024, large 1280

In [56]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large").to(device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large")
tokenizer.pad_token = tokenizer.eos_token
model.eval();

Device: cpu


### Change basis for lm_head
- [linear_rep_geometry
/store_matrices.py](https://github.com/KihoPark/linear_rep_geometry/blob/main/store_matrices.py)

In [57]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for 
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())

# g = g.float()
# inv_sqrt_cov_gamma = inv_sqrt_cov_gamma.float()
# sqrt_cov_gamma = sqrt_cov_gamma.float()

1280
torch.Size([50257, 1280])


In [58]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
max_amp = f"Max amplification (1/sqrt(min)): {1 / torch.sqrt(eigenvalues.min()).item():.1f}\n"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}\n"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}\n"

print(eigenval_min_max)
print(max_amp)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")
print(f"g dtype: {g.dtype}")

Eigenval min: 7.621865370310843e-05
Eigenval max: 0.08969025313854218
Max amplification (1/sqrt(min)): 114.5

gamma min: -0.5182147026062012
gamma max: 0.4273260533809662

gamma @ inv_sqrt_cov_gamma min: -8.831315994262695
gamma @ inv_sqrt_cov_gamma max: 8.074169158935547

gamma dtype: torch.float32
g dtype: torch.float32


### Define data for training probe and inference tests

In [100]:
import pandas as pd 

concept_df = pd.read_json('gendered_pairs/data.json')
concept_df = concept_df[concept_df.get('skip', False) != True] # skipping pronoun contrastive pairs
male_ex = concept_df['male_word'].to_list()
female_ex = concept_df['female_word'].to_list()

inference_prompts = [
    "Long live the",
    "The lion is the",
    "In the hierarchy of medieval society, the highest rank was the",
    "Arthur was a legendary",
    "He was known as the warrior",
    "In a monarchy, the ruler is usually a",
    "He sat on the throne, the",
    "A sovereign ruler in a monarchy is often a",
    "His domain was vast, for he was a",
    "The lion, in many cultures, is considered the",
    "He wore a crown, signifying he was the",
    "A male sovereign who reigns over a kingdom is a",
    "Every kingdom has its ruler, typically a",
    "The prince matured and eventually became the",
    "In the deck of cards, alongside the queen is the"
]

### get concept direction
- just getting the avg of difference vector between gamma(Y(0)) - gamma(Y(1))
- gut says logistic reg probe is better because error/noise can be better quantified 
    - plus can determine alpha that acheives desired prob


In [114]:
m_inputs = tokenizer(male_ex, return_tensors="pt", padding=True, truncation=True).to(device)
f_inputs = tokenizer(female_ex, return_tensors="pt", padding=True, truncation=True).to(device)

m_exclude = m_inputs.attention_mask.sum(dim=1) > 1
f_exclude = f_inputs.attention_mask.sum(dim=1) > 1
valid_mask = torch.stack([m_exclude,f_exclude]).T.sum(dim=1) == 0

m_ready = m_inputs.input_ids[valid_mask][:,0]
f_ready = f_inputs.input_ids[valid_mask][:,0]

print(m_ready.size())
gamma_bar_W = (gamma[m_ready] - gamma[f_ready]).mean(dim=0)

# can do it this way or (g[m_ready] - g[f_ready]).mean(dim=0)
concept_dir = (gamma[m_ready] - gamma[f_ready]).mean(dim=0) # or (g[m_ready] - g[f_ready]).mean(dim=0)
concept_dir = concept_dir @ inv_sqrt_cov_gamma
concept_dir /= concept_dir.norm()

torch.Size([9])


### Inference steering
- assuming i just keep adding gender dir at each decoding step
- this could be simpler to put in logit processor so we can use model.generate()

In [115]:
from tabulate import tabulate

prompt = inference_prompts[2]
prompt_enc = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)

# what are sensible alphas here e.g. values that acheive some target prob for that class?
alphas = [1, 0, -1]
k = 12
cols = []
print(f"Prompt: '{prompt}'\n")

for alpha in alphas:
    with torch.no_grad():
        # transform, steer, and get causal inner-product between embedding and unembedding
        outputs = model(**prompt_enc, output_hidden_states=True)
        last_token_idx = (prompt_enc.attention_mask.sum(dim=1) - 1).item()

        # (l(x) + alpha * gender_norm).T @ g(y)
        # where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A 
        # Note: Code seems to swap A_inv with A e.g. 
        # it applies A_inv to gamma (lm_head) instead of lambda (embeddings) like equation suggests

        lambda_x = outputs.hidden_states[-1][:, last_token_idx, :] # last token emb but calling it lambda for paper consistence
        l_causal = lambda_x @ sqrt_cov_gamma
        l_steered = l_causal + alpha * concept_dir

        causal_logits = l_steered @ g.T
        causal_log_probs = torch.log_softmax(causal_logits, dim=-1)
        topk = torch.topk(causal_log_probs, k)

        tokens = [tokenizer.decode([idx.item()]) for idx in topk.indices[0]]
        log_probs = topk.values[0].tolist()
        cols.append([f"{tok} ({lp:.2f})" for tok, lp in zip(tokens, log_probs)])

rows = list(zip(*cols))
table = tabulate(rows, headers=[f"alpha = {a}" for a in alphas], tablefmt="github")
print(table)

Prompt: 'In the hierarchy of medieval society, the highest rank was the'

| alpha = 1       | alpha = 0      | alpha = -1       |
|-----------------|----------------|------------------|
| man (-1.35)     | knight (-3.11) | woman (-1.83)    |
| knight (-2.03)  | " (-3.14)      | lady (-2.33)     |
| lord (-2.78)    | bishop (-3.22) | woman (-3.36)    |
| Prince (-3.13)  | d (-3.44)      | female (-3.39)   |
| prince (-3.53)  | bar (-3.70)    | wife (-3.44)     |
| " (-3.70)       | kn (-3.74)     | nun (-3.79)      |
| king (-3.71)    | monk (-3.86)   | princess (-3.82) |
| bishop (-4.15)  | noble (-3.89)  | kn (-3.85)       |
| husband (-4.51) | king (-3.90)   | d (-3.93)        |
| monk (-4.55)    | lord (-4.04)   | she (-3.93)      |
| noble (-4.59)   | priest (-4.19) | ear (-3.96)      |
| son (-4.60)     | ' (-4.26)      | queen (-4.04)    |


### custom lm_head 
if the code is right could easily create a logits processor

In [116]:
from transformers import GPT2LMHeadModel, AutoModelForCausalLM

class SteerableGPT2(GPT2LMHeadModel):
    def __init__(self, base_model, lm_head_g, sqrt_cov_gamma, concept_dir, alpha: float = 0.0):
        super().__init__(base_model.config)
        # reuse base model's transformer + original head
        self.transformer = base_model.transformer
        self.lm_head= base_model.lm_head

        # g(y) = gamma(y) @ A where A = Cov(gamma)^(-1/2) 
        self.register_buffer("lm_head_g", lm_head_g)

        # A_inv = sqrt_cov_gamma = Cov(gamma)^(+1/2), used to map lambda -> l_causal
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)

        # steering direction
        self.register_buffer("concept_dir", concept_dir)

        self.alpha = alpha
    
    def forward(self, *args, **kwargs):
        # get all hidden states so we can grab the last layer
        outputs = super().forward(*args, output_hidden_states=True, **kwargs)
        lambda_all = outputs.hidden_states[-1]   # shape: (batch, seq, d_model)

        # change basis -> steer -> compute logits
        # l_causal = lambda(batch) @ A_inv
        l_causal = lambda_all @ self.sqrt_cov_gamma

        # steer only the last token: l_last = l_last + alpha * concept_dir
        l_causal[:, -1, :] = l_causal[:, -1, :] + self.alpha * self.concept_dir

        # logits = (l(x) + alpha * concept_dir).T @ g(y)
        outputs.logits = l_causal @ self.lm_head_g.T
        
        return outputs

base = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large").to(device)
causal_model = SteerableGPT2(
    base_model=base,
    lm_head_g=g,
    sqrt_cov_gamma=sqrt_cov_gamma,
    concept_dir=concept_dir, 
    alpha=-1.0
)

In [117]:
prompt = inference_prompts[9]
prompt_enc = tokenizer(prompt, return_tensors='pt').to(device)
output = causal_model.generate(
    input_ids=prompt_enc.input_ids,
    attention_mask=prompt_enc.attention_mask,
    max_new_tokens=25,
    do_sample=True,
    temperature=0.3,
    top_p=0.85,
    repetition_penalty=1.5,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

print(f"Prompt:       {prompt}")
print(f"Continuation: {''.join(tokenizer.batch_decode(output))}")

Prompt:       The lion, in many cultures, is considered the
Continuation: The lion, in many cultures, is considered the mother of women. Women are mothers to female lions and females have a womanly role within their families."

Women's
