# Day 18 - Path finding with a twist

Today's puzzle is another maze, but with locked doors and keys.

Initially I treated this as a very basic [pathfinding problem](https://en.wikipedia.org/wiki/Pathfinding): doors are simply walls until you have collected the right key, and keys are part of the state of a path in the search space. A state in the search contains the location $(x, y)$, what keys we have collected, $k_1, k_2, \cdot, k_n$, and having taken $\sigma$ steps. We may want to link to the previous states that lead us there so we can recover the path. We've reached the end when the number of keys in a state matches the number of lower-case letters in the maze.

However, I found this to still take a very long time, as part 1 resulted in several millions of states being searched. Then I figured out I needed to create two levels of search, to create a much better approach:

* Use BFS to find the shortest path between any two keys or between keys and the starting position. No A* needed here as the number of paths between two locations, even in a large maze, is limited, and there is no easy heuristic to use anyway.

* Once we have a dependency map of the form *pos -> pos takes N steps and passes through doors D*, we can use A* to see what gives us the best combination of traversals. For any given key or start position we know how many steps it takes to get to other locations and if we have already collected the keys to get there.

The latter basically searches a graph of connected states (pick up the keys in *this* order) with weighted edges (it takes this many steps to move between keys $k_x$ and $k_y$).

In [1]:
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field, fields
from enum import Enum
from heapq import heapify, heappush, heappop
from itertools import count
from typing import (
    FrozenSet,
    Generic,
    Iterator,
    List,
    Mapping,
    NamedTuple,
    Sequence,
    Tuple,
    TypeVar,
)

from IPython.display import clear_output, display, DisplayHandle, Pretty


T = TypeVar("T")


class PriorityQueue(Generic[T]):
    def __init__(self, *initial: Tuple[int, T]) -> None:
        self._queue: List[Tuple[int, int, T]] = []
        self._count = count()
        for pri, item in initial:
            self.put(pri, item)
        heapify(self._queue)

    def __len__(self) -> int:
        return len(self._queue)

    def put(self, pri: int, item: T) -> None:
        heappush(self._queue, (pri, next(self._count), item))

    def get(self) -> T:
        if not self:
            raise ValueError('Queue is empty')
        return heappop(self._queue)[-1]
    


class Pos(NamedTuple):
    x: int = 0
    y: int = 0
        
    def __add__(self, other: Pos) -> Pos:  # type: ignore
        if isinstance(other, Pos):
            return Pos(self.x + other.x, self.y + other.y)
        return NotImplemented

    def __sub__(self, other: Pos) -> Pos:  # type: ignore
        if isinstance(other, Pos):
            return Pos(self.x - other.x, self.y - other.y)
        return NotImplemented

    def __matmul__(self, other: Pos) -> int:
        """Manhattan distance, via overloading Pos @ Pos"""
        if isinstance(other, Pos):
            return sum(map(abs, (self - other)))
        return NotImplemented


# make dataclasses with fields and defaults work with __slots__
# adapted from https://github.com/ericvsmith/dataclasses/issues/28
def add_slots(cls):
    # Create a new dict for our new class.
    cls_dict = dict(cls.__dict__)
    field_names = tuple(f.name for f in fields(cls))
    cls_dict['__slots__'] = field_names
    for field_name in field_names:
        # Remove our attributes, if present. They'll still be
        #  available in _MARKER.
        cls_dict.pop(field_name, None)
    # Remove __dict__ itself.
    cls_dict.pop('__dict__', None)
    # And finally create the class.
    qualname = getattr(cls, '__qualname__', None)
    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
    if qualname is not None:
        cls.__qualname__ = qualname
    return cls
    

@add_slots
@dataclass(frozen=True)
class MazePathState:
    pos: Pos
    maze_value: str
    doors: FrozenSet[str] = field(compare=False, default=frozenset())
    steps: int = field(compare=False, default=0)
        
    def moves(self, maze: Maze) -> Iterator[MazePathState]:
        for delta in (Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)):
            pos = self.pos + delta
            try:
                maze_value = maze[pos]
                if maze_value == "#":
                    continue
                doors = self.doors
                # collate keys and doors we encounter
                if maze_value.isupper():
                    doors |= {maze_value}
                yield MazePathState(pos, maze_value, doors, self.steps + 1)
            except IndexError:
                pass


# Paths between a given key or start and all keys that can be reached
# outer mapping has keys (a, b, c, ..) and start (@), inner mapping
# only has keys. We never need to return to the start.
Dependencies = Mapping[str, Mapping[str, MazePathState]]


