In [2]:
import tqdm
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList

In [3]:
# Custom stopping criteria: stop when the last token is a period.
class StopOnToken(StoppingCriteria):
    def __init__(self, token_id):
        self.token_id = token_id

    def __call__(self, input_ids, scores, **kwargs):
        # Stop if the last token equals the target token.
        if input_ids[0, -1].item() == self.token_id:
            return True
        return False

In [4]:
# Load the model and tokenizer (e.g., GPT-2).
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [5]:

# Define the stop token (a period).
stop_token_id = tokenizer.convert_tokens_to_ids(".")

In [6]:
def generate_sentence(prompt, max_new_tokens=50):
    """Generate a sentence from a prompt and compute average entropy over generated tokens."""
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    stopping_criteria = StoppingCriteriaList([StopOnToken(stop_token_id)])
    
    # Use sampling to get diverse outputs.
    outputs = model.generate(
        input_ids,
        max_length=input_ids.shape[1] + max_new_tokens,
        stopping_criteria=stopping_criteria,
        output_scores=True,
        return_dict_in_generate=True,
        do_sample=True,
    )
    
    generated_ids = outputs.sequences[0]
    scores = outputs.scores  # List of tensors (one per generated token)
    
    # Calculate entropy at each generation step.
    entropies = []
    epsilon = 1e-10  # Avoid log(0)
    for step_logits in scores:
        probs = torch.softmax(step_logits, dim=-1)
        entropy = -(probs * torch.log(probs + epsilon)).sum()
        entropies.append(entropy.item())
    
    avg_entropy = sum(entropies) / len(entropies) if entropies else 0
    sentence = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return sentence, avg_entropy

In [7]:
# A simple tree node to hold each branch
class Branch:
    def __init__(self, text, avg_entropy, children=None):
        self.text = text
        self.avg_entropy = avg_entropy
        self.children = children if children is not None else []

    def __repr__(self):
        return f"Branch(text={self.text!r}, avg_entropy={self.avg_entropy:.2f}, children={self.children})"

In [8]:
def build_tree(prompt, entropy_threshold, num_branches=3, depth=0, max_depth=3):
    """
    Generate a sentence branch and, if the average antropy is above the threshold,
    create multiple continuation branches recursively. 
    """
    sentence, avg_entropy = generate_sentence(prompt)
    branch = Branch(text=sentence, avg_entropy=avg_entropy)

    # Continue branching if the sentence is 'uncertain' (entropy above threshold)
    # and we haven't reached the maximum tree depth.
    if avg_entropy > entropy_threshold and depth < max_depth:
        for _ in range(num_branches):
            child_branch = build_tree(sentence, entropy_threshold, num_branches, depth+1, max_depth)
            branch.children.append(child_branch)
    return branch

In [10]:
init_prompt = "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
entropy_threshold = 5.0
tree = build_tree(init_prompt, entropy_threshold)
print(tree)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Branch(text='Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Only about a dozen.', avg_entropy=2.78, children=[])
