In [1]:
from lib import get_data
test_data = get_data(8, '\n', 'test')
data = get_data(8, '\n')

In [89]:
type Box = tuple[int, int, int]

class Pair:
    def __init__(self, jxn: tuple[Box, Box], dist: float): 
        self.jxn = jxn 
        self.dist = dist
    def __str__(self) -> str:
        return f"{self.jxn}: {self.dist}"

In [101]:
from math import sqrt

def build_box(box: str) -> Box:
    row = box.split(',')
    return (int(row[0]), int(row[1]), int(row[2]))

def get_boxes(data: list[str]) -> set[Box]:
    return set([build_box(x) for x in data])

def get_distance(p: Box, q: Box) -> float:
    return sqrt(sum([(p[i] - q[i]) * (p[i] - q[i]) for i in range(3)]))

def build_pair(b1: Box, b2: Box) -> Pair:
    return Pair((b1, b2), get_distance(b1, b2))

def calculate_paired_distances(data: list[str]) -> list[Pair]:
    pairs: list[Pair] = []
    for i, box in enumerate(data):
        b1 = build_box(box)
        for j in range(i + 1, len(data)):
            b2 = build_box(data[j])
            pairs.append(build_pair(b1, b2))
    pairs.sort(key=lambda x: x.dist)
    return pairs

def add_pair(circuit: set[Box], pair: Pair):
    circuit.add(pair.jxn[0])    
    circuit.add(pair.jxn[1])

def add_pair_to_circuits(circuits: list[set[Box]], pair: Pair) -> list[set[Box]]:
    new_circuits = []
    added_circuit = None
    for circuit in circuits:
        relinked = False
        # if a jxn box is in the circuit
        if pair.jxn[0] in circuit or pair.jxn[1] in circuit:
            # if first find in a circuit
            if added_circuit is None:
                add_pair(circuit, pair)
                added_circuit = circuit
            #  in case of multi-linking with other circuit
            else:
                relinked = True
                add_pair(circuit, pair)
                for box in circuit:
                    added_circuit.add(box)
        
        # Push circuit to replacement list
        if not relinked:
            new_circuits.append(circuit)
                     
    if added_circuit is None:
        b1, b2 = pair.jxn[0], pair.jxn[1]
        new_circuits.append(set([b1, b2]))
    return new_circuits

def get_last_needed_jxn_idx(circuits_count: list[int]) -> int:
    for i in range(len(circuits_count) - 1, -1, -1):
        if circuits_count[i] != 1:
            return i
    return 0

In [None]:
from math import prod

def part_one(data: list[str], top: int):
    pairs = calculate_paired_distances(data)
    circuits: list[set[Box]] = []
    for pair in pairs[:top]:
        circuits = add_pair_to_circuits(circuits, pair)
    return prod(sorted([len(x) for x in circuits], reverse=True)[:3])

part_one(test_data, 10)
part_one(data, 1000)

In [None]:
def part_two(data: list[str]):
    pairs = calculate_paired_distances(data)
    boxes = get_boxes(data)
    circuits: list[set[Box]] = []
    pair_count: list[int] = []
    for pair in pairs:
        circuits = add_pair_to_circuits(circuits, pair)
        boxes.discard(pair.jxn[0])
        boxes.discard(pair.jxn[1])
        pair_count.append(len(circuits) + len(boxes))
    
    idx = get_last_needed_jxn_idx(pair_count)
    final_pair = pairs[idx + 1]
    return final_pair.jxn[0][0] * final_pair.jxn[1][0]

part_two(test_data)
part_two(data)