In [124]:
from functools import cache
from itertools import product
import numpy as np

from aoc import submit

DAY = 19
ROTATIONS = [
    (1, 2, 3), (1, -3, 2), (1, -2, -3), (1, 3, -2),
    (2, 3, 1), (2, -1, 3), (2, -3, -1), (2, 1, -3),
    (3, 1, 2), (3, -2, 1), (3, -1, -2), (3, 2, -1),
    (-1, 3, 2), (-1, -2, 3), (-1, -3, -2), (-1, 2, -3),
    (-2, 1, 3), (-2, -3, 1), (-2, -1, -3), (-2, 3, -1),
    (-3, 2, 1), (-3, -1, 2), (-3, -2, -1), (-3, 1, -2),
]


def sign(n):
    return -1 if n < 0 else int(n > 0)

In [131]:
def rotate(arr, matrix):
    return np.array([arr[abs(m) - 1] * sign(m) for m in matrix])


def parse_input(raw):
    scanners = [s.splitlines()[1:] for s in raw.split('\n\n')]
    return [np.array([[*map(int, line.split(','))] for line in scanner]) for scanner in scanners]

@cache
def solve(raw):
    scanners = parse_input(raw)
    positions = {0: np.array([0, 0, 0])}
    distances = [[np.sum(np.abs(beacons - beacon), axis=1)
                  for beacon in beacons]
                 for beacons in scanners]

    @cache
    def distance_set(i, ii):
        return set(distances[i][ii])

    def overlap(i, j):
        for ii, jj in product(range(len(distances[i])), range(len(distances[j]))):
            overlaps = distance_set(i, ii) & distance_set(j, jj)
            if len(overlaps) >= 12:
                return ii, jj, overlaps
        return 0, 0, None

    def find_rotation(v, w):
        for rot in ROTATIONS:
            if np.array_equal(v, rotate(w, rot)):
                return rot

    queue = [0]
    while queue and len(positions) < len(scanners):
        i = queue.pop()
        for j, other in enumerate(scanners):
            if positions.get(j) is not None: continue
            ii, jj, overlaps = overlap(i, j)

            if overlaps:
                i_beacon = scanners[i][ii]
                j_beacon = scanners[j][jj]
                for dist in overlaps:
                    if dist == 0: continue
                    i_diff = i_beacon - scanners[i][distances[i][ii] == dist]
                    j_diff = j_beacon - scanners[j][distances[j][jj] == dist]

                    rot = find_rotation(i_diff[0], j_diff[0])
                    if rot is None: continue

                    positions[j] = i_beacon - rotate(j_beacon, rot)
                    scanners[j] = np.array([rotate(b, rot) + positions[j] for b in scanners[j]])
                    queue.append(j)
                    queue.append(i)
                    break

    return scanners, positions


@submit(day=DAY)
def part_one(raw):
    scanners, _ = solve(raw)
    beacons = set().union([tuple(beacon) for beacons in scanners for beacon in beacons])
    return len(beacons)


part_one:
✅ example: 79             (8.18 ms)
✅ input:   405            (460.66 ms)


In [126]:
@submit(day=DAY)
def part_two(raw):
    _, positions = solve(raw)
    return max(np.sum(np.abs(p1 - p2)) for (p1, p2) in product(positions.values(), repeat=2))

part_two:
✅ example: 3621           (0.14 ms)
✅ input:   12306          (5.49 ms)
