### Load in LM for testing
- interstingly gpt2-medium throws nan's but script seems to work well for small and large
- each lm_head has different dims so small is 768, medium 1024, large 1280

In [1]:
import os, torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)
cache_dir = (Path.cwd() / "models").resolve()
cache_dir.mkdir(parents=True, exist_ok=True)

device = (
    "cuda" if torch.cuda.is_available()
    # else ("mps" if torch.backends.mps.is_available() else "cpu")
    else "cpu"
)

os.environ["HF_HOME"] = str(cache_dir)
print(f'Device: {device}')

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large").to(device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large")
tokenizer.pad_token = tokenizer.eos_token
model.eval()

Device: cpu


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=3840, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=1280)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=5120, nx=1280)
          (c_proj): Conv1D(nf=1280, nx=5120)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

### Change basis for lm_head
- [linear_rep_geometry
/store_matrices.py](https://github.com/KihoPark/linear_rep_geometry/blob/main/store_matrices.py)

In [2]:
gamma = model.lm_head.weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim=0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(cov_gamma)

inv_sqrt_cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T

# gamma is our original head and inv_sqrt_cov_gamma puts us in a causal basis
g = gamma @ inv_sqrt_cov_gamma

# maybe i confused but A_inv = sqrt_cov_gamma and A = inv_sqrt_cov_gamma for 
# l(x).T @ g(y)
# where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A (referencing paper eq and presentation eq on youtube)
print(model.config.hidden_size)
print(g.size())

g = g.float()
inv_sqrt_cov_gamma = inv_sqrt_cov_gamma.float()
sqrt_cov_gamma = sqrt_cov_gamma.float()

1280
torch.Size([50257, 1280])


In [11]:
eigenval_min_max = f"Eigenval min: {eigenvalues.min()}\nEigenval max: {eigenvalues.max()}"
max_amp = f"Max amplification (1/sqrt(min)): {1 / torch.sqrt(eigenvalues.min()).item():.1f}\n"
gamma_min_max = f"gamma min: {gamma.min()}\ngamma max: {gamma.max()}\n"
g_min_max = f"gamma @ inv_sqrt_cov_gamma min: {g.min()}\ngamma @ inv_sqrt_cov_gamma max: {g.max()}\n"

print(eigenval_min_max)
print(max_amp)
print(gamma_min_max)
print(g_min_max)
print(f"gamma dtype: {gamma.dtype}")
print(f"g dtype: {g.dtype}")

Eigenval min: 7.621865370310843e-05
Eigenval max: 0.08969025313854218
Max amplification (1/sqrt(min)): 114.5

gamma min: -0.5182147026062012
gamma max: 0.4273260533809662

gamma @ inv_sqrt_cov_gamma min: -8.831315994262695
gamma @ inv_sqrt_cov_gamma max: 8.074169158935547

gamma dtype: torch.float32
g dtype: torch.float32


### Define data for training probe and inference tests

In [4]:
gender_pairs = {
    "he is the king": "she is the queen",
    "he is a man": "she is a woman",
    "he is the prince": "she is the princess",
    "he is an actor": "she is an actress",
    "he is my brother": "she is my sister",
    "he is my father": "she is my mother",
    "he is my son": "she is my daughter",
    "he is my uncle": "she is my aunt",
    "he is my husband": "she is my wife",
    "he is my grandfather": "she is my grandmother",
    "he is my nephew": "she is my niece",
    "he was a gentleman": "she was a lady",
    "the boy ran home": "the girl ran home",
    "Mr. Smith arrived": "Ms. Smith arrived",
    "the chairman spoke": "the chairwoman spoke",
    "he is a waiter": "she is a waitress",
    "the spokesman said": "the spokeswoman said",
    "he is a hero": "she is a heroine",
    "the landlord agreed": "the landlady agreed",
    "he is a duke": "she is a duchess",
}

inference_prompts = [
    "Long live the",
    "The lion is the",
    "In the hierarchy of medieval society, the highest rank was the",
    "Arthur was a legendary",
    "He was known as the warrior",
    "In a monarchy, the ruler is usually a",
    "He sat on the throne, the",
    "A sovereign ruler in a monarchy is often a",
    "His domain was vast, for he was a",
    "The lion, in many cultures, is considered the",
    "He wore a crown, signifying he was the",
    "A male sovereign who reigns over a kingdom is a",
    "Every kingdom has its ruler, typically a",
    "The prince matured and eventually became the",
    "In the deck of cards, alongside the queen is the"
]

male_ex = list(gender_pairs.keys())
female_ex = list(gender_pairs.values())

In [5]:
m_inputs = tokenizer(male_ex, return_tensors="pt", padding=True, truncation=True).to(device)
f_inputs = tokenizer(female_ex, return_tensors="pt", padding=True, truncation=True).to(device)

with torch.no_grad():
    out1 = model(**m_inputs, output_hidden_states=True)
    out2 = model(**f_inputs, output_hidden_states=True)
    last_hidden_m, last_hidden_f = out1.hidden_states[-1], out2.hidden_states[-1]
    # we want to pool embeddings for each token in sequence so we have full sequence representation
    m_emb, f_emb = last_hidden_m.mean(dim=1), last_hidden_f.mean(dim=1)

### train probe
- could probably just get difference vec would like to ask about this tbh
- gut says probe is better because error/noise can be better quantified 
    - plus can determine alpha that acheives desired prob

In [6]:
from sklearn.linear_model import LogisticRegression

# make sure probe is trained in right representation space 
# since we will apply it to l(x) later where l(x) = lambda(x) @ A_inv
X = torch.cat([m_emb, f_emb], dim=0) @ sqrt_cov_gamma

y = torch.cat([
    torch.ones(len(male_ex)), 
    torch.zeros(len(female_ex))
    ])

clf = LogisticRegression(
    max_iter=1000,
    penalty=None,
    solver='lbfgs'
)
clf.fit(X.detach().cpu().numpy(), y.detach().cpu().numpy())
concept_dir = torch.tensor(clf.coef_[0], dtype=torch.float32, device=device)
concept_dir_norm = concept_dir / concept_dir.norm()

# simple difference
# gender_dir = (m_emb - f_emb).mean(dim=0)
# gender_dir = gender_dir # @ sqrt_cov_gamma
# gender_norm = gender_dir/gender_dir.norm()

### Inference steering
- assuming i just keep adding gender dir at each decoding step
- this could be simpler to put in logit processor so we can use model.generate()

In [7]:
from tabulate import tabulate

prompt = inference_prompts[2]
prompt_enc = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)

