In [3]:
from collections import namedtuple
from dataclasses import dataclass
from copy import deepcopy


# Classes
DIRECTIONS = [(0, -1), (-1, 0), (1, 0), (0, 1)]
FACING_DIRECTIONS = {"^": (-1, 0), ">": (0, 1), "v": (1, 0), "<": (0, -1)}


@dataclass
class Node:
    x: int
    y: int
    distance: int
    entry_direction: str
    is_end: bool


# Logic
def read_file(path):
    res = open(path, "r").readlines()
    res = [list(x.strip()) for x in res]
    res = [[x if x != "S" else ">" for x in l] for l in res]
    return res


def get_shortest_distance_node(node_list: list[Node]):
    min = float("inf")
    shortest_node = None
    shortest_node_index = -1
    for i, n in enumerate(node_list):
        if n.distance < min and n.distance != float("inf"):
            min = n.distance
            shortest_node_index = i
    if shortest_node_index >= 0:
        shortest_node = node_list.pop(shortest_node_index)
    return shortest_node


def get_node_by_xy(node_list: list[Node], x: int, y: int, pop: bool = True):
    for i, n in enumerate(node_list):
        if n.x == x and n.y == y:
            if pop:
                res = node_list.pop(i)
                return res
            else:
                return node_list[i]

    return None


def dx_dy_to_dir(dx, dy):
    for d, dir in FACING_DIRECTIONS.items():
        if dir[0] == dx and dir[1] == dy:
            return d


def turn_clockwise(direction, turn_count: int = 1):
    dirs = list(FACING_DIRECTIONS.keys())
    i = dirs.index(direction)
    next_i = (i + turn_count) % 4
    return dirs[next_i]


def turn_anticlockwise(direction):
    return turn_clockwise(direction, 3)


def calculate_distance(from_node: Node, to_node: Node):
    dx, dy = to_node.x - from_node.x, to_node.y - from_node.y
    direction_needed = dx_dy_to_dir(dx, dy)

    if from_node.entry_direction == direction_needed:
        turn_count = 0
    else:
        if (
            turn_clockwise(from_node.entry_direction) == direction_needed
            or turn_anticlockwise(from_node.entry_direction) == direction_needed
        ):
            turn_count = 1
        else:
            turn_count = 2

    return direction_needed, turn_count * 1000 + 1


def get_neighbours(node_list: list[Node], node: Node):
    res = []
    for dx, dy in DIRECTIONS:
        x, y = node.x + dx, node.y + dy
        loop_res = get_node_by_xy(node_list, x, y)
        if loop_res is not None:
            res.append(loop_res)

    return res


def draw_grid(grid, visited):
    local_grid = deepcopy(grid)
    for x, l in enumerate(local_grid):
        for y, c in enumerate(l):
            v = get_node_by_xy(visited, x, y, False)
            if v:
                local_grid[x][y] = v.entry_direction
            elif c == ".":
                local_grid[x][y] = "."

    header = "".join([f"{x:<3}" for x in range(len(local_grid[0]))])
    print("   " + header)
    for x, l in enumerate(local_grid):
        line = "  ".join(l)
        print(f"{x:>3} {line}")

    print("\n\n")


def initialise_node_collections(input):
    # Initialise nodes
    unvisited = []
    visited = []

    for x, l in enumerate(input):
        for y, c in enumerate(l):
            if c == ">":
                unvisited.append(Node(x, y, 0, ">", False))
            elif c == "E":
                unvisited.append(Node(x, y, float("inf"), "", True))
            elif c != "#":
                unvisited.append(Node(x, y, float("inf"), "", False))

    return unvisited, visited

In [4]:
input = read_file("input")
unvisited, visited = initialise_node_collections(input)

while True:
    start = get_shortest_distance_node(unvisited)
    if start:
        visited.append(start)
        if start.is_end:
            break
    else:
        break
    neigh = get_neighbours(unvisited, start)

    for n in neigh:
        direction, dist = calculate_distance(start, n)
        n.distance = dist + start.distance
        n.entry_direction = direction
        unvisited.append(n)
        if n.is_end:
            break

    # draw_grid(input, visited)

