V2: 
- use GNNs to generalize strategy to arbitrary sized chess graphs
    - maybe this will get a general strategy which we can use on 16x16 (so we dont have to train on 16x16 first)
- if that works, make multi agent adversarial model for robber

In [1]:
from sage.all import graphs
import copy, sys, math


def available_squares(cop_pos: list, robber_pos: tuple) -> list:
    #Function that returns the set of squares that are available to the robber
    #return the adjacent vertices of robber_pos \ adjacent vertices of cops    
    c_neighbors = set(cop_pos)
    for cop in cop_pos:
        c_neighbors = c_neighbors.union(set(G.neighbors(cop)))

    if robber_pos == (-1, -1): 
        return set(G.vertices()) - c_neighbors
    else:
        r_neighbors = set(G.neighbors(robber_pos)).union({robber_pos})
        return r_neighbors - c_neighbors

def get_axis(cop_pos, robber_pos):
    #returns 'h', 'v', 'posd', 'negd' or none for axis which a cop position occupies
    #assuming given correct input ie intersecting but not same posn
    if cop_pos == robber_pos:
        return 'h'
        
    if robber_pos[0] == cop_pos[0]:
        return 'v'
    elif robber_pos[1] == cop_pos[1]:
        return 'h'
    elif cop_pos[0] + cop_pos[1] == robber_pos[0] + robber_pos[1]:
        return 'negd'
    elif cop_pos[0] - cop_pos[1] == robber_pos[0] - robber_pos[1]:
        return 'posd'
    else:
        return None

def remove_axes_squares(robber_pos, avail, occupied_axes):
    #filters out the squares which are on a cop occupied axis from a list of vertices
    for axis in map(lambda a: a[0], set(filter(lambda a: a[1], occupied_axes.items()))):
        if axis == 'h':
            #remove all with move[1] == robber_pos[1]
            avail = set(filter(lambda posn: posn[1] != robber_pos[1], avail))
        elif axis == 'v':
            #remove all with move[0] == robber_pos[0]
            avail = set(filter(lambda posn: posn[0] != robber_pos[0], avail))
        elif axis == 'negd':
            #remove all w/move[0]+move[1] == robber_pos[0]+robber_pos[1]
            avail = set(filter(lambda posn: posn[0] + posn[1] != robber_pos[0] + robber_pos[1], avail))
        elif axis == 'posd':
            #remove all w/move[0]-move[1] == robber_pos[0]-robber_pos[1]
            avail = set(filter(lambda posn: posn[0] - posn[1] != robber_pos[0] - robber_pos[1], avail))
    return avail

def get_intersecting_squares(cop_pos, robber_pos, occupied_axes):
    #compiles a list of the squares which cops can reach that directly attack the robber on a unique line
    avail = set(G.neighbors(cop_pos)).union({cop_pos}).intersection(
        set(G.neighbors(robber_pos)).union({robber_pos}))
    return remove_axes_squares(robber_pos, avail, occupied_axes)

def get_SD_length(robber_pos):
    #determines the length of the SD
    diff = robber_pos[0] - robber_pos[1]
    sum = robber_pos[0] + robber_pos[1]
    posd_len = len(list(filter(lambda posn: posn[0] - posn[1] == diff, G.vertices())))
    negd_len = len(list(filter(lambda posn: posn[0] + posn[1] == sum, G.vertices())))
    return min(posd_len, negd_len)

def get_min_LSD_move(robber_pos: tuple, moves: list) -> tuple:
    #determines the cop config which forces the minimum LSD
    min_LSD = sys.maxsize
    min_move = tuple()

    for move in moves:
        r_moves = available_squares(move[0], robber_pos)
        LSD = -1

        #get LSD of robber for this possible move
        for r_move in r_moves:
            SD_len = get_SD_length(r_move)
            if SD_len > LSD:
                LSD = SD_len
    
    if LSD < min_LSD:
        min_LSD = LSD
        min_move = move
    
    return min_move

def seenp(move: list, robber_pos: tuple, cop_states: list, robber_states: list)-> bool:
    if not (cop_states and robber_states): #empty states given for some reason
        return False
    
    for t in range(len(cop_states) - 1, -1, -1):
        if (set(cop_states[t]) == set(move) and robber_states[t] == robber_pos):
            print("SEEN")
            return True
        
    return False

#n^2 algo where n = # available moves, which is constant...
def remove_seen_moves(moves: list, robber_pos: tuple, cop_states: list, robber_states: list)-> list:
    #removes cop moves from a list of available cop moves if that state has already been seen

    if (len(cop_states) != len(robber_states)):
        raise Exception("given nonmatching states")
    

    new_moves = [(m,v) for (m,v) in moves
             if not seenp(m, robber_pos, cop_states, robber_states)]

    return new_moves

def get_closest_unoccupied(cop_pos: list, robber_pos: tuple, idx: int)-> tuple:
    V = G.neighbors(cop_pos[idx])

    min_dist = sys.maxsize
    min_pos = cop_pos[idx]
    
    for v in V:
        dist = math.sqrt((v[0] - robber_pos[0])**2 + (v[1] - robber_pos[1])**2)
        if dist < min_dist and v not in cop_pos:
            min_dist = dist
            min_pos = v

    return min_pos

