Skip to content

Commit

Permalink
Merge pull request #11 from carsonfarmer/api-simplification
Browse files Browse the repository at this point in the history
Simplify and move API around
  • Loading branch information
cjqf committed Jul 4, 2016
2 parents a5ba9e9 + bf19c9c commit 9236496
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 74 deletions.
2 changes: 1 addition & 1 deletion fastpair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# Copyright (c) 2002-2015, David Eppstein
# Licensed under the MIT Licence (http://opensource.org/licenses/MIT).

from .base import FastPair, interact
from .base import FastPair
85 changes: 40 additions & 45 deletions fastpair/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,22 @@
from operator import itemgetter
from collections import defaultdict
import scipy.spatial.distance as dist
from scipy import mean as _mean, array as _array

__all__ = ["interact", "FastPair", "dist", "default_dist"]

default_dist = dist.euclidean
__all__ = ["FastPair", "dist"]


def interact(u, v):
"""Compute element-wise mean(s) from two arrays."""
return tuple(_mean(_array([u, v]), axis=0))


class _adict(dict):
class attrdict(dict):
"""Simple dict with support for accessing elements as attributes."""
def __init__(self, *args, **kwargs):
super(_adict, self).__init__(*args, **kwargs)
super(attrdict, self).__init__(*args, **kwargs)
self.__dict__ = self


class FastPair(object):
"""FastPair 'sketch' class.
"""
def __init__(self, min_points=10, dist=default_dist, merge=interact):
def __init__(self, min_points=10, dist=dist.euclidean):
"""Initialize an empty FastPair data-structure.
Parameters
Expand All @@ -75,19 +68,11 @@ def __init__(self, min_points=10, dist=default_dist, merge=interact):
from `scipy.spatial.distance` will do the trick. By default, the
Euclidean distance function is used. This function should play
nicely with the `merge` function.
merge : function, default=scipy.mean
Can be any Python function that returns a single 'point' from two
input 'points'. By default, the element-wise mean(s) from two input
point arrays is used. If a user has a 'special' point class; for
example, one that represents cluster centroids, then the user can
specify a function that returns valid clusters. This function
should play nicely with the `dist` function.
"""
self.min_points = min_points
self.dist = dist
self.merge = merge
self.initialized = False # Has the data-structure been initialized?
self.neighbors = defaultdict(_adict) # Dict of neighbor points and dists
self.neighbors = defaultdict(attrdict) # Dict of neighbor points and dists
self.points = list() # Internal point set; entries may be non-unique

def __add__(self, p):
Expand Down Expand Up @@ -129,6 +114,16 @@ def __contains__(self, p):
def __iter__(self):
return iter(self.points)

def __getitem__(self, item):
if not item in self:
raise KeyError("{} not found".format(item))
return self.neighbors[item]

def __setitem__(self, item, value):
if not item in self:
raise KeyError("{} not found".format(item))
self._update_point(item, value)

def build(self, points=None):
"""Build a FastPairs data-structure from a set of (new) points.
Expand Down Expand Up @@ -179,7 +174,7 @@ def closest_pair(self):
"""
if len(self) < 2:
raise ValueError("Must have `npoints >= 2` to form a pair.")
elif len(self) < self.min_points:
elif not self.initialized:
return self.closest_pair_brute_force()
a = self.points[0] # Start with first point
d = self.neighbors[a].dist
Expand All @@ -194,6 +189,24 @@ def closest_pair_brute_force(self):
"""Find closest pair using brute-force algorithm."""
return _closest_pair_brute_force(self.points)

def sdist(self, p):
"""Compute distances from input to all other points in data-structure.
This returns an iterator over all other points and their distance
from the input point `p`. The resulting iterator returns tuples with
the first item giving the distance, and the second item giving in
neighbor point. The `min` of this iterator is essentially a brute-
force 'nearest-neighbor' calculation. To do this, supply `itemgetter`
(or a lambda version) as the `key` argument to `min`.
Examples
--------
>>> fp = FastPair().build(points)
>>> min(fp.sdist(point), key=itemgetter(0))
"""
return ((self.dist(a, b), b) for a, b in
zip(cycle([p]), self.points) if b != a)

