Skip to content
This repository has been archived by the owner on Apr 16, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3 from mpewsey/spatial-hash
Browse files Browse the repository at this point in the history
Refactor SpatialHash
  • Loading branch information
mpewsey committed May 5, 2019
2 parents 9e84f8f + 566eb3a commit 7b5317b
Showing 1 changed file with 135 additions and 90 deletions.
225 changes: 135 additions & 90 deletions civpy/survey/spatial_hash.py
Expand Up @@ -64,15 +64,15 @@ def _add_points(self, points):
self._dim = points.shape[1]
self.points = points

hashes = self._multi_hash(points)
hash_dict = self._dict
hashes = self._multi_hash(points, norm=True)
odict = self._dict

for i, h in enumerate(hashes):
if h not in hash_dict:
hash_dict[h] = []
hash_dict[h].append(i)
if h not in odict:
odict[h] = []
odict[h].append(i)

def _multi_hash(self, points):
def _multi_hash(self, points, norm):
"""
Returns a list of dictionary hash keys corresponding to the input
points.
Expand All @@ -81,25 +81,51 @@ def _multi_hash(self, points):
----------
points : list
A list of points of shape (N, D).
norm : bool
If True, normalizes the points to their grid index. Otherwise,
assumes that the input points are grid indices.
"""
points = np.asarray(points)
points = (points // self._grid).astype('int')
return [hash(tuple(x)) for x in points]
if norm:
points = np.asarray(points) // self._grid

return (hash(tuple(x)) for x in points)

def _hash(self, point):
def _hash(self, point, norm):
"""
Returns the hash key corresponding to the input point.
Parameters
----------
point : list
A list of shape (D,).
norm : bool
If True, normalizes the points to their grid index. Otherwise,
assumes that the input points are grid indices.
"""
point = np.asarray(point)
point = (point // self._grid).astype('int')
if norm:
point = np.asarray(point) // self._grid

return hash(tuple(point))

def get(self, point):
def multi_get(self, points, norm=True):
"""
Parameters
----------
points : list
A list of points of shape (N, D).
norm : bool
If True, normalizes the points to their grid index. Otherwise,
assumes that the input points are grid indices.
"""
result = []
odict = self._dict

for x in self._multi_hash(points, norm):
result.extend(odict.get(x, []))

return np.asarray(np.unique(result), dtype='int')

def get(self, point, norm=True):
"""
Returns the point indices correesponding to the same hash as the input
point.
Expand All @@ -108,9 +134,33 @@ def get(self, point):
----------
point : list
A list of shape (D,).
norm : bool
If True, normalizes the points to their grid index. Otherwise,
assumes that the input points are grid indices.
"""
h = self._hash(point)
return self._dict.get(h, [])
point = self._hash(point, norm)
return self._dict.get(point, [])

def _query_point_hash(self, point, ro, ri):
# Calculate worst case offsets
diag = self._grid * self._dim**0.5
ri = max(ri - diag, 0)
ro = ro + diag

# Create meshgrid of hash indices
p = np.column_stack([point - ro, point + ro]) // self._grid
p = [np.arange(a, b+1) for a, b in p]
p = np.array(np.meshgrid(*p), dtype='int').T.reshape(-1, self._dim)

# Filter hashes by distance
dist = np.linalg.norm(point - self._grid * p, axis=1)

if ri == 0:
p = p[dist <= ro]
else:
p = p[(dist <= ro) & (dist >= ri)]

return self.multi_get(p, norm=False)

def query_point(self, point, ro, ri=0):
"""
Expand All @@ -128,98 +178,93 @@ def query_point(self, point, ro, ri=0):
"""
point = np.asarray(point)
self._check_shape(point)

# Get hash filtered points
result = self._query_point_hash(point, ro, ri)
p = self.points[result]

# Filter points by distance
dist = np.linalg.norm(p - point, axis=1)

if ri == 0:
f = (dist <= ro)
else:
f = (dist <= ro) & (dist >= ri)

return result[f][dist[f].argsort()]

def _query_range_hash(self, a, b, ro, ri, u, l):
# Calculate worst case offsets
diag = self._grid * self._dim**0.5
hi = max(ri - diag, 0)
ho = ro + diag
result = [[]]

x = np.column_stack([point - ro, point + ro])
x = (x // self._grid).astype('int')
x = [np.arange(a, b+1) for a, b in x]
x = np.array(np.meshgrid(*x), dtype='float').T.reshape(-1, self._dim)
x *= self._grid

# Filter hashes
dist = np.linalg.norm(point - x, axis=1)
x = x[(dist <= ho) & (dist >= hi)]

for p in x:
p = self.get(p)
result.append(p)

# Evaluate points
result = np.unique(np.concatenate(result)).astype('int')
x = self.points[result]
dist = np.linalg.norm(x - point, axis=1)
f = (dist <= ro) & (dist >= ri)
result = result[f][dist[f].argsort()]

return result

def query_range(self, start, stop, ro, ri=0):
ri = max(ri - diag, 0)
ro = ro + diag

# Create meshgrid of hash indices
x = np.column_stack([a - ro, a - ro]).min(axis=1)
y = np.column_stack([b + ro, b + ro]).max(axis=1)

p = np.column_stack([x, y]) // self._grid
p = [np.arange(x, y+1) for x, y in p]
p = np.array(np.meshgrid(*p), dtype='int').T.reshape(-1, self._dim)

# Filter hashes by projection and offset
v = self._grid * p - b
proj = np.dot(v, u)
dist = np.linalg.norm(v - proj.reshape(-1, 1) * u, axis=1)
del v

if ri == 0:
p = p[(proj >= -diag) & (proj <= l+diag) & (dist <= ro)]
else:
p = p[(proj >= -diag) & (proj <= l+diag) & (dist <= ro) & (dist >= ri)]

return self.multi_get(p, norm=False)

def query_range(self, a, b, ro, ri=0):
"""
Returns an array of point indices for all points along the specified
range within the inner and outer offsets.
Parameters
----------
start : list
a : list
The starting point for the range. The point should be of shape (D,).
stop : list
b : list
The ending point for the range. The point should be of shape (D,).
ro : float
The outer offset beyond which points will be excluded.
ri : float
The inner offset before which points will be excluded.
"""
start = np.asarray(start)
stop = np.asarray(stop)
self._check_shape(start)
self._check_shape(stop)
unit = stop - start
length = np.linalg.norm(unit)
a = np.asarray(a)
b = np.asarray(b)
self._check_shape(a)
self._check_shape(b)

if length == 0:
return self.query_point(start, ro, ri)
# Create unit vector for range
u = a - b
l = np.linalg.norm(u)

unit = unit / length
diag = self._grid * self._dim**0.5
hi = max(ri - diag, 0)
ho = ro + diag
mi = -diag
mo = length + diag
result = [[]]

a = np.column_stack([start - ro, stop - ro]).min(axis=1)
b = np.column_stack([start + ro, stop + ro]).max(axis=1)

x = np.column_stack([a, b])
x = (x // self._grid).astype('int')
x = [np.arange(a, b+1) for a, b in x]
x = np.array(np.meshgrid(*x), dtype='float').T.reshape(-1, self._dim)
x *= self._grid

# Filter hashes
s = x - start
proj = np.dot(s, unit)
off = np.linalg.norm(s - np.expand_dims(proj, 1)*unit, axis=1)
x = x[(proj >= mi) & (proj <= mo) & (off <= ho) & (off >= hi)]

for p in x:
p = self.get(p)
result.append(p)

# Evaluate points
result = np.unique(np.concatenate(result)).astype('int')
x = self.points[result]
s = x - start
proj = np.dot(s, unit)
off = np.linalg.norm(s - np.expand_dims(proj, 1)*unit, axis=1)
f = (proj >= 0) & (proj <= length) & (off <= ro) & (off >= ri)
result = result[f][off[f].argsort()]

return result
if l == 0:
return self.query_point(a, ro, ri)

u = u / l

# Get hash filtered points
result = self._query_range_hash(a, b, ro, ri, u, l)
p = self.points[result]

# Filter points by projection and offset
v = p - b
proj = np.dot(v, u)
dist = np.linalg.norm(v - proj.reshape(-1, 1) * u, axis=1)

if ri == 0:
f = (proj >= 0) & (proj <= l) & (dist <= ro)
else:
f = (proj >= 0) & (proj <= l) & (dist <= ro) & (dist >= ri)

return result[f][dist[f].argsort()]

def _plot_1d(self, ax, sym):
"""
Expand Down

0 comments on commit 7b5317b

Please sign in to comment.