# Day 20 - Torus geometry

Part 1 is a simple path-finding problem. Get from `AA` to `ZZ` through a maze complicated by portals. When you get to a portal point the next step gets you somewhere else entirely. This could easily throw a spanner into A* heuristics (Manhattan distance won't cut it!). I'm taking a leaf from [day 18](./Day%2018.ipynb) and using two levels of path finding: mapping distances from portal to portal, and then do a path finding search that jumps straight from one portal to the next.

In [1]:
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field, fields
from enum import IntEnum
from heapq import heapify, heappush, heappop
from itertools import count
from typing import (
    Dict,
    FrozenSet,
    Generic,
    Iterator,
    List,
    Literal,
    Mapping,
    NamedTuple,
    Optional,
    Protocol,
    Sequence,
    Tuple,
    TypeVar,
)

T = TypeVar("T")


class PriorityQueue(Generic[T]):
    def __init__(self, *initial: Tuple[int, T]) -> None:
        self._queue: List[Tuple[int, int, T]] = []
        self._count = count()
        for pri, item in initial:
            self.put(pri, item)
        heapify(self._queue)

    def __len__(self) -> int:
        return len(self._queue)

    def put(self, pri: int, item: T) -> None:
        heappush(self._queue, (pri, next(self._count), item))

    def get(self) -> T:
        if not self:
            raise ValueError('Queue is empty')
        return heappop(self._queue)[-1]
    

class Pos(NamedTuple):
    x: int = 0
    y: int = 0
        
    def __add__(self, other: Pos) -> Pos:  # type: ignore
        if isinstance(other, Pos):
            return Pos(self.x + other.x, self.y + other.y)
        return NotImplemented

    def __sub__(self, other: Pos) -> Pos:  # type: ignore
        if isinstance(other, Pos):
            return Pos(self.x - other.x, self.y - other.y)
        return NotImplemented


class PortalSide(IntEnum):
    outer = 0
    inner = 1
        
    @property
    def opposite(self) -> PortalSide:
        return PortalSide(1 - self.value)  # 1 - 1 => 0, 1 - 0 => 1

    
class PortalSides(NamedTuple):
    outer: Pos = Pos()
    inner: Optional[Pos] = None


# make dataclasses with fields and defaults work with __slots__
# adapted from https://github.com/ericvsmith/dataclasses/issues/28
def add_slots(cls: T) -> T:
    # Create a new dict for our new class.
    cls_dict = dict(cls.__dict__)
    field_names = tuple(f.name for f in fields(cls))
    cls_dict['__slots__'] = field_names
    for field_name in field_names:
        # Remove our attributes, if present. They'll still be
        #  available in _MARKER.
        cls_dict.pop(field_name, None)
    # Remove __dict__ itself.
    cls_dict.pop('__dict__', None)
    # And finally create the class.
    qualname = getattr(cls, '__qualname__', None)
    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)  # type: ignore
    if qualname is not None:
        cls.__qualname__ = qualname  # type: ignore
    return cls


@add_slots
@dataclass(frozen=True)
class MazePathState:
    pos: Pos
    steps: int = field(compare=False, default=0)
        
    def moves(self, maze: Maze) -> Iterator[MazePathState]:
        for delta in (Pos(-1, 0), Pos(1, 0), Pos(0, -1), Pos(0, 1)):
            pos = self.pos + delta
            maze_value = maze[pos]
            if maze_value == ".":
                yield MazePathState(pos, self.steps + 1)


# Portals have two sides, outer and inner
# Distances are between (portal, side) pairs, path-finding
# then uses (portal, side) to map to the opposide (portal, side) pair.
class Portal(NamedTuple):
    name: str
    side: PortalSide
        
    def __repr__(self) -> str:
        return f"{self.name}:{self.side.name}"

    @property
    def opposite(self):
        return Portal(self.name, self.side.opposite)


Distances = Mapping[Portal, Mapping[Portal, int]]


class PortalTraversalState(Protocol):
    @property
    def portal(self) -> Portal:
        ...

    @property
    def steps(self) -> int:
        ...

    def moves(self, maze: Maze) -> Iterator[PortalTraversalState]:
        ...