@add_slots
@dataclass(frozen=True)
class KeyCollectState:
    key: str
    keys: FrozenSet[str] = frozenset()
    steps: int = field(compare=False, default=0)
    path: Tuple[str] = field(compare=False, default=())
        
    def moves(self, maze: Maze) -> KeyCollectState:
        keys = self.keys
        can_unlock = set(map(str.upper, keys))
        for other, state in maze.dependency_map[self.key].items():
            if other in keys:
                # no need to collect keys more than once.
                continue
            if state.doors <= can_unlock:
                # we can reach his state
                yield KeyCollectState(
                    other,
                    keys | {other},
                    self.steps + state.steps,
                    self.path + (other,)
                )


class Maze(Mapping[Pos, str]):
    width: int
    height: int
    start: Pos
    key_pos: Mapping[str, pos]
    door_pos: Mapping[str, pos]
    _dependencies: Optional[Dependencies] = None

    def __init__(self, lines: Sequence[str]) -> None:
        self._lines = lines
        self.width = max(len(l) for l in lines)
        self.height = len(lines)
        keys = {}
        doors = {}
        start = None
        for y, l in enumerate(lines):
            keys.update({c: Pos(x, y) for x, c in enumerate(l) if c.islower()})
            doors.update({c: Pos(x, y) for x, c in enumerate(l) if c.isupper()})
            if not start and (x := l.find("@")) > -1:
                self.start = Pos(x, y)
        self.key_pos = keys
        self.door_pos = doors
        
    def __repr__(self) -> str:
        maze = "\n".join(self._lines)
        return f"""\
<Maze width={self.width} height={self.height} start={self.start}
      keys={{{', '.join(sorted(self.key_pos))}}}
      maze=
{maze}
>
"""
    
    def __getitem__(self, pos: Pos) -> str:
        return self._lines[pos.y][pos.x]

    def __len__(self) -> int:
        return sum(len(l) for l in self._lines)
    
    def __iter__(self) -> Iterator[Pos]:
        for y, l in enumerate(self._lines):
            for x in range(len(l)):
                yield Pos(x, y)
                
    @property
    def dependency_map(self) -> Dependencies:
        if self._dependencies is None:
            self._dependencies = dependencies = {
                "@": dict(self.bfs(self.start, "@"))
            }
            for key, pos in self.key_pos.items():
                dependencies[key] = dict(self.bfs(pos, key))
        return self._dependencies

    def bfs(self, start: Pos, start_value: str) -> Iterable[Tuple[str, MazePathState]]:
        """Find the shortest paths to reachable keys."""
        start_state = MazePathState(start, start_value)
        queue = deque([start_state])
        seen = set()

        while queue:
            current = queue.popleft()
            value = current.maze_value
            if value != start_value and value.islower():
                yield value, current

            seen.add(current)
            for neighbor in current.moves(self):
                if neighbor in seen:
                    continue
                queue.append(neighbor)
                
    def shortest_path(self) -> int:
        keys = set(self.key_pos)
        keycount = len(keys)
        
        start = KeyCollectState("@")
        queue = PriorityQueue((0, start))
        open = {start: 0}
        closed = set()

        while open:
            current = queue.get()
            
            if open.get(current) != current.steps:
                # ignore items in the queue for which a shorter
                # path exists
                continue

            if current.keys == keys:
                return current.steps

            del open[current]
            closed.add(current)
            for neighbor in current.moves(self):
                if neighbor in closed:
                    continue
                if open.get(neighbor, float('inf')) <= neighbor.steps:
                    continue
                open[neighbor] = neighbor.steps
                queue.put(neighbor.steps, neighbor)

                
part1_tests = {
    "#########\n#b.A.@.a#\n#########": 8,
    (
        "########################\n#f.D.E.e.C.b.A.@.a.B.c.#\n######################.#\n"
        "#d.....................#\n########################"
    ): 86,
    (
        "########################\n#...............b.C.D.f#\n#.######################\n"
        "#.....@.a.B.c.d.A.e.F.g#\n########################"
    ): 132,
    (
        "#################\n#i.G..c...e..H.p#\n########.########\n#j.A..b...f..D.o#\n"
        "########@########\n#k.E..a...g..B.n#\n########.########\n#l.F..d...h..C.m#\n"
        "#################"
    ): 136,
    (
        "########################\n#@..............ac.GI.b#\n###d#e#f################\n"
        "###A#B#C################\n###g#h#i################\n########################\n"
    ): 81,
}
for testmaze, expected in part1_tests.items():
    assert Maze(testmaze.splitlines()).shortest_path() == expected

In [2]:
import aocd
data = aocd.get_data(day=18, year=2019)

In [3]:
print("Part 1:", Maze(data.splitlines()).shortest_path())

Part 1: 3586
