In [1]:
import networkx as nx
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Output
import random

def calculate_height(tree, root=0):
    """
    Calculates the height of the tree from the given root.
    The 'height' is the maximum shortest-path distance from 'root' to any other node.
    """
    if not tree.nodes:
        return 0
    undirected = tree.to_undirected()
    lengths = nx.shortest_path_length(undirected, source=root)
    return max(lengths.values())

def generate_balanced_tree(num_leaves):
    """
    Generate a 'balanced' binary tree with exactly num_leaves leaves.
    We start with 1 leaf (the root), and repeatedly expand one leaf into two children
    until we have exactly num_leaves leaves.
    
    Each expansion replaces 1 leaf with 2 children => net +1 leaf.
    After (num_leaves - 1) expansions, we have exactly num_leaves leaves.
    """
    if num_leaves < 1:
        return nx.DiGraph()
    if num_leaves == 1:
        # Edge case: just a single node
        tree = nx.DiGraph()
        tree.add_node(0)
        return tree

    tree = nx.DiGraph()
    root = 0
    tree.add_node(root)
    leaves = [root]        # Track current leaves
    leaves_count = 1       # Start with 1 leaf (the root)
    next_label = 1

    # Keep expanding until we reach exactly num_leaves leaves
    while leaves_count < num_leaves:
        # Pop from the front => BFS order (most balanced shape)
        parent = leaves.pop(0)
        
        # Create two children
        left_child = next_label
        right_child = next_label + 1
        
        tree.add_node(left_child)
        tree.add_node(right_child)
        tree.add_edge(parent, left_child)
        tree.add_edge(parent, right_child)
        
        # We replaced 1 leaf with 2 => net +1 leaf
        leaves_count += 1
        next_label += 2
        
        # Add these new children to the leaves
        leaves.append(left_child)
        leaves.append(right_child)

    return tree

def generate_unbalanced_tree_random(num_leaves):
    """
    Generate an 'unbalanced' binary tree with exactly num_leaves leaves,
    by randomly choosing which leaf to expand each time.
    
    We start with 1 leaf (the root). Each expansion replaces 1 leaf with 2 children,
    so after (num_leaves - 1) expansions, we get exactly num_leaves leaves.
    
    The randomness comes from picking a random leaf to expand at each step.
    """
    if num_leaves < 1:
        return nx.DiGraph()
    if num_leaves == 1:
        # Edge case: single-node tree
        tree = nx.DiGraph()
        tree.add_node(0)
        return tree

    tree = nx.DiGraph()
    root = 0
    tree.add_node(root)
    leaves = [root]        # Keep track of current leaves
    leaves_count = 1       # Start with 1 leaf
    next_label = 1

    # Perform exactly (num_leaves - 1) expansions to get num_leaves leaves total
    while leaves_count < num_leaves:
        # Choose a random leaf to expand
        parent = random.choice(leaves)
        leaves.remove(parent)

        # Create two children
        left_child = next_label
        right_child = next_label + 1
        tree.add_node(left_child)
        tree.add_node(right_child)
        tree.add_edge(parent, left_child)
        tree.add_edge(parent, right_child)

        # Net +1 leaf
        leaves_count += 1
        next_label += 2

        # Add the two new leaves
        leaves.append(left_child)
        leaves.append(right_child)

    return tree

def get_rooted_positions(tree, root=0):
    """
    Assign each node a 'subset' = distance from the root.
    Use nx.multipartite_layout to place nodes by layer.
    Then swap (x, y) so the tree is drawn top-to-bottom.
    """
    if not tree.nodes:
        return {}
    
    undirected = tree.to_undirected()
    dist_from_root = nx.shortest_path_length(undirected, source=root)
    
    for node, dist in dist_from_root.items():
        tree.nodes[node]['subset'] = dist

    pos = nx.multipartite_layout(tree, subset_key='subset', align='vertical', scale=2.0)
    
    # Swap x,y to have root at top and leaves at bottom
    pos_swapped = {node: (y, -x) for node, (x, y) in pos.items()}
    return pos_swapped

def draw_tree(tree, ax, title, root=0):
    """
    Draw the tree on an Axes object, showing the height in the title.
    """
    height = calculate_height(tree, root=root)
    pos = get_rooted_positions(tree, root=root)
    nx.draw(
        tree, pos, with_labels=True, node_size=600,
        node_color="lightblue", arrows=False, ax=ax
    )
    ax.set_title(f"{title}\nHeight: {height}", fontsize=12)
    ax.axis("off")

def display_interactive_tree():
    """
    An interactive widget with a slider for the number of leaves.
    Displays two trees (balanced vs. random unbalanced) one below the other.
    """
    output = Output()

    def update(num_leaves):
        with output:
            output.clear_output(wait=True)
            
            # Generate trees
            balanced_tree = generate_balanced_tree(num_leaves)
            random_unbalanced_tree = generate_unbalanced_tree_random(num_leaves)

            # Create the figure with 2 subplots (vertical stacking)
            fig, axes = plt.subplots(2, 1, figsize=(10, 14))

            # Top: Balanced Tree
            draw_tree(balanced_tree, axes[0], f"Balanced (Leaves={num_leaves})")

            # Bottom: Random Unbalanced Tree
            draw_tree(random_unbalanced_tree, axes[1], f"Random Unbalanced (Leaves={num_leaves})")

            plt.tight_layout()
            plt.show()

    # Slider from 1 to 20 leaves
    slider = IntSlider(min=1, max=20, step=1, value=4, description="# Leaves")
    interact(update, num_leaves=slider)
    display(output)

# Finally, call the function to show the widget:
display_interactive_tree()


interactive(children=(IntSlider(value=4, description='# Leaves', max=20, min=1), Output()), _dom_classes=('wid…

Output()