In [None]:
%matplotlib widget
from collections import namedtuple
import heapq
from typing import Optional
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import numpy as np
plt.rcParams['animation.ffmpeg_path'] = '/Users/chris/Downloads/ffmpeg'
c1 = mpl.colormaps['viridis']
c2 = mpl.colormaps['cool']

lines = open('./data.txt', 'r').read().splitlines()

Point = namedtuple('Point', 'x y')

moves = [
    lambda p: Point(p.x, p.y - 1),
    lambda p: Point(p.x - 1, p.y),
    lambda p: Point(p.x, p.y + 1),
    lambda p: Point(p.x + 1, p.y)
]

class Node:
    def __init__(self, point: Point, height: int):
        self.point = point
        self.height = height
    
    def __str__(self):
        return f'Node {self.point} h: {self.height}'
    
    def __lt__(self, other): return True
    def __le__(self, other): return True
    def __gt__(self, other): return True
    def __ge__(self, other): return True
    def __eq__(self, other): return True
    def __ne__(self, other): return True
    
    def calc_moves(self, node_map):
        global moves
        nodes = [get_node(node_map, move(self.point)) for move in moves]
        #inbounds check
        nodes = [node for node in nodes if node is not None]
        # can only move to neighbors at the same height or one more
        self.neighbors = [node for node in nodes if node.height <= self.height + 1]
    
    def set_distance(self, end):
        self.distance = abs(self.point.x - end.point.x) + abs(self.point.y - end.point.y)

class PriorityQueue:
    def __init__(self, ):
        self.elements: list[tuple[float, Node]] = []
    def empty(self) -> bool:
        return not self.elements
    def put(self, item: Node, priority: float):
        heapq.heappush(self.elements, (priority, item))
    def get(self) -> Node:
        return heapq.heappop(self.elements)[1]

NodeMap = list[list[Node]]
node_map: NodeMap = []
node_line: list[Node] = []

def get_node(node_map: NodeMap, point: Point):
    if (point.x < 0 or point.y < 0 or point.x >= len(node_map[0]) or point.y >= len(node_map)):
        return None
    else:
        return node_map[point.y][point.x]

cur_loc: Node = None
initial: Node = None
dest: Node = None

base = ord('a')
for y in range(len(lines)):
    line = lines[y]
    for x in range(len(line)):
        spot = line[x]
        isStart = False
        isEnd = False
        if spot == 'S':
            isStart = True
            spot = 'a'
        elif spot == 'E':
            isEnd = True
            spot = 'z'
        height = ord(spot) - base
        node = Node(Point(x, y), height)
        node_line.append(node)
        if (isStart):
            initial = node
        elif (isEnd):
            dest = node
    node_map.append(node_line)
    node_line = []

for line in node_map:
    for node in line:
        node.calc_moves(node_map)
        node.set_distance(dest)

print (f'initial {initial}')
print (f'dest {dest}')

CameFrom = dict[Point, Optional[Point]]
CostSoFar = dict[Point, float]

class Step:
    def __init__(self, end: Node, path: list[Node], costs: CostSoFar):
        self.end = end
        self.path = path.copy()
        self.costs = costs.copy()

def do_astar(initial: Node, steps: list[Step] = None):
    frontier = PriorityQueue()
    frontier.put(initial, 0)

    came_from: CameFrom = {}
    cost_so_far: CostSoFar = {}
    came_from[initial.point] = None
    cost_so_far[initial.point] = 0
    max_cost = 0

    while not frontier.empty():
        current = frontier.get()
        if steps is not None: steps.append(Step(current, reconstruct_path(came_from, current), cost_so_far))

        if current is dest:
            break
        
        for next in current.neighbors:
            next_point = next.point
            new_cost = cost_so_far[current.point] + 1
            max_cost = max(max_cost, new_cost)
            if next_point not in cost_so_far or new_cost < cost_so_far[next_point]:
                cost_so_far[next_point] = new_cost
                priority = new_cost + next.distance
                frontier.put(next, priority)
                came_from[next_point] = current.point
        
    return came_from, cost_so_far, max_cost

def reconstruct_path(came_from: CameFrom, dest: Node):
    if dest.point not in came_from:
        return []
    path = []
    x = came_from[dest.point]
    while x is not None:
        node = get_node(node_map, x)
        path.insert(0, node)
        x = came_from[x]
    return path

terrain_matrix = np.matrix([[node.height for node in row] for row in node_map])

In [None]:
steps: list[Step] = []
cf, csf, mc = do_astar(initial, steps)
path = reconstruct_path(cf, dest)

len(path)

In [18]:
anim_fig, anim_ax = plt.subplots(1, 1, figsize=(6 * len(node_map[0]) / len(node_map), 6))

def get_blank():
    return [[-1] * len(node_map[0]) for i in range(len(node_map))]

the_path = None
def init():
    global the_path
    anim_ax.pcolormesh(terrain_matrix, cmap=c1, rasterized=True, vmin=0, vmax=25)
    the_path, = anim_ax.plot([], [], color='red')

last_color = None
def run(data: Step):
    global last_color, the_path
    if last_color: last_color.remove()
    b = get_blank()
    for p, c in data.costs.items():
        b[p.y][p.x] = c
    b = np.matrix(b)
    b = np.ma.masked_array(b, b < 0)
    last_color = anim_ax.pcolormesh(b, cmap=c2, rasterized=True, vmin=0, vmax=mc)
    the_path.set_data([n.point.x + 0.5 for n in data.path], [n.point.y + 0.5 for n in data.path])

# init()
# run(steps[len(steps) // 2])
# run(steps[len(steps) // 2 + 1])

ani = anim.FuncAnimation(anim_fig, run, steps, interval=100, init_func=init, cache_frame_data=False)
video = ani.save('./anim.mp4', fps=60)
plt.close()

  last_color = anim_ax.pcolormesh(b, cmap=c2, rasterized=True, vmin=0, vmax=mc)
  anim_ax.pcolormesh(terrain_matrix, cmap=c1, rasterized=True, vmin=0, vmax=25)
  last_color = anim_ax.pcolormesh(b, cmap=c2, rasterized=True, vmin=0, vmax=mc)


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6 * len(node_map[0]) / len(node_map), 6))
ax.pcolormesh(terrain_matrix, cmap=c1, rasterized=True, vmin=0, vmax=25)

ax.plot([n.point.x + 0.5 for n in path], [n.point.y + 0.5 for n in path], color='red')


In [None]:
start_nodes = [node for node_row in node_map for node in node_row if node.height == 0]

paths = []
path_lengths = []

for node in start_nodes:
    test_path = reconstruct_path(do_astar(node)[0], dest)
    length = len(test_path)
    if (length) > 0:
        paths.append(test_path)
        path_lengths.append(length)

shortest = min([length for length in path_lengths if length > 0])
shortest_index = path_lengths.index(shortest)


In [None]:
p2_fig, p2_ax = plt.subplots(1, 1, figsize=(6 * len(node_map[0]) / len(node_map), 6))
p2_ax.pcolormesh(terrain_matrix, cmap=c1, rasterized=True, vmin=0, vmax=25)

for i in range(len(paths)):
    if i == shortest_index: continue
    p2_ax.plot([n.point.x + 0.5 for n in paths[i]], [n.point.y + 0.5 for n in paths[i]], color='gray')

p2_ax.plot([n.point.x + 0.5 for n in paths[shortest_index]], [n.point.y + 0.5 for n in paths[shortest_index]], color='red')