# Day 17 
## Part 1
Start off with implementing A* with a Manhattan distance ($m$) heuristic. I'm not sure if that's going to be fast enough, I thought originally of using the minimum sum of $m$ heat losses within the square boundary between the current position and the end but that won't be admissible as you could have e.g. 9s surrounded by 1s. 

What's the size of the problem?

In [1]:
input = open("input").read()

len(input)

20022

I think that's going to be too big for such a simple heuristic. Let's think about the one but see how far the simpler one gets.

In [2]:
from functools import cache
from dataclasses import dataclass

@dataclass
class Point:
    x: int
    y: int

    def __add__(self, other):
        return self.__class__(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return self.__class__(self.x - other.x, self.y - other.y)

    def __neg__(self):
        return self.__class__(-self.x, -self.y)

    def __hash__(self):
        return hash((self.x, self.y))

    def __lt__(self, other):
        if self.x < other.x:
            return True
        elif self.x > other.x:
            return False
        else:
            return self.y < other.y

    def __iter__(self):
        yield self.x
        yield self.y

    def __mod__(self, other):
        if isinstance(other, Point):
            return self.__class__(self.x % other.x, self.y % other.y)
        else:
            return self.__class__(self.x % other, self.y % other)
        
    def __mul__(self, multiple):
        return self.__class__(self.x * multiple, self.y * multiple)
    
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y
    

N = Point(0, 1)
S = Point(0, -1)
W = Point(-1, 0)
E = Point(1, 0)


def manhattan_distance(p1, p2):
    return abs(p1.x - p2.x) + abs(p1.y - p2.y)

In [3]:
def parse_data(s):
    grid = {}
    lines = s.strip().splitlines()
    for y, line in zip(range(len(lines) - 1, -1, -1), lines):
        for x, c in enumerate(line):
            grid[Point(x, y)] = int(c)
    return grid

test_data = parse_data("""2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533""")

data = parse_data(input)

In [4]:
import heapq
from collections import namedtuple
from functools import cache

State = namedtuple("State", "h cost position")


def a_star_factory(grid):
    end_position = Point(max(p.x for p in grid), 0)

    @cache
    def a_star(start_position):
        q = [State(
            manhattan_distance(start_position, end_position),
            0,
            start_position
        )]
    
        visited = {start_position}
    
        while q:
            state = heapq.heappop(q)
            for d in {N, S, W, E}:
                p = state.position + d
                if p == end_position:
                    return state.cost + grid[p]
                if p in grid and p not in visited:
                    cost = state.cost + grid[p]
                    visited.add(p)
                    h = manhattan_distance(p, end_position) + cost
                    heapq.heappush(
                        q,
                        State(
                            h,
                            cost,
                            p
                        )
                    )

    return a_star
    

def part_1(grid):
    a_star = a_star_factory(grid)
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    q = [
        (
            State(
                a_star(start_position),
                0,
                start_position,
            ),
            [None, None, None]
        )
    ]

    while q:
        state, last_three_directions = heapq.heappop(q)
        directions = {N, S, W, E}
        if last_three_directions[-1] is not None:
            # Remove reverse
            directions.discard(-last_three_directions[-1])
        if len(set(last_three_directions)) == 1:
            directions.discard(last_three_directions[-1])
        for d in directions:
            p = state.position + d
            if p == end_position:
                return state.cost + grid[p]
            if p in grid:
                cost = state.cost + grid[p]
                h = a_star(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        last_three_directions[1:] + [d]
                    )
                )

In [None]:
import heapq
from collections import namedtuple
from functools import cache

State = namedtuple("State", "h cost position")


def a_star_factory(grid):
    end_position = Point(max(p.x for p in grid), 0)

    @cache
    def a_star(start_position):
        q = [State(
            manhattan_distance(start_position, end_position),
            0,
            start_position
        )]
    
        visited = {start_position}
    
        while q:
            state = heapq.heappop(q)
            for d in {N, S, W, E}:
                p = state.position + d
                if p == end_position:
                    return state.cost + grid[p]
                if p in grid and p not in visited:
                    cost = state.cost + grid[p]
                    visited.add(p)
                    h = manhattan_distance(p, end_position) + cost
                    heapq.heappush(
                        q,
                        State(
                            h,
                            cost,
                            p
                        )
                    )

    return a_star
    

def part_1(grid):
    a_star = a_star_factory(grid)
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    q = [
        (
            State(
                a_star(start_position),
                0,
                start_position,
            ),
            ()
        )
    ]
    visited = (start_position, ())

    while q:
        state, last_equal_directions = heapq.heappop(q)
        directions = {N, S, W, E}
        if last_three_directions[-1] is not None:
            # Remove reverse
            directions.discard(-last_three_directions[-1])
        if len(set(last_three_directions)) == 1:
            directions.discard(last_three_directions[-1])
        for d in directions:
            p = state.position + d
            ds = tuple
            if p == end_position:
                return state.cost + grid[p]
            if p in grid:
                cost = state.cost + grid[p]
                h = a_star(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        last_three_directions[1:] + [d]
                    )
                )

In [5]:
h = a_star_factory(test_data)
h(Point(0, max(p.y for p in test_data)))

78

In [6]:
%%time 

part_1(test_data)

CPU times: user 494 ms, sys: 7.14 ms, total: 501 ms
Wall time: 499 ms


102

That works, but is very slow for test data on part 1.

In [7]:
%%time

# part_1(data)

KeyboardInterrupt: 

That rapidly chews up the memory. I think there needs to be some sort of global check that we haven't visited the same state before. All we need to know is the last direction and the number of times it was used, plus the current position. Given that this will be 4 directions x 3 consecutive times used x ~20000 I think that will be tractable.
*TODO* implement that.

In [2]:
import heapq
from collections import namedtuple
from functools import cache

State = namedtuple("State", "h cost position")


def a_star_factory(grid):
    end_position = Point(max(p.x for p in grid), 0)

    @cache
    def a_star(start_position):
        q = [State(
            manhattan_distance(start_position, end_position),
            0,
            start_position
        )]
    
        visited = {start_position}
    
        while q:
            state = heapq.heappop(q)
            for d in {N, S, W, E}:
                p = state.position + d
                if p == end_position:
                    return state.cost + grid[p]
                if p in grid and p not in visited:
                    cost = state.cost + grid[p]
                    visited.add(p)
                    h = manhattan_distance(p, end_position) + cost
                    heapq.heappush(
                        q,
                        State(
                            h,
                            cost,
                            p
                        )
                    )

    return a_star
    

def part_1(grid):
    a_star = a_star_factory(grid)
    start_position = Point(0, max(p.y for p in grid))
    end_position = Point(max(p.x for p in grid), 0)
    q = [
        (
            State(
                a_star(start_position),
                0,
                start_position,
            ),
            ()
        )
    ]
    visited = (start_position, ())

    while q:
        state, last_equal_directions = heapq.heappop(q)
        directions = {N, S, W, E}
        if last_three_directions[-1] is not None:
            # Remove reverse
            directions.discard(-last_three_directions[-1])
        if len(set(last_three_directions)) == 1:
            directions.discard(last_three_directions[-1])
        for d in directions:
            p = state.position + d
            ds = tuple
            if p == end_position:
                return state.cost + grid[p]
            if p in grid:
                cost = state.cost + grid[p]
                h = a_star(p) + cost
                heapq.heappush(
                    q,
                    (
                        State(
                            h,
                            cost,
                            p
                        ),
                        last_three_directions[1:] + [d]
                    )
                )

()

In [3]:
()

()

In [4]:
x = (1,2,3)
x[1:]

(2, 3)