In [10]:
from collections import namedtuple
import heapq
from typing import Optional

lines = open('./data.txt', 'r').read().splitlines()

Point = namedtuple('Point', 'x y')

moves = [
    lambda p: Point(p.x, p.y - 1),
    lambda p: Point(p.x - 1, p.y),
    lambda p: Point(p.x, p.y + 1),
    lambda p: Point(p.x + 1, p.y)
]

class Node:
    def __init__(self, point: Point, height: int):
        self.point = point
        self.height = height
    
    def __str__(self):
        return f'Node {self.point} h: {self.height}'
    
    def __lt__(self, other): return True
    def __le__(self, other): return True
    def __gt__(self, other): return True
    def __ge__(self, other): return True
    def __eq__(self, other): return True
    def __ne__(self, other): return True
    
    def calc_moves(self, node_map):
        global moves
        nodes = [get_node(node_map, move(self.point)) for move in moves]
        #inbounds check
        nodes = [node for node in nodes if node is not None]
        # can only move to neighbors at the same height or one more
        self.neighbors = [node for node in nodes if node.height <= self.height + 1]
    
    def set_distance(self, end):
        self.distance = abs(self.point.x - end.point.x) + abs(self.point.y - end.point.y)

class PriorityQueue:
    def __init__(self):
        self.elements: list[tuple[float, Node]] = []
    def empty(self) -> bool:
        return not self.elements
    def put(self, item: Node, priority: float):
        heapq.heappush(self.elements, (priority, item))
    def get(self) -> Node:
        return heapq.heappop(self.elements)[1]

NodeMap = list[list[Node]]
node_map: NodeMap = []
node_line: list[Node] = []

def get_node(node_map: NodeMap, point: Point):
    if (point.x < 0 or point.y < 0 or point.x >= len(node_map[0]) or point.y >= len(node_map)):
        return None
    else:
        return node_map[point.y][point.x]

cur_loc: Node = None
initial: Node = None
dest: Node = None

base = ord('a')
for y in range(len(lines)):
    line = lines[y]
    for x in range(len(line)):
        spot = line[x]
        isStart = False
        isEnd = False
        if spot == 'S':
            isStart = True
            spot = 'a'
        elif spot == 'E':
            isEnd = True
            spot = 'z'
        height = ord(spot) - base
        node = Node(Point(x, y), height)
        node_line.append(node)
        if (isStart):
            initial = node
        elif (isEnd):
            dest = node
    node_map.append(node_line)
    node_line = []

for line in node_map:
    for node in line:
        node.calc_moves(node_map)
        node.set_distance(dest)

print (f'initial {initial}')
print (f'dest {dest}')

CameFrom = dict[Point, Optional[Point]]
CostSoFar = dict[Point, float]

def do_astar(initial: Node):
    frontier = PriorityQueue()
    frontier.put(initial, 0)

    came_from: CameFrom = {}
    cost_so_far: CostSoFar = {}
    came_from[initial.point] = None
    cost_so_far[initial.point] = 0

    while not frontier.empty():
        current = frontier.get()

        if current is dest:
            break
        
        for next in current.neighbors:
            next_point = next.point
            new_cost = cost_so_far[current.point] + 1
            if next_point not in cost_so_far or new_cost < cost_so_far[next_point]:
                cost_so_far[next_point] = new_cost
                priority = new_cost + next.distance
                frontier.put(next, priority)
                came_from[next_point] = current.point
        
    return came_from, cost_so_far

def reconstruct_path(came_from: CameFrom, dest: Node):
    if dest.point not in came_from:
        return []
    path = []
    x = came_from[dest.point]
    while x is not None:
        node = get_node(node_map, x)
        path.insert(0, node)
        x = came_from[x]
    return path


initial Node Point(x=0, y=20) h: 0
dest Node Point(x=135, y=20) h: 25


In [7]:
path = reconstruct_path(do_astar(initial)[0], dest)

len(path)

447

In [12]:
start_nodes = [node for node_row in node_map for node in node_row if node.height == 0]

path_lengths = []

for node in start_nodes:
    path = reconstruct_path(do_astar(node)[0], dest)
    path_lengths.append(len(path))

min([length for length in path_lengths if length > 0])

446