In [None]:
import json
import heapq
import numpy as np

In [None]:
class Station:
    def __init__(self, name, x, y):
        self.name = name
        self.x = x
        self.y = y
        
        self.g = np.inf         # Path cost
        self.h = 0              # Heuristic cost
        self.f = 0              # Total cost 
        
        self.parent = None      # Parent node
        self.connections = {}   # Neighbouring stations
        self.prev_line = None   # Previous line taken to reach this station

    def add_connection(self, station, line):
        self.connections[station] = line

    def __lt__(self, other):
        return self.f < other.f
    
    def __str__(self):
        return self.name
    
class Line:
    def __init__(self, name, travel_time, wait_time):
        self.name = name
        self.travel_time = travel_time
        self.wait_time = wait_time

    def __str__(self):
        return self.name

In [None]:
def build_subway_graph():
    subway = json.load(open('subway.json', 'r'))

    # Turn stations into objects and store them by their names
    stations = {}
    for stop_name, stop_location in subway["locations"].items():
        station = Station(name = stop_name, x=stop_location[0], y=stop_location[1])
        stations[stop_name] = station

    # Turn lines into objects
    lines = {}
    for line_name, line_attr in subway["lines"].items():
        line = Line(name=line_name, travel_time=line_attr["travel_time"], wait_time=line_attr["wait_time"])
        lines[line_name] = line

    # Add connections between stations
    for line_name, line_attr in subway["lines"].items():
        stop_names = line_attr["stops"]
        for i in range(len(stop_names) - 1):
            stations[stop_names[i]].add_connection(station=stations[stop_names[i+1]], line=lines[line_name])
            stations[stop_names[i+1]].add_connection(station=stations[stop_names[i]], line=lines[line_name])

    return stations, lines

In [None]:
def get_heuristic(a, b):
    """ 
    :param a: node object representing point 1
    :param b: node object representing point 2

    :return: Euclidean distance between two points
    """
    return np.sqrt((a.x - b.x)**2 + (a.y - b.y)**2)

In [None]:
def astar(start, end):
    """
    :param start: node object representing start point
    :param end: node object representing end point

    :return: the optimal path found
    """
    open_list = []      # Potential nodes to explore
    closed_list = []    # Nodes that have already been explored

    heapq.heappush(open_list, start)
    start.g = 0         # Starting station has 0 distance to itself

    # Loop until you reach the end
    while open_list:
        # Explore the next frontier station (lowest f value)
        curr = heapq.heappop(open_list)
        closed_list.append(curr)

        # If the end is found, work backwards to build the path
        if curr == end:
            path = []
            while curr:
                path.append(curr)
                curr = curr.parent
            return path[::-1]

        # Else, update all neighbours
        for next in curr.connections.keys():
            # Don't explore any stations in the closed list
            if next in closed_list:
                continue

            # Calculate the time to travel to the next station + wait time for switching lines
            travel_cost = curr.connections[next].travel_time
            if curr.prev_line is None or curr.prev_line.name != curr.connections[next].name:
                travel_cost += curr.connections[next].wait_time
                
            # If a shorter path to the next station is found, update its values
            new_distance = curr.g + travel_cost
            if new_distance < next.g:
                next.parent = curr
                next.g = new_distance
                next.h = get_heuristic(next, end)
                next.f = next.g + next.h
                next.prev_line = curr.connections[next]

                # If the next station isn't in the open list (frontier), add it
                if next not in open_list:
                    heapq.heappush(open_list, next)

    # Path not found
    return None

stations, lines = build_subway_graph()
path = astar(stations["stop 1"], stations["stop 3"])

# In this example line 2 has a long wait time, so the optimal path is to go all the way around
[station.name for station in path]