def minimize_avail_helper(curr_cop_pos, robber_pos, i, occupied_axes, cop_moves, robber_moves):
    #goal is to find the minimizing config of cops
    #so track a min config and min robber avail squares
    #occupied axes is a dict, represents which robber axes are occupied in current backtracking iteration

    #base case i > #cops
    if i >= len(curr_cop_pos):
        return curr_cop_pos, len(available_squares(curr_cop_pos, robber_pos))

    avail = get_intersecting_squares(curr_cop_pos[i], robber_pos, occupied_axes)

    moves = list() # list(posn) -> avail_squares

    for move in avail:
        #find axes this occupies
        axis = get_axis(move, robber_pos)
        occupied_axes[axis] = True
        new_cop_pos = copy.deepcopy(curr_cop_pos)
        new_cop_pos[i] = move #move cop i
        
        curr_config, curr_squares = minimize_avail_helper(new_cop_pos, robber_pos, i+1, occupied_axes, cop_moves, robber_moves)
        moves.append((curr_config, curr_squares))
    
        occupied_axes[axis] = False

    #remove all moves which revisit board states
    moves = remove_seen_moves(moves, robber_pos, cop_moves, robber_moves)

    if not moves:
        #animal case-- go closer to the robber
        curr_cop_pos[i] = get_closest_unoccupied(curr_cop_pos, robber_pos, i)
        return minimize_avail_helper(curr_cop_pos, robber_pos, i+1, occupied_axes, cop_moves, robber_moves)
    
    #get squares with min # available squares for robber
    vals = map(lambda tup: tup[1], moves)
    min_val = min(vals)
    min_avail_moves = list(filter(lambda tup: tup[1] == min_val, moves))

    #sort by which config gives the max min SD
    best_move = get_min_LSD_move(robber_pos, min_avail_moves)
    
    return best_move[0], best_move[1]

def minimize_available(cop_pos: list, robber_pos: tuple, cop_states, robber_states) -> list:
    # Function that returns the move for the cops that minimizes the number of available squares for the robber
    #this function could be the combinatorially large one, but we are going to introduce our greedy heuristic
    #our strategy is such: the cops should always directly threaten a unique line of movement
    #get set of cop i available_moves \intersect set of robber 
    #filter out whichever are on occupied axes
    #use backtracking algorithm, recursively call min_avail_helper w/i+1, new cop_pos

    occupied_axes = {
        'h': False,
        'v': False,
        'negd': False,
        'posd': False
    }
    
    min_config, min_squares = minimize_avail_helper(cop_pos, robber_pos, 0, occupied_axes, cop_states, robber_states)
    
    return min_config

def maximize_available(cop_pos: list, cop_states, robber_states, robber_pos:tuple = (-1, -1)) -> tuple: #-1 denotes no robber placed yet ie startin
    # Function that returns move for the robber that maximizes the number of squares for their next turn (assuming cops try to minimize)
    #get set of valid moves available_squares
    #for all moves m, call available_squares(cop_pos, m), get size of set
    #track max size and move, return that move
    #O(n)

    r_neighbors = available_squares(cop_pos, robber_pos)

    moves = dict() # move -> min cop move in anticipation
    
    for move in r_neighbors:
        #Q: here, do we want cop moves to take into account that cops wont repeat moves ?
        #at this point, cops are making suboptimal moves
        # i guess dont take into account, as robber doesnt care for repeating moves?
        cop_response = minimize_available(cop_pos, move, [], [])
        max_min_val = len(available_squares(cop_response, move))

        moves[move] = max_min_val
    
    if not moves and robber_pos == (-1, -1):
        return (0, 0)
    elif not moves and robber_pos != (-1, -1):
        return robber_pos

    #sort by value descending, then by SD descending
    max_val = max(moves.values())
    max_avail_moves = {k: v for k, v in moves.items() if v == max_val}

    best_moves = sorted(
        max_avail_moves.items(),
        key=lambda item: -get_SD_length(item[0])
    )
        
    return best_moves[0][0]

def k_cop_win(cop_start, robber_start, itr, cop_states, robber_states):
    #returns true if cop win possible with k cops
    cop_move = minimize_available(cop_start, robber_start, cop_states, robber_states) # The cops try to minimize the available squares
    print("Cops move:", cop_move)
    cop_states.append(cop_move)
    robber_states.append(robber_start)
    avail_squares.append(len(available_squares(cop_move, robber_start)))
    robber_move = maximize_available(cop_move, cop_states, robber_states, robber_start) # The robber tries to maximize this minimum
    print("Robber moves to:", robber_move)
    print(avail_squares[-1], "squares available for after move", itr)

    cop_states.append(cop_move)
    robber_states.append(robber_move)
    
    # Checking if the cops have captured the robber
    if len(available_squares(cop_move, robber_move)) == 0:
        return True, cop_states, robber_states
    
    # If the cops can't decrease the number of available moves, they lose
    #if len(avail_squares) > 1 and avail_squares[-1] > avail_squares[-2]:
    #    print("available squares increased")
    #    return False

    if itr > n**2:
        print("iterations exceeded")
        return False, cop_states, robber_states
            
    # If the cop's haven't won yet, keep going
    return k_cop_win(cop_move, robber_move, itr+1, cop_states, robber_states)

In [2]:
#CODE FOR GENERATING ANIMAL/ROYAL GRAPHS GIVEN DIRECTIONS

def make_graph(n, slopes, animal=False):
    from sage.all import QQ, Infinity

    vertices = [(x, y) for x in range(n) for y in range(n)]
    G = Graph()
    G.add_vertices(vertices)

    for i, (x1, y1) in enumerate(vertices):
        # Convert slope list to exact rational numbers or Infinity
        D = set(QQ(s) if s != 'inf' else Infinity for s in slopes)
        for j in range(i+1, len(vertices)):
            x2, y2 = vertices[j]
            dx = x2 - x1
            dy = y2 - y1

            if dx == 0:
                slope = Infinity
            else:
                slope = QQ(dy) / QQ(dx)

            if slope in D:
                G.add_edge((x1, y1), (x2, y2))
                if animal:
                    D.remove(slope)

    return G


