In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F


model_name = "Qwen/Qwen1.5-1.8B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

model.eval()
print("Qwen 1.5-1.8B ready")


Loading tokenizer...
Loading model...
Qwen 1.5-1.8B ready


In [3]:
prompt = "What is 2 + 3? Think step by step and enclose your final answer in \\boxed{}."
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()

with torch.no_grad():
    output = model.generate(
        input_ids,
        temperature=0.0,
        return_dict_in_generate=True,
        output_scores=True,
        max_new_tokens=100
    )

generated_ids = output.sequences[0]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
scores = output.scores  # logits for each generated token

print("Generated response:\n", generated_text)
print("Number of generated tokens:", len(scores))


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Generated response:
 What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. The answer is 5.
Number of generated tokens: 7


In [7]:
top_k = 5
alt_token_data = []

for i, logits in enumerate(output.scores):
    log_probs = F.log_softmax(logits[0], dim=-1)
    topk_logprobs, topk_ids = torch.topk(log_probs, top_k)

    greedy_token_id = generated_ids[input_ids.shape[1] + i]
    prefix_ids = generated_ids[:input_ids.shape[1] + i]
    prefix_text = tokenizer.decode(prefix_ids, skip_special_tokens=True)

    for token_id, log_prob in zip(topk_ids, topk_logprobs):
        if token_id.item() == greedy_token_id.item():
            continue
        alt_token = tokenizer.decode(token_id.unsqueeze(0))
        alt_token_data.append((prefix_text, alt_token, log_prob.item()))

# Show samples
for entry in alt_token_data[:5]:
    print("Prefix:", entry[0])
    print("Alt token:", entry[1])
    print("Log-prob:", entry[2])
    print("----")