def _find_neighbor(self, p):
"""Find and update nearest neighbor of a given point."""
# If no neighbors available, set flag for `_update_point` to find
Expand All @@ -216,12 +229,6 @@ def _find_neighbor(self, p):
self.neighbors[p].neigh = q
return dict(self.neighbors[p]) # Return plain ol' dict

def merge_closest(self):
dist, (a, b) = self.closest_pair()
c = self.merge(a, b)
self -= b
return self._update_point(a, c)

def _update_point(self, old, new):
"""Update point location, neighbors, and distances.
Expand Down Expand Up @@ -255,26 +262,14 @@ def _update_point(self, old, new):
self.neighbors[q].dist = d
return dict(self.neighbors[new])

def sdist(self, p):
"""Compute distances from input to all other points in data-structure.
This returns an iterator over all other points and their distance
from the input point `p`. The resulting iterator returns tuples with
the first item giving the distance, and the second item giving in
neighbor point. The `min` of this iterator is essentially a brute-
force 'nearest-neighbor' calculation. To do this, supply `itemgetter`
(or a lambda version) as the `key` argument to `min`.
Examples
--------
>>> fp = FastPair().build(points)
>>> min(fp.sdist(point), key=itemgetter(0))
"""
return ((self.dist(a, b), b) for a, b in
zip(cycle([p]), self.points) if b != a)
# def merge_closest(self):
# dist, (a, b) = self.closest_pair()
# c = self.merge(a, b)
# self -= b
# return self._update_point(a, c)


def _closest_pair_brute_force(pts, dst=default_dist):
def _closest_pair_brute_force(pts, dst=dist.euclidean):
"""Compute closest pair of points using brute-force algorithm.
Notes
Expand Down
46 changes: 18 additions & 28 deletions fastpair/test/test_fastpair.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from itertools import cycle, combinations, groupby
import random
import pytest
from fastpair import FastPair, interact
from fastpair import FastPair
from math import isinf, isnan

from scipy import mean, array, unique
Expand All @@ -38,13 +38,9 @@ def rand_tuple(dim=2):
return tuple([random.random() for _ in range(dim)])


def to_codebook(X, part):
"""Calculates centroids according to flat cluster assignment."""
codebook = []
X = array(X)
for i in unique(part):
codebook.append(tuple(X[part == i].mean(0)))
return codebook
def interact(u, v):
"""Compute element-wise mean(s) from two arrays."""
return tuple(mean(array([u, v]), axis=0))


# Setup fixtures
Expand Down Expand Up @@ -192,29 +188,23 @@ def test_update_point(self):
assert res[1] == neigh["neigh"]

def test_merge_closest(self):
# Still failing sometimes...
ps = PointSet()
fp1 = FastPair().build(ps)
fp2 = FastPair().build(ps)
# This needs to be 'fleshed' out more... lots of things to test here
random.seed(1234)
ps = PointSet(d=4)
fp = FastPair().build(ps)
# fp2 = FastPair().build(ps)
n = len(ps)
while n >= 2:
dist, (a, b) = fp1.closest_pair()
dist, (a, b) = fp.closest_pair()
new = interact(a, b)
fp1 -= b # Drop b
fp1._update_point(a, new)
fp2.merge_closest()
fp -= b # Drop b
fp._update_point(a, new)
n -= 1
assert len(fp1) == len(fp2) == 1 # == len(fp2)
assert fp1.points == fp2.points # == fp2.points
# Compare points
assert contains_same(list(fp1.neighbors.keys()), list(fp2.neighbors.keys()))
# Compare neighbors
assert contains_same([n["neigh"] for n in fp1.neighbors.values()],
[n["neigh"] for n in fp2.neighbors.values()])
# Compare dists
assert all_close([n["dist"] for n in fp1.neighbors.values()],
[n["dist"] for n in fp2.neighbors.values()])
assert len(fp) == 1 == n
points = [(0.69903599809571437, 0.52457534006594131,
0.7614753848101149, 0.37011695654655385)]
assert all_close(fp.points[0], points[0])
# Should have < 2 points now...
with pytest.raises(ValueError):
fp1.closest_pair()
fp2.closest_pair()
fp.closest_pair()
# fp2.closest_pair()

0 comments on commit 9236496

Please sign in to comment.