In [8]:
'''
EDIT THIS CODE TO CHANGE THE GRAPH
similar to evans code, just input slopes into a list and pass it into make_graph function (specify animal or royal w/bool)
'''

n=15
knight = [2, -2, 1/2, -1/2]
queen = [0, 'inf', 1, -1]
bishop = [1, -1]
idk = [1/3, -1/3, 3, 3]
G = make_graph(n, queen, False)


In [10]:
#n = 15
#T = n**2
#G = graphs.QueenGraph([n,n])

'''
run this code to run the above greedy algorithm on the graph defined above
define a list of tuples representing where you want your cops to start in (x,y) coords
then pass into play_game()

you could also use this to iteratively check the largest n for which k cops can win w/this algorithm in a loop
'''

avail_squares = list()

def play_game(cops_start):
    print(cops_start)
    robber_start = maximize_available(cops_start.copy(), [], [])
    robber_moves = [robber_start]
    cop_moves = [cops_start]
    print(f"rstart: {robber_start}, cops: {cops_start}")
    return k_cop_win(cops_start, robber_start, 1, cop_moves, robber_moves)

#6x6 domination
dom_start = math.floor((n+1)/2) - 3
dom_set = [(dom_start, dom_start), (dom_start + 4, dom_start + 2), (dom_start + 2, dom_start + 4)]

corner_start = [(0,0), (n-1,n-1), (0,n-1)]

two_cops = [(0,0), (n-1,n-1)]

mid = math.floor(n/2)
print(mid)
four_cops = [(mid, mid), (mid-1, mid), (mid-1,mid-1), (mid,mid-1)]

knight_diag = [(mid, mid), (mid+1, mid+1), (mid-1,mid-1)]

winp, cop_moves, robber_moves = play_game(dom_set)
print("Cop win:", winp)
#print(cop_moves, robber_moves)

7
[(5, 5), (9, 7), (7, 9)]
rstart: (8, 4), cops: [(5, 5), (9, 7), (7, 9)]
Cops move: [(7, 5), (12, 4), (8, 10)]
Robber moves to: (6, 2)
4 squares available for after move 1
Cops move: [(7, 2), (8, 4), (6, 8)]
Robber moves to: (2, 6)
3 squares available for after move 2
Cops move: [(7, 11), (6, 6), (2, 8)]
Robber moves to: (5, 3)
3 squares available for after move 3
Cops move: [(7, 3), (8, 6), (5, 5)]
Robber moves to: (1, 7)
2 squares available for after move 4
Cops move: [(1, 3), (0, 6), (7, 7)]
Robber moves to: (8, 0)
2 squares available for after move 5
Cops move: [(3, 5), (6, 0), (8, 6)]
Robber moves to: (12, 4)
3 squares available for after move 6
Cops move: [(11, 5), (12, 6), (8, 4)]
Robber moves to: (9, 1)
1 squares available for after move 7
Cops move: [(8, 2), (12, 1), (9, 5)]
Robber moves to: (14, 6)
1 squares available for after move 8
Cops move: [(14, 2), (12, 8), (8, 6)]
Robber moves to: (9, 1)
2 squares available for after move 9
SEEN
Cops move: [(8, 2), (9, 11), (13, 1)]


In [5]:
%pip install ipyevents

Note: you may need to restart the kernel to use updated packages.


In [6]:
from sage.all import graphs
import networkx as nx
import matplotlib.pyplot as plt
from IPython.display import display
import ipywidgets as widgets
import math
from ipyevents import Event


def get_state(r_state, c_state):
    #get dict w/red = occupied by cops, blue = cops, black = robber, green = available for robber movement
    cop_occ = set()
    for cop in c_state:
        cop_occ = cop_occ.union(set(G.neighbors(cop)))

    cop_occ -= set(c_state)
    
    state = {
        'blue': set(c_state),
        'black': {r_state},
        'green': set(available_squares(c_state, r_state)) - {r_state},
        'red': cop_occ - {r_state}
    }
    return state

def convert_to_state(robber_moves, cop_moves):
    if len(robber_moves) != len(cop_moves):
        raise Exception("nonequal lists given")
    
    states = list()

    for state in range(len(robber_moves)):
        states.append(get_state(robber_moves[state], cop_moves[state]))

    return states

# Update function for each turn
def update(turn, states, nx_G, pos):
    fig, ax = plt.subplots(figsize=(12, 12))
    nx.draw(nx_G, pos, ax=ax, node_color='lightgrey', node_size=400, with_labels=False)

    for color, nodes in states[turn].items():
        nx.draw_networkx_nodes(nx_G, pos, nodelist=list(nodes), node_color=color, node_size=400, ax=ax)

    ax.set_title(f"Turn {math.floor(turn / 2) + 1}")
    ax.set_axis_off()
    plt.show()

'''
run this code to visualize moves made
'''

def display_game(G, cop_moves, robber_moves):
    n = int(math.sqrt(len(list(G.vertices()))))

    pos = {(i, j): (i, j) for i in range(n) for j in range(n)}
    G.set_pos(pos)
    nx_G = G.networkx_graph()

    states = convert_to_state(robber_moves, cop_moves)

    slider = widgets.IntSlider(min=0, max=len(states) - 1, step=1, value=0)
    
    out = widgets.interactive_output(
        update,
        {
            'turn': slider,
            'states': widgets.fixed(states),
            'nx_G': widgets.fixed(nx_G),
            'pos': widgets.fixed(pos)
        }
    )

    event = Event(source=slider, watched_events=['keydown'])

    def handle_event(event):
        if event['key'] == 'ArrowRight':
            slider.value = min(slider.max, slider.value + slider.step)
        elif event['key'] == 'ArrowLeft':
            slider.value = max(slider.min, slider.value - slider.step)

    event.on_dom_event(handle_event)

    display(slider, out)


