In [3]:
from dataclasses import dataclass
from scipy.spatial.distance import euclidean
import numpy as np


@dataclass
class JunctionBox:
    identifier: int
    x_coord: int
    y_coord: int
    z_coord: int

    def __str__(self):
        return f"JunctionBox {self.identifier}: ({self.x_coord}, {self.y_coord}, {self.z_coord})"

    def coords(self):
        return (self.x_coord, self.y_coord, self.z_coord)


def parse_junction_boxes(input_text: str) -> list[JunctionBox]:
    junction_boxes = []
    for idx, line in enumerate(input_text.strip().split("\n")):
        x_str, y_str, z_str = line.split(",")
        junction_box = JunctionBox(
            identifier=idx,
            x_coord=int(x_str),
            y_coord=int(y_str),
            z_coord=int(z_str),
        )
        junction_boxes.append(junction_box)
    return junction_boxes


def calculate_distances(
    junction_boxes: list[JunctionBox],
) -> list[tuple[tuple[int, int], float]]:
    distances = []
    for i, box_a in enumerate(junction_boxes):
        for j, box_b in enumerate(junction_boxes):
            if i < j:
                dist = euclidean(box_a.coords(), box_b.coords())
                distances.append(((box_a.identifier, box_b.identifier), dist))
    distances.sort(key=lambda x: x[1])
    return distances


def calculate_networks(
    distances: list[tuple[tuple[int, int], float]],
    max_connections: None | int = 10,
) -> list[list[int]] | tuple[int]:
    networks: list[set[int]] = []

    if max_connections is not None:
        distances = distances[:max_connections]

    for box_ids, _ in distances:
        for network in networks:
            if box_ids[0] in network or box_ids[1] in network:
                network.update(box_ids)
                break
        else:
            networks.append(set(box_ids))

        for count, network in enumerate(networks):
            for other_network in networks[count + 1 :]:
                if network.intersection(other_network):
                    network.update(other_network)
                    networks.remove(other_network)
                    break
        if (len(networks) == 1) and (len(networks[0]) == len(junction_boxes)):
            return box_ids
    networks.sort(key=lambda x: len(x), reverse=True)
    return networks


def calculate_score_part_1(networks: list[set[int]]) -> int:
    return np.prod([len(network) for network in networks[:3]])


def calculate_score_part_2(
    junction_boxes: list[JunctionBox], networks: list[set[int]]
) -> int:
    return junction_boxes[networks[0]].x_coord * junction_boxes[networks[1]].x_coord


with open("inputs/day8.txt", "r") as f:
    input_text = f.read()

junction_boxes = parse_junction_boxes(input_text)
distances = calculate_distances(junction_boxes)
part_1_networks = calculate_networks(distances, max_connections=1000)
part_2_networks = calculate_networks(distances, max_connections=None)
print(f"Part 1: {calculate_score_part_1(part_1_networks)}")
print(f"Part 2: {calculate_score_part_2(junction_boxes, part_2_networks)}")

Part 1: 352584
Part 2: 9617397716