Prefix: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}.
Alt token:  
Log-prob: -1.467862844467163
----
Prefix: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}.
Alt token:  Sure
Log-prob: -3.327237844467163
----
Prefix: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}.
Alt token:  \
Log-prob: -3.499112844467163
----
Prefix: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}.
Alt token:  {
Log-prob: -3.710050344467163
----
Prefix: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. The
Alt token:  sum
Log-prob: -2.3823564052581787
----


In [10]:
def generate_completion(prompt, max_new_tokens=100, temperature=1.0, seed=None):
    if seed is not None:
        torch.manual_seed(seed)

    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
    output_ids = model.generate(
        **input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


In [11]:
def sample_fork_completions(prefix, alt_token, num_samples=10, max_new_tokens=100):
    fork_prompt = prefix + alt_token
    completions = []

    for i in range(num_samples):
        gen = generate_completion(
            fork_prompt,
            max_new_tokens=max_new_tokens,
            temperature=1.0,
            seed=None  # randomness enabled
        )
        completions.append(gen)

    return {
        "fork_prompt": fork_prompt,
        "prefix": prefix,
        "alt_token": alt_token,
        "completions": completions
    }


In [12]:
fork_results = []

for i, (prefix, alt_token, logp) in enumerate(alt_token_data[:5]):
    print(f"\n=== Fork {i} ===")
    fork_data = sample_fork_completions(prefix, alt_token)
    fork_data["log_prob"] = logp
    fork_results.append(fork_data)

    for j, comp in enumerate(fork_data["completions"]):
        print(f"[{j}] {comp[:150]}...\n")



=== Fork 0 ===
[0] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5...

[1] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5....

[2] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5....

[3] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 3. 510 is greater than 304. 215 is greater than 200. 865 is greater than ...

[4] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5.
Great job! I also notice that you can solve some basic arithme...

[5] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5....

[6] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5.
The answer is: 5...

[7] What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5. The answer is: 5...

[8] What is 2 + 3? Think step by step and enclose your fi

In [13]:
def is_correct_answer(text, answer="5"):
    # You can tweak this to check for boxed versions, variations, etc.
    return answer in text or f"\\boxed{{{answer}}}" in text

for i, fork in enumerate(fork_results):
    completions = fork["completions"]
    correct = [c for c in completions if is_correct_answer(c)]

    print(f"\n=== Fork {i} Summary ===")
    print(f"Prompt: {fork['fork_prompt'][:80]}...")
    print(f"Alt token: {fork['alt_token']}")
    print(f"Log-prob: {fork['log_prob']:.4f}")
    print(f"Correct completions: {len(correct)} / {len(completions)}")
    print("Sample correct:" if correct else "No correct completions.")
    for c in correct[:2]:
        print("  -", c.strip()[:100])



=== Fork 0 Summary ===
Prompt: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. ...
Alt token:  
Log-prob: -1.4679
Correct completions: 10 / 10
Sample correct:
  - What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5
  - What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. 2 + 3 = 5.

=== Fork 1 Summary ===
Prompt: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. Sur...
Alt token:  Sure
Log-prob: -3.3272
Correct completions: 5 / 10
Sample correct:
  - What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. Sure, here's the step-b
  - What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. Sure, I can help you wi

=== Fork 2 Summary ===
Prompt: What is 2 + 3? Think step by step and enclose your final answer in \boxed{}. \...
Alt token:  \
Log-prob: -3.4991
Correct completions: 10 / 10
Sample correct:
  - What is 2 + 3? Think step by

In [15]:
import pandas as pd
csv_forks = []
for fork in fork_results:
    completions = fork["completions"]
    correct = sum(is_correct_answer(c) for c in completions)

    csv_forks.append({
        "prefix": fork["prefix"],
        "alt_token": fork["alt_token"],
        "log_prob": fork["log_prob"],
        "correct_completions": correct,
        "total_completions": len(completions),
        "accuracy": correct / len(completions)
    })

df_forks = pd.DataFrame(csv_forks)
df_forks.to_csv("fork_analysis.csv", index=False)
print("Saved fork_analysis.csv")


Saved fork_analysis.csv


In [3]:
from datasets import load_dataset

# Load the subject split
dataset = load_dataset(
    "tasksource/mmlu",
    name="high_school_government_and_politics",
    split="test"
)

# Print one example
print(dataset[0])


{'question': 'Which of the following best describes the balance the Supreme Court has struck between the establishment clause and the free-exercise clause?', 'choices': ['Freedom of speech is protected except in certain situations, such as yelling "fire" in a crowded theater.', 'Once a church has been recognized by the federal government, its tax-exempt status can never be revoked.', 'Once Congress has created an administrative agency, that agency can be dissolved only by a constitutional amendment.', 'State-sponsored prayer during school hours is prohibited, but voluntary prayer by student groups before school is allowed.'], 'answer': 3}


In [6]:
def format_prompt(example):
    question = example["question"]
    choices = example["choices"]
    labels = ["A", "B", "C", "D"]
    lettered_choices = [f"{label}. {text}" for label, text in zip(labels, choices)]
    joined = "\n".join(lettered_choices)
    prompt = f"{question}\n{joined}\nAnswer:"
    return prompt



In [9]:
prompt = format_prompt(dataset[0])
output1 = generate_completion(prompt, seed=42)
output2 = generate_completion(prompt, seed=123)

print("Completion 1:\n", output1)
print("\nCompletion 2:\n", output2)


Completion 1:
 Which of the following best describes the balance the Supreme Court has struck between the establishment clause and the free-exercise clause?
A. Freedom of speech is protected except in certain situations, such as yelling "fire" in a crowded theater.
B. Once a church has been recognized by the federal government, its tax-exempt status can never be revoked.
C. Once Congress has created an administrative agency, that agency can be dissolved only by a constitutional amendment.
D. State-sponsored prayer during school hours is prohibited, but voluntary prayer by student groups before school is allowed.
Answer: A

Completion 2:
 Which of the following best describes the balance the Supreme Court has struck between the establishment clause and the free-exercise clause?
A. Freedom of speech is protected except in certain situations, such as yelling "fire" in a crowded theater.
B. Once a church has been recognized by the federal government, its tax-exempt status can never be revo

In [11]:

def find_forking_index(output1, output2):
    tokens1 = tokenizer(output1, return_tensors="pt")["input_ids"][0]
    tokens2 = tokenizer(output2, return_tensors="pt")["input_ids"][0]

    min_len = min(len(tokens1), len(tokens2))
    for i in range(min_len):
        if tokens1[i] != tokens2[i]:
            return i
    return -1  # No fork found
results = []

for i in range(10):  # Start small with 10 examples
    ex = dataset[i]
    prompt = format_prompt(ex)

    out1 = generate_completion(prompt, seed=42)
    out2 = generate_completion(prompt, seed=123)

    fork_idx = find_forking_index(out1, out2)

    results.append({
        "question_id": i,
        "prompt": prompt,
        "completion_1": out1,
        "completion_2": out2,
        "forking_index": fork_idx,
        "correct_answer": ex["choices"][ex["answer"]]
    })

    print(f"[{i}] Fork at token:", fork_idx)


[0] Fork at token: 115
[1] Fork at token: 115
[2] Fork at token: 56
[3] Fork at token: -1
[4] Fork at token: 66
[5] Fork at token: 67
[6] Fork at token: 50
[7] Fork at token: 72
[8] Fork at token: 70
[9] Fork at token: -1


In [12]:
import json
import pandas as pd

# Save to JSON 
with open("forking_results.json", "w") as f:
    json.dump(results, f, indent=2)

# Save to CSV
csv_friendly = [
    {
        "question_id": r["question_id"],
        "prompt": r["prompt"],
        "completion_1": r["completion_1"].replace("\n", " "),
        "completion_2": r["completion_2"].replace("\n", " "),
        "forking_index": r["forking_index"],
        "correct_answer": r["correct_answer"]
    }
    for r in results
]

df = pd.DataFrame(csv_friendly)
df.to_csv("forking_results.csv", index=False)