In [30]:
%pip install imageio
import matplotlib.pyplot as plt
import imageio
import os

import os
import imageio

def save_game_gif(G, cop_moves, robber_moves, gif_path="game.gif", dpi=100):
    n = int(math.sqrt(len(list(G.vertices()))))
    pos = {(i, j): (i, j) for i in range(n) for j in range(n)}
    G.set_pos(pos)
    nx_G = G.networkx_graph()
    states = convert_to_state(robber_moves, cop_moves)

    temp_dir = "frames_temp"
    os.makedirs(temp_dir, exist_ok=True)

    frame_paths = []

    for i in range(len(states)):
        fig, ax = plt.subplots(figsize=(12, 12))
        nx.draw(nx_G, pos, ax=ax, node_color='lightgrey', node_size=400, with_labels=False)

        for color, nodes in states[i].items():
            nx.draw_networkx_nodes(nx_G, pos, nodelist=list(nodes), node_color=color, node_size=400, ax=ax)

        ax.set_title(f"Turn {math.floor(i / 2) + 1}")
        ax.set_axis_off()

        frame_path = os.path.join(temp_dir, f"frame_{i:03d}.png")
        fig.savefig(frame_path, dpi=dpi)
        frame_paths.append(frame_path)
        plt.close(fig)

    # Stitch frames into GIF
    images = [imageio.imread(p) for p in frame_paths]
    imageio.mimsave(gif_path, images, fps=3)

    # Clean up
    for p in frame_paths:
        os.remove(p)
    os.rmdir(temp_dir)

    print(f"GIF saved to {gif_path}")


Note: you may need to restart the kernel to use updated packages.


In [None]:
#save_game_gif(G, cop_moves, robber_moves, gif_path="game8.gif")
display_game(G, cop_moves, robber_moves)

IntSlider(value=0, max=40)

Output()

REINFORCEMENT LEARNING MODEL:

In [36]:
%pip install gymnasium
%pip install torch torchvision torchaudio
%pip install torch-geometric

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting aiohttp (from torch-geometric)
  Using cached aiohttp-3.12.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)
Collecting aiohappyeyeballs>=2.5.0 (from aiohttp->torch-geometric)
  Using cached aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->torch-geometric)
  Using cached aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch-geometric)
  Using cached frozenlist-1.6.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->torch-geometric)
  Using cached multidict-6.4.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x

In [42]:
import torch
from torch_geometric.data import Data

def sage_to_pyg_data(G_sage, cop_positions: list[tuple], robber_position: tuple):
    """
    Convert a Sage Graph (with vertices=(i,j) for i,j in [0..n-1]) into a PyG Data:
      - edge_index: shape [2, E]
      - x:         shape [n^2, 2] where x[v,0]=1 if a cop is on v, x[v,1]=1 if robber on v.
    """
    # 1) Extract and sort all vertices (cast Sage Integers → int):
    vertices = list(G_sage.vertices())               # e.g. [(0,0),(0,1),...]
    vertices = [(int(a), int(b)) for (a, b) in vertices] #cast
    vertices.sort(key=lambda v: (v[0], v[1]))         # ensure lex order

    n2 = len(vertices)                               # should be n*n
    node2idx = { vert: idx for idx, vert in enumerate(vertices) }

    # 2) Build edge_index lists:
    row, col = [], []
    for (i, j) in vertices:
        v_idx = node2idx[(i, j)]
        for nbr in G_sage.neighbors((i, j)):
            i2, j2 = int(nbr[0]), int(nbr[1])
            u_idx = node2idx[(i2, j2)]
            row.append(v_idx)
            col.append(u_idx)
    edge_index = torch.tensor([row, col], dtype=torch.long)  # [2, num_edges]

    # 3) Build node‐feature matrix x of shape [n2, 2]:
    x = torch.zeros((n2, 2), dtype=torch.float)
    for (ci, cj) in cop_positions:
        idx_c = node2idx[(ci, cj)]
        x[idx_c, 0] = 1
    (ri, rj) = robber_position
    idx_r = node2idx[(ri, rj)]
    x[idx_r, 1] = 1

    return Data(x=x, edge_index=edge_index)


In [38]:
import gymnasium as gym
from gymnasium.spaces import Dict, MultiDiscrete, Tuple, Discrete
import numpy as np
import random

