##SET-UP

In [84]:
!pip install pyvis



In [85]:
import math


import networkx as nx

from pyvis.network import Network


import matplotlib.pyplot as plt

import matplotlib.cm as cm
import matplotlib.colors as mcolors

In [86]:
G = nx.DiGraph()

In [87]:
MOD = 5

tie_state = ( (0, 0), (0, 0), 0.5 )

end_states = set( tie_state )

end_values = {}
end_values[ tie_state ] = 0.5

visited = set()
unreachable = set()
cycles = [ [] ]

loop_edges = set()
loop_paths = set()
all_paths = set()

start_state = ( (1, 1), (1, 1), 0 )
start_nodes = []
start_path = []


weight_color_map = {
    0: 'blue',
    0.5: 'white',
    1: 'red'
}

## normalize(h1, h2):

In [88]:
def normalize(h1, h2):
    """
    Normalize the representation of each player's hands by sorting.
    This avoids distinguishing between (1,2) and (2,1) as different.
    """
    return tuple(sorted(h1)), tuple(sorted(h2))

## map_states( state, path_nodes, current_path )

In [89]:
def map_states( current_state, path_nodes, include_cycles=False ):

    path_nodes.append( current_state )

    visited.add(current_state)
    G.add_node(current_state)

    next_states = get_next_states(current_state)

    for state in next_states:

        if state in path_nodes:

            cycles.append( list(path_nodes[ path_nodes.index(state) : path_nodes.index(current_state) ]) )

            if not include_cycles:
                G.add_edge( current_state, (state[0], state[1], 0.5) )
            else:
                G.add_edge( current_state, state )

        elif state in visited:

            G.add_edge( current_state, state )


        else:

            G.add_edge( current_state, state )

            # new_path = list( current_path + (current_state,state) )
            # new_path_nodes = list(path_nodes) + list(state)

            map_states( state, path_nodes )



    if len( next_states ) == 0:
        end_states.add(current_state)

In [90]:
def relink_cycles():

    for here,there in list( G.edges() ):

        if there[2]==0.5:

            if G.has_node( there ): G.remove_node( there )
            G.add_edge( here, ( there[0], there[1], (-1)*here[2]+1 ) )

## compute_end_values(state):

In [91]:
def compute_end_values( state ):

    if state in end_values: return end_values[state]

    successors = list(G.successors(state))

    if not successors:
        value = state[2]
        end_values[state] = value
        return value

    child_values = list()
    for child in successors:
        child_values.append(compute_end_values( child ))

    value = max( child_values )*(not state[2]) + min( child_values )*state[2]

    end_values[state] = value
    return value

In [92]:
def make_weighted():

    for u,v in G.edges():

        G[u][v]['weight'] = compute_end_values( v )

## get_next_states(state):

In [93]:
def get_next_states(state):

    a, b, turn = state
    current, opponent = (a, b) if turn == 0 else (b, a)

    next_states = set()

    for i in range(2):
        if current[i] == 0:
            continue

        for j in range(2):
            if opponent[j] == 0:
                continue

            # Copy opponent's hands and apply tap (add fingers mod 3)
            new_opp = list(opponent)
            new_opp[j] = (new_opp[j] + current[i]) % MOD

            new_current = list(current)  # Current player's hand remains the same

            # Rebuild new state: normalize both hands and switch turn
            if turn == 0:
                # After Player A's move, it's Player B's turn
                new_state = normalize(tuple(new_current), tuple(new_opp)) + (1,)
            else:
                # After Player B's move, it's Player A's turn
                new_state = normalize(tuple(new_opp), tuple(new_current)) + (0,)

            ##next_states.append(new_state)
            next_states.add(new_state)

    return next_states

## get_distance_to_end_states(end_states):

In [94]:
def get_distance_to_end_states(G, end_states):
    G_rev = G.reverse()
    distance = {}

    for end in end_states:
        sp_lengths = nx.single_source_shortest_path_length(G_rev, end)
        for node, dist in sp_lengths.items():
            if node not in distance or dist < distance[node]:
                distance[node] = dist

    return distance

## LAYOUTS

In [95]:
def get_quadrant_spring_layout(G, end_states):
    # Step 1: Group nodes into 4 quadrants
    quadrants = {
        (0, 0): [],  # end_state, turn=0 → top-left
        (0, 1): [],  # end_state, turn=1 → top-right
        (1, 0): [],  # normal_state, turn=0 → bottom-left
        (1, 1): []   # normal_state, turn=1 → bottom-right
    }

    for node in G.nodes:
        is_normal = 1 if node not in end_states else 0
        turn = node[2]
        left = 1 if turn==is_normal else 0
        quadrants[( left , turn )].append(node)

    pos = {}
    quadrant_offset = {
        (0, 0): (-1, 1),  # top-left
        (0, 1): (1, 1),   # top-right
        (1, 0): (-1, -1), # bottom-left
        (1, 1): (1, -1),  # bottom-right
    }

    for key, nodes in quadrants.items():
        subG = G.subgraph(nodes)
        local_pos = nx.spiral_layout(subG)  # stable layout per quadrant
        offset_x, offset_y = quadrant_offset[key]

        for node, (x, y) in local_pos.items():
            pos[node] = (x + offset_x, y + offset_y)

    return pos

