In [1]:
from lib import get_data
test_data = get_data(8, '\n', 'test')
data = get_data(8, '\n')

In [60]:
from math import sqrt
type Box = tuple[int, int, int]
class Pair:
    def __init__(self, jxn: tuple[Box, Box], dist: float): 
        self.jxn = jxn 
        self.dist = dist
    def __str__(self) -> str:
        return f"{self.jxn}: {self.dist}"

def get_coord(box: str) -> Box:
    row = box.split(',')
    return (int(row[0]), int(row[1]), int(row[2]))

def get_distance(p: Box, q: Box) -> float:
    p1 = (p[0] - q[0]) * (p[0] - q[0])
    p2 = (p[1] - q[1]) * (p[1] - q[1])
    p3 = (p[2] - q[2]) * (p[2] - q[2])
    return sqrt(p1 + p2 + p3)

def calculate_paired_distances(data: list[str]):
    pairs: list[Pair] = []
    for i, box in enumerate(data):
        b1 = get_coord(box)
        for j in range(i + 1, len(data)):
            b2 = get_coord(data[j])
            d = get_distance(b1, b2)
            key = (b1, b2)
            pairs.append(Pair(key, d))
    pairs.sort(key=lambda x: x.dist)
    return pairs

def add_pair_to_circuits(circuits: list[set[Box]], pair: Pair):
    new_circuits = []
    added_circuit = None
    for circuit in circuits:
        relinked = False
        # if a jxn box is in the circuit
        if pair.jxn[0] in circuit or pair.jxn[1] in circuit:
            # if first find in a circuit
            if added_circuit is None:
                circuit.add(pair.jxn[0])    
                circuit.add(pair.jxn[1])
                added_circuit = circuit
            #  in case of multi-linking with other circuit
            else:
                relinked = True
                added_circuit.add(pair.jxn[0])
                added_circuit.add(pair.jxn[1])
                for box in circuit:
                    added_circuit.add(box)
        
        # Push circuit to replacement list
        if not relinked:
            new_circuits.append(circuit)
                     
    if added_circuit is None:
        b1, b2 = pair.jxn[0], pair.jxn[1]
        new_circuits.append(set([b1, b2]))
    return new_circuits

In [None]:
from math import prod

def part_one(data: list[str], top: int):
    pairs = calculate_paired_distances(data)
    circuits: list[set[Box]] = []
    for pair in pairs[:top]:
        circuits = add_pair_to_circuits(circuits, pair)
    return prod(sorted([len(x) for x in circuits], reverse=True)[:3])

part_one(test_data, 10)
part_one(data, 1000)