class CopsAndRobbersEnv(gym.Env):
    """
    Custom Environment for Cops and Robbers on an nxn grid or graph.
    currently from the pov of the cops
    """
    def __init__(self, graph, k, render_mode=None):
        super().__init__()

        # --- Inputs ---
        self.k = k
        self.graph = graph  #SageMath graph
        self.nodes = list(self.graph.vertices())
        self.n = math.sqrt(len(self.nodes))
        self.last_num_avail = 99999
        max_deg = int(self.get_max_deg() + 1)

        #print(type(max_deg))

        # 1) Observation space: dict with
        #    - "cop_pos": flattened k (x,y) pairs
        #    - "robber_pos": single (x,y) pair
        self.observation_space = Dict({
            "cop_pos": MultiDiscrete([self.n, self.n] * self.k),   # [x1,y1, x2,y2, …, xk,yk]
            "robber_pos": MultiDiscrete([self.n, self.n])     # [xr,yr]
        })

        # 2) Action space: a tuple of k Discrete spaces, each of size (max_deg+1)
        #    (we’ll map 0…max_deg-1 to “move to the i‑th neighbor” and max_deg to “stay”)
        self.action_space = Tuple([Discrete(max_deg) for _ in range(self.k)])

        self.cop_pos = [(-1, -1)] * self.k #list of k tuples representing cop positions
        self.robber_pos = (-1, -1)
        self.render_mode = render_mode
        self.itr = 0

    def get_max_deg(self):
        max_degree = 0

        for vertex in self.graph.vertices():
            degree = self.graph.degree(vertex)
            if degree > max_degree:
                max_degree = degree
        return max_degree

    def get_collective_euclidean_dist(self):
        total_dist = 0
        for cop in self.cop_pos:
            total_dist += math.sqrt((cop[0] - self.robber_pos[0])**2 + (cop[1] - self.robber_pos[1])**2)
        return total_dist
    
    '''
    get the current observations, should be robber and cop positions i think
    '''
    def get_obs(self):
        # Flatten your list of k (x,y) tuples into a single list:
        flat_cops = [coord for pos in self.cop_pos for coord in pos]
        # Robber is a single (x,y)
        #print(self.robber_pos)
        rob = list(self.robber_pos)
        return {
            "cop_pos": np.array(flat_cops, dtype=int),
            "robber_pos": np.array(rob, dtype=int)
        }
    
    def get_info(self):
        '''
        return info pertinent to reward
        Q: what do we care about?
        A: i think
            - robber curr LSD length / layer (jacobs idea)
            - #available squares for robber?
        maybe define reward fn first
        '''
        return {
            "distance": 1,
            "SD_length": 1,
            "robber_available:": 1
        }

    def reset(self, seed = None, options = None):       
        """
        Reset the environment to an initial state and return the initial observation.
        random start?
        """
        super().reset(seed=seed)
        
        #random for now, start doesnt matter
        for _ in range(self.k):
            self.cop_pos = random.sample(self.nodes, self.k)

        self.robber_pos = maximize_available(self.cop_pos, [], []) #pick maximizing available start for robber
        self.itr = 0
        self.last_num_avail = available_squares(self.cop_pos, self.robber_pos)
        
        # Return initial state
        observation = self.get_obs()
        info = self.get_info()
        return observation, info

    def step(self, action):
        """
        Take an action (cop move) and return next observation, reward, terminated, truncated, info.
        input is an action in the form of a k-tuple, each from 0 to max_deg
        recall for cop alpha, a value of i means move to the ith neighbors, a value of max_deg means stay
        """
        assert self.action_space.contains(action), "Invalid action!"

        # Check terminal state first! to catch dominating start 
        terminated = not available_squares(self.cop_pos, self.robber_pos)
        truncated = self.itr > max(self.n**2, 100)

        if not terminated and not truncated: #make action for both cops and robbers
            self.itr += 1
    
            for i, act in enumerate(action):
                current_pos = self.cop_pos[i]
                neighbors = self.graph.neighbors(current_pos)
    
                if act < len(neighbors):
                    self.cop_pos[i] = neighbors[act]  #move to ith neighbor
                else:
                    pass  #stay in place
    
            #self.robber_pos = maximize_available(self.cop_pos, [], [], self.robber_pos)
            choices = list(available_squares(self.cop_pos, self.robber_pos))
            self.robber_pos = random.sample(choices, 1)[0] if choices else self.robber_pos
    
        # Define reward
        reward = float()
        if terminated:
            reward = 50000 #reward for catching
        elif truncated:
            reward = -100  #penalize running out of time
        else:
            #small negative reward each step to encourage faster capture
            reward -= 0.1

            #penalize robber mobility — less mobility is better for cops
            avail = available_squares(self.cop_pos, self.robber_pos)
            reward -= len(avail)  # weight this — tune as needed

            # Penalize collective distance from cops to robber
            dist_penalty = self.get_collective_euclidean_dist()
            reward -= 0.5 * dist_penalty  # again, weight can be tuned

            #penalize cops on the same square
            if len(self.cop_pos) != len(set(self.cop_pos)):
                reward -= 100
            
            if avail < self.last_num_avail:
                reward += 100
            
            self.last_num_avail = avail

    
        observation = self.get_obs()
        info = self.get_info()        

        #self.render()
        return observation, reward, terminated, truncated, info

    def render(self):
        """
        visualize the environment
        do it with above defined fns
        """
        if self.render_mode == "human":
            print(f"Cops at {self.cop_pos}, Robber at {self.robber_pos}, iteration {self.itr}")
        
    def get_moves(self):
        return self.cop_pos, self.robber_pos
    

class GNNCopsAndRobbersEnv(CopsAndRobbersEnv):
    """
    Subclass of your original Sage‐based environment.
    Overrides reset() and step() to return a PyG Data instead of a numpy dict.
    """
    def __init__(self, graph, k, render_mode=None):
        super().__init__(graph=graph, k=k, render_mode=render_mode)
        # self.graph is the Sage graph (e.g. QueenGraph([n,n]))
        # self.cop_pos, self.robber_pos are managed by the parent class.

    def get_pyg_data(self):
        return sage_to_pyg_data(self.graph, self.cop_pos, self.robber_pos)

    def reset(self, seed=None, options=None):
        obs_dict, info = super().reset(seed=seed, options=options)
        # Now self.cop_pos, self.robber_pos have been initialized by the parent.
        data = self.get_pyg_data()
        return data, info

    def step(self, action):
        obs_dict, reward, terminated, truncated, info = super().step(action)
        data = self.get_pyg_data()
        return data, reward, terminated, truncated, info

gym.register(
    id="gymnasium_env/CopsAndRobbers-v0",
    entry_point=CopsAndRobbersEnv,
)
gym.pprint_registry()