@add_slots
@dataclass(frozen=True)
class DirectPortalTraversalState:
    portal: Portal
    portals: FrozenSet[Portal] = frozenset()
    steps: int = field(compare=False, default=0)
    path: Tuple[str, ...] = field(compare=False, default=())

    def moves(self, maze: Maze) -> Iterator[PortalTraversalState]:
        portal, portals = self.portal, self.portals
        for other, steps in maze.portal_distances[portal].items():
            if other in portals:
                # no need to traverse portals more than once
                continue
            yield DirectPortalTraversalState(
                other.opposite,
                portals | {portal},
                # stepping through the portal adds 1 step
                self.steps + steps + 1,
                self.path + (portal,)
            )


class Maze(Mapping[Pos, str]):
    portals: Dict[str, PortalSides]
    pos_portal: Dict[Pos, Portal]
    _lines: Sequence[str]
    _distances: Optional[Distances] = None
    
    def __init__(self, maze: Sequence[str]) -> None:
        self._lines = maze
        
        # find portals; either on lines from left-to-right
        # on rows in the same column from top to bottom
        self.portals = {}
        self.pos_portal = {}
        
        # 'outer' portals are those at x in {0, width - 3} for horizontal
        # labels, or y in {0, height - 3} for vertical, inner otherwise
        # (offsets are for first character of the .LL or LL. sequence)
        width, height = max(len(l) for l in maze), len(maze)
        houter, vouter = {0, width - 3}, {0, height - 3}

        # take 3 lines at a time, y is index of 1st
        for y, (l1, l2, l3) in enumerate(zip(maze, maze[1:], maze[2:])):
            # take three characters at a time; x is index of 1st.
            for x, (c1, c2, c3) in enumerate(zip(l1, l1[1:], l1[2:])):
                if c1 == '.' and c2.isalpha() and c3.isalpha():
                    self._add_portal(c2 + c3, Pos(x, y), x in houter)
                elif c1.isalpha() and c2.isalpha() and c3 == '.':
                    self._add_portal(c1 + c2, Pos(x + 2, y), x in houter)
                if c1 == '.' and l2[x].isalpha() and l3[x].isalpha():
                    self._add_portal(l2[x] + l3[x], Pos(x, y), y in vouter)
                elif c1.isalpha() and l2[x].isalpha() and l3[x] == '.':
                    self._add_portal(c1 + l2[x], Pos(x, y + 2), y in vouter)
    
    def _add_portal(self, name: str, pos: Pos, outer: bool) -> None:
        sides = self.portals.get(name, PortalSides())
        if outer:
            self.portals[name] = sides._replace(outer=pos)
            self.pos_portal[pos] = Portal(name, PortalSide.outer)
        else:
            self.portals[name] = sides._replace(inner=pos)
            self.pos_portal[pos] = Portal(name, PortalSide.inner)

    def __repr__(self) -> str:
        maze = "\n".join(self._lines)
        portals = " ".join([
            f"{n}: ({','.join(map(str, p1)) if p1 else ''})"
            f"~({','.join(map(str, p2)) if p2 else ''})" 
            for n, (p1, p2) in sorted(self.portals.items())
        ])  
        return f"""\
<Maze portals={portals}
      maze=
{maze}
>
"""
    
    def __getitem__(self, pos: Pos) -> str:
        return self._lines[pos.y][pos.x]

    def __len__(self) -> int:
        return sum(len(l) for l in self._lines)
    
    def __iter__(self) -> Iterator[Pos]:
        for y, l in enumerate(self._lines):
            for x in range(len(l)):
                yield Pos(x, y)
                
    @property
    def portal_distances(self) -> Distances:
        if self._distances is None:
            distances = self._distances = {}
            for name, positions in self.portals.items():
                if name == "ZZ":
                    continue
                for side in PortalSide:
                    pos = positions[side]
                    if pos is not None:
                        distances[Portal(name, side)] = dict(self.bfs(pos))
        return self._distances
    
    def bfs(self, start: Pos) -> Iterator[Tuple[Portal, int]]:
        """Find the shortest paths to reachable portals."""
        start_state = MazePathState(start)
        queue = deque([start_state])
        seen = set()
        pos_portal = self.pos_portal

        while queue:
            current = queue.popleft()
            pos = current.pos
            if pos != start and pos in pos_portal and pos_portal[pos].name != "AA":
                yield pos_portal[pos], current.steps

            seen.add(current)
            for neighbor in current.moves(self):
                if neighbor in seen:
                    continue
                queue.append(neighbor)

    def shortest_path(self) -> int:
        return self._search_astar(
            DirectPortalTraversalState(Portal("AA", PortalSide.outer)),
        )
        
    def _search_astar(
        self,
        start: PortalTraversalState,
    ) -> int:
        goal = Portal("ZZ", PortalSide.inner)
        queue = PriorityQueue((0, start))
        open = {start: 0}
        closed = set()

        while open:
            current = queue.get()
            
            if open.get(current) != current.steps:
                # ignore items in the queue for which a shorter
                # path exists
                continue

            if current.portal == goal:
                # -1, because we never step through ZZ
                return current.steps - 1

            del open[current]
            closed.add(current)
            for neighbor in current.moves(self):
                if neighbor in closed:
                    continue
                if open.get(neighbor, float('inf')) <= neighbor.steps:
                    continue
                open[neighbor] = neighbor.steps
                queue.put(neighbor.steps, neighbor)
        
        assert False, "should never reach here"


