In [5]:
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import List

@dataclass
class TreeNode:
    weight: float
    children: List["TreeNode"]

    def __init__(self, weight: float):
        self.weight = weight
        self.children = []

    def add_child(self, child: "TreeNode") -> None:
        self.children.append(child)

N = 3  # fixed depth from the assignment


def generate_tree(
    n_children: int,
    max_depth: int,
    root_weight: float,
) -> TreeNode:
    """
    Generate a full n-ary tree of a given depth.
    Depth 0 means "just the root".
    Each child has weight = parent.weight / n_children.
    """

    def build_node(weight: float, depth: int) -> TreeNode:
        node = TreeNode(weight)
        if depth < max_depth:
            child_weight = weight / n_children
            for _ in range(n_children):
                node.add_child(build_node(child_weight, depth + 1))
        return node

    return build_node(root_weight, depth=0)


def make_root_literal(n_children: int) -> TreeNode:
    """
    Version that follows the statement literally:
    tree of depth N=3, root weight = 1/n.
    Total weight in this case is (N+1)/n = 4/n.
    """
    root_weight = 1.0 / n_children      # as in the problem text
    return generate_tree(n_children, max_depth=N, root_weight=root_weight)


def make_root_unit_sum(n_children: int) -> TreeNode:
    """
    Version tuned so that the total sum of weights is 1
    for ANY n. Here we choose root_weight = 1/(N+1),
    so each level sums to 1/(N+1), and there are N+1 levels.
    """
    root_weight = 1.0 / (N + 1)
    return generate_tree(n_children, max_depth=N, root_weight=root_weight)

def dfs_sum(node: TreeNode) -> float:
    total = node.weight
    for child in node.children:
        total += dfs_sum(child)
    return total

def bfs_sum(root: TreeNode) -> float:
    total = 0.0
    q: deque[TreeNode] = deque([root])
    while q:
        node = q.popleft()
        total += node.weight
        for child in node.children:
            q.append(child)
    return total

def dfs_flip_sum(node: TreeNode) -> float:
    total = node.weight      # use current sign first
    node.weight *= -1        # then flip
    for child in node.children:
        total += dfs_flip_sum(child)
    return total


def bfs_flip_sum(root: TreeNode) -> float:
    total = 0.0
    q: deque[TreeNode] = deque([root])
    while q:
        node = q.popleft()
        total += node.weight
        node.weight *= -1
        for child in node.children:
            q.append(child)
    return total

def bfs_sum_recursive(level_nodes: List[TreeNode]) -> float:
    if not level_nodes:
        return 0.0

    total = 0.0
    next_level: List[TreeNode] = []
    for node in level_nodes:
        total += node.weight
        next_level.extend(node.children)

    return total + bfs_sum_recursive(next_level)


def bfs_flip_sum_recursive(level_nodes: List[TreeNode]) -> float:
    if not level_nodes:
        return 0.0

    total = 0.0
    next_level: List[TreeNode] = []
    for node in level_nodes:
        total += node.weight
        node.weight *= -1
        next_level.extend(node.children)

    return total + bfs_flip_sum_recursive(next_level)

def main() -> None:
    print("=== 1) DFS and BFS sums with literal root weight = 1/n (N = 3) ===")
    for n in [2, 3, 4, 5]:
        root = make_root_literal(n)
        s_dfs = dfs_sum(root)
        s_bfs = bfs_sum(root)
        print(f"n={n}: dfs_sum={s_dfs:.6f}, bfs_sum={s_bfs:.6f}, expected=(N+1)/n={ (N+1)/n :.6f}")

    print("\n=== 2) DFS and BFS sums where total is forced to 1 for various n ===")
    for n in [2, 3, 5, 10]:
        root = make_root_unit_sum(n)
        s_dfs = dfs_sum(root)
        s_bfs = bfs_sum(root)
        print(f"n={n}: dfs_sum={s_dfs:.6f}, bfs_sum={s_bfs:.6f}  (should both be 1.0)")

    print("\n=== 3) Flip sign with DFS/BFS: first run 1, second run -1 (fixed n) ===")
    fixed_n = 3

    # DFS flip
    root = make_root_unit_sum(fixed_n)
    first = dfs_flip_sum(root)
    second = dfs_flip_sum(root)
    print(f"DFS flip, n={fixed_n}: first={first:.6f}, second={second:.6f}")

    # BFS flip
    root = make_root_unit_sum(fixed_n)
    first = bfs_flip_sum(root)
    second = bfs_flip_sum(root)
    print(f"BFS flip, n={fixed_n}: first={first:.6f}, second={second:.6f}")

    print("\n=== 4) Recursive vs iterative BFS (sanity check) ===")
    root = make_root_unit_sum(4)
    print(f"Iterative BFS sum: {bfs_sum(root):.6f}")
    print(f"Recursive BFS sum: {bfs_sum_recursive([root]):.6f}")

    print("\n=== 5) Recursive BFS flip check ===")
    root = make_root_unit_sum(4)
    first = bfs_flip_sum_recursive([root])
    second = bfs_flip_sum_recursive([root])
    print(f"Recursive BFS flip: first={first:.6f}, second={second:.6f}")


if __name__ == "__main__":
    main()


=== 1) DFS and BFS sums with literal root weight = 1/n (N = 3) ===
n=2: dfs_sum=2.000000, bfs_sum=2.000000, expected=(N+1)/n=2.000000
n=3: dfs_sum=1.333333, bfs_sum=1.333333, expected=(N+1)/n=1.333333
n=4: dfs_sum=1.000000, bfs_sum=1.000000, expected=(N+1)/n=1.000000
n=5: dfs_sum=0.800000, bfs_sum=0.800000, expected=(N+1)/n=0.800000

=== 2) DFS and BFS sums where total is forced to 1 for various n ===
n=2: dfs_sum=1.000000, bfs_sum=1.000000  (should both be 1.0)
n=3: dfs_sum=1.000000, bfs_sum=1.000000  (should both be 1.0)
n=5: dfs_sum=1.000000, bfs_sum=1.000000  (should both be 1.0)
n=10: dfs_sum=1.000000, bfs_sum=1.000000  (should both be 1.0)

=== 3) Flip sign with DFS/BFS: first run 1, second run -1 (fixed n) ===
DFS flip, n=3: first=1.000000, second=-1.000000
BFS flip, n=3: first=1.000000, second=-1.000000

=== 4) Recursive vs iterative BFS (sanity check) ===
Iterative BFS sum: 1.000000
Recursive BFS sum: 1.000000

=== 5) Recursive BFS flip check ===
Recursive BFS flip: first=1.00