In [None]:
import heapq
import numpy as np

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

In [None]:
class Intersection:
    def __init__(self, name, x, y,):
        self.name = name
        self.x = x
        self.y = y
        
        self.g = np.inf         # Path cost
        self.h = 0              # Heuristic cost
        self.f = 0              # Total cost 
        
        self.parent = None      # Parent node
        self.connections = {}   # Neighbouring stations

    def add_connection(self, intersection, cost):
        self.connections[intersection] = cost

    def __lt__(self, other):
        return self.f < other.f

In [None]:
def generate_random_intersections(height, width, connection_chance, max_cost=50, seed=0):
    # Create the intersection objects
    np.random.seed(seed)
    intersections = {}
    name = 0
    for i in range(height):
        for j in range(width):
            intersections[(i, j)] = Intersection(name=name, x=i, y=j)
            name += 1

    # Add connections between intersections
    for i in range(height):
        for j in range(width):
            has_connection = False
            while not has_connection:
                if i > 0 and np.random.rand() < connection_chance:
                    intersections[(i, j)].add_connection(intersections[(i-1, j)], cost=np.random.randint(1, max_cost))
                    has_connection = True

                if i < height - 1 and np.random.rand() < connection_chance:
                    intersections[(i, j)].add_connection(intersections[(i+1, j)], cost=np.random.randint(1, max_cost/10))
                    has_connection = True

                if j > 0 and np.random.rand() < connection_chance:
                    intersections[(i, j)].add_connection(intersections[(i, j-1)], cost=np.random.randint(1, max_cost))
                    has_connection = True

                if j < width - 1 and np.random.rand() < connection_chance:
                    intersections[(i, j)].add_connection(intersections[(i, j+1)], cost=np.random.randint(1, max_cost/10))
                    has_connection = True

    return intersections

In [None]:
def get_heuristic(a, b):
    """ 
    :param a: node object representing point 1
    :param b: node object representing point 2

    :return: Manhattan distance between two points
    """
    return abs(a.x - b.x) + abs(a.y - b.y)

In [None]:
def astar(start, end):
    """
    :param start: node object representing start point
    :param end: node object representing end point

    :return: the optimal path found
    """
    open_list = []      # Potential nodes to explore
    closed_list = []    # Nodes that have already been explored

    heapq.heappush(open_list, start)
    start.g = 0         # Starting station has 0 distance to itself

    # Loop until you reach the end
    while open_list:
        # Explore the next frontier station (lowest f value)
        curr = heapq.heappop(open_list)
        closed_list.append(curr)

        # If the end is found, work backwards to build the path
        if curr == end:
            path = []
            while curr:
                path.append(curr.name)
                curr = curr.parent
            return path[::-1]

        # Else, update all neighbours
        for next in curr.connections.keys():
            # Don't explore any stations in the closed list
            if next in closed_list:
                continue

            # Calculate the time to travel to the next station + wait time for switching lines
            # If a shorter path to the next station is found, update its values
            new_distance = curr.g + curr.connections[next]
            if new_distance < next.g:
                next.parent = curr
                next.g = new_distance
                next.h = get_heuristic(next, end)
                next.f = next.g + next.h

                # If the next station isn't in the open list (frontier), add it
                if next not in open_list:
                    heapq.heappush(open_list, next)

    # Path not found
    return None


In [None]:
def get_nodes_and_edges(intersections):
    """Draws map of stations based on json file
    
    :return: nodes, edges of graph
    """
    # Turn intersections into nodes
    nodes = {}
    for location, intersection in intersections.items():
        nodes[intersection.name] = location
    
    # Turn roads into edges
    edges = []
    for location, intersection in intersections.items():
        for next_intersection, weight in intersection.connections.items():
            edges.append((intersection.name, next_intersection.name, {'weight': weight}))

    return nodes, edges

def get_graph(intersections, plot=True, directed=False):
    """Draws map of stations based on json file
    :param plot: boolean value to show plot or not
    
    :return: nx graph object
    """
    plt.figure(figsize=(9, 9))
    
    G = nx.DiGraph() if directed else nx.Graph()

    nodes, edges = get_nodes_and_edges(intersections)

    G.add_nodes_from([intersection.name for intersection in intersections.values()])
    G.add_edges_from(edges)
    
    # Visualize the graph with node positions
    nx.draw(G, nodes, with_labels=True, node_color='lightblue', node_size=500, font_weight='bold')
    nx.draw_networkx_edges(G, nodes, width=1.0, alpha=0.5)
    nx.draw_networkx_edge_labels(G, nodes, edge_labels={(u, v): d['weight'] for u, v, d in G.edges(data=True)})

    if plot:
        plt.show()

    return G

def draw_graph_path(G, intersections, shortest_path):
    """
    :param shortest_path: list of numbers representing the node indices of path
    """
    
    # Visualize the graph with node positions and edge labels
    nodes, _ = get_nodes_and_edges(intersections)
    
    # Convert the colors to their RGBA values
    start_rgba = np.array(mcolors.to_rgba('greenyellow' ))
    target_rgba = np.array(mcolors.to_rgba('salmon'))
    
    # Compute the color for each node based on its position between start and target in the shortest path
    node_colors = [
        mcolors.to_hex(start_rgba + (shortest_path.index(node) / (len(shortest_path) - 1)) * (target_rgba - start_rgba))
        if node in shortest_path else 'lightgray' for node in G.nodes()
    ]

    edge_colors = ['red' if edge in zip(shortest_path, shortest_path[1:]) else 'none' for edge in G.edges()]

    # draw graph
    nx.draw(
        G,
        nodes,
        with_labels=True,
        node_color=node_colors,
        node_size=500,
        font_weight='bold',
        edge_color=edge_colors,
        edgecolors=['black' if node in shortest_path else 'none' for node in G.nodes()],
        linewidths=2.0,
        width=2.0,
        arrows=True
    )
    
    legend_labels = {'Start Node': 'greenyellow', 'Target Node': 'salmon'}
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, markeredgewidth=2, markeredgecolor='black', label=label)
                      for label, color in legend_labels.items()]
    
    plt.legend(handles=legend_handles)
    plt.show()

In [None]:
np.random.seed(20827383)
intersections = generate_random_intersections(height=8, width=10, max_cost=50, connection_chance=0.7, seed=10)

In [None]:
start = intersections[(0, 8)]
stop = intersections[(6, 0)]
path_astar = astar(start, stop)
print(path_astar)

In [None]:
# Compute shortest path using Dijkstra's algorithm
G = get_graph(intersections, plot=False, directed=True)

path_cost_dijkstra, path_dijkstra = nx.single_source_dijkstra(G, source=8, target=60)
path_cost_astar = sum(G[u][v]['weight'] for u, v in zip(path_astar[:-1], path_astar[1:]))

print(f"Shortest path using Dijkstra: {path_dijkstra}, cost: {path_cost_dijkstra}")
print(f"Shortest path using A*: {path_astar}, cost: {path_cost_astar}")

draw_graph_path(G, intersections, path_dijkstra)

In [None]:
path_a_star = astar(intersections[(0, 0)], intersections[(5, 5)])
print("Shortest path using A*:", path_a_star)
G = get_graph(intersections, plot=False, directed=True)
draw_graph_path(G, intersections, path_a_star)