In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
import bisect

In [None]:
data = load_data(2023, 17)

In [None]:
# data, part_1, part_2
tests = [
    (
        """2413432311323
3215453535623
3255245654254
3446585845452
4546657867536
1438598798454
4457876987766
3637877979653
4654967986887
4564679986453
1224686865563
2546548887735
4322674655533""",
        102,
        94,
    ),
    (
        """111111111111
999999999991
999999999991
999999999991
999999999991""",
        None,
        71,
    ),
    (
        """111119999999
999919999999
999919911999
999911111111""",
        24,
        None,
    ),
]

# Part 1

In [None]:
def gen_map(lines):
    map_ = {}
    for y, line in enumerate(lines):
        for x, c in enumerate(line):
            map_[x + 1j * y] = int(c)
    return map_

In [None]:
def travel_cost(map_, target, max_length, min_length=0):
    queue = []
    min_costs = {}

    def _enqueue(cost, pos, dir, steps, queue=queue, min_costs=min_costs):
        if pos not in map_:
            return
        if (pos, dir, steps) in min_costs:
            if cost >= min_costs[(pos, dir, steps)]:
                return
        min_costs[(pos, dir, steps)] = cost
        bisect.insort(queue, (cost, pos, dir, steps), key=lambda t: -t[0])

    _enqueue(0, 1, 1, 0)
    _enqueue(0, 1j, 1j, 0)
    min_cost = 9 * len(map_)
    while queue:
        cost, pos, dir, steps = queue.pop()
        if cost > min_cost:
            return min_cost
        steps += 1
        assert map_[pos] > 0
        cost += map_[pos]
        if pos == target and steps >= min_length:
            min_cost = min(cost, min_cost)
            continue
        if steps < max_length:
            # continue forward
            _enqueue(cost, pos + dir, dir, steps)
        if steps >= min_length:
            # turn clockwise
            dir = dir * 1j
            _enqueue(cost, pos + dir, dir, 0)
            # turn counterclockwise
            _enqueue(cost, pos - dir, -dir, 0)

In [None]:
def crucible_path(data, max_length=3, min_length=0):
    lines = data.splitlines()
    map_ = gen_map(lines)
    target = len(lines[0]) - 1 + 1j * (len(lines) - 1)
    return travel_cost(map_, target, max_length, min_length)

In [None]:
check(crucible_path, tests)
crucible_path(data)

# Part 2

In [None]:
check(crucible_path, tests, 2, max_length=10, min_length=4)
crucible_path(data, max_length=10, min_length=4)