===== classic_control =====
Acrobot-v1             CartPole-v0            CartPole-v1
MountainCar-v0         MountainCarContinuous-v0 Pendulum-v1
===== phys2d =====
phys2d/CartPole-v0     phys2d/CartPole-v1     phys2d/Pendulum-v0
===== box2d =====
BipedalWalker-v3       BipedalWalkerHardcore-v3 CarRacing-v3
LunarLander-v3         LunarLanderContinuous-v3
===== toy_text =====
Blackjack-v1           CliffWalking-v0        FrozenLake-v1
FrozenLake8x8-v1       Taxi-v3
===== tabular =====
tabular/Blackjack-v0   tabular/CliffWalking-v0
===== mujoco =====
Ant-v2                 Ant-v3                 Ant-v4
Ant-v5                 HalfCheetah-v2         HalfCheetah-v3
HalfCheetah-v4         HalfCheetah-v5         Hopper-v2
Hopper-v3              Hopper-v4              Hopper-v5
Humanoid-v2            Humanoid-v3            Humanoid-v4
Humanoid-v5            HumanoidStandup-v2     HumanoidStandup-v4
HumanoidStandup-v5     InvertedDoublePendulum-v2 InvertedDoublePendulum-v4
InvertedDoublePendulu

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch_geometric.nn import GCNConv

class GNNCopPolicy(nn.Module):
    def __init__(self, in_feats=2, hidden_dim=64, emb_dim=64):
        """
        in_feats:  number of node features (2 = [is_cop, is_robber]).
        hidden_dim: intermediate GCN dimension.
        emb_dim:     final node embedding dimension.
        """
        super().__init__()
        # 1) Two-layer GCN
        self.conv1 = GCNConv(in_feats, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, emb_dim)

        # 2) Policy head: MLP that scores (h_v || h_u) → scalar logit
        self.policy_mlp = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, 1)
        )
        # 3) Value head: map a pooled graph embedding → V(s)
        self.value_head = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, data: Data, cop_positions: list[tuple]):
        """
        data.x         : [n_nodes, in_feats]
        data.edge_index: [2, n_edges]
        cop_positions  : list of k (i,j) pairs

        Returns:
          - action_logits: a list of length k; each is a 1D FloatTensor of size (|neighbors of that cop| + 1).
          - state_value:   a single‐element FloatTensor giving V(s).
        """
        x, edge_index = data.x, data.edge_index  # x: [n², 2]

        # 1) GCN message‐passing → x_emb: [n², emb_dim]
        h = F.relu(self.conv1(x, edge_index))   # [n², hidden_dim]
        h = self.conv2(h, edge_index)           # [n², emb_dim]

        # 2) Critic: pool all node embeddings (simple mean) → graph embedding [emb_dim]
        graph_embed = h.mean(dim=0)              # [emb_dim]
        state_value = self.value_head(graph_embed)  # [1]

        # 3) Actor: for each cop i, gather its node idx and neighbor idxs
        action_logits = []
        num_nodes = data.x.size(0)
        board_n = int(math.sqrt(num_nodes))  # since num_nodes = n²

        for (ci, cj) in cop_positions:
            v_idx = ci * board_n + cj           # integer node‐index
            h_v = h[v_idx]                      # [emb_dim]

            # Find all neighbors of v_idx:
            # In PyG, edge_index[0] == v_idx → edge_index[1] are neighbors.
            mask_src = (edge_index[0] == v_idx)
            nbr_idxs = edge_index[1, mask_src].unique()  # 1D tensor of neighbor indices

            # Build a candidate list: “stay” (h_v) + each neighbor embedding h_u
            candidate_embs = [h_v] + [h[int(u)] for u in nbr_idxs]
            logits = []
            for h_u in candidate_embs:
                pair = torch.cat([h_v, h_u], dim=-1)    # [2*emb_dim]
                score = self.policy_mlp(pair)           # [1]
                logits.append(score)
            logits = torch.cat(logits, dim=0)  # size = (#neighbors+1)
            action_logits.append(logits)

        return action_logits, state_value


In [40]:
import torch
import torch.optim as optim

def train_reinforce(env: GNNCopsAndRobbersEnv,
                    policy: GNNCopPolicy,
                    num_episodes=5000,
                    gamma=0.99,
                    lr=1e-3):
    """
    A bare‐bones REINFORCE trainer.
    env         : GNNCopsAndRobbersEnv
    policy      : GNNCopPolicy
    num_episodes: how many episodes to sample
    gamma       : discount factor
    lr          : learning rate for Adam
    """
    optimizer = optim.Adam(policy.parameters(), lr=lr)

    for episode in range(num_episodes):
        data, info = env.reset()               # data: PyG Data with x & edge_index
        cop_positions = env.cop_pos            # list of k (i,j) pairs
        log_probs = []                         # will store sum of log‐probs per timestep
        rewards = []                           # will store rewards

        done = False
        while not done:
            # 1) Forward pass: get a list of logits per cop + state value
            action_logits_list, _ = policy(data, cop_positions)
            total_logprob = 0.0

            # 2) Sample one action per cop
            actions = []
            for logits in action_logits_list:
                probs = torch.softmax(logits, dim=0)
                m = torch.distributions.Categorical(probs)
                a_i = m.sample()
                total_logprob = total_logprob + m.log_prob(a_i)
                actions.append(int(a_i.item()))

            log_probs.append(total_logprob)

            # 3) Convert “per‐cop index” → actual (i,j) moves
            chosen_moves = []
            num_nodes = data.x.size(0)
            board_n = int(math.sqrt(num_nodes))

            for i, a_i in enumerate(actions):
                (ci, cj) = cop_positions[i]
                v_idx = ci * board_n + cj
                mask_src = (data.edge_index[0] == v_idx)
                nbr_idxs = data.edge_index[1, mask_src].unique()

                if a_i == 0:
                    new_idx = v_idx   # stay
                else:
                    new_idx = int(nbr_idxs[a_i - 1])
                new_i = new_idx // board_n
                new_j = new_idx % board_n
                chosen_moves.append((new_i, new_j))

            # 4) Step environment
            data_next, r, terminated, truncated, info = env.step(chosen_moves)
            rewards.append(r)

            data = data_next
            cop_positions = env.cop_pos   # updated internally by env.step()
            done = terminated or truncated

        # 5) Episode ended: compute discounted returns
        Gt = 0.0
        returns = []
        for r in reversed(rewards):
            Gt = r + gamma * Gt
            returns.insert(0, Gt)
        returns = torch.tensor(returns, dtype=torch.float)

        # 6) Normalize returns (optional but often helpful)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # 7) Compute policy loss = -∑_t [ logπ(a_t|s_t) * R_t ]
        policy_loss = 0.0
        for logp, R in zip(log_probs, returns):
            policy_loss = policy_loss - (logp * R)

        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        # 8) Logging every 100 episodes
        if (episode + 1) % 100 == 0:
            avg_return = sum(rewards) / len(rewards)
            print(f"Episode {episode+1:4d} | AvgReturn (this ep): {avg_return:.2f}")

    return policy