part1_tests = {
    (
        "         A           \n         A           \n  #######.#########  \n"
        "  #######.........#  \n  #######.#######.#  \n  #######.#######.#  \n"
        "  #######.#######.#  \n  #####  B    ###.#  \nBC...##  C    ###.#  \n"
        "  ##.##       ###.#  \n  ##...DE  F  ###.#  \n  #####    G  ###.#  \n"
        "  #########.#####.#  \nDE..#######...###.#  \n  #.#########.###.#  \n"
        "FG..#########.....#  \n  ###########.#####  \n             Z       \n"
        "             Z       "
    ): 23,
    (
        """\
                   A               
                   A               
  #################.#############  
  #.#...#...................#.#.#  
  #.#.#.###.###.###.#########.#.#  
  #.#.#.......#...#.....#.#.#...#  
  #.#########.###.#####.#.#.###.#  
  #.............#.#.....#.......#  
  ###.###########.###.#####.#.#.#  
  #.....#        A   C    #.#.#.#  
  #######        S   P    #####.#  
  #.#...#                 #......VT
  #.#.#.#                 #.#####  
  #...#.#               YN....#.#  
  #.###.#                 #####.#  
DI....#.#                 #.....#  
  #####.#                 #.###.#  
ZZ......#               QG....#..AS
  ###.###                 #######  
JO..#.#.#                 #.....#  
  #.#.#.#                 ###.#.#  
  #...#..DI             BU....#..LF
  #####.#                 #.#####  
YN......#               VT..#....QG
  #.###.#                 #.###.#  
  #.#...#                 #.....#  
  ###.###    J L     J    #.#.###  
  #.....#    O F     P    #.#...#  
  #.###.#####.#.#####.#####.###.#  
  #...#.#.#...#.....#.....#.#...#  
  #.#####.###.###.#.#.#########.#  
  #...#.#.....#...#.#.#.#.....#.#  
  #.###.#####.###.###.#.#.#######  
  #.#.........#...#.............#  
  #########.###.###.#############  
           B   J   C               
           U   P   P               
"""
    ): 58,
}
for testmaze, expected in part1_tests.items():
    Maze(testmaze.splitlines()).shortest_path() == expected

In [2]:
import aocd
data = aocd.get_data(day=20, year=2019)
assert isinstance(data, str)

In [3]:
print("Part 1:", Maze(data.splitlines()).shortest_path())

Part 1: 522


## Part 2

Now we need to add *depth* to our puzzle state, and distinguishing between 'inner' and 'outer' portals. The rest remains the same, I'm now doubly glad I re-used the two-layer path-finding concept from day 18, as we can now re-use the distance graph at every level.

I refactored part 1 to add ordering to the portal info (it was simply a tuple before) so I can distinguish between inner and outer states in the second part without having to supply too many subclasses.

In [4]:
class PortalDepth(NamedTuple):
    portal: Portal
    depth: int
    
    def __repr__(self) -> str:
        return f"{self.portal}:{self.depth}"


