# A\* search

- <https://adventofcode.com/2021/day/15>

Part 1 is a straight-forward A\* search problem.


In [1]:
from __future__ import annotations

from dataclasses import dataclass, replace
from heapq import heappop, heappush
from itertools import count
from typing import Iterator, TypeAlias

Pos: TypeAlias = tuple[int, int]


@dataclass(frozen=True)
class Node:
    """Node on the A* search graph"""

    x: int = 0
    y: int = 0
    risk: int = 0

    @property
    def pos(self) -> Pos:
        return self.x, self.y

    def cost(self, target: Pos) -> int:
        """Calculate the cost for this node, f(n) = g(n) + h(n)

        The cost of this node is the total risk encounterd (g) plus
        estimated cost to get to end goal (h).

        Here we use the manhattan distance to the target as
        the estimated cost.

        """
        return self.risk + abs(target[0] - self.x) + abs(target[1] - self.y)

    def transitions(self, cavern: Cavern) -> Iterator[Node]:
        positions = (
            (self.x + dx, self.y + dy) for dx, dy in ((-1, 0), (0, -1), (0, 1), (1, 0))
        )
        yield from (
            replace(self, x=x, y=y, risk=self.risk + cavern[x, y])
            for x, y in positions
            if (x, y) in cavern
        )


class Cavern:
    def __init__(self, map: list[str]) -> None:
        self._height = len(map)
        self._width = len(map[0])
        self._matrix = [[int(c) for c in row] for row in map]
        self.target = (self._width - 1, self._height - 1)

    def __getitem__(self, pos: Pos) -> int:
        x, y = pos
        return self._matrix[y][x]

    def __contains__(self, pos: Pos) -> bool:
        x, y = pos
        return 0 <= x < self._width and 0 <= y < self._height

    def __str__(self) -> str:
        return "\n".join(["".join([str(r) for r in row]) for row in self._matrix])

    def lowest_total_risk(self) -> int:
        start = Node()
        open = {start}
        unique = count()  # tie breaker when costs are equal
        pqueue = [(start.cost(self.target), next(unique), start)]
        closed = set()
        risks = {start.pos: start.risk}  # pos -> risk. Ignore nodes that took more risk
        while open:
            node = heappop(pqueue)[-1]

            if node.pos == self.target:
                return node.risk

            open.remove(node)
            closed.add(node)
            for new in node.transitions(self):
                if new in closed or new in open:
                    continue
                if risks.get(new.pos, float("inf")) < new.risk:
                    continue
                risks[new.pos] = new.risk
                open.add(new)
                heappush(pqueue, (new.cost(self.target), next(unique), new))


test_cavern_map = """\
1163751742
1381373672
2136511328
3694931569
7463417111
1319128137
1359912421
3125421639
1293138521
2311944581
""".splitlines()
test_cavern = Cavern(test_cavern_map)
assert test_cavern.lowest_total_risk() == 40

In [2]:
import aocd

cavern_map = aocd.get_data(day=15, year=2021).splitlines()
cavern = Cavern(cavern_map)
print("Part 1:", cavern.lowest_total_risk())

Part 1: 503


# Part 2: scale up the map

Part two tests if your A\* search can handle a larger map.


In [3]:
from itertools import product


class LargeCavern(Cavern):
    def __init__(self, map: list[str]) -> None:
        super().__init__(map)
        source = self._matrix
        self._matrix = [
            [
                (source[y][x] + dx + dy - 1) % 9 + 1
                for dx, x in product(range(5), range(self._width))
            ]
            for dy, y in product(range(5), range(self._height))
        ]
        self._width *= 5
        self._height *= 5
        self.target = (self._width - 1, self._height - 1)


test_large_cavern = LargeCavern(test_cavern_map)
assert test_large_cavern.lowest_total_risk() == 315

In [4]:
larger_cavern = LargeCavern(cavern_map)
print("Part 2:", larger_cavern.lowest_total_risk())

Part 2: 2853