In [43]:
# 6.1) Build an arbitrary Sage graph (e.g. 8×8 queen graph):
from sage.all import graphs
G8 = graphs.QueenGraph([8, 8])  # keeps your existing Sage code

# 6.2) Wrap it in our GNN‐compatible env
env = GNNCopsAndRobbersEnv(graph=G8, k=3, render_mode=None)

# 6.3) Create the GNN policy (actor‐critic)
policy = GNNCopPolicy(in_feats=2, hidden_dim=64, emb_dim=64)
policy = policy.to("cuda")  # or "cpu" if no GPU

# 6.4) Train with REINFORCE
trained_policy = train_reinforce(env=env, 
                                 policy=policy, 
                                 num_episodes=3000, 
                                 gamma=0.99, 
                                 lr=1e-3)

# 6.5) Evaluate 5 episodes
for _ in range(5):
    data, info = env.reset()
    cop_positions = env.cop_pos
    done = False
    total_r = 0.0

    while not done:
        with torch.no_grad():
            logits_list, _ = trained_policy(data.to("cuda"), cop_positions)
            # pick greedy actions for eval
            actions = []
            for logits in logits_list:
                probs = torch.softmax(logits, dim=0)
                a_i = torch.argmax(probs).item()
                actions.append(a_i)

            # Convert to (i,j) moves
            chosen_moves = []
            num_nodes = data.x.size(0)
            board_n = int(math.sqrt(num_nodes))
            for i, a_i in enumerate(actions):
                (ci, cj) = cop_positions[i]
                v_idx = ci * board_n + cj
                mask_src = (data.edge_index[0] == v_idx)
                nbr_idxs = data.edge_index[1, mask_src].unique()

                if a_i == 0:
                    new_idx = v_idx
                else:
                    new_idx = int(nbr_idxs[a_i - 1])
                new_i = new_idx // board_n
                new_j = new_idx % board_n
                chosen_moves.append((new_i, new_j))

        # Step environment
        data, r, terminated, truncated, info = env.step(chosen_moves)
        total_r += r
        cop_positions = env.cop_pos
        done = terminated or truncated

    print("Eval ep reward:", total_r)


TypeError: can't assign a sage.rings.integer.Integer to a torch.FloatTensor

BELOW IS PREVIOUS CODE FOR NON-GNN RL MODEL

In [10]:
env = gym.make("gymnasium_env/CopsAndRobbers-v0", graph=G, k=3, render_mode="human")
observation, info = env.reset()

episode_over = False
total_reward = 0.0
c_states = list()
r_states = list()

while not episode_over:
    action = env.action_space.sample() # agent policy that uses the observation and info
    observation, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    episode_over = terminated or truncated

    cop_state, robber_state = env.unwrapped.get_moves()
    #print(c_states)
    c_states.append(copy.deepcopy(cop_state))
    r_states.append(copy.deepcopy(robber_state))

print(total_reward)
print(r_states)
print(c_states)

