In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F


model_name = "Qwen/Qwen1.5-1.8B"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

model.eval()
print("Qwen 1.5-1.8B ready")

In [None]:
def generate_completion(prompt, max_new_tokens=20, temperature=1.0, seed=None):
    import torch
    if seed is not None:
        torch.manual_seed(seed)

    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
    output_ids = model.generate(
        **input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,  # we want different paths
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)

In [None]:
def sample_fork_completions(prefix, alt_token, num_samples=10, max_new_tokens=100):
    fork_prompt = prefix + alt_token
    completions = []

    for i in range(num_samples):
        gen = generate_completion(
            fork_prompt,
            max_new_tokens=max_new_tokens,
            temperature=1.0,
            seed=None  # randomness enabled
        )
        completions.append(gen)

    return {
        "fork_prompt": fork_prompt,
        "prefix": prefix,
        "alt_token": alt_token,
        "completions": completions
    }

In [None]:
from datasets import load_dataset

# Load the subject split
dataset = load_dataset(
    "tasksource/mmlu",
    name="high_school_government_and_politics",
    split="test"
)

# Print one example
print(dataset[0])

In [None]:
def format_prompt(example):
    question = example["question"]
    choices = example["choices"]
    labels = ["A", "B", "C", "D"]
    lettered_choices = [f"{label}. {text}" for label, text in zip(labels, choices)]
    joined = "\n".join(lettered_choices)
    prompt = f"{question}\n{joined}\nAnswer:"
    return prompt

In [None]:
prompt = format_prompt(dataset[0])
output1 = generate_completion(prompt, seed=42)
output2 = generate_completion(prompt, seed=123)

print("Completion 1:\n", output1)
print("\nCompletion 2:\n", output2)