In [None]:
%load_ext autoreload
%autoreload 2
from aoc.lib import load, timing

YEAR = 2024
DAY = 10
TEST = False
TESTDATA = '89010123\n78121874\n87430965\n96549874\n45678903\n32019012\n01329801\n10456732\n'

In [None]:
from itertools import pairwise, product
import numpy as np
from collections import defaultdict
from aoc.lib import CharGrid
import networkx as nx

@timing
def prepare_data():
    data = load(YEAR, DAY, split_lines=True, test=TESTDATA if TEST else None)
    grid = CharGrid(data['split'])
    lows = set()
    highs = set()
    graph = nx.DiGraph()
    for pos in grid.walk():
        level = int(grid[pos])
        if level == 0:
            lows.add(pos)
        elif level == 9:
            highs.add(pos)
        for nb in grid.neighbours(pos):
            nb_chr = grid[nb]
            if nb_chr is not None:
                nb_level = int(nb_chr)
                if level == nb_level + 1:
                    graph.add_edge(nb, pos)
    return grid, graph, lows, highs

In [None]:
# Level 1: interpreting the data as a map, count each combination of 0 and 9 that are connected
@timing
def level1(grid, graph, lows, highs):
    count = 0
    for low, high in product(lows, highs):
        if nx.has_path(graph, low, high):
            count += 1
    return count

grid, graph, lows, highs = prepare_data()
print(level1(grid, graph, lows, highs))

In [None]:
# Level 2: now count each distinct trail going from a 0 to a 9
@timing
def level2(grid, graph, lows, highs):
    count = 0
    for low, high in product(lows, highs):
        count += sum(1 for _ in nx.all_simple_paths(graph, low, high))
    return count

grid, graph, lows, highs = prepare_data()
print(level2(grid, graph, lows, highs))