4389.56929242651
[(2, 0), (1, 1), (1, 1), (2, 1), (3, 0), (0, 0), (0, 4), (4, 0), (4, 5), (0, 1), (0, 1), (0, 2), (0, 2), (3, 5), (5, 3), (5, 0), (4, 0), (5, 0), (5, 3), (3, 1), (0, 1), (3, 1), (3, 4), (4, 3), (4, 3), (4, 3)]
[[(1, 2), (5, 1), (1, 5)], [(4, 5), (4, 2), (4, 5)], [(4, 2), (5, 3), (3, 4)], [(4, 2), (4, 4), (3, 5)], [(4, 5), (1, 4), (2, 4)], [(4, 1), (5, 4), (3, 5)], [(3, 0), (3, 2), (4, 5)], [(3, 2), (3, 3), (3, 4)], [(1, 0), (3, 2), (2, 4)], [(1, 3), (4, 3), (2, 5)], [(5, 3), (4, 3), (2, 5)], [(2, 3), (3, 4), (2, 5)], [(2, 3), (3, 0), (3, 4)], [(5, 0), (5, 0), (0, 4)], [(4, 0), (0, 5), (4, 0)], [(3, 1), (0, 1), (4, 4)], [(2, 1), (0, 3), (5, 4)], [(0, 3), (4, 3), (3, 4)], [(1, 4), (1, 0), (3, 4)], [(2, 3), (1, 0), (5, 4)], [(1, 4), (3, 2), (5, 0)], [(2, 4), (2, 3), (5, 5)], [(5, 1), (4, 1), (5, 3)], [(5, 1), (3, 0), (2, 0)], [(5, 5), (3, 0), (5, 3)], [(5, 5), (3, 0), (5, 3)]]


  logger.warn(
  logger.warn(


In [11]:
display_game(G, c_states, r_states)

IntSlider(value=0, max=25)

Output()

In [12]:
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym
import time
from gymnasium.spaces import Tuple, MultiDiscrete

#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = ""   # force CPU

class TupleToMultiDiscreteWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        assert isinstance(env.action_space, Tuple)
        self.original_space = env.action_space
        self.action_space = MultiDiscrete([space.n for space in env.action_space])

    def action(self, action):
        return tuple(action)

    def reverse_action(self, action):
        return list(action)

def make_env(G):
    def _init():
        env = gym.make("gymnasium_env/CopsAndRobbers-v0", graph=G, k=3)
        env = TupleToMultiDiscreteWrapper(env)
        return env
    return _init

In [25]:
G = make_graph(16, queen, False)

# Create vectorized environment with, say, 8 parallel copies
env = make_vec_env(make_env(G), n_envs=8)
#env = gym.make("gymnasium_env/CopsAndRobbers-v0", graph=G, k=3, render_mode="human")
#env = TupleToMultiDiscreteWrapper(env)

# Train PPO w/checkpointing
model = PPO("MultiInputPolicy", env, device='cuda', verbose=1)
for i in range(10):
    model.learn(total_timesteps=500_000, reset_num_timesteps=False)
    model.save(f"ppo_cr_16x16_itr{i+1}")

#model.learn(total_timesteps=1_000)
#model.save("ppo_cr_16.zip")

Using cuda device
----------------------------------
| rollout/           |           |
|    ep_len_mean     | 258       |
|    ep_rew_mean     | -2.04e+04 |
| time/              |           |
|    fps             | 21        |
|    iterations      | 1         |
|    time_elapsed    | 776       |
|    total_timesteps | 16384     |
----------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 258          |
|    ep_rew_mean          | -2.04e+04    |
| time/                   |              |
|    fps                  | 20           |
|    iterations           | 2            |
|    time_elapsed         | 1563         |
|    total_timesteps      | 32768        |
| train/                  |              |
|    approx_kl            | 9.934802e-05 |
|    clip_fraction        | 0            |
|    clip_range           | 0.2          |
|    entropy_loss         | -12.3        |
|    explained_variance   | -3

KeyboardInterrupt: 

In [29]:
env = make_vec_env(make_env(G), n_envs=8)
model.learn(total_timesteps=2_000_000, reset_num_timesteps=False)

model.save("ppo_cr16.zip")

---------------------------------
| rollout/           |          |
|    ep_len_mean     | 258      |
|    ep_rew_mean     | -2e+04   |
| time/              |          |
|    fps             | 22       |
|    iterations      | 1        |
|    time_elapsed    | 744      |
|    total_timesteps | 1654784  |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 258          |
|    ep_rew_mean          | -1.99e+04    |
| time/                   |              |
|    fps                  | 21           |
|    iterations           | 2            |
|    time_elapsed         | 1515         |
|    total_timesteps      | 1671168      |
| train/                  |              |
|    approx_kl            | 0.0064652245 |
|    clip_fraction        | 0.0462       |
|    clip_range           | 0.2          |
|    entropy_loss         | -11.3        |
|    explained_variance   | 0            |
|    learning_r

KeyboardInterrupt: 

In [30]:
model.save("ppo_cr16.zip")

In [18]:
%pip install stable-baselines3[extra]

Note: you may need to restart the kernel to use updated packages.


In [23]:
from stable_baselines3.common.evaluation import evaluate_policy
#model = PPO.load('/mnt/c/Users/danie/reu2025/reu2025/ppo_cr_v1', device='cpu')
env = gym.make("gymnasium_env/CopsAndRobbers-v0", graph=G, k=3)
env = TupleToMultiDiscreteWrapper(env)
obs, _ = env.reset()
done = False

c_states = list()
r_states = list()

G = make_graph(16, queen, False)

#env = make_vec_env(make_env(G), n_envs=8)
obs, _ = env.reset()
#cop_state, robber_state = env.envs[0].unwrapped.get_moves()
cop_state, robber_state = env.unwrapped.get_moves()
c_states.append(copy.deepcopy(cop_state))
r_states.append(copy.deepcopy(robber_state))

done = False
total_reward = 0
itr = 0

while not done:
    itr += 1
    action, _ = model.predict(obs, deterministic=True)
    '''
    obs, rewards, dones, infos = env.step(action)
    total_reward += rewards[0]      # since vectorized you get a 1‑element batch
    done = dones[0]
    print(dones)
    '''
    obs, rewards, terminated, truncated, info = env.step(action)
    total_reward += rewards
    done = terminated or truncated

    if done: #lmao
        break

    cop_state, robber_state = env.unwrapped.get_moves()
    c_states.append(copy.deepcopy(cop_state))
    r_states.append(r_states[-1])

    c_states.append(c_states[-1])
    r_states.append(copy.deepcopy(robber_state))

print(f"total reward: {total_reward}, number of rounds: {itr}")

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=1)
print(f"Mean reward: {mean_reward} ± {std_reward}")

total reward: -20927.9166402739, number of rounds: 258
Mean reward: -20711.09934387207 ± 180.3704003720745


In [24]:
display_game(G, c_states, r_states)

IntSlider(value=0, max=514)

Output()