## Probability based tree generation

In [None]:
import os
import torch
import torch.nn.functional as F
import math
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import json
from graphviz import Digraph

###############################################
# Utility Functions for Model Inference
###############################################


def get_next_token_probs(sequence_ids):
    """
    Given a tensor of input_ids (shape: [1, seq_len]),
    returns the probability distribution for the next token.
    """
    with torch.no_grad():
        outputs = model(sequence_ids, return_dict=True)
    # Get logits for the last time step (shape: [1, vocab_size])
    logits = outputs.logits[:, -1, :]
    probs = F.softmax(logits, dim=-1)
    return probs


###############################################
# Candidate Generation (Iterative Version)
###############################################


def generate_segment_candidates(initial_ids, top_k, min_norm_prob, max_segment_length):
    """
    Generates segment candidates iteratively using a stack.
    Each candidate is a sequence generated until a special token is reached.

    Args:
        initial_ids (torch.Tensor or list[int]): Starting sequence (tensor shape: [1, seq_len]).
        top_k (int): Number of top tokens to consider at each extension.
        min_norm_prob (float): Minimum per-token normalized probability threshold.
        max_segment_length (int): Maximum number of tokens to add to the segment.

    Returns:
        candidates (list): List of tuples (candidate_ids, token_list, normalized_probability)
                           where candidate_ids is a tensor, token_list is the list of token IDs
                           (the full sequence from the start), and normalized_probability is the per-token normalized probability.
    """
    candidates = []

    # Convert the initial_ids to a list of tokens.
    if isinstance(initial_ids, torch.Tensor):
        tokens = initial_ids.squeeze(0).tolist()
    else:
        tokens = initial_ids.copy()

    # Each stack element is a tuple: (tokens_accum, current_log_prob, depth, visited)
    # 'visited' is a frozenset of decoded sequences to avoid repetition.
    stack = [(tokens, 0.0, 0, frozenset())]

    while stack:
        tokens_accum, current_log_prob, depth, visited = stack.pop()

        # Terminate if we hit model's max positions or our max segment length.
        if len(tokens_accum) >= model.config.n_positions or depth >= max_segment_length:
            avg_log_prob = current_log_prob / (len(tokens_accum) if tokens_accum else 1)
            norm_prob = math.exp(avg_log_prob)
            current_ids = torch.tensor([tokens_accum], device=model.device)
            candidates.append((current_ids, tokens_accum.copy(), norm_prob))
            continue

        # Build input tensor from the current token list.
        current_ids = torch.tensor([tokens_accum], device=model.device)
        probs = get_next_token_probs(current_ids)

        # Get top_k tokens for extension.
        topk_probs, topk_indices = torch.topk(probs.squeeze(0), top_k)
        for prob, idx in zip(topk_probs, topk_indices):
            token = idx.item()
            new_log_prob = current_log_prob + math.log(prob.item())
            new_tokens = tokens_accum + [token]
            avg_log_prob = new_log_prob / len(new_tokens)
            norm_prob = math.exp(avg_log_prob)

            if norm_prob < min_norm_prob:
                continue

            # Decode the new sequence only for checking visited.
            new_sequence = tokenizer.decode(new_tokens).strip()
            if new_sequence in visited:
                continue

            new_visited = visited.union({new_sequence})
            if token in SPECIAL_TOKEN_IDS:
                new_ids = torch.tensor([new_tokens], device=model.device)
                candidates.append((new_ids, new_tokens.copy(), norm_prob))
            else:
                stack.append((new_tokens, new_log_prob, depth + 1, new_visited))

    return candidates


###############################################
# Tree Building Functions (Segment-Level)
###############################################


