In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
data = load_data(2024, 20)

In [None]:
# data, part_1, part_2
tests = [
    (
        """###############
#...#...#.....#
#.#.#.#.#.###.#
#S#...#.#.#...#
#######.#.#.###
#######.#.#...#
#######.#.###.#
###..E#...#...#
###.#######.###
#...###...#...#
#.#####.#.###.#
#.#...#.#.#...#
#.#.#.#.#.#.###
#...#...#...###
###############
""",
        30,
        41,
    ),
]

# Part 1

In [None]:
def get_track(data):
    layout = set()
    for j, line in enumerate(data.splitlines()):
        for i, c in enumerate(line):
            if c != "#":
                layout.add((i, j))
            if c == "S":
                start = i, j
            if c == "E":
                end = i, j
    pos = start
    d = 0
    track = {}
    while pos != end:
        track[pos] = d
        d += 1
        i, j = pos
        for di, dj in [
            (-1, 0),
            (1, 0),
            (0, -1),
            (0, 1),
        ]:
            next_ = i + di, j + dj
            if next_ in layout and next_ not in track:
                pos = next_
    track[end] = d
    return track

In [None]:
def find_shortcuts(track, threshold, length):
    shortcuts = 0
    for i, j in track:
        # with the provided lengths, iterating throught the neighborhood is
        # faster than checking every track pair
        for di in range(-length, length + 1):
            jrange = length - abs(di)
            for dj in range(-jrange, jrange + 1):
                next_ = i + di, j + dj
                if next_ in track:
                    d = abs(di) + abs(dj)
                    if track[next_] - track[i, j] - d >= threshold:
                        shortcuts += 1
    return shortcuts

In [None]:
def get_shortcuts(data, threshold, length=2):
    track = get_track(data)
    return find_shortcuts(track, threshold, length)

In [None]:
check(get_shortcuts, tests, threshold=4)
get_shortcuts(data, threshold=100)

# Part 2

In [None]:
check(get_shortcuts, tests, 2, threshold=70, length=20)
get_shortcuts(data, threshold=100, length=20)

# KDTree-based search

I tried to use a 2D index to identify nearby track points.
This second solution has better scaling, but is not significantly faster with lengths 2 (part 1) and 20 (part 2).

In [None]:
from scipy.spatial import KDTree

In [None]:
def find_shortcuts(track, threshold, length):
    points = list(track)
    idx = KDTree(points)
    shortcuts = 0
    for i, j in track:
        for pt_id in idx.query_ball_point((i, j), length, p=1):
            ei, ej = points[pt_id]
            d = abs(ei - i) + abs(ej - j)
            assert d <= length, f"{d} > {length}"
            if track[ei, ej] - track[i, j] - d >= threshold:
                shortcuts += 1
    return shortcuts

In [None]:
check(get_shortcuts, tests, 2, threshold=70, length=20)
get_shortcuts(data, threshold=100, length=20)