@add_slots
@dataclass(frozen=True)
class RecursivePortalTraversalState:
    portal: Portal
    depth: int = 0
    portals: FrozenSet[PortalDepth] = frozenset()
    steps: int = field(compare=False, default=0)
    path: Tuple[PortalDepth, ...] = field(compare=False, default=())

    def moves(self, maze: Maze) -> Iterator[PortalTraversalState]:
        portal, portals, depth = self.portal, self.portals, self.depth
        pdepth = PortalDepth(self.portal, depth)
        for other, steps in maze.portal_distances[portal].items():
            if PortalDepth(other, depth) in portals:
                # no need to traverse portals at a given depth more than once
                continue
            if not depth and other.side is PortalSide.outer and other.name != "ZZ":
                # can't traverse through outer portals at depth 0
                # unless it is ZZ we reached.
                continue
            elif depth and other.name in {"AA", "ZZ"}:
                # AA and ZZ are only options at depth 0
                continue
            # outer -> inner? up a level.
            # inner -> outer? down a level.
            newdepth = depth + (1 if other.side is PortalSide.inner else -1)
            yield RecursivePortalTraversalState(
                other.opposite,
                newdepth,
                portals | {pdepth},
                # stepping through the portal adds 1 step
                self.steps + steps + 1,
                self.path + (pdepth,)
            )


class RecursiveMaze(Maze):
    def shortest_path(self) -> int:
        return self._search_astar(
            RecursivePortalTraversalState(Portal("AA", PortalSide.outer)),
        )


part2_tests = {
    (
        "         A           \n         A           \n  #######.#########  \n"
        "  #######.........#  \n  #######.#######.#  \n  #######.#######.#  \n"
        "  #######.#######.#  \n  #####  B    ###.#  \nBC...##  C    ###.#  \n"
        "  ##.##       ###.#  \n  ##...DE  F  ###.#  \n  #####    G  ###.#  \n"
        "  #########.#####.#  \nDE..#######...###.#  \n  #.#########.###.#  \n"
        "FG..#########.....#  \n  ###########.#####  \n             Z       \n"
        "             Z       "
    ): 26,
    (
        """\
             Z L X W       C                 
             Z P Q B       K                 
  ###########.#.#.#.#######.###############  
  #...#.......#.#.......#.#.......#.#.#...#  
  ###.#.#.#.#.#.#.#.###.#.#.#######.#.#.###  
  #.#...#.#.#...#.#.#...#...#...#.#.......#  
  #.###.#######.###.###.#.###.###.#.#######  
  #...#.......#.#...#...#.............#...#  
  #.#########.#######.#.#######.#######.###  
  #...#.#    F       R I       Z    #.#.#.#  
  #.###.#    D       E C       H    #.#.#.#  
  #.#...#                           #...#.#  
  #.###.#                           #.###.#  
  #.#....OA                       WB..#.#..ZH
  #.###.#                           #.#.#.#  
CJ......#                           #.....#  
  #######                           #######  
  #.#....CK                         #......IC
  #.###.#                           #.###.#  
  #.....#                           #...#.#  
  ###.###                           #.#.#.#  
XF....#.#                         RF..#.#.#  
  #####.#                           #######  
  #......CJ                       NM..#...#  
  ###.#.#                           #.###.#  
RE....#.#                           #......RF
  ###.###        X   X       L      #.#.#.#  
  #.....#        F   Q       P      #.#.#.#  
  ###.###########.###.#######.#########.###  
  #.....#...#.....#.......#...#.....#.#...#  
  #####.#.###.#######.#######.###.###.#.#.#  
  #.......#.......#.#.#.#.#...#...#...#.#.#  
  #####.###.#####.#.#.#.#.###.###.#.###.###  
  #.......#.....#.#...#...............#...#  
  #############.#.#.###.###################  
               A O F   N                     
               A A D   M                     
"""
    ): 396,
}
for testmaze, expected in part2_tests.items():
    RecursiveMaze(testmaze.splitlines()).shortest_path() == expected

In [5]:
print("Part 2:", RecursiveMaze(data.splitlines()).shortest_path())

Part 2: 6300
