In [None]:
import math
from queue import PriorityQueue
from typing import Generator, Callable

In [None]:
grid = dict()
with open("day17_input.txt") as file:
    for row, line in enumerate(file):
        for col, char in enumerate(line.strip()):
            grid[(row, col)] = int(char)

END_POS = (row, col)

In [None]:
def find_path_to(state, path_dict):
    """Backtrack from a state to the start state to find the path taken."""
    path = []
    while state in path_dict:
        path.append(state[0])
        state = path_dict[state]
    path.append(state[0])
    return path


def print_path(state, path_dict):
    """Print the grid with the path taken highlighted."""
    path = set(find_path_to(state, path_dict))

    RED = "\033[91m"
    END = "\033[0m"

    for row in range(END_POS[0] + 1):
        for col in range(END_POS[1] + 1):
            if (row, col) in path:
                print(RED + str(grid[(row, col)]) + END, end="")
            else:
                print(grid[(row, col)], end="")
        print()

In [None]:
def find_solution(neighbour_func: Callable, grid: dict):
    start = (0, 0)
    heatloss = dict()
    prev_path = dict()
    queue = PriorityQueue()

    # State: (pos, prev, num_moves)
    start_state = (start, start, 0)
    queue.put((0, start_state))
    heatloss[start_state] = 0

    # Dijkstra with a state of (pos, prev_pos, num_moves)
    while not queue.empty():
        loss, state = queue.get()
        pos, _, _ = state
        for new_pos, new_moves in neighbour_func(state, grid):
            new_state = (new_pos, pos, new_moves)
            new_loss = loss + grid[new_pos]
            if new_loss < heatloss.get(new_state, math.inf):
                heatloss[new_state] = new_loss
                prev_path[new_state] = state
                queue.put((new_loss, new_state))

    return heatloss, prev_path

# Part 1


In [None]:
def valid_neighbours_part_1(
    state, grid
) -> Generator[tuple[tuple[int, int], int], None, None]:
    (row, col), (prev_row, prev_col), num_moves = state
    row_diff = row - prev_row
    col_diff = col - prev_col
    if num_moves == 3:
        # We must change direction
        if row_diff == 0:
            diffs = [(1, 0), (-1, 0)]  # Can only move up or down
        else:
            diffs = [(0, 1), (0, -1)]  # Can only move left or right
        for d in diffs:
            new_pos = (row + d[0], col + d[1])
            if new_pos in grid:
                yield new_pos, 1
    else:
        # Try moving in any direction
        for d in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
            new_pos = (row + d[0], col + d[1])
            if (new_pos not in grid) or (new_pos == (prev_row, prev_col)):
                # Can't go outside the grid or back to the previous position
                continue
            new_row, new_col = new_pos
            if (prev_row == new_row) or (prev_col == new_col):
                new_moves = num_moves + 1
            else:
                new_moves = 1
            yield new_pos, new_moves

In [None]:
heatloss, prev_path = find_solution(valid_neighbours_part_1, grid)
answer = min(value for key, value in heatloss.items() if key[0] == END_POS)
print("Answer:", answer)

In [None]:
# All possible solutions
solutions = sorted(
    ((key, value) for key, value in heatloss.items() if key[0] == END_POS),
    key=lambda x: x[1],
)
solutions

In [None]:
# Print the optimal path
print_path(solutions[0][0], prev_path)

# Part 2


In [None]:
def valid_neighbours_part_2(
    state, grid
) -> Generator[tuple[tuple[int, int], int], None, None]:
    (row, col), (prev_row, prev_col), num_moves = state
    row_diff = row - prev_row
    col_diff = col - prev_col
    if 0 < num_moves < 4:
        # We must continue in this direction
        new_pos = (row + row_diff, col + col_diff)
        if new_pos in grid:
            yield new_pos, num_moves + 1
    elif num_moves == 10:
        # We must turn left of right
        if row_diff == 0:
            # We must turn vertically
            for d in [1, -1]:
                new_pos = (row + d, col)
                if new_pos in grid:
                    yield new_pos, 1
        elif col_diff == 0:
            # We must turn horizontally
            for d in [1, -1]:
                new_pos = (row, col + d)
                if new_pos in grid:
                    yield new_pos, 1
    else:
        # We can move in any direction
        for d in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
            new_pos = (row + d[0], col + d[1])
            new_row, new_col = new_pos
            if (new_pos not in grid) or (new_pos == (prev_row, prev_col)):
                # Outside the grid or back to the previous position
                continue
            if (prev_row == row == new_row) or (prev_col == col == new_col):
                new_moves = num_moves + 1
            else:
                new_moves = 1
            yield new_pos, new_moves

In [None]:
heatloss, prev_path = find_solution(valid_neighbours_part_2, grid)
answer = min(
    value
    for key, value in heatloss.items()
    if (key[0] == END_POS) and (4 <= key[2] <= 10)
)
print("Answer:", answer)

In [None]:
# All possible solutions
solutions = sorted(
    (
        (key, value)
        for key, value in heatloss.items()
        if (key[0] == END_POS) and (4 <= key[2] <= 10)
    ),
    key=lambda x: x[1],
)
solutions

In [None]:
# Print the optimal path
print_path(solutions[0][0], prev_path)