# Day 05

Imports.

In [1]:
from dataclasses import dataclass
import sys


Read input.

In [2]:
with open("05_input.txt", "r") as f:
    input = f.read().strip()


## Part 1

Define data structures and utilities.

In [3]:
@dataclass
class Interval:
    start: int
    end: int

    def __contains__(self, item: int) -> bool:
        return self.start <= item < self.end

    def __gt__(self, other: "Interval") -> bool:
        return self.start > other.start

    def __repr__(self) -> str:
        return f"[{self.start}:{self.end})"

    def overlaps(self, other: "Interval") -> bool:
        return self.start < other.end and other.start < self.end

    def intersection(self, other: "Interval") -> "Interval":
        if not self.overlaps(other):
            return None
        return Interval(
            start=max(self.start, other.start),
            end=min(self.end, other.end),
        )

@dataclass
class DualInterval:
    src: Interval
    dest: Interval
    src_to_dest_dist: int

    def __init__(self, dest_start: int, src_start: int, range_length: int):
        self.src = Interval(src_start, src_start + range_length)
        self.dest = Interval(dest_start, dest_start + range_length)
        self.src_to_dest_dist = dest_start - src_start

    def __contains__(self, item: int) -> bool:
        return item in self.src

    def __repr__(self) -> str:
        return f"{self.src} -> {self.dest}"

    def src_to_dest(self, item: int) -> int:
        return item + self.src_to_dest_dist

    def dest_to_src(self, item: int) -> int:
        return item - self.src_to_dest_dist

@dataclass
class Map:
    dual_intervals: list[DualInterval]

    def __init__(self, dual_intervals: list[DualInterval]):
        self.dual_intervals = dual_intervals
        self.sort(by_src=True)
        self._fill_gaps()

    def sort(self, by_src: bool = True) -> None:
        if by_src:
            self.dual_intervals.sort(key=lambda x: x.src.start)
        else:
            self.dual_intervals.sort(key=lambda x: x.dest.start)

    def _fill_gaps(self) -> None:
        # Between each interval
        for ix in range(1, len(self.dual_intervals)):
            prev_di = self.dual_intervals[ix - 1]
            curr_di = self.dual_intervals[ix]
            if prev_di.src.end != curr_di.src.start:
                range_length = curr_di.src.start - prev_di.src.end
                new_di = DualInterval(prev_di.src.end, prev_di.src.end, range_length)
                self.dual_intervals.insert(ix, new_di)

        # Between 0 and the first interval
        first_di = self.dual_intervals[0]
        if 0 not in first_di:
            new_first_di = DualInterval(0, 0, first_di.src.start)
            self.dual_intervals.insert(0, new_first_di)

        # Between the last interval and the max int
        last_di = self.dual_intervals[-1]
        max_int = sys.maxsize
        if max_int not in last_di:
            range_length = max_int - last_di.src.end
            new_last_di = DualInterval(last_di.src.end, last_di.src.end, range_length)
            self.dual_intervals.append(new_last_di)

    def src_to_dest(self, item: int) -> int:
        for di in self.dual_intervals:
            if item in di:
                return di.src_to_dest(item)
        raise ValueError(f"{item} not in {self.dual_intervals}")


def parse_maps(map_sections: list[str]) -> list[Map]:
    maps = []
    for section in map_sections:
        dual_intervals = []
        for line in section.split("\n")[1:]:
            di = DualInterval(*[int(x) for x in line.split()])
            dual_intervals.append(di)
        maps.append(Map(dual_intervals))
    return maps


def parse_part_1(input: str) -> tuple[list[int], list[Map]]:
    sections = input.split("\n\n")
    seeds = [int(seed) for seed in sections[0].split(": ")[1].split(" ")]
    maps = parse_maps(map_sections=sections[1:])
    return seeds, maps


def seed_to_location(seed: int, maps: list[Map]) -> int:
    dest_nb = seed
    for map_i in maps:
        dest_nb = map_i.src_to_dest(dest_nb)
    return dest_nb


In [4]:
%%timeit

def part_1(input: str) -> int:
    seeds, maps = parse_part_1(input)
    min_location = 1e9

    for seed_i in seeds:
        location = seed_to_location(seed_i, maps)
        min_location = min(min_location, location)

    return min_location

min_location = part_1(input)
assert min_location == 177_942_185


513 µs ± 9.64 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Part 2

Define utilities.

In [5]:
def parse_seeds(seed_section: str) -> list[Interval]:
    seeds = [int(seed) for seed in seed_section.split(": ")[1].split(" ")]
    seed_pairs = [(seeds[ix], seeds[ix + 1]) for ix in range(0, len(seeds), 2)]
    sorted_seed_pairs = sorted(seed_pairs, key=lambda x: x[0])
    return [
        Interval(start, start + length)
        for start, length in sorted_seed_pairs
    ]


def parse_part_2(input: str) -> tuple[list[Interval], list[Map]]:
    sections = input.split("\n\n")
    seeds = parse_seeds(seed_section=sections[0])
    maps = parse_maps(map_sections=sections[1:])
    return seeds, maps


def find_best_seed(seeds: list[Interval], map_src_intervals: list[Interval]):
    # map_src_intervals is the list of the first map intervals sorted to give
    # the lowest locations first.
    for src_interval in map_src_intervals:
        for seed_interval in seeds:
            if src_interval.overlaps(seed_interval):
                intersection = src_interval.intersection(seed_interval)
                return intersection.start


In [6]:
%%timeit

seeds, maps = parse_part_2(input)

last_map = maps[-1]
# Sort by destination to get lowest location intervals first
last_map.sort(by_src=False)
last_map_src_intervals = [di.src for di in last_map.dual_intervals]

for map_i in reversed(maps[:-1]):
    higher_is_better = last_map_src_intervals[0] > last_map_src_intervals[1]
    next_last_map_src_intervals = []

    for src_interval in last_map_src_intervals:
        dest_intersections, src_intersections = [], []

        for di in map_i.dual_intervals:
            if src_interval.overlaps(di.dest):
                dest_inter = src_interval.intersection(di.dest)
                dest_intersections.append(dest_inter)
                src_inter = Interval(
                    start=di.dest_to_src(dest_inter.start),
                    end=di.dest_to_src(dest_inter.end),
                )
                src_intersections.append(src_inter)

        src_intersections = [
            src_inter
            for _, src_inter in sorted(
                zip(dest_intersections, src_intersections),
                key=lambda pair: pair[0].start,
                reverse=higher_is_better
            )
        ]

        next_last_map_src_intervals.extend(src_intersections)

    last_map_src_intervals = next_last_map_src_intervals

best_seed = find_best_seed(seeds, last_map_src_intervals)
location = seed_to_location(best_seed, maps)
assert location == 69_841_803


2.56 ms ± 51 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