# what are sensible alphas here e.g. values that acheive some target prob for that class?
alphas = [1.5, 0, -1.5]
k = 12
cols = []
print(f"Prompt: '{prompt}'\n")

for alpha in alphas:
    with torch.no_grad():
        # transform, steer, and get causal inner-product between embedding and unembedding
        outputs = model(**prompt_enc, output_hidden_states=True)
        last_token_idx = (prompt_enc.attention_mask.sum(dim=1) - 1).item()

        # (l(x) + alpha * gender_norm).T @ g(y)
        # where l(x) = lambda(x) @ A_inv and g(y) = gamma(y) @ A 
        # Note: Code seems to swap A_inv with A e.g. 
        # it applies A_inv to gamma (lm_head) instead of lambda (embeddings) like equation suggests

        lambda_x = outputs.hidden_states[-1][:, last_token_idx, :] # last token emb but calling it lambda for paper consistence
        l_causal = lambda_x @ sqrt_cov_gamma
        l_steered = l_causal + alpha * concept_dir_norm

        causal_logits = l_steered @ g.T
        causal_log_probs = torch.log_softmax(causal_logits, dim=-1)
        topk = torch.topk(causal_log_probs, k)

        tokens = [tokenizer.decode([idx.item()]) for idx in topk.indices[0]]
        log_probs = topk.values[0].tolist()
        cols.append([f"{tok} ({lp:.2f})" for tok, lp in zip(tokens, log_probs)])

