In [None]:
import collections
import itertools
from dataclasses import dataclass
import re

from typing import Tuple, List

In [None]:
def get_scanners(filename):
    scanners = []
    with open(filename) as file:
        for line in file:
            if match := re.search(r"scanner (\d+)", line):
                scanner = Scanner(number=int(match.group(1)), beacons=[])
            elif match := re.match(r"(-?\d+),(-?\d+),(-?\d+)", line):
                pos = Point(tuple(map(int, match.groups())))
                scanner.beacons.append(pos)
            else:
                scanners.append(scanner)
        scanners.append(scanner)

    return scanners

In [None]:
def make_rotations():
    rotations = []
    pp = [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
    sp = [(1, 1, 1), (-1, -1, 1), (-1, 1, -1), (1, -1, -1)]
    pn = [(0, 2, 1), (1, 0, 2), (2, 1, 0)]
    sn = [(-1, 1, 1), (1, -1, 1), (1, 1, -1), (-1, -1, -1)]
    for perm, sign in itertools.chain(itertools.product(pp, sp), itertools.product(pn,sn)):
        rotations.append((perm, sign))
    return rotations

In [None]:
@dataclass(frozen=True)
class Point:
    pos: Tuple[int, int, int]
    
    def __sub__(self, other):
        return Point(tuple(self.pos[i] - other.pos[i] for i in range(3)))
        
    def rotate(self, rotation):
        perm, sign = rotation
        return Point(tuple(self.pos[perm[i]]*sign[i] for i in range(3)))
    
    def distance(self, other):
        return sum(abs(self.pos[i] - other.pos[i]) for i in range(3))

@dataclass
class Scanner:
    number: int
    beacons: List[Point]
    pos: Tuple[int, int, int] = None


In [None]:
def solve_scanner(ref_scanner, new_scanner):
    for rotation in make_rotations():
        diffs = collections.Counter()
        for ref_beacon in ref_scanner.beacons:
            for new_beacon in new_scanner.beacons:
                new_beacon_rot = new_beacon.rotate(rotation)
                diffs.update([new_beacon_rot - ref_beacon])
        most_common = diffs.most_common(1)
        if most_common[0][1] >= 12:
            translation = most_common[0][0]
            # print(f"Using rotation {rotation} and translation {translation.pos}")
            return rot_trans_scanner(new_scanner, rotation, translation)
        
    return False

In [None]:
def rot_trans_scanner(scanner, rotation, translation):
    for i, beacon in enumerate(scanner.beacons):
        scanner.beacons[i] = beacon.rotate(rotation) - translation
    scanner.pos = Point((0,0,0)) - translation
    return scanner

# Part 1

In [None]:
unsolved_scanners = get_scanners("day19.input")

solved_scanners = [unsolved_scanners.pop(0)]
solved_scanners[0].pos = Point((0,0,0))

In [None]:
while unsolved_scanners:
    to_solve = unsolved_scanners.pop()
    for ref in solved_scanners:
        if solved := solve_scanner(ref, to_solve):
            solved_scanners.append(solved)
            print(f"Solved {solved.number} from {ref.number}")
            break
    else:
        unsolved_scanners.insert(0, to_solve)

In [None]:
beacons = set()
for scanner in solved_scanners:
    for beacon in scanner.beacons:
        beacons.add(beacon)
len(beacons)  

# Part 2

In [None]:
distances = []
for scanner in solved_scanners:
    for other in solved_scanners:
        distances.append(scanner.pos.distance(other.pos))

In [None]:
max(distances)