# Day 16
## Part 1
Use Dijkstra.

In [49]:
from dataclasses import dataclass
import heapq

@dataclass(eq=True, frozen=True)
class Point:
    x: int
    y: int

    def __add__(self, other):
        return self.__class__(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return self.__class__(self.x - other.x, self.y - other.y)

    def __neg__(self):
        return self.__class__(-self.x, -self.y)

    def __lt__(self, other):
        if self.x < other.x:
            return True
        elif self.x > other.x:
            return False
        else:
            return self.y < other.y

    def __iter__(self):
        yield self.x
        yield self.y

    def __mod__(self, other):
        if isinstance(other, Point):
            return self.__class__(self.x % other.x, self.y % other.y)
        else:
            return self.__class__(self.x % other, self.y % other)
        
    def __mul__(self, multiple):
        return self.__class__(self.x * multiple, self.y * multiple)

N = Point(0, 1)
S = Point(0, -1)
W = Point(-1, 0)
E = Point(1, 0)

DIRECTIONS = {N, E, S, W}

def parse_data(s):
    grid = {}
    lines = s.strip().splitlines()
    for y, line in zip(range(len(lines) - 1, -1, -1), lines):
        for x, c in enumerate(line):
            if c != "#":
                grid[Point(x, y)] = c
    return grid

def part_1(grid):
    starting_point = next(p for p in grid if grid[p] == "S")
    end_point = next(p for p in grid if grid[p] == "E")
    q = []
    seen = set()
    seen.add((starting_point, E))
    heapq.heappush(q, (0, starting_point, E))

    while q:
        score, p, current_d = heapq.heappop(q)
        seen.add((p, current_d))

        if p == end_point:
            return score

        for d in DIRECTIONS - {-current_d}:
            next_point = p + d
            if next_point in grid and (next_point, d) not in seen:
                new_score = score + (1 if d == current_d else 1001)
                heapq.heappush(
                    q, 
                    (new_score, next_point, d)
                )

test_data_1 = parse_data("""###############
#.......#....E#
#.#.###.#.###.#
#.....#.#...#.#
#.###.#####.#.#
#.#.#.......#.#
#.#.#####.###.#
#...........#.#
###.#.#####.#.#
#...#.....#.#.#
#.#.#.###.#.#.#
#.....#...#.#.#
#.###.#.#.#.#.#
#S..#.....#...#
###############""")
part_1(test_data_1)

7036

In [50]:
test_data_2 = parse_data("""#################
#...#...#...#..E#
#.#.#.#.#.#.#.#.#
#.#.#.#...#...#.#
#.#.#.#.###.#.#.#
#...#.#.#.....#.#
#.#.#.#.#.#####.#
#.#...#.#.#.....#
#.#.#####.#.###.#
#.#.#.......#...#
#.#.###.#####.###
#.#.#...#.....#.#
#.#.#.#####.###.#
#.#.#.........#.#
#.#.#.#########.#
#S#.............#
#################""")
part_1(test_data_2)

11048

In [51]:
%%time

data = parse_data(open("input").read())
part_1(data)

CPU times: user 556 ms, sys: 36 μs, total: 556 ms
Wall time: 555 ms


88468

## Part 2

Keep track of visited nodes, returning each path when the destination is found. When the path cost becomes higher than the best cost stop iterating.

In [53]:
def best_paths(grid):
    starting_point = next(p for p in grid if grid[p] == "S")
    end_point = next(p for p in grid if grid[p] == "E")
    q = []
    seen = set()
    seen.add((starting_point, E))
    heapq.heappush(q, (0, starting_point, E, frozenset({starting_point})))
    best_score = None

    while q:
        score, p, current_d, path = heapq.heappop(q)
        seen.add((p, current_d))

        if p == end_point:
            if best_score is None:
                best_score = score
            if score == best_score:
                yield path
            else:
                break

        for d in DIRECTIONS - {-current_d}:
            next_point = p + d
            if next_point in grid and (next_point, d) not in seen:
                new_score = score + (1 if d == current_d else 1001)
                heapq.heappush(
                    q, 
                    (
                        new_score, 
                        next_point, 
                        d, 
                        path | frozenset({next_point})
                    )
                )

def part_2(data):
    return len(frozenset.union(*best_paths(data)))

part_2(test_data_1)

45

In [54]:
part_2(test_data_2)

64

In [55]:
part_2(data)

616