# Day 15 - path finding

We are asked to play out a battle between elves and goblins. The biggest challenge is figuring out how to move; once next to an enemy stepping through the actions should be simple.

We've last seen a similar problem (finding the best options in a simulated battle) in [AoC 2015, day 22](https://adventofcode.com/2015/day/22), and I'll again use [the A* search algorithm](https://en.wikipedia.org/wiki/A*_search_algorithm) to find the paths, picking the shortest path (sorting equal lengths by the 'reading' order, `(y, x)` tuple sorting, of the first step). Once a shortest path to a target is found, we can stop the search and execute a move.

So the trick here is to see a short path that started off to the left then down as *separate* from a node that steps down then left. In most A* implementations you'd see this as the same node in the graph (position the same, distance the same, same heuristic score to any target goals).

In [1]:
import sys
from dataclasses import dataclass, field
from enum import Enum
from heapq import heappush, heappop
from itertools import count
from operator import attrgetter
from typing import Iterable, Iterator, Mapping, Optional, Sequence, Set, Tuple

Position = Tuple[int, int]  # y, x order

class NoTargetsRemaining(Exception):
    """No more targets to attack"""

class NoTargetsReachable(Exception):
    """No path found that reaches a target"""

class Race(Enum):
    elf = 'E'
    goblin = 'G' 

@dataclass(order=True)
class Unit:
    race: Race = field(compare=False)
    y: int
    x: int
    hitpoints: int = field(default=200, compare=False)
    attackpower: int = field(default=3, compare=False)

    @property
    def pos(self) -> Position:
        return self.y, self.x
    
    def adjacent(self, cave: 'CaveCombat') -> Set[Position]:
        """All cave positions adjacent to the current position"""
        positions = (
            (self.y + dy, self.x + dx)
            for dy, dx in ((-1, 0), (0, -1), (0, 1), (1, 0))
        )
        return {(y, x) for y, x in positions if cave.map[y][x] == '.'}

    def available(self, cave: 'CaveCombat') -> Set[Position]:
        """All positions this unit could move to"""
        return {pos for pos in self.adjacent(cave) if cave[pos] is None}

    def turn(self, cave: 'CaveCombat') -> None:
        # find targets to go after
        targets = [u for u in cave.units if u.race is not self.race]
        if not targets:
            # end combat
            raise NoTargetsRemaining

        # determine if we need to move
        adjacent = self.adjacent(cave)
        in_range = [
            u for pos in adjacent for u in (cave[pos],)
            if u and u.race is not self.race
        ]
        
        # we need to move, make a move if possible
        if not in_range:
            # find a target to move to
            target_positions = set().union(*(t.available(cave) for t in targets))
            if not target_positions:
                # no positions to move to, turn ends
                return
        
            # pick a shortest path to one of the positions, returns our new position
            try:
                self.y, self.x = cave.search_path(self.pos, target_positions)
            except NoTargetsReachable:
                pass
            
            # check for in-range targets once more now that we have moved
            adjacent = self.adjacent(cave)
            in_range = [
                u for pos in adjacent for u in (cave[pos],)
                if u and u.race is not self.race
            ]

        # attack if in range of a target
        if in_range:
            # pick target with lowest hitpoints; ties broken by reading order
            target = min(in_range, key=attrgetter('hitpoints', 'y', 'x'))
            target.hitpoints -= self.attackpower
            if target.hitpoints <= 0:
                # target died, remove them from the cave
                cave.units.remove(target)
            return
        
_sentinel_first_pos: Position = (-1, -1)

@dataclass(frozen=True, order=True)
class Node:
    """Node on the A* search graph"""
    y: int
    x: int
    distance: int = 0
    # position of first actual transition node. Needed to distinguish
    # between multiple possible paths to the same goal, and this is
    # used to set the new unit position once a path has been picked.
    first_pos: Position = _sentinel_first_pos
        
    @property
    def pos(self) -> Position:
        return self.y, self.x
        
    def cost(self, goals: Set[Position]) -> int:
        """Calculate the cost for this node, f(n) = g(n) + h(n)
        
        The cost of this node is the distance travelled (g) plus
        estimated cost to get to nearest goal (h).
        
        Here we use the manhattan distance to the nearest goal as
        the estimated cost.
        
        """
        distances = (abs(y - self.y) + abs(x - self.x) for y, x in goals)
        return self.distance + min(distances)
    
    def transitions(self, cave: 'CaveCombat') -> Iterator['Node']:
        cls = type(self)
        positions = (
            (self.y + dy, self.x + dx)
            for dy, dx in ((-1, 0), (0, -1), (0, 1), (1, 0))
        )
        return (
            cls(
                y, x, self.distance + 1,
                (y, x) if self.first_pos == _sentinel_first_pos else self.first_pos,
            )
            for y, x in positions
            if cave.map[y][x] == '.' and cave[(y, x)] is None
        )

@dataclass
class CaveCombat:
    map: Sequence[str]
    units: Sequence[Unit]
    round: int = 0
        
    def __post_init__(self):
        # internal cache of unit positions, updated before each unit turn
        self._unit_positions: Mapping = {}
    
    @classmethod
    def from_lines(cls, lines: Iterable[str]) -> 'CaveCombat':
        map = []
        units = []
        unit_chars = ''.join(r.value for r in Race)
        for y, line in enumerate(lines):
            cleaned = []
            for x, c in enumerate(line):
                if c in unit_chars:
                    units.append(Unit(Race(c), y, x))
                    c = '.'
                cleaned.append(c)
            map.append(''.join(cleaned))
        return cls(map, units)

    def __str__(self) -> str:
        map = [list(l) for l in self.map]
        for unit in self.units:
            map[unit.y][unit.x] = unit.race.value
        return '\n'.join([''.join(l) for l in map])

    def __getitem__(self, yx: Position) -> Optional[Unit]:
        if self._unit_positions:
            return self._unit_positions.get(yx)
        return next((u for u in self.units if u.pos == yx), None)
    
    def do_battle(self) -> int:
        while True:
            result = self.turn()
            if result is not None:
                return result

    def turn(self) -> Optional[int]:
        for unit in sorted(self.units):
            # skip units that are dead; these are still in the sorted
            # loop iterable but have been removed from self.units
            if unit.hitpoints <= 0:
                continue
                
            # cache unit positions once per round
            self._unit_positions = {u.pos: u for u in self.units}
            
            try:
                unit.turn(self)
            except NoTargetsRemaining:
                return self.round * sum(u.hitpoints for u in self.units)

        self.round += 1
        return None
    
    def search_path(self, start: Position, goals: Set[Position]) -> Position:
        start_node = Node(*start)
        open = {start_node}
        unique = count()  # tie breaker when costs are equal
        pqueue = [(start_node.cost(goals), next(unique), start_node)]
        closed = set()
        shortest = []
        # maximum distance we search to find a path; we should be able
        # to reach *any* point in a cave within this limit.
        limit = (len(self.map) + max(len(l) for l in self.map))
        while open:
            node = heappop(pqueue)[-1]

            if node.pos in goals:
                if shortest:
                    assert shortest[0].distance <= node.distance
                    if shortest[0].distance < node.distance:
                        # no more paths that are shorter, we are done
                        break
                shortest.append(node)

            open.remove(node)
            closed.add(node)
            for new in node.transitions(self):
                if new in closed or new in open or new.distance > limit:
                    continue
                open.add(new)
                heappush(pqueue, (new.cost(goals), next(unique), new))

        if not shortest:
            # all searches exhausted and no reachable goal found
            raise NoTargetsReachable

        # now we need to pick a path. We pick the shortest path where the end-goal comes
        # first in reading order. If there are multiple paths of equal length to that
        # goal, then we need to pick the one whose first step position comes in reading
        # order. This happens to be the sort-order for nodes with equal distance.
        return min(shortest).first_pos

In [2]:
movetest = CaveCombat.from_lines('''\
#########
#G..G..G#
#.......#
#.......#
#G..E..G#
#.......#
#.......#
#G..G..G#
#########'''.splitlines())
outputs = (
    '#########\n#G..G..G#\n#.......#\n#.......#\n#G..E..G#\n#.......#\n#.......#\n#G..G..G#\n#########',
    '#########\n#.G...G.#\n#...G...#\n#...E..G#\n#.G.....#\n#.......#\n#G..G..G#\n#.......#\n#########',
    '#########\n#..G.G..#\n#...G...#\n#.G.E.G.#\n#.......#\n#G..G..G#\n#.......#\n#.......#\n#########',
    '#########\n#.......#\n#..GGG..#\n#..GEG..#\n#G..G...#\n#......G#\n#.......#\n#.......#\n#########'
)
for expected in outputs:
    assert str(movetest) == expected
    movetest.turn()

combattest = CaveCombat.from_lines('''\
#######
#.G...#
#...EG#
#.#.#G#
#..G#E#
#.....#
#######'''.splitlines())
rounds = (
    (0, '#######\n#.G...#\n#...EG#\n#.#.#G#\n#..G#E#\n#.....#\n#######', ('G', 200), ('E', 200), ('G', 200), ('G', 200), ('G', 200), ('E', 200)),
    (1, '#######\n#..G..#\n#...EG#\n#.#G#G#\n#...#E#\n#.....#\n#######', ('G', 200), ('E', 197), ('G', 197), ('G', 200), ('G', 197), ('E', 197)),
    (2, '#######\n#...G.#\n#..GEG#\n#.#.#G#\n#...#E#\n#.....#\n#######', ('G', 200), ('G', 200), ('E', 188), ('G', 194), ('G', 194), ('E', 194)),
    (23, '#######\n#...G.#\n#..G.G#\n#.#.#G#\n#...#E#\n#.....#\n#######', ('G', 200), ('G', 200), ('G', 131), ('G', 131), ('E', 131)),
    (24, '#######\n#..G..#\n#...G.#\n#.#G#G#\n#...#E#\n#.....#\n#######', ('G', 200), ('G', 131), ('G', 200), ('G', 128), ('E', 128)),
    (25, '#######\n#.G...#\n#..G..#\n#.#.#G#\n#..G#E#\n#.....#\n#######', ('G', 200), ('G', 131), ('G', 125), ('G', 200), ('E', 125)),
    (26, '#######\n#G....#\n#.G...#\n#.#.#G#\n#...#E#\n#..G..#\n#######', ('G', 200), ('G', 131), ('G', 122), ('E', 122), ('G', 200)),
    (27, '#######\n#G....#\n#.G...#\n#.#.#G#\n#...#E#\n#...G.#\n#######', ('G', 200), ('G', 131), ('G', 119), ('E', 119), ('G', 200)),
    (28, '#######\n#G....#\n#.G...#\n#.#.#G#\n#...#E#\n#....G#\n#######', ('G', 200), ('G', 131), ('G', 116), ('E', 113), ('G', 200)),
    (47, '#######\n#G....#\n#.G...#\n#.#.#G#\n#...#.#\n#....G#\n#######', ('G', 200), ('G', 131), ('G', 59), ('G', 200)),
)
for round, expected, *units in rounds:
    while combattest.round != round:
        combattest.turn()
    assert str(combattest) == expected
    assert [(u.race.value, u.hitpoints) for u in sorted(combattest.units)] == units

assert combattest.turn() == 27730

tests = (
    (
        '#######\n#G..#E#\n#E#E.E#\n#G.##.#\n#...#E#\n#...E.#\n#######',
        '#######\n#...#E#\n#E#...#\n#.E##.#\n#E..#E#\n#.....#\n#######',
        36334
    ),
    (
        '#######\n#E..EG#\n#.#G.E#\n#E.##E#\n#G..#.#\n#..E#.#\n#######',
        '#######\n#.E.E.#\n#.#E..#\n#E.##.#\n#.E.#.#\n#...#.#\n#######',
        39514
    ),
    (
        '#######\n#E.G#.#\n#.#G..#\n#G.#.G#\n#G..#.#\n#...E.#\n#######',
        '#######\n#G.G#.#\n#.#G..#\n#..#..#\n#...#G#\n#...G.#\n#######',
        27755
    ),
    (
        '#######\n#.E...#\n#.#..G#\n#.###.#\n#E#G#G#\n#...#G#\n#######',
        '#######\n#.....#\n#.#G..#\n#.###.#\n#.#.#.#\n#G.G#G#\n#######',
        28944
    ),
    (
        '#########\n#G......#\n#.E.#...#\n#..##..G#\n#...##..#\n#...#...#\n#.G...G.#\n#.....G.#\n#########',
        '#########\n#.G.....#\n#G.G#...#\n#.G##...#\n#...##..#\n#.G.#...#\n#.......#\n#.......#\n#########',
        18740
    ),
)
for start, end, expected in tests:
    testcave = CaveCombat.from_lines(start.splitlines())
    assert testcave.do_battle() == expected
    assert str(testcave) == end

In [3]:
import aocd

data = aocd.get_data(day=15, year=2018)

In [4]:
cave = CaveCombat.from_lines(data.splitlines())
print('Part 1:', cave.do_battle())

Part 1: 195811
