In [12]:
import numpy as np
import math
import heapq
import ray
from gym import Env
from gym.spaces import Discrete, Box, MultiDiscrete
from ray import tune, rllib, air
from ray.rllib.algorithms.ppo import PPOConfig
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from ray.tune.registry import register_env
from ray.rllib.utils.pre_checks.env import check_env

In [None]:
# define the custom routing grid environemnt
# observation_space

# action_space
    # 0 - > move up
    # 1 - > move down
    # 2 - > move right
    # 3 - > move left


class RtGridEnv(Env):
    def __init__(self, length, width, nets, macros):
        """
        Args:
            length (int): length of the canvas
            width (int): width of the canvas
            nets (list): a list of nets to be routed
            macros (list): a list of macros that has been placed on the canvas by placement
        """
        self.length = length
        self.width = width
        self.nets = nets
        self.marcos = macros

        self.agent_position = np.array([0,0])
        self.goal_position = np.array([self.length - 1, self.width -1])

        self.action_space = Discrete(4)
        self.observation_space = MultiDiscrete([self.length, self.width])

        self.path_x = [0]
        self.path_y = [0]

        # Define the position and size of the rectangle obstacle
        self.rect_x = 1
        self.rect_y = 0
        self.rect_length = 2
        self.rect_width = 2

        self.obst_x_range = np.array(range(self.rect_x, self.rect_x+self.rect_length+1))
        self.obst_y_range = np.array(range(self.rect_y, self.rect_y+self.rect_width+1))

    def prim_mst(self, pins):
        """
        Compute the Minimum Spanning Tree (MST) using Prim's algorithm.

        Args:
            pins (list): List of (x, y) coordinates representing the pin locations.

        Returns:
            dict: a dictionary containing the vertices of all the edges in the MST

        Note:
            - The pins list should contain at least two points.
        """

        def euclidean_distance(p1, p2):
            """
            Compute the Euclidean distance between two points.

            Args:
                p1 (tuple): First point (x, y) coordinates.
                p2 (tuple): Second point (x, y) coordinates.

            Returns:
                float: Euclidean distance between the two points.
            """
            x1, y1 = p1
            x2, y2 = p2
            return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
        
        distances = {}
        for i in range(len(pins)):
            for j in range(i+1, len(pins)):
                p1 = pins[i]
                p2 = pins[j]
                distances[(i, j)] = euclidean_distance(p1, p2)
        
        # Initialize
        num_pins = len(pins)
        visited = [False] * num_pins
        mst_u = []
        mst_v = []
        start_vertex = 0
        visited[start_vertex] = True
        
        # Create a priority queue
        pq = []
        
        # Mark the initial vertex as visited
        for i in range(num_pins):
            if i != start_vertex:
                heapq.heappush(pq, (distances[(start_vertex, i)], start_vertex, i))
        
        # Print initial priority queue
        print("Initial priority queue:")
        for item in pq:
            print(item)
        print()
        
        # Update the priority queue and perform Prim's algorithm
        while pq:
            if (len(mst_u) == len(pins) -1): # for n pins, the MST should at most have n-1 edges
                break

            weight, u, v = heapq.heappop(pq)
            print("Selected edge:")
            print(weight, u, v)
            print()
            
            if visited[v]:
                print(f"Skipping edge: {weight} - {u} - {v}")
                continue
            
            # Prim's algorithm iteration
            visited[v] = True
            print("u:",u)
            print("v:",v)
            mst_u.append(u)
            mst_v.append(v)
            
            for i in range(num_pins):
                if not visited[i]:
                    heapq.heappush(pq, (distances[(v, i)], v, i))
        
        mst = {'u':mst_u,'v':mst_v}
        return mst
    
    def step(self, action):
        def check_obstacle(x,y):
            """
            Check whether a point (x,y) is within the obstacle area, returns True if the point is within the obstacle, Flase if not
            """
            is_present = np.any(self.obst_x_range == x) and np.any(self.obst_y_range == y)
            return is_present
        
        if action == 0:  # Up
            if ((self.agent_position[1] + 1) <= self.width - 1) and not check_obstacle(self.agent_position[0], self.agent_position[1] + 1): # if within the bound, then accept the move
                self.agent_position[1] = self.agent_position[1] + 1
        elif action == 1:  # Down
            if ((self.agent_position[1] - 1) >= 0) and not check_obstacle(self.agent_position[0], self.agent_position[1] - 1):
                self.agent_position[1] = self.agent_position[1] - 1
        elif action == 2:  # Right
            if ((self.agent_position[0] + 1) <= self.length - 1) and not check_obstacle(self.agent_position[0] + 1, self.agent_position[1]):
                self.agent_position[0] = self.agent_position[0] + 1
        elif action == 3:  # Left
            if ((self.agent_position[0] - 1) >= 0) and not check_obstacle(self.agent_position[0] - 1, self.agent_position[1]):
                self.agent_position[0] = self.agent_position[0] - 1

        done = np.array_equal(self.agent_position, self.goal_position)
        reward = 10 if done else -1

        # Update the agent's path
        self.path_x.append(self.agent_position[0])  
        self.path_y.append(self.agent_position[1])  

        return self.agent_position, reward, done, {}
    
    def render(self):
        """
        Plot the agent's path
        """
        # Create a new figure and axis
        fig, ax = plt.subplots()

        # Create a Rectangle patch
        rectangle = patches.Rectangle((self.rect_x, self.rect_y), self.rect_length, self.rect_width, linewidth=1, edgecolor='g', facecolor='g')

        # Add the rectangle to the axis
        ax.add_patch(rectangle)

        # Plot the agent path
        plt.plot(self.path_x, self.path_y, 'r')

        # Set the labels and title
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.title('Agent Path')

        # Set the x and y axis limits
        #plt.xlim(0, self.length-1)
        #plt.ylim(0, self.width-1)

        # Set the x and y axis ticks
        plt.xticks(range(0, self.length))
        plt.yticks(range(0, self.width))

        # Set the grid
        plt.grid(color='blue', linestyle='--', linewidth=0.5)

        # Show the plot
        plt.show()

    def reset(self):
        self.agent_position = np.array([0,0])
        self.path_x = [0]
        self.path_y = [0]
        return self.agent_position