In [1]:
import torch as t
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_name = "gemma-2-27b"
dataset = "newsroom"

In [3]:
probe_path = f"/workspace/PPairS_results/{dataset}/{model_name}/probe"
s_probe = t.load(f"{probe_path}_s.pt", weights_only=True)
u_probe = t.load(f"{probe_path}_u.pt", weights_only=True)

print(s_probe.shape)
print(u_probe.shape)

torch.Size([4, 4608])
torch.Size([4, 4608])


In [4]:
def load_model_and_tokenizer(model_name: str, lora_path: str = None) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
    # load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=t.bfloat16,
        trust_remote_code=True,
        use_cache=True
    )
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

In [5]:
model, tokenizer = load_model_and_tokenizer(f"/workspace/models/{model_name}-it")

[2025-05-27 19:03:08,243] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

In [6]:
We = model.model.embed_tokens.weight
norm = model.model.norm
Wu = model.lm_head.weight

print(We.shape)
print(Wu.shape)

torch.Size([256000, 4608])
torch.Size([256000, 4608])


In [7]:
print(u_probe.shape)

torch.Size([4, 4608])


In [8]:
# Average u_probe over dimension 1 to get a single direction vector
probe_direction = u_probe.mean(dim=0)  # shape: (d_model,)
print(f"Probe direction shape: {probe_direction.shape}")

# Calculate cosine similarities with embedding matrix We
we_similarities = t.nn.functional.cosine_similarity(probe_direction.unsqueeze(0), We)
we_top_indices = we_similarities.argsort(descending=True)[:10]
print("\nTop 10 most similar tokens in embedding matrix We:")
for i, idx in enumerate(we_top_indices):
    token = tokenizer.decode(idx.item())
    similarity = we_similarities[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Similarity: {similarity:.4f}")

# Calculate cosine similarities with unembedding matrix Wu
wu_similarities = t.nn.functional.cosine_similarity(probe_direction.unsqueeze(0), Wu)
wu_top_indices = wu_similarities.argsort(descending=True)[:10]
print("\nTop 10 most similar tokens in unembedding matrix Wu:")
for i, idx in enumerate(wu_top_indices):
    token = tokenizer.decode(idx.item())
    similarity = wu_similarities[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Similarity: {similarity:.4f}")

# Logit-lens approach: apply norm then Wu
normalized_probe = norm(probe_direction).to(t.bfloat16)
logits = Wu @ normalized_probe
top_logit_indices = logits.argsort(descending=True)[:10]
print("\nTop 10 tokens via logit-lens approach:")
for i, idx in enumerate(top_logit_indices):
    token = tokenizer.decode(idx.item())
    logit_value = logits[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Logit: {logit_value:.4f}")


Probe direction shape: torch.Size([4608])



Top 10 most similar tokens in embedding matrix We:
1. Token: ' RAL' (ID: 111640) - Similarity: 0.0231
2. Token: ' viper' (ID: 119060) - Similarity: 0.0228
3. Token: ' rattles' (ID: 151417) - Similarity: 0.0226
4. Token: 'GetKeyDown' (ID: 183054) - Similarity: 0.0225
5. Token: ' screenwriter' (ID: 191114) - Similarity: 0.0217
6. Token: ' rocker' (ID: 96531) - Similarity: 0.0216
7. Token: 'Dax' (ID: 218725) - Similarity: 0.0216
8. Token: ' 微信' (ID: 93689) - Similarity: 0.0212
9. Token: ' LDAP' (ID: 139984) - Similarity: 0.0212
10. Token: 'thage' (ID: 111472) - Similarity: 0.0203

Top 10 most similar tokens in unembedding matrix Wu:
1. Token: ' RAL' (ID: 111640) - Similarity: 0.0231
2. Token: ' viper' (ID: 119060) - Similarity: 0.0228
3. Token: ' rattles' (ID: 151417) - Similarity: 0.0226
4. Token: 'GetKeyDown' (ID: 183054) - Similarity: 0.0225
5. Token: ' screenwriter' (ID: 191114) - Similarity: 0.0217
6. Token: ' rocker' (ID: 96531) - Similarity: 0.0216
7. Token: 'Dax' (ID: 218725) - S

In [9]:
# Average u_probe over dimension 1 to get a single direction vector
probe_direction = u_probe.mean(dim=0)  # shape: (d_model,)
print(f"Probe direction shape: {probe_direction.shape}")

# Calculate cosine similarities with embedding matrix We
we_similarities = t.nn.functional.cosine_similarity(probe_direction.unsqueeze(0), We)
we_top_indices = we_similarities.argsort(descending=False)[:10]
print("\nTop 10 most similar tokens in embedding matrix We:")
for i, idx in enumerate(we_top_indices):
    token = tokenizer.decode(idx.item())
    similarity = we_similarities[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Similarity: {similarity:.4f}")

# Calculate cosine similarities with unembedding matrix Wu
wu_similarities = t.nn.functional.cosine_similarity(probe_direction.unsqueeze(0), Wu)
wu_top_indices = wu_similarities.argsort(descending=False)[:10]
print("\nTop 10 most similar tokens in unembedding matrix Wu:")
for i, idx in enumerate(wu_top_indices):
    token = tokenizer.decode(idx.item())
    similarity = wu_similarities[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Similarity: {similarity:.4f}")

# Logit-lens approach: apply norm then Wu
normalized_probe = norm(probe_direction).to(t.bfloat16)
logits = Wu @ normalized_probe
top_logit_indices = logits.argsort(descending=False)[:10]
print("\nTop 10 tokens via logit-lens approach:")
for i, idx in enumerate(top_logit_indices):
    token = tokenizer.decode(idx.item())
    logit_value = logits[idx].item()
    print(f"{i+1}. Token: '{token}' (ID: {idx.item()}) - Logit: {logit_value:.4f}")


Probe direction shape: torch.Size([4608])



Top 10 most similar tokens in embedding matrix We:
1. Token: 'Summary' (ID: 9292) - Similarity: -0.0682
2. Token: 'Sorry' (ID: 12156) - Similarity: -0.0651
3. Token: ' Summary' (ID: 13705) - Similarity: -0.0615
4. Token: ' Sorry' (ID: 26199) - Similarity: -0.0612
5. Token: ' sorry' (ID: 9897) - Similarity: -0.0590
6. Token: ' SUMMARY' (ID: 40702) - Similarity: -0.0552
7. Token: 'sorry' (ID: 43718) - Similarity: -0.0533
8. Token: 'Sum' (ID: 5751) - Similarity: -0.0496
9. Token: ' summary' (ID: 13367) - Similarity: -0.0492
10. Token: ' but' (ID: 901) - Similarity: -0.0492

Top 10 most similar tokens in unembedding matrix Wu:
1. Token: 'Summary' (ID: 9292) - Similarity: -0.0682
2. Token: 'Sorry' (ID: 12156) - Similarity: -0.0651
3. Token: ' Summary' (ID: 13705) - Similarity: -0.0615
4. Token: ' Sorry' (ID: 26199) - Similarity: -0.0612
5. Token: ' sorry' (ID: 9897) - Similarity: -0.0590
6. Token: ' SUMMARY' (ID: 40702) - Similarity: -0.0552
7. Token: 'sorry' (ID: 43718) - Similarity: -0.0