def build_tree_segment(
    sequence_ids,
    max_depth,
    current_depth=0,
    top_k=5,
    min_norm_prob=0.3,
    max_segment_length=20,
):
    """
    Recursively builds a tree where each branch is a segment generated until a special token is encountered.
    A "node" represents the newly generated tokens between special tokens.

    Returns:
        A dictionary with the decoded full sequence ('sequence'), the list of children nodes,
        and (for non-root nodes) a 'segment_tokens' field that contains only the newly generated tokens.
    """
    decoded = tokenizer.decode(sequence_ids[0])
    # Get parent's tokens as a list.
    parent_tokens = sequence_ids[0].tolist()

    # Stop if max depth is reached or EOS is encountered.
    if current_depth >= max_depth or tokenizer.eos_token_id in sequence_ids[0]:
        return {"sequence": decoded, "children": []}

    # Generate candidate segments.
    segment_candidates = generate_segment_candidates(
        sequence_ids,
        top_k=top_k,
        min_norm_prob=min_norm_prob,
        max_segment_length=max_segment_length,
    )

    # Sort candidates by normalized probability (highest first) and select top_k.
    segment_candidates.sort(key=lambda x: x[2], reverse=True)
    segment_candidates = segment_candidates[:top_k]

    node = {"sequence": decoded, "children": []}
    for new_ids, full_candidate_tokens, seg_norm_prob in segment_candidates:
        # Extract only the newly generated tokens (i.e. the segment) by removing the parent's prefix.
        segment = full_candidate_tokens[len(parent_tokens) :]
        child = build_tree_segment(
            new_ids,
            max_depth,
            current_depth + 1,
            top_k,
            min_norm_prob,
            max_segment_length,
        )
        # Save only the segment tokens in the child.
        child["segment_tokens"] = segment
        child["segment_norm_prob"] = seg_norm_prob
        node["children"].append(child)

    return node


def generate_hierarchical_tree_segment(
    max_depth=5, top_k=5, min_norm_prob=0.3, max_segment_length=20
):
    """
    Starts with the BOS token and builds a hierarchical tree using segment-level steps.
    """
    initial_ids = tokenizer.encode(tokenizer.bos_token, return_tensors="pt").to(device)
    tree = build_tree_segment(
        initial_ids,
        max_depth=max_depth,
        top_k=top_k,
        min_norm_prob=min_norm_prob,
        max_segment_length=max_segment_length,
    )
    return tree


###############################################
# Visualization and Pruning Functions
###############################################


def visualize_tree_segment(
    tree, graph=None, parent_id=None, counter=[0], include_special_tokens=True
):
    """
    Recursively adds nodes and edges from the tree to a Graphviz Digraph.
    Each node displays only the segment tokens generated at that node (i.e. the tokens
    added between special tokens), not the full sequence.

    Args:
      tree (dict): The tree dictionary.
      graph (Digraph): The Graphviz graph instance.
      parent_id (str): ID of the parent node.
      counter (list): A mutable counter to generate unique node IDs.
      include_special_tokens (bool): If True, special tokens (including <BOS>, <EOS>, <DOWNWARD>, etc.)
                                     are shown; if False, they are filtered out.

    Returns:
      graph (Digraph): The updated Graphviz Digraph.
    """
    if graph is None:
        graph = Digraph()

    node_id = f"node_{counter[0]}"
    counter[0] += 1

    # For non-root nodes, use the stored 'segment_tokens' (which is just the newly generated segment).
    if "segment_tokens" in tree and tree["segment_tokens"]:
        tokens = tree["segment_tokens"]
        if not include_special_tokens:
            tokens = [tok for tok in tokens if tok not in SPECIAL_TOKEN_IDS]
        label = tokenizer.decode(tokens).strip()
        if not label:
            label = "<empty>"
    else:
        # For the root node, display the BOS token.
        label = tokenizer.bos_token

    # Optionally, append normalized probability.
    if "segment_norm_prob" in tree:
        label += f"\n(prob: {tree['segment_norm_prob']:.2f})"

    graph.node(node_id, label)

    if parent_id is not None:
        graph.edge(parent_id, node_id)

    for child in tree.get("children", []):
        visualize_tree_segment(
            child,
            graph,
            node_id,
            counter,
            include_special_tokens=include_special_tokens,
        )

    return graph


def prune_tree_branch(node, seen_words=None):
    """
    Recursively prunes the tree by removing branches where the "word"
    (decoded from the segment tokens after filtering special tokens) is duplicated along the branch.

    This enforces uniqueness within each branch only.
    """
    print("Pruning branch-level duplicates...")
    if seen_words is None:
        seen_words = set()

    if "segment_tokens" in node and node["segment_tokens"]:
        word_tokens = [
            tok for tok in node["segment_tokens"] if tok not in SPECIAL_TOKEN_IDS
        ]
        word = tokenizer.decode(word_tokens).strip()
    else:
        word = tokenizer.bos_token

    if word in seen_words:
        return None

    new_seen = seen_words.union({word})
    pruned_children = []
    for child in node.get("children", []):
        pruned_child = prune_tree_branch(child, new_seen)
        if pruned_child is not None:
            pruned_children.append(pruned_child)
    node["children"] = pruned_children
    return node