rows = list(zip(*cols))
table = tabulate(rows, headers=[f"alpha = {a}" for a in alphas], tablefmt="github")
print(table)

Prompt: 'In the hierarchy of medieval society, the highest rank was the'

| alpha = 1.5        | alpha = 0      | alpha = -1.5     |
|--------------------|----------------|------------------|
| king (-1.33)       | knight (-3.11) | nun (-1.61)      |
| King (-2.75)       | " (-3.14)      | lady (-2.05)     |
| monk (-2.92)       | bishop (-3.22) | she (-2.95)      |
| lord (-3.23)       | d (-3.44)      | d (-2.99)        |
| bishop (-3.35)     | bar (-3.70)    | kn (-3.81)       |
| " (-3.49)          | kn (-3.74)     | princess (-3.94) |
| prince (-3.87)     | monk (-3.86)   | woman (-3.95)    |
| knight (-4.26)     | noble (-3.89)  | convent (-3.96)  |
| son (-4.26)        | king (-3.90)   | v (-3.97)        |
| Archbishop (-4.27) | lord (-4.04)   | herself (-4.17)  |
| priest (-4.37)     | priest (-4.19) | widow (-4.28)    |
| man (-4.44)        | ' (-4.26)      | queen (-4.30)    |


### custom lm_head 
if the code is right could easily create a logits processor

In [8]:
from transformers import GPT2LMHeadModel, AutoModelForCausalLM

class SteerableGPT2(GPT2LMHeadModel):
    def __init__(self, base_model, lm_head_g, sqrt_cov_gamma, concept_dir, alpha: float = 0.0):
        super().__init__(base_model.config)
        # reuse base model's transformer + original head
        self.transformer = base_model.transformer
        self.lm_head= base_model.lm_head

        # g(y) = gamma(y) @ A where A = Cov(gamma)^(-1/2) 
        self.register_buffer("lm_head_g", lm_head_g)

        # A_inv = sqrt_cov_gamma = Cov(gamma)^(+1/2), used to map lambda -> l_causal
        self.register_buffer("sqrt_cov_gamma", sqrt_cov_gamma)

        # steering direction
        self.register_buffer("concept_dir", concept_dir)

        self.alpha = alpha
    
    def forward(self, *args, **kwargs):
        # get all hidden states so we can grab the last layer
        outputs = super().forward(*args, output_hidden_states=True, **kwargs)
        lambda_all = outputs.hidden_states[-1]   # shape: (batch, seq, d_model)

        # change basis -> steer -> compute logits
        # l_causal = lambda(batch) @ A_inv
        l_causal = lambda_all @ self.sqrt_cov_gamma

        # steer only the last token: l_last = l_last + alpha * concept_dir
        l_causal[:, -1, :] = l_causal[:, -1, :] + self.alpha * self.concept_dir

        # logits = (l(x) + alpha * concept_dir).T @ g(y)
        outputs.logits = l_causal @ self.lm_head_g.T
        
        return outputs

base = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large").to(device)
causal_model = SteerableGPT2(
    base_model=base,
    lm_head_g=g,
    sqrt_cov_gamma=sqrt_cov_gamma,
    concept_dir=concept_dir_norm, 
    alpha=-1.0
)

In [9]:
prompt = inference_prompts[9]
prompt_enc = tokenizer(prompt, return_tensors='pt').to(device)
output = causal_model.generate(
    input_ids=prompt_enc.input_ids,
    attention_mask=prompt_enc.attention_mask,
    max_new_tokens=25,
    do_sample=True,
    temperature=0.3,
    top_p=0.85,
    repetition_penalty=1.5,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

print(f"Prompt:       {prompt}")
print(f"Continuation: {''.join(tokenizer.batch_decode(output))}")

Prompt:       The lion, in many cultures, is considered the
Continuation: The lion, in many cultures, is considered the symbol of beauty and fertility. The woman herself may be beautiful or she might have her breasts filled with milk to make herself pregnant
