In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import deque

In [165]:
class ChopsticksGame:
    def __init__(self, state=((1, 1), (1, 1)), player=0):
        self.player = player
        self.state = self.rectify_state(state)

    @staticmethod
    def compare_player_states(state1, state2):
        return frozenset(state1) == frozenset(state2)

    def compare_states(self, state1, state2):
        return (
            self.compare_player_states(state1[0], state2[0])
            and self.compare_player_states(state1[1], state2[1])
        )

    @staticmethod
    def rectify_state(state):
        zero = lambda value: value if not value >= 5 else 0
        return (
            (zero(state[0][0]), zero(state[0][1])),
            (zero(state[1][0]), zero(state[1][1])),
        )

    def __eq__(self, other):
        return self.compare_states(self.state, other.state) and self.player == other.player

    def __repr__(self):
        return f"ChopsticksGame(state={self.state}, player={self.player})"

    def __hash__(self):
        return hash((frozenset(self.state[0]), frozenset(self.state[1]), self.player))

    def __name__(self):
        return None

    def player_0_win(self):
        return self.state[1][0] == 0 and self.state[1][1] == 0

    def enumerate_states(self):
        player = self.player

        other = 1 - player
        zero = lambda value: value if not value >= 5 else 0

        atack_other_states = [((zero(self.state[other][0] + self.state[player][0]), self.state[other][1]))]
        if self.state[player][0] != self.state[player][1]:
            atack_other_states.append(
                (zero(self.state[other][0] + self.state[player][1]), self.state[other][1])
            )

        if self.state[other][0] != self.state[other][1]:
            atack_other_states.append(
                (self.state[other][0], zero(self.state[other][1] + self.state[player][0]))
            )

        if (self.state[player][0] != self.state[player][1]) and (
            self.state[other][0] != self.state[other][1]
        ):
            atack_other_states.append(
                (self.state[other][0], zero(self.state[other][1] + self.state[player][1]))
            )

        total = self.state[player][0] + self.state[player][1]
        switch_player_states = [
            (switch, total - switch)
            for switch in range(max(total - 5, 0), int(total / 2 + 1))
        ]

        if player == 0:
            atack_states = [(self.state[0], atack) for atack in atack_other_states]
            switch_states = [(switch, self.state[1]) for switch in switch_player_states]
        else:
            atack_states = [(atack, self.state[1]) for atack in atack_other_states]
            switch_states = [(self.state[0], switch) for switch in switch_player_states]

        switch_states = [
            state for state in switch_states if not self.compare_states(state, self.state)
        ]

        states = atack_states + switch_states

        return states

    def get_neighbors(self):
        states = self.enumerate_states()
        neighbours = [self.__class__(state, 1 - self.player) for state in states]
        return neighbours

In [166]:
def build_adjacency_matrix(start_node):
    node_to_index = {}
    index_to_node = []
    adj_matrix = np.zeros((625, 625), dtype=int)

    queue = deque([start_node])
    visited = set()

    while queue:
        current = queue.popleft()
        if current in visited:
            continue
        visited.add(current)

        if current not in node_to_index:
            node_to_index[current] = len(index_to_node)
            index_to_node.append(current)

        for neighbor in current.get_neighbors():
            if neighbor not in node_to_index:
                node_to_index[neighbor] = len(index_to_node)
                index_to_node.append(neighbor)
            i = node_to_index[current]
            j = node_to_index[neighbor]
            adj_matrix[i, j] = 1
            queue.append(neighbor)

    n = len(index_to_node)
    return adj_matrix[:n, :n], index_to_node

In [167]:
start_node = ChopsticksGame()
adjacency_matrix, index_to_node = build_adjacency_matrix(start_node)

In [174]:
def compute_player_0_attractor(adj_matrix, index_to_node):
    n = len(index_to_node)
    successors = [set(np.where(adj_matrix[i] == 1)[0]) for i in range(n)]
    predecessors = [set(np.where(adj_matrix[:, i] == 1)[0]) for i in range(n)]

    attr = set(i for i, node in enumerate(index_to_node) if node.player_0_win())
    frontier = set(attr)

    while frontier:
        new_frontier = set()
        for s in range(n):
            if s in attr:
                continue
            node = index_to_node[s]
            succ = successors[s]

            if node.player == 0:
                if any(t in attr for t in succ):
                    attr.add(s)
                    new_frontier.add(s)
            else:
                if succ and all(t in attr for t in succ):
                    attr.add(s)
                    new_frontier.add(s)
        frontier = new_frontier

    return attr

In [176]:
attr_set = compute_player_0_attractor(adjacency_matrix, index_to_node)

In [186]:
attr_set = {
    index_to_node[i]
    for i in attr_set
    if not index_to_node[i].state[1] == (0, 0)
    and not index_to_node[i].state[0] == (0, 0)
}

In [189]:
for i in attr_set:
    print(i)

ChopsticksGame(state=((2, 2), (0, 3)), player=0)
ChopsticksGame(state=((2, 4), (0, 1)), player=0)
ChopsticksGame(state=((1, 1), (0, 4)), player=0)
ChopsticksGame(state=((0, 4), (0, 1)), player=0)
ChopsticksGame(state=((4, 1), (3, 0)), player=0)
ChopsticksGame(state=((4, 4), (3, 3)), player=0)
ChopsticksGame(state=((3, 3), (0, 4)), player=0)
ChopsticksGame(state=((4, 2), (0, 4)), player=0)
ChopsticksGame(state=((4, 4), (2, 4)), player=0)
ChopsticksGame(state=((0, 3), (0, 3)), player=0)
ChopsticksGame(state=((4, 4), (0, 4)), player=1)
ChopsticksGame(state=((4, 4), (2, 0)), player=0)
ChopsticksGame(state=((3, 1), (0, 2)), player=0)
ChopsticksGame(state=((0, 2), (0, 3)), player=0)
ChopsticksGame(state=((3, 4), (3, 0)), player=0)
ChopsticksGame(state=((4, 4), (0, 1)), player=0)
ChopsticksGame(state=((3, 2), (2, 0)), player=0)
ChopsticksGame(state=((4, 4), (4, 4)), player=0)
ChopsticksGame(state=((4, 4), (3, 2)), player=0)
ChopsticksGame(state=((0, 4), (0, 4)), player=0)
ChopsticksGame(state