In [1]:
from transformer_lens import HookedTransformer
import transformer_lens
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import transformer_lens.utils as utils
import hashlib
import yaml 
import hashlib
import pickle
import numpy as np
import matplotlib.pyplot as plt 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = utils.get_device()

reference_model_path = 'meta-llama/Llama-3.1-8B'
baseline_model_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

baseline_model_hf = AutoModelForCausalLM.from_pretrained(baseline_model_path, torch_dtype=torch.bfloat16)
baseline_model_tokenizer = AutoTokenizer.from_pretrained(baseline_model_path)

baseline_model = HookedTransformer.from_pretrained_no_processing(
    reference_model_path,
    hf_model=baseline_model_hf,
    tokenizer=baseline_model_tokenizer,
    device=device,
    move_to_device=True,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.31it/s]


KeyboardInterrupt: 

In [None]:
model=baseline_model

In [None]:

def compute_feature_mode_metric(
    model: HookedTransformer,
    prompt: str,
    pos_features: list[list[int]],
    neg_features: list[list[int]],
):
    """
    Computes:
      - normalized probabilities for each feature sequence (sum to 1)
      - unnormalized log-scores for each
      - logit (log-odds) of producing any negative feature.
    Uses model.to_tokens() for tokenization.
    """
    device = model.cfg.device
    # 1) Tokenize prompt with the built-in hook
    #    (adds BOS if the model is configured to)
    input_ids = model.to_tokens(prompt).to(device)

    # 2) Score one feature sequence by accumulating log-probs
    def sequence_score(feature: list[int]) -> torch.Tensor:
        ctx = input_ids.clone()
        total_log_prob = torch.tensor(0.0, device=device)
        for tok in feature:
            # run the model on the current context
            logits, cache = model.run_with_cache(ctx)
            last_logits = logits[:, -1, :]  # [1, vocab_size]
            log_probs = torch.log_softmax(last_logits, dim=-1)
            total_log_prob = total_log_prob + log_probs[0, tok]
            # append the ground-truth token to the context
            ctx = torch.cat([ctx, torch.tensor([[tok]], device=device)], dim=1)
        return total_log_prob

    # 3) Compute scores for all features
    all_features = pos_features + neg_features
    scores = torch.stack([sequence_score(f) for f in all_features])  # (n+m,)

    # 4) Softmax to get normalized probabilities
    norm_probs = torch.softmax(scores, dim=0)                        # (n+m,)

    # 5) Sum up the negative-feature mass & compute logit
    num_pos   = len(pos_features)
    neg_prob  = norm_probs[num_pos:].sum()
    neg_logit = torch.log(neg_prob) - torch.log(1 - neg_prob)

    return norm_probs, scores, neg_logit




In [None]:
pos_features = [[1271], [1271, 1505], [1271, 8417], [334, 37942, 25]]
neg_features = [[33413]]

for pf in pos_features:
    print(model.to_string(pf))

print("\n\n")

for nf in neg_features:
    print(model.to_string(nf))

In [None]:

prompt       = "<｜User｜>If a pizza is cut into 8 equal slices and 3 slices are eaten, what fraction remains?<｜Assistant｜><think>\n"
pos_features = [[1271], [1271, 1505], [1271, 8417], [334, 37942, 25]]
neg_features = [[33413]]

norm_probs, scores, neg_logit = compute_feature_mode_metric(
    model, prompt, pos_features, neg_features
)

print("Normalized probabilities:", norm_probs)
print("Unnormalized log-scores:", scores)
print("Negative-feature logit:", neg_logit)

In [None]:

prompt       = "<｜User｜>If a pizza is cut into 8 equal slices and 3 slices are eaten, what fraction remains?<｜Assistant｜><think>\nTo determine the fraction of the pizza that remains after eating 3 slices out of 8, I start by noting that the total number of slices is 8.\n\nNext, I calculate the fraction of the pizza that has been eaten by dividing the number of eaten slices by the total number of slices, which gives me 3/8.\n\nFinally, to find the fraction that remains, I subtract the eaten fraction from the whole, resulting in 1 minus 3/8, which equals 5/8.\n</think>\n\n"
pos_features = [[1271], [1271, 1505], [1271, 8417], [334, 37942, 25]]
neg_features = [[33413]]

norm_probs, scores, neg_logit = compute_feature_mode_metric(
    model, prompt, pos_features, neg_features
)

print("Normalized probabilities:", norm_probs)
print("Unnormalized log-scores:", scores)
print("Negative-feature logit:", neg_logit)