[x for x in visited if x.is_end]

[Node(x=1, y=139, distance=89472, entry_direction='^', is_end=True)]

In [15]:
from heapq import heappop, heappush

input = read_file("input")

for x, l in enumerate(input):
    for y, c in enumerate(l):
        if c == ">":
            START = (x, y)
        elif c == "E":
            END = (x, y)

input[END[0]][END[1]] = "."

DIRECTIONS = [(0, 1), (1, 0), (0, -1), (-1, 0)]

nodes = [(0, 0, *START)]
visited = set()

while nodes:
    score, direction, i, j = heappop(nodes)

    if (i, j) == END:
        break
    if (direction, i, j) in visited:
        continue

    visited.add((direction, i, j))

    x = i + DIRECTIONS[direction][0]
    y = j + DIRECTIONS[direction][1]
    if input[x][y] == "." and (direction, x, y) not in visited:
        # Move one step
        heappush(nodes, (score + 1, direction, x, y))

    # Try turning
    turns = [(direction - 1) % 4, (direction + 1) % 4]
    for turn in turns:
        if (turn, i, j) not in visited:
            heappush(nodes, (score + 1000, turn, i, j))

print(score)

89460


In [21]:
from heapq import heappop, heappush

input = read_file("example1")

for x, l in enumerate(input):
    for y, c in enumerate(l):
        if c == ">":
            START = (x, y)
        elif c == "E":
            END = (x, y)

input[END[0]][END[1]] = "."

DIRECTIONS = [(0, 1), (1, 0), (0, -1), (-1, 0)]

nodes = [(0, 0, *START)]
visited = set()

while nodes:
    score, direction, i, j = heappop(nodes)

    if (i, j) == END:
        break
    if (direction, i, j) in visited:
        continue

    visited.add((direction, i, j))

    x = i + DIRECTIONS[direction][0]
    y = j + DIRECTIONS[direction][1]
    if input[x][y] == "." and (direction, x, y) not in visited:
        # Move one step
        heappush(nodes, (score + 1, direction, x, y))

    # Try turning
    turns = [(direction - 1) % 4, (direction + 1) % 4]
    for turn in turns:
        if (turn, i, j) not in visited:
            heappush(nodes, (score + 1000, turn, i, j))

print(score)

7036


In [30]:
from heapq import heappop, heappush

input = read_file("input")

for x, l in enumerate(input):
    for y, c in enumerate(l):
        if c == ">":
            START = (x, y)
        elif c == "E":
            END = (x, y)


def can_visit(d, i, j, score):
    prev_score = visited.get((d, i, j))
    if prev_score and prev_score < score:
        return False
    visited[(d, i, j)] = score
    return True


input[END[0]][END[1]] = "."

DIRECTIONS = [(0, 1), (1, 0), (0, -1), (-1, 0)]

nodes = [(0, 0, *START, set([START]))]
visited = {}
lowest_score = None
winning_paths = []

while nodes:
    score, direction, i, j, path = heappop(nodes)

    if lowest_score and lowest_score < score:
        # Stop exploring as the scores are not going to beat best score
        break

    if (i, j) == END:
        lowest_score = score
        winning_paths.append(path)
        continue

    if not can_visit(direction, i, j, score):
        continue

    x = i + DIRECTIONS[direction][0]
    y = j + DIRECTIONS[direction][1]
    if input[x][y] == "." and can_visit(direction, x, y, score + 1):
        # Move one step
        new_path = deepcopy(path)
        new_path.add((x, y))
        heappush(nodes, (score + 1, direction, x, y, new_path))

    # Try turning
    turns = [(direction - 1) % 4, (direction + 1) % 4]
    for turn in turns:
        if can_visit(turn, i, j, score + 1000):
            heappush(nodes, (score + 1000, turn, i, j, path))

len(set([x for xs in winning_paths for x in xs]))

504