def prune_tree_global(node, global_seen=None):
    """
    Recursively prunes the tree so that each unique "word" (decoded from the segment tokens)
    appears only once in the entire tree.

    This enforces global uniqueness.
    """
    print("Pruning global duplicates...")
    if global_seen is None:
        global_seen = set()

    if "segment_tokens" in node and node["segment_tokens"]:
        word_tokens = [
            tok for tok in node["segment_tokens"] if tok not in SPECIAL_TOKEN_IDS
        ]
        word = tokenizer.decode(word_tokens).strip()
    else:
        word = tokenizer.bos_token

    if word in global_seen:
        return None

    global_seen.add(word)

    pruned_children = []
    for child in node.get("children", []):
        pruned_child = prune_tree_global(child, global_seen)
        if pruned_child is not None:
            pruned_children.append(pruned_child)
    node["children"] = pruned_children
    return node


###############################################
# Configuration, Model Loading, and Run
###############################################

num_classes = 10
checkpoint_number = 1060
model_size = "small"
sample_first_batch = True
sampling_mode = "class_aware"

# --- Load Tokenizer ---
tokenizer = GPT2Tokenizer.from_pretrained(f"./custom_tokenizer")
print("Tokenizer loaded from checkpoint.")

# --- Load Model ---
checkpoint_dir = f"model_output_{num_classes}/model_size_{model_size}/sample_first_batch_{sample_first_batch}/sampling_mode_{sampling_mode}/checkpoint-{checkpoint_number}"
config = GPT2Config.from_pretrained(checkpoint_dir)
model = GPT2LMHeadModel.from_pretrained(checkpoint_dir, config=config)

# Resize model embeddings to account for any added special tokens.
model.resize_token_embeddings(len(tokenizer))

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"  # Forcing CPU; change if you wish to use GPU.
model.to(device)
model.eval()
print(f"Model loaded from {checkpoint_dir}")

# --- Special Token IDs ---
SPECIAL_TOKEN_IDS = {
    tokenizer.convert_tokens_to_ids("<DOWNWARD>"),
    tokenizer.eos_token_id,
    tokenizer.pad_token_id,
    tokenizer.bos_token_id,
}

# --- Hierarchical Generation Parameters ---
max_depth = 7
top_k = 2
min_norm_prob = 0.5

# --- Run the Hierarchical Generation (Segment-Level) ---
tree = generate_hierarchical_tree_segment(
    max_depth=max_depth,
    top_k=top_k,
    min_norm_prob=min_norm_prob,
    max_segment_length=20,
)

# --- Optional Pruning ---
# Toggle these flags to apply branch-level and/or global pruning.
apply_branch_pruning = False  # Set to True to apply branch-level pruning.
apply_global_pruning = False  # Set to True to apply global pruning.

if apply_branch_pruning:
    tree = prune_tree_branch(tree)
if apply_global_pruning:
    tree = prune_tree_global(tree)

# Create a tag to include in the filename indicating which pruning was applied.
pruning_tag = f"branchpruned_{apply_branch_pruning}_globalpruned_{apply_global_pruning}"

# --- Visualize and Render the Tree ---
# Set include_special_tokens to True to show special tokens (like <EOS>) in node labels.
graph = visualize_tree_segment(tree, include_special_tokens=False)
output_filename = (
    f"./tree_figures/num_classes_{num_classes}_sample_first_batch_{sample_first_batch}_"
    f"sampling_mode_{sampling_mode}_checkpoint_number_{checkpoint_number}_max_depth_{max_depth}_"
    f"top_k_{top_k}_min_norm_prob_{min_norm_prob}_{pruning_tag}"
)
graph.render(output_filename, view=True)

  from .autonotebook import tqdm as notebook_tqdm


Tokenizer loaded from checkpoint.


OSError: Incorrect path_or_model_id: 'model_output_10/model_size_small/sample_first_batch_True/sampling_mode_class_aware/checkpoint-1060'. Please provide either the path to a local folder or the repo_id of a model on the Hub.