# Day 17 
## Part 1
Start off with implementing A* with a Manhattan distance ($m$) heuristic. I'm not sure if that's going to be fast enough, I thought originally of using the minimum sum of $m$ heat losses within the square boundary between the current position and the end but that won't be admissible as you could have e.g. 9s surrounded by 1s. 

What's the size of the problem?

In [1]:
input = open("input").read()

len(input)

20022

I think that's going to be too big for such a simple heuristic.  I think there needs to be some sort of global check that we haven't visited the same state before. All we need to know is the last direction and the number of times it was used, plus the current position. Given that this will be 4 directions * 3 possible times used * ~20000 I think that will be tractable.

In [2]:
from functools import cache
from dataclasses import dataclass

@dataclass
class Point:
    x: int
    y: int

    def __add__(self, other):
        return self.__class__(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return self.__class__(self.x - other.x, self.y - other.y)

    def __neg__(self):
        return self.__class__(-self.x, -self.y)

    def __hash__(self):
        return hash((self.x, self.y))

    def __lt__(self, other):
        if self.x < other.x:
            return True
        elif self.x > other.x:
            return False
        else:
            return self.y < other.y

    def __iter__(self):
        yield self.x
        yield self.y

    def __mod__(self, other):
        if isinstance(other, Point):
            return self.__class__(self.x % other.x, self.y % other.y)
        else:
            return self.__class__(self.x % other, self.y % other)
        
    def __mul__(self, multiple):
        return self.__class__(self.x * multiple, self.y * multiple)
    
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    

N = Point(0, 1)
S = Point(0, -1)
W = Point(-1, 0)
E = Point(1, 0)


def manhattan_distance(p1, p2):
    return abs(p1.x - p2.x) + abs(p1.y - p2.y)

In [3]:
def parse_data(s):
    grid = {}
    lines = s.strip().splitlines()
    for y, line in zip(range(len(lines) - 1, -1, -1), lines):
        for x, c in enumerate(line):
            grid[Point(x, y)] = int(c)
    return grid

test_data = parse_data("""2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533""")

data = parse_data(input)

In [4]:
import heapq
from collections import namedtuple
from functools import cache

State = namedtuple("State", "h cost position")
DirectionState = namedtuple("DirectionState", "direction times_used")
PositionState = namedtuple("PositionState", "position direction_state")

def part_1(grid):
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    m = lambda p: manhattan_distance(p, end_position)
    q = [
        (
            State(
                m(start_position),
                0,
                start_position,
            ),
            DirectionState(S, 0)
        )
    ]
    visited = set()

    while q:
        state, dirstate = heapq.heappop(q)
        directions = {N, S, W, E}
        directions.discard(-dirstate.direction)
        if dirstate.times_used == 3:
            directions.discard(dirstate.direction)
        for d in directions:
            p = state.position + d
            ps = PositionState(
                p,
                DirectionState(d, dirstate.times_used + 1) 
                if dirstate.direction == d
                else DirectionState(d, 1)
            )
            if p == end_position:
                return state.cost + grid[p]
            if p in grid and ps not in visited:
                visited.add(ps)
                cost = state.cost + grid[p]
                h = m(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        ps.direction_state
                    )
                )

In [58]:
%%time 

part_1(test_data)

CPU times: user 26.8 ms, sys: 2.56 ms, total: 29.3 ms
Wall time: 27.3 ms


102

In [6]:
%%time

part_1(data)

CPU times: user 2.38 s, sys: 26.8 ms, total: 2.4 s
Wall time: 2.4 s


698

Great. That worked with no problems whatsoever.

## Part 2

This is fiddly.

In [7]:
DirectionState = namedtuple("DirectionState", "direction times_used")
PositionState = namedtuple("PositionState", "position direction_state")

def part_2(grid):
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    m = lambda p: manhattan_distance(p, end_position)
    q = [
        (
            State(
                m(start_position),
                0,
                start_position,
            ),
            DirectionState(S, 0)
        ),
        (
            State(
                m(start_position),
                0,
                start_position,
            ),
            DirectionState(E, 0)
        )
    ]
    visited = set()

    while q:
        state, dirstate = heapq.heappop(q)
        directions = {N, S, W, E}
        directions.discard(-dirstate.direction)
        if dirstate.times_used == 10:
            directions.discard(dirstate.direction)
        elif dirstate.times_used < 4:
            directions = {dirstate.direction}
        for d in directions:
            p = state.position + d
            ps = PositionState(
                p,
                DirectionState(d, dirstate.times_used + 1) 
                if dirstate.direction == d
                else DirectionState(d, 1)
            )
            if p == end_position:
                return state.cost + grid[p]
            if p in grid and ps not in visited:
                visited.add(ps)
                cost = state.cost + grid[p]
                h = m(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        ps.direction_state
                    )
                )

In [8]:
%%time

part_2(test_data)

CPU times: user 12.5 ms, sys: 831 Âµs, total: 13.3 ms
Wall time: 12.6 ms


94

In [9]:
%%time

part_2(data)

CPU times: user 8.2 s, sys: 52.9 ms, total: 8.25 s
Wall time: 8.25 s


819

Great. Right on the test data, wrong on the actual data.

In [10]:
test_data_2 = parse_data("""111111111111
999999999991
999999999991
999999999991
999999999991""")

part_2(test_data_2)

47

Why is that 71 in the problem statement? Shouldn't it go ten blocks to the right and then go down?

Aha! I missed the "(or even before it can stop at the end)" part.

In [11]:
DirectionState = namedtuple("DirectionState", "direction times_used")
PositionState = namedtuple("PositionState", "position direction_state")

def part_2(grid):
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    m = lambda p: manhattan_distance(p, end_position)
    q = [
        (
            State(
                m(start_position),
                0,
                start_position,
            ),
            DirectionState(S, 0)
        ),
        (
            State(
                m(start_position),
                0,
                start_position,
            ),
            DirectionState(E, 0)
        )
    ]
    visited = set()

    while q:
        state, dirstate = heapq.heappop(q)
        directions = {N, S, W, E}
        directions.discard(-dirstate.direction)
        if dirstate.times_used == 10:
            directions.discard(dirstate.direction)
        elif dirstate.times_used < 4:
            directions = {dirstate.direction}
        for d in directions:
            p = state.position + d
            ps = PositionState(
                p,
                DirectionState(d, dirstate.times_used + 1) 
                if dirstate.direction == d
                else DirectionState(d, 1)
            )
            if p == end_position and ps.direction_state.times_used >= 4:
                return state.cost + grid[p]
            if p in grid and ps not in visited:
                visited.add(ps)
                cost = state.cost + grid[p]
                h = m(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        ps.direction_state
                    )
                )

In [12]:
part_2(test_data)

94

In [13]:
part_2(test_data_2)

71

In [14]:
%%time

part_2(data)

CPU times: user 8.42 s, sys: 36.6 ms, total: 8.46 s
Wall time: 8.45 s


825