# Day 8
## Part 1

The obvious solution is to sort distances. How many points are there?

In [1]:
from advent import read_input, Point3D, distance

def parse_data(s):
    return [
        Point3D(*(int(x) for x in line.split(",")))
        for line in s.strip().splitlines()
    ]

data = parse_data(read_input())
len(data)

1000

Ok, about half a million, that makes comparing each distance tractable.

Compare each unique pair of points, using a heap queue (max variant, need to upgrade Python to 3.14) to track the $n$ shortest pairs.

In [12]:
import heapq

def n_shortest_distances(n, points):
    h = []
    for i, p1 in enumerate(points):
        for p2 in points[i + 1:]:
            d = distance(p1, p2)
            if len(h) < n:
                heapq.heappush_max(h, (d, p1, p2))
            else:
                x = heapq.heappop_max(h)
                if d < x[0]:
                    heapq.heappush_max(h, (d, p1, p2))
                else:
                    heapq.heappush_max(h, x)
    return h

Create a graph of connected junction boxes.

In [26]:
import networkx as nx
import math

def circuits(n, junction_boxes):
    g = nx.Graph()
    g.add_edges_from([
        (p1, p2) 
        for _, p1, p2 in n_shortest_distances(n, junction_boxes)
    ])
    return list(nx.connected_components(g))

def part_1(data, n=1000):
    return math.prod(sorted([len(x) for x in circuits(n, data)], reverse=True)[:3])

test_data = parse_data("""162,817,812
57,618,57
906,360,560
592,479,940
352,342,300
466,668,158
542,29,236
431,825,988
739,650,466
52,470,668
216,146,977
819,987,18
117,168,530
805,96,715
346,949,466
970,615,88
941,993,340
862,61,35
984,92,344
425,690,689
""")

assert part_1(test_data, 10) == 40

In [28]:
data = parse_data(read_input())

part_1(data)

131580

## Part 2

In [30]:
def shortest_distances(points):
    h = []
    for i, p1 in enumerate(points):
        for p2 in points[i + 1:]:
            d = distance(p1, p2)
            heapq.heappush(h, (d, p1, p2))
    while h:
        yield heapq.heappop(h)

def part_2(points):
    g = nx.Graph()
    g.add_nodes_from(points)
    for _, p1, p2 in shortest_distances(points):
        g.add_edge(p1, p2)
        if nx.is_connected(g):
            return p1.x * p2.x

assert part_2(test_data) == 25272

In [31]:
part_2(data)

6844224

That's slow.

In [32]:
%%timeit 

part_2(data)

9.7 s ± 18.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Checking for full connectivity at each step is the expensive part but I'm not sure I can be bothered optimising it.