In [96]:
def get_bipartite_by_zero_count_layout(G):
    # Group nodes by (zero_count, turn)
    groups = {}
    for node in G.nodes:
        a, b, turn = node
        num_zeros = [a[0], a[1], b[0], b[1]].count(0)
        key = (num_zeros, turn)
        groups.setdefault(key, []).append(node)

    pos = {}
    spacing = 10  # horizontal spacing
    for (zero_count, turn), nodes in sorted(groups.items()):
        y = 1 if turn == 0 else 0  # bipartite y-position
        x_start = zero_count * spacing  # grouped by zero count
        for i, node in enumerate(sorted(nodes)):
            x = x_start + i * 1.6  # spread within group
            pos[node] = (x, y)

    return pos

In [97]:
def get_zero_count_bipartite_spring_layout(G, end_states):
    from collections import defaultdict

    pos = {}
    # Structure: buckets[(zero_count, turn)] = list of nodes
    buckets = defaultdict(list)
    end_nodes = []

    # Categorize nodes
    for node in G.nodes:
        a, b, turn = node
        if node in end_states:
            end_nodes.append(node)
        else:
            total_zeros = [a[0], a[1], b[0], b[1]].count(0)
            zero_class = min(total_zeros, 2)
            buckets[(zero_class, turn)].append(node)

    # Layout end states in top row
    spacing = 2.0
    for i, node in enumerate(sorted(end_nodes)):
        x = i * spacing
        y = 2
        pos[node] = (x, y)

    # Layout each (zero_class, turn) subgroup
    base_x_offsets = {0: 0, 1: 30, 2: 60}  # space between zero classes
    x_sub_offset = {0: -20, 1: +20}          # left/right for turn 0/1
    y_base = 0

    for (zero_class, turn), nodes in buckets.items():
        if not nodes:
            continue
        subG = G.subgraph(nodes)
        local_pos = nx.spring_layout(subG, seed=100)  # or use planar_layout()

        offset_x = base_x_offsets[zero_class] + x_sub_offset[turn]
        offset_y = y_base - zero_class * 8  # stack classes vertically

        for node, (x, y) in local_pos.items():
            pos[node] = (x * 8 + offset_x, y * 4 + offset_y)

    return pos

## MAIN

In [None]:
map_states( start_state, start_nodes, include_cycles=True )
compute_end_values( start_state )

# longest_path = list( nx.dag_longest_path( G ) )
# shortest_path = list( nx.shortest_path( G, source=start_state ) )

done=False

while not done:

    done = True

    for node in list( G.nodes ):

        if node not in end_values:

            calculable = True
            for child in list( G.successors( node ) ):
                calculable = calculable and (child in end_values)

            if calculable:
                compute_end_values( node )
                done = False

relink_cycles()
cycles = list( nx.simple_cycles( G ) )

print( "NO CYCLES: "+str(nx.is_directed_acyclic_graph( G )) )
make_weighted()

norm = mcolors.Normalize(vmin=0, vmax=1)
cmap = cm.bwr

edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
edge_colors = [cmap(norm(w)) for w in edge_weights]
node_colors = [cmap(norm(end_values[node])) if node in end_values else "gray" for node in G.nodes]







labels = {node: str(node[0])+"\n"+str(node[1])+"\n"+str(node[2]) for node in G.nodes}

H = G.reverse()

pos = nx.bfs_layout( G, start_state )
# pos = nx.arf_layout( G )



# path = nx.shortest_path(G, source=((1, 3),(1,3),0), target=((1, 3),(1,3),0))

# shortest_cycle = min( cycles, key=len )
# longest_cycle = max( cycles, key=len )

# shortest_cycle_edges = list(zip(shortest_cycle[:-1], shortest_cycle[1:]))
# shortest_cycle_edges.append((shortest_cycle[-1], shortest_cycle[0]))

# longest_cycle_edges = list(zip(longest_cycle[:-1], longest_cycle[1:]))
# longest_cycle_edges.append((longest_cycle[-1], longest_cycle[0]))


# longest_path_edges = list(zip(longest_path[:-1], longest_path[1:]))

# shortest_path_edges = list(zip(shortest_path[:-1], shortest_path[1:]))

for node, (x, y) in pos.items():
        pos[node] = (y, -x)

# Draw graph
plt.figure(figsize=(50, 28))
nx.draw(
    G, pos,
    node_color=node_colors,
    edge_color=edge_colors,
    node_size=500,
    arrows=True
)


nx.draw_networkx_labels(G, pos, labels=labels, font_size=6)
# nx.draw_networkx_edges(G, pos, edgelist=longest_cycle_edges, edge_color='blue', width=10, arrows=True)
# nx.draw_networkx_edges(G, pos, edgelist=shortest_cycle_edges, edge_color='red', width=5, arrows=True)


plt.title("Longest and Shortest Cycles on Mod "+str(MOD))
plt.axis('off')
plt.show()