In [None]:
import pyparsing as pp
from bisect import bisect_left

In [None]:
filename = "inputs/12-05.txt"
with open(filename, "r") as f:
    data = f.read()

In [None]:
tests = [
    (
        """seeds: 79 14 55 13

seed-to-soil map:
50 98 2
52 50 48

soil-to-fertilizer map:
0 15 37
37 52 2
39 0 15

fertilizer-to-water map:
49 53 8
0 11 42
42 0 7
57 7 4

water-to-light map:
88 18 7
18 25 70

light-to-temperature map:
45 77 23
81 45 19
68 64 13

temperature-to-humidity map:
0 69 1
1 0 69

humidity-to-location map:
60 56 37
56 93 4""",
        35,
        46,
    ),
]

In [None]:
def check(function, part=1, tests=tests):
    for input, *outputs in tests:
        output = outputs[part - 1]
        if output is not None:
            result = function(input)
            assert result == output, f"{result} == {output}"

# Part 1

In [None]:
uint = pp.Word(pp.nums).set_parse_action(lambda toks: int(toks[0]))
word = pp.Word(pp.alphas)
header_expr = pp.Suppress("seeds") + pp.Suppress(":") + pp.Group(pp.OneOrMore(uint)).set_parse_action(lambda toks: toks.as_list())
mapping_expr = pp.Group(
    uint.set_results_name("destination")
    + uint.set_results_name("source")
    + uint.set_results_name("length")
)
block_expr = pp.Group(
    word.set_results_name("source")
    + pp.Suppress("-to-")
    + word.set_results_name("destination")
    + pp.Suppress("map:")
    + pp.OneOrMore(mapping_expr)
)
data_expr = header_expr + pp.OneOrMore(block_expr)

In [None]:
def apply_mapping(seed, almanac, categories, start="seed", end="location"):
    cat = start
    while cat != end:
        mappings = almanac[cat]
        map_idx = bisect_left(mappings, (seed + 1, -1)) - 1
        if map_idx >= 0:
            src, length, dest = mappings[map_idx]
            assert src <= seed
            if src + length > seed:
                seed += dest - src
        cat = categories[cat]
    return seed

In [None]:
def min_seed(data):
    seeds, *groups = data_expr.parse_string(data)
    almanac = {}
    categories = {}
    for source, dest, *mappings in groups:
        almanac[source] = []
        categories[source] = dest
        for dest_start, source_start, length in mappings:
            almanac[source].append((source_start, length, dest_start))
        almanac[source] = sorted(almanac[source])
    return min([apply_mapping(seed, almanac, categories) for seed in seeds])

In [None]:
check(min_seed)
min_seed(data)

# Part 2

In [None]:
def apply_mapping(seed_pair, almanac, categories, start="seed", end="location"):
    cat = start
    seeds = [seed_pair]
    while cat != end:
        mappings = almanac[cat]
        next_seeds = []
        for seed, range in seeds:
            while range > 0:
                map_idx = bisect_left(mappings, (seed + 1, -1)) - 1
                if map_idx < 0:
                    # no mapping
                    length = min(mappings[0][0], range)
                    next_seeds.append((seed, length))
                    seed += length
                    range -= length
                else:
                    src, mapping_length, dest = mappings[map_idx]
                    assert src <= seed
                    if src + mapping_length > seed:
                        length = min(mapping_length + src - seed, range)
                        next_seeds.append((seed + dest - src, length))
                        seed += length
                        range -= length
                    else:
                        # no mapping
                        if map_idx == len(mappings) - 1:
                            next_seeds.append((seed, range))
                            range = 0
                        else:
                            next = mappings[map_idx + 1][0]
                            length = min(next - seed, range)
                            next_seeds.append((seed + dest - src, length))
                            seed += length
                            range -= length
        cat = categories[cat]
        seeds = next_seeds
    return min([seed for seed, _ in seeds])

In [None]:
def min_seed(data):
    seeds, *groups = data_expr.parse_string(data)
    seed_pairs = list(zip(seeds[::2], seeds[1::2]))
    almanac = {}
    categories = {}
    for source, dest, *mappings in groups:
        almanac[source] = []
        categories[source] = dest
        for dest_start, source_start, length in mappings:
            almanac[source].append((source_start, length, dest_start))
        almanac[source] = sorted(almanac[source])
    return min([apply_mapping(seed_pair, almanac, categories) for seed_pair in seed_pairs])

In [None]:
check(min_seed, 2)
min_seed(data)