Skip to content

Commit

Permalink
Merge pull request #20 from mchalela/refactor
Browse files Browse the repository at this point in the history
Periodicity class implemented within GriSPy
  • Loading branch information
mchalela committed Nov 24, 2021
2 parents 01e6386 + 981edfe commit 45670c7
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 217 deletions.
8 changes: 6 additions & 2 deletions grispy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
GriSPy is a regular grid search algorithm for quick nearest-neighbor lookup.
"""

__all__ = ["Grid", "GriSPy"]


__version__ = "0.2.0"

Expand All @@ -24,3 +22,9 @@

from .core import Grid, GriSPy
from .periodicity import Periodicity

# =============================================================================
# ALL
# =============================================================================

__all__ = ["Grid", "GriSPy", "Periodicity"]
197 changes: 29 additions & 168 deletions grispy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,6 @@
EMPTY_ARRAY = np.array([], dtype=int)


# =============================================================================
# PERIODICITY CONF CLASS
# =============================================================================


@attr.s(frozen=True)
class PeriodicityConf:
"""Internal representation of the periodicity of the Grid."""

periodic_flag = attr.ib()
pd_hi = attr.ib()
pd_low = attr.ib()
periodic_edges = attr.ib()
periodic_direc = attr.ib()


# =============================================================================
# MAIN CLASS
# =============================================================================
Expand Down Expand Up @@ -544,7 +528,7 @@ class GriSPy(Grid):
Total number of cells.
cell_width: ndarray
Cell size in each dimension.
periodic_flag: bool
isperiodic: bool
If any dimension has periodicity.
periodic_conf: grispy.core.PeriodicityConf
Statistics and intermediate results to make easy and fast the searchs
Expand All @@ -564,9 +548,8 @@ def __attrs_post_init__(self):
"""Init more params and build the grid."""
super().__attrs_post_init__()

self.periodic, self.periodic_conf = self._build_periodicity(
periodic=self.periodic, dim=self.dim
)
if isinstance(self.periodic, dict):
self.periodic = Periodicity(edges=self.periodic, dim=self.dim)

@metric.validator
def _validate_metric(self, attr, value):
Expand All @@ -581,132 +564,32 @@ def _validate_metric(self, attr, value):

@periodic.validator
def _validate_periodic(self, attr, value):
"""Validate if dict or Periodicity instance.
The rest of the validation is handled by Periodicty validators.
"""
# Chek if dict
if not isinstance(value, dict):
if not isinstance(value, (dict, Periodicity)):
raise TypeError(
"Periodicity: Argument must be a dictionary. "
"Got instead type {}".format(type(value))
)

# If dict is empty means no perioity, stop validation.
if len(value) == 0:
return

# Check if keys and values are valid
for k, v in value.items():
# Check if integer
if not isinstance(k, int):
raise TypeError(
"Periodicity: Keys must be integers. "
"Got instead type {}".format(type(k))
)

# Check if tuple or None
if not (isinstance(v, tuple) or v is None):
raise TypeError(
"Periodicity: Values must be tuples. "
"Got instead type {}".format(type(v))
)
if v is None:
continue

# Check if edges are valid numbers
has_valid_number = all(
[
isinstance(v[0], (int, float)),
isinstance(v[1], (int, float)),
]
"Periodicity: Argument must be of type dictionary or "
"Periodicity. Got instead type {}".format(type(value))
)
if not has_valid_number:
raise TypeError(
"Periodicity: Argument must be a tuple of "
"2 real numbers as edge descriptors. "
)

# Check that first number is lower than second
if not v[0] < v[1]:
raise ValueError(
"Periodicity: First argument in tuple must be "
"lower than second argument."
)

# =========================================================================
# PROPERTIES
# =========================================================================

@property
def periodic_flag(self):
"""Proxy to ``periodic_conf_.periodic_flag``."""
return self.periodic_conf.periodic_flag
def isperiodic(self):
"""Proxy to ``periodic.isperiodic``."""
return self.periodic.isperiodic

# =========================================================================
# INTERNAL IMPLEMENTATION
# =========================================================================

def _build_periodicity(self, periodic, dim):
"""Cleanup the periodicity configuration.
Remove the unnecessary axis from the periodic dict and also creates
a configuration for use in the search.
"""
# assume no periodicity
cleaned_periodic = {}

periodic_flag = False
pd_hi, pd_low = None, None
periodic_edges, periodic_direc = None, None

periodic_flag = any([x is not None for x in list(periodic.values())])

# now check if periodic
if periodic_flag:

pd_hi = np.ones((1, dim)) * np.inf
pd_low = np.ones((1, dim)) * -np.inf
periodic_edges = []
for k in range(dim):
aux = periodic.get(k)
cleaned_periodic[k] = aux
if aux:
pd_low[0, k] = aux[0]
pd_hi[0, k] = aux[1]
aux = np.insert(aux, 1, 0.0)
else:
aux = np.zeros((1, 3))
periodic_edges = np.hstack(
[
periodic_edges,
np.tile(aux, (3 ** (dim - 1 - k), 3 ** k)).T.ravel(),
]
)

periodic_edges = periodic_edges.reshape(dim, 3 ** dim).T
periodic_edges -= periodic_edges[::-1]
periodic_edges = np.unique(periodic_edges, axis=0)

mask = periodic_edges.sum(axis=1, dtype=bool)
periodic_edges = periodic_edges[mask]

periodic_direc = np.sign(periodic_edges)

return cleaned_periodic, PeriodicityConf(
periodic_flag=periodic_flag,
pd_hi=pd_hi,
pd_low=pd_low,
periodic_edges=periodic_edges,
periodic_direc=periodic_direc,
)

def _distance(self, centre_0, centres):
"""Compute distance between points.
metric options: 'euclid', 'sphere'
Notes: In the case of 'sphere' metric, the input units must be degrees.
"""
"""Compute distance between points."""
if len(centres) == 0:
return EMPTY_ARRAY.copy()
metric_func = (
Expand Down Expand Up @@ -826,53 +709,33 @@ def _get_neighbor_cells(
return neighbor_cells

def _near_boundary(self, centres, distance_upper_bound):
mask = np.zeros((len(centres), self.dim), dtype=bool)
"""Check if given centres are within distance of the grid boundary."""
window = np.zeros((len(centres), self.dim), dtype=bool)
for k in range(self.dim):
if self.periodic[k] is None:
continue
mask[:, k] = (
window[:, k] = (
abs(centres[:, k] - self.periodic[k][0]) < distance_upper_bound
)
mask[:, k] += (
window[:, k] += (
abs(centres[:, k] - self.periodic[k][1]) < distance_upper_bound
)
return mask.sum(axis=1, dtype=bool)

def _mirror(self, centre, distance_upper_bound):
pd_hi, pd_low, periodic_edges, periodic_direc = (
self.periodic_conf.pd_hi,
self.periodic_conf.pd_low,
self.periodic_conf.periodic_edges,
self.periodic_conf.periodic_direc,
)

mirror_centre = centre - periodic_edges

mask = periodic_direc * distance_upper_bound
mask = mask + mirror_centre
mask = (mask >= pd_low) * (mask <= pd_hi)
mask = np.prod(mask, 1, dtype=bool)
return mirror_centre[mask]
return window.sum(axis=1, dtype=bool)

def _mirror_universe(self, centres, distance_upper_bound):
"""Generate Terran centres in the Mirror Universe."""
terran_centres = np.array([[]] * self.dim).T
terran_indices = np.array([], dtype=int)
near_boundary = self._near_boundary(centres, distance_upper_bound)
if not np.any(near_boundary):
terran_centres = np.array([[]] * self.dim).T
terran_indices = np.array([], dtype=int)
return terran_centres, terran_indices

for i, centre in enumerate(centres):
if not near_boundary[i]:
continue
mirror_centre = self._mirror(centre, distance_upper_bound[i])
if len(mirror_centre) > 0:
terran_centres = np.concatenate(
(terran_centres, mirror_centre), axis=0
)
terran_indices = np.concatenate(
(terran_indices, np.repeat(i, len(mirror_centre)))
)
terran_centres = self.periodic.mirror(centres[near_boundary], levels=1)
# track original indices
multiplicity = self.periodic.multiplicity(levels=1)
indices = np.arange(len(centres))[near_boundary]
terran_indices = np.repeat(indices, multiplicity)

return terran_centres, terran_indices

# =========================================================================
Expand Down Expand Up @@ -907,9 +770,7 @@ def set_periodicity(self, periodic, inplace=False):
if inplace:
periodic_attr = attr.fields(GriSPy).periodic
periodic_attr.validator(self, periodic_attr, periodic)
self.periodic, self.periodic_conf = self._build_periodicity(
periodic=periodic, dim=self.dim
)
self.periodic = Periodicity(periodic, dim=self.dim)
else:
return GriSPy(
data=self.data,
Expand Down Expand Up @@ -982,7 +843,7 @@ def bubble_neighbors(
)

# We need to generate mirror centres for periodic boundaries...
if self.periodic_flag:
if self.isperiodic:
terran_centres, terran_indices = self._mirror_universe(
centres, distance_upper_bound
)
Expand Down Expand Up @@ -1099,7 +960,7 @@ def shell_neighbors(
)

# We need to generate mirror centres for periodic boundaries...
if self.periodic_flag:
if self.isperiodic:
terran_centres, terran_indices = self._mirror_universe(
centres, distance_upper_bound
)
Expand Down
Loading

0 comments on commit 45670c7

Please sign in to comment.