# Playground

AOC Day 8 Part 1

In [26]:
input = [
    "162,817,812",
    "57,618,57",
    "906,360,560",
    "592,479,940",
    "352,342,300",
    "466,668,158",
    "542,29,236",
    "431,825,988",
    "739,650,466",
    "52,470,668",
    "216,146,977",
    "819,987,18",
    "117,168,530",
    "805,96,715",
    "346,949,466",
    "970,615,88",
    "941,993,340",
    "862,61,35",
    "984,92,344",
    "425,690,689"
]

coordinates = [
    (int(x.split(",")[0]), int(x.split(",")[1]), int(x.split(",")[2])) for x in input
]

coordinates

[(162, 817, 812),
 (57, 618, 57),
 (906, 360, 560),
 (592, 479, 940),
 (352, 342, 300),
 (466, 668, 158),
 (542, 29, 236),
 (431, 825, 988),
 (739, 650, 466),
 (52, 470, 668),
 (216, 146, 977),
 (819, 987, 18),
 (117, 168, 530),
 (805, 96, 715),
 (346, 949, 466),
 (970, 615, 88),
 (941, 993, 340),
 (862, 61, 35),
 (984, 92, 344),
 (425, 690, 689)]

We can use a min-heap to store the pairs based on the distance.

In [27]:
import math
from heapq import heappush

def insert_to_heap(A):
    # https://docs.python.org/3/library/heapq.html
    heap = []

    for i in range(len(coordinates) - 1):
        for j in range(i + 1, len(coordinates)):
            ci = coordinates[i]
            cj = coordinates[j]
            x1, y1, z1 = ci[0], ci[1], ci[2]
            x2, y2, z2 = cj[0], cj[1], cj[2]
            distance = math.sqrt((x1 - x2)**2 + (y1 - y2)**2 + (z1 - z2)**2)
            heappush(heap, (distance, i, j))

    return heap

The graph above needs to reflect the following connections.

<div align="center">
    <img src="./img.png" style="max-height:500px">
</div>

Let's try to model this as a [Union-Find](https://yuminlee2.medium.com/union-find-algorithm-ffa9cd7d2dba) problem. This usually appears under "[Disjoint Set](https://en.wikipedia.org/wiki/Disjoint-set_data_structure)" problem sets.

In union finding, we have two subroutines:

1. `find()`: Responsible for finding the root of an existing set and find out if two elements are in the same set.
2. `union()`: Merge the sets if the two given elements `a` and `b` are disjoint.

There is also a third operation `make_set()` but it is already done via the `distance` calculation given in the problem.

NB: The main source of reference for this implementation was from [Medium](https://yuminlee2.medium.com/union-find-algorithm-ffa9cd7d2dba#6382).

In [28]:
from collections import defaultdict
from heapq import heappop

def f(coordinates, k):
    heap = insert_to_heap(coordinates)
    n = len(heap)
    selected = heap[:k]

    # Step 1: Initialize parent and size arrays with the length of the total number of elements.
    # Originally, every parent is a root node.
    parent = list(range(n))
    size = [1] * n

    # Finds the root of an existing set.
    # Note that here `node` is an index to our 3D vector.
    def find(node):
        while node != parent[node]:
            parent[node] = parent[parent[node]]
            node = parent[node]
        return node
    
    def union(node1, node2):
        # Step 2.a: Find root parent and check if two subsets are in the same set.
        root1, root2 = find(node1), find(node2)

        if root1 == root2:
            return False
        
        # Step 2.b Chek what the larger set is.
        if size[root1] > size[root2]:
            # parent[root2] = root1 -> merge the smaller set to larger set.
            parent[root2] = root1
            # Increment the size of the larger set.
            size[root1] += 1
        else:
            # parent[root1] = root2 -> merge the smaller set to larger set.
            parent[root1] = root2
            # Increment the size of the larger set.
            size[root2] += 1

        return True
    
    # Step 2: Traversal through all the edges
    for k in range(len(selected)):
        _, i, j = heappop(heap)
        union(i, j)

    comps = defaultdict(int)
    for idx in range(n):
        comps[find(idx)] += 1   # Off by 1!!!!

    # Sort descending
    sizes = sorted(comps.values(), reverse=True)

    product = 1
    for s in sizes[:3]:
        product *= s

    return product

f(coordinates, 10)

40

$40$ is the correct answer.

Let's try it for our `input.txt`.

In [29]:
with open(file="input.txt") as file:
    input =  [line.rstrip() for line in file]

    coordinates = [
        (int(x.split(",")[0]), int(x.split(",")[1]), int(x.split(",")[2])) for x in input
    ]

    print(f(coordinates, 1000))
    

133574


$133574$ is the correct answer!

---

### Part II