In [190]:
from pathlib import Path
from itertools import combinations
import numpy as np
from operator import mul
from functools import reduce
from copy import deepcopy

input_file = Path(".") / "input.txt"

def print_matrix(matrix: list, title: str="Matrix"):
    if title:
        print(title.title())
    print(np.array(matrix), "\n")

class Point:
    def __init__(self, x, y, z: str):
        self._x = int(x)
        self._y = int(y)
        self._z = int(z)
        self._connections = set()
    
    def __repr__(self) -> str:
        return f"Point({self._x},{self._y},{self._z} connections={len(self._connections)})"
    
    def __eq__(self, other):
        return self._x == other._x and self._y == other._y and self._z == other._z

    def __hash__(self):
        return hash((self._x, self._y, self._z))

    def x(self) -> int:
        return self._x

    # calculate the Euclidean distance between two Points
    def distance(self, other: Point) -> np.float64:
        p1 = np.array((self._x, self._y, self._z))
        p2 = np.array((other._x, other._y, other._z))

        return np.linalg.norm(p1 - p2)

    def connect(self, point: Point) -> None:
        self._connections.add(point)
        point._connections.add(self)

    # Depth First Search traversal of the connected points
    def traverse(self, visited: list) -> list:
        visited.append(self)
        for p in self._connections:
            if not p in visited:
                p.traverse(visited)
        return visited
'''
class Graph:
    def __init__(self):
        # keys are Points, values represent connections by set
        self._nodes = {}

    def _add_node(self, point1, point2: Point): -> None:
        if point1 in self._nodes:
            self._nodes[point1].add(point2)
        else:
            self._nodes[point1] = set(point2)
        
    def connect(self, point1, point2: Point) -> None:
        self._add_node(point1, point2)
        self._add_node(point2, point1)

    # Depth First Search traversal of the connected points starting from the given one
    def traverse(self, start: Point) -> list:
        if not start in self._nodes:
            raise ValueError(f"Start point {start} is not part of the graph")
        
        return self._traverse(start, [])

    def _traverse(self, point: Point, visited: list) -> list:
        visited.append(point)
        for p in self._nodes[point]:
            if not p in visited:
                self._traverse(p, visited)
        return visited
'''

# take a set of points and calculate the distance between all pairs, that needs (N * N-1) / 2 calculations
# return a list of tuples (distance, PointA, PointB) sorted by increasing distance
# only the shortest N / 2 pairs are returned as per description
def calculate_distance_of_point_pairs(points: set) -> list:
    result = []
    unique_pairs = list(combinations(points, 2))
    for point1, point2 in unique_pairs:
        distance = point1.distance(point2)
        result.append((distance, point1, point2))
    return sorted(result)

# connect point paris until the max_limit connections are reached
def clasterize_points(points_by_distance: list, max_limit: int) -> None:
    points = points_by_distance
    if max_limit > 0:
        points = points_by_distance[0:max_limit]
    for _, point1, point2 in points:
        point1.connect(point2)

# connect point pairs until all are fully connected
# return the last two points connected together
def clasterize_points_until_fully_connected(points_by_distance: list, points: set) -> tuple:
    min_connections_required = points_count - 1
    i = 0
    for _, point1, point2 in points_by_distance:
        i += 1
        point1.connect(point2)
        if i > min_connections_required:
            count = len(calculate_clusters_size(points))
            if count == 1:
                break
            min_connections_required += count - 1
    return (point1, point2)

def calculate_clusters_size(points: set) -> list:
    points_clone = set(points)
    cluster_sizes = []
    while len(points_clone) > 0:
        point = points_clone.pop()
        reachable_points = point.traverse([])
        cluster_sizes.append(len(reachable_points))
        for point in reachable_points:
            points_clone.discard(point)
    return sorted(cluster_sizes, reverse=True)

points = set() # store input as set of Points
with input_file.open(mode="r", encoding="utf-8") as file:
    for line in file:
        points.add(Point(*line.strip().split(",")))

points = frozenset(points)
clasterize_points(calculate_distance_of_point_pairs(points), len(points))
cluster_sizes = calculate_clusters_size(points)
#print(cluster_sizes)

#point1, point2 = clasterize_points_until_fully_connected(calculate_distance_of_point_pairs(points), points)

print(f"Part1 answer: {reduce(mul, cluster_sizes[0:3])}")
#print(f"Part2 answer: {point1.x() * point2.x()}")



Part1 answer: 54180


In [105]:
sorted([ (5, Point(10,5,5)), (2, Point(6,4,2)), (1, Point(6,6,6))])

[(1, Point(6,6,6)), (2, Point(6,4,2)), (5, Point(10,5,5))]

In [27]:
Point(1,2,3) == Point(1,2,3)

True

In [53]:
s1 = set({Point(1,2,3), Point(4,5,6)})
s2 = set({Point(7,8,9), Point(0,0,0), Point(6,6,6)})
l = list([s1, s2])
print(list(map(len, l)))

[2, 3]


In [147]:
[20, 6, 3, 5, 3, 2][0:]

[20, 6, 3, 5, 3, 2]

In [185]:
{Point(1,2,3): ["a"]}

{Point(1,2,3 connections=0): ['a']}