In [52]:
from typing import Tuple, List
vector = Tuple[float, float]

def squared_distance(p: vector, q: vector) -> float:
    '''Returns the squared distance between points p and q'''
    (px, py), (qx, qy) = p, q
    return (px - qx)**2 + (py - qy)**2

def main(points: List[vector]) -> float:
    Px = sorted(points, key=lambda x: x[0])
    Py = sorted(points, key=lambda x: x[1])
    return closest_pair(Px, Py)

def closest_pair(Px: List[vector], Py: List[vector]) -> float:
    if len(Px) > 3:
        mid = len(Px) // 2
        Qx, Rx = Px[:mid], Px[mid:]
        Qy, Ry = Py[:mid], Py[mid:]

        d1 = closest_pair(Qx, Qy)
        d2 = closest_pair(Rx, Ry)
        min_d = min(d1, d2)
        d3 = closest_split_pair(Px, Py, min_d)
        return min_d if not d3 else min(min_d, d3)
    min_d = None
    for p in Px:
        for q in Px:
            if p != q:
                d = squared_distance(p,q)
                if not min_d or d < min_d:
                    min_d = d
    return min_d
                
def closest_split_pair(Px: List[vector], Py: List[vector], d: float) -> float:
    median_x = Px[len(Px) // 2][0] 
    Sy = [point for point in Py 
          if median_x - d <= point[0] <= median_x + d]
    min_d = None
    for i in range(0, len(Sy) - 1):
        for j in range(i, i + min(8, len(Sy) - i)):
            dist_ij = squared_distance(Sy[i], Sy[j])
            if not min_d or squared_distance(Sy[i], Sy[j]) < min_d:
                min_d = dist_ij
    return min_d
            
    
    
    
        


In [67]:
from typing import Tuple, List, NamedTuple

class Point(NamedTuple):
    x: float
    y: float

def squared_distance(p: Point, q: Point) -> float:
    '''Returns the squared distance between points p and q'''
    (px, py), (qx, qy) = p, q
    return (px - qx)**2 + (py - qy)**2

def main(points: List[Point]) -> float:
    Px = sorted(points, key=lambda p: p.x)
    Py = sorted(points, key=lambda p: p.y)
    return closest_pair(Px, Py)

def closest_pair(Px: List[Point], Py: List[Point]) -> float:
    if len(Px) > 3:
        mid = len(Px) // 2
        Qx, Rx = Px[:mid], Px[mid:]
        Qy, Ry = Py[:mid], Py[mid:]

        d1 = closest_pair(Qx, Qy)
        d2 = closest_pair(Rx, Ry)
        min_d = min(d1, d2)
        d3 = closest_split_pair(Px, Py, min_d)
        return min_d if not d3 else min(min_d, d3)
    min_d = None
    for p in Px:
        for q in Px:
            if p != q:
                d = squared_distance(p,q)
                if not min_d or d < min_d:
                    min_d = d
    return min_d
                
def closest_split_pair(Px: List[Point], Py: List[Point], d: float) -> float:
    median_x = Px[len(Px) // 2].x
    Sy = [point for point in Py 
          if median_x - d <= point.x <= median_x + d]
    min_d = None
    for i in range(0, len(Sy) - 1):
        for j in range(i, i + min(8, len(Sy) - i)):
            dist_ij = squared_distance(Sy[i], Sy[j])
            if not min_d or squared_distance(Sy[i], Sy[j]) < min_d:
                min_d = dist_ij
    return min_d
            
    
    
    
        


In [69]:
import unittest

tests = (
    (
        [Point(-7, 9), Point(-8, -1), Point(-5, -5)],
        25
    ),
    (
        [Point(15, -15), Point(-16, -1), Point(8, 2), Point(16, 3), Point(-20, 13), Point(-6, 18)],
        65
    ),
    (
        [Point(-16, -30), Point(29, -27), Point(-28, -26), Point(-5, -19), Point(0, -16), Point(-2, -6), Point(28, -3), Point(-13, 1), Point(-4, 10), Point(-14, 17)],
        34
    ),
    (
        [Point(24, -45), Point(-25, -43), Point(-36, -41), Point(0, -37), Point(-9, -35), Point(33, -31), Point(-34, -30), Point(-23, -26), Point(23, -25), Point(15, -20), Point(28, -9), Point(-10, -3), Point(-28, 1), Point(-45, 16), Point(-35, 17), Point(44, 20), Point(22, 22), Point(-11, 42), Point(20, 43), Point(16, 45)],
        20
    ),
    (
        [Point(x,y) for x, y in [(-4, -90), (-47, -89), (44, -88), (17, -85), (-28, -82), (60, -78), (88, -76), (89, -73), (-64, -63), (86, -59), (4, -58), (-44, -54), (-83, -53), (-54, -46), (-72, -45), (18, -44), (-74, -38), (58, -30), (84, -28), (-49, -26), (-38, -24), (87, -23), (50, -22), (33, -18), (14, -17), (74, -10), (-59, -9), (-29, -8), (-17, -1), (-3, 4), (-71, 12), (69, 14), (-32, 16), (68, 28), (2, 31), (-69, 36), (-57, 45), (11, 46), (-40, 54), (25, 55), (52, 67), (61, 68), (27, 69), (49, 72), (-79, 75), (-39, 76), (-60, 77), (82, 82), (-48, 87), (-53, 88)]],
        10
    )
)

def check(points, min_d):
    return min_d == main(points)

class TestClosestPair(unittest.TestCase):
    def test_01(self):
        self.assertTrue(check(tests[0][0], tests[0][1]))

    def test_02(self):
        self.assertTrue(check(tests[1][0], tests[1][1]))

    def test_03(self):
        self.assertTrue(check(tests[2][0], tests[2][1]))

    def test_04(self):
        self.assertTrue(check(tests[3][0], tests[3][1]))
       
    def test_05(self):
        self.assertTrue(check(tests[4][0], tests[4][1]))

unittest.main(argv=['first-arg-is-ignored'], verbosity = 3, exit = False)


test_01 (__main__.TestClosestPair) ... FAIL
test_02 (__main__.TestClosestPair) ... ok
test_03 (__main__.TestClosestPair) ... ok
test_04 (__main__.TestClosestPair) ... ok
test_05 (__main__.TestClosestPair) ... ok

FAIL: test_01 (__main__.TestClosestPair)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "<ipython-input-69-934345aa8a19>", line 31, in test_01
    self.assertTrue(check(tests[0][0], tests[0][1]))
AssertionError: False is not true

----------------------------------------------------------------------
Ran 5 tests in 0.008s

FAILED (failures=1)


<unittest.main.TestProgram at 0x28fc5930588>