In [None]:
from heapq import heappop, heappush
from queue import Queue

import pygambit as gbt

In [None]:
class MaxPriorityQueue(Queue):
    """Variant of Queue that retrieves open entries in priority order (highest first).

    Entries are typically tuples of the form:  (priority number, data).
    """

    def _init(self, maxsize):
        self.queue = []

    def _qsize(self):
        return len(self.queue)

    def _put(self, item):
        heappush(self.queue, -item)

    def _get(self):
        return -heappop(self.queue)

In [None]:
class SubgameRootFinder:

    def __init__(self):
        self.game = None
        # two auxiliary fields
        self.index_to_node = None
        self.node_to_index = None
        # final result is stored in this field
        self.infoset_to_roots = None

    def _build_mappings(self) -> tuple[list[gbt.Node], dict[gbt.Node, int]]:
        index_to_node = list(self.game.nodes)
        node_to_index = {node: index for (index, node) in enumerate(self.game.nodes)}
        self.index_to_node = index_to_node
        self.node_to_index = node_to_index

    def _build_layer(self, nodes: set[int]) -> MaxPriorityQueue:
        result = MaxPriorityQueue()
        for node in nodes:
            result.put(node)
        return result

    def _explore_component(self, start_node: gbt.Node, infoset_to_roots):
        visited_infosets = set()
        visited_nodes = set() # set-ification of frontier
        frontier = MaxPriorityQueue()

        while True:

            if start_node.infoset not in infoset_to_roots:
                if start_node.infoset not in visited_infosets:
                    visited_infosets.add(start_node.infoset)
                    for member in start_node.infoset.members:
                        if member != start_node:
                            frontier.put(self.node_to_index[member])
                            visited_nodes.add(member)
                if (not frontier.empty()
                    and start_node.parent
                    and start_node.parent not in visited_nodes):
                    frontier.put(self.node_to_index[start_node.parent])
                    visited_nodes.add(start_node.parent)
                if frontier.empty():
                    for infoset in visited_infosets:
                        infoset_to_roots[infoset] = start_node
                    break

            else:
                reroot = infoset_to_roots[start_node.infoset]
                if reroot not in visited_nodes:
                    frontier.put(self.node_to_index[reroot])
                    visited_nodes.add(reroot)
                    for (infoset, root) in infoset_to_roots.items():
                        if root == reroot:
                            visited_infosets.add(infoset)
                infoset_to_roots = {
                    infoset: root for infoset, root in infoset_to_roots.items() if root != reroot
                }

            start_node = self.index_to_node[frontier.get()]

        return infoset_to_roots

    def _find_roots_layer(
            self, layer: MaxPriorityQueue, infoset_to_roots: dict[gbt.Infoset, gbt.Node]
        ):
        while not layer.empty():
            node = self.index_to_node[layer.get()]
            # check if the node's infoset was encountered and recorded in I2R
            if node.infoset in infoset_to_roots:
                continue
            infoset_to_roots = self._explore_component(node, infoset_to_roots)
        return infoset_to_roots

    def find_roots(self, game: gbt.Game) -> dict[gbt.Infoset, gbt.Node]:
        self.game = game
        self._build_mappings()
        infoset_to_roots = {}
        leaves = {node for node in game.nodes if node.is_terminal}
        exploration_layer = self._build_layer({self.node_to_index[leaf.parent] for leaf in leaves})

        while not exploration_layer.empty():
            infoset_to_roots_copy = infoset_to_roots.copy()
            infoset_to_roots = self._find_roots_layer(exploration_layer, infoset_to_roots)
            exploration_layer = self._build_layer({
                self.node_to_index[node.parent]
                for node in infoset_to_roots.values()
                if node.parent and node not in infoset_to_roots_copy.values()
            })
        self.infoset_to_roots = infoset_to_roots
        return infoset_to_roots

In [None]:
import os


def get_node_path(node: gbt.Node) -> list[int]:
    """
    Computes the path from a given node up to the root as a list of action
    indices. This is the correct, leaf-to-root representation.
    """
    if not node.parent:
        return []
    path = []
    curr = node
    while curr.parent:
        parent = curr.parent
        action_index = list(parent.children).index(curr)
        path.append(action_index)
        curr = parent
    return path


def run_tests():
    """
    Executes a series of predefined tests on the SubgameRootFinder class.
    """
    print("Running SubgameRootFinder Prototype Tests...\n")

    # Define the test cases with the correct, leaf-to-root path representation.
    test_cases = {
        "PI.efg": {
            (0, 0): [],       # Player 1, infoset 1 (index 0)
            (0, 1): [0],      # Player 1, infoset 2 (index 1)
            (1, 0): [0, 0],   # Player 2, infoset 1 (index 0)
            (1, 1): [1, 0],   # Player 2, infoset 2 (index 1)
            (1, 2): [1]       # Player 2, infoset 3 (index 2)
        },
        "e01.efg": {
            (0, 0): [],
            (1, 0): [],
            (2, 0): []
        },
        "e02.efg": {
            (0, 0): [],
            (1, 0): [1],
            (0, 1): [1, 1]
        },
        "s-diff.efg": {
            (0, 0): [], (0, 1): [], (0, 2): [], (0, 3): [], (0, 4): [],
            (1, 0): [], (1, 1): [], (1, 2): [], (1, 3): [], (1, 4): []
        },
        "multi-subgame.efg": {
            (0, 0): [],
            (0, 1): [0, 0],
            (0, 2): [0, 0, 0, 0],
            (0, 3): [],
            (1, 0): [],
            (1, 1): [0, 0, 0],
            (1, 2): [0, 0, 0, 0, 0]
        },
        "noPR-AM-driver-one-player.efg": {
            (0, 0): [],
            (0, 1): [],
            (0, 2): [1, 0, 0]
        }
    }

    # Iterate through tests, run the algorithm, and assert correctness.
    all_passed = True
    test_games_dir = "../tests/test_games"

    for test_name, expected_map in test_cases.items():
        print(f"--- Running Test: {test_name} ---")
        file_path = os.path.join(test_games_dir, test_name)

        transformed_result = {} # Define outside of try block for error reporting

        if not os.path.exists(file_path):
            print(f"SKIPPED: Test file not found at '{file_path}'")
            all_passed = False
            continue

        try:
            game = gbt.read_efg(file_path)
            finder = SubgameRootFinder()
            result_map = finder.find_roots(game)

            # Transform the result into the stable, comparable format
            transformed_result = {
                (list(game.players).index(infoset.player), infoset.number): get_node_path(node)
                for infoset, node in result_map.items()
            }

            assert transformed_result == expected_map, \
                "The result map does not match the expected map."
            print("PASSED")

        except AssertionError as e:
            all_passed = False
            print(f"FAILED: {e}")
            print(f"Expected: {expected_map}")
            print(f"Actual:   {transformed_result}")
        except Exception as e:
            all_passed = False
            print(f"ERROR: An unexpected error occurred during the test: {type(e).__name__}: {e}")

        print("-" * (len(test_name) + 20) + "\n")

    # Final summary
    if all_passed:
        print("All prototype tests completed successfully!")
    else:
        print("Some prototype tests failed or were skipped.")

if __name__ == "__main__":
    run_tests()