Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/standardize query args #395

Merged
merged 19 commits into from Aug 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions ChangeLog.md
Expand Up @@ -7,6 +7,7 @@ Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to
### Added
* Ability to specify NeighborQuery objects as points for neighbor-based pair computes.
* Various validation tests.
* Added standard method for preprocessing arguments of pair computations.

### Changed
* All compute objects that perform neighbor computations now use NeighborQuery internally.
Expand All @@ -22,8 +23,6 @@ Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to

### Removed
* The freud.util module.

### Removed
* Python 2 is no longer supported. Python 3.5+ is required.

## v1.2.2 - 2019-08-15
Expand Down
36 changes: 18 additions & 18 deletions cpp/locality/AABBQuery.h
Expand Up @@ -55,11 +55,11 @@ class AABBQuery : public NeighborQuery
this->validateQueryArgs(args);
if (args.mode == QueryArgs::ball)
{
return queryBall(query_points, n_query_points, args.rmax, args.exclude_ii);
return queryBall(query_points, n_query_points, args.r_max, args.exclude_ii);
}
else if (args.mode == QueryArgs::nearest)
{
return query(query_points, n_query_points, args.nn, args.rmax, args.scale, args.exclude_ii);
return query(query_points, n_query_points, args.num_neighbors, args.r_max, args.scale, args.exclude_ii);
}
else
{
Expand All @@ -76,7 +76,7 @@ class AABBQuery : public NeighborQuery
unsigned int num_neighbors, bool exclude_ii = false) const
{
throw std::runtime_error("AABBQuery k-nearest-neighbor queries must use the function signature that "
"provides rmax and scale guesses.");
"provides r_max and scale guesses.");
}

std::shared_ptr<NeighborQueryIterator> query(const vec3<float>* query_points, unsigned int n_query_points, unsigned int num_neighbors,
Expand All @@ -101,26 +101,26 @@ class AABBQuery : public NeighborQuery
AABBTree m_aabb_tree; //!< AABB tree of points

protected:
//! Validate the combination of specified arguments.
/*! Add to parent function to account for the various arguments
* specifically required for AABBQuery nearest neighbor queries.
*/
virtual void validateQueryArgs(QueryArgs& args) const
{
if (args.mode == QueryArgs::ball)
{
if (args.rmax == -1)
throw std::runtime_error("You must set rmax in the query arguments.");
}
else if (args.mode == QueryArgs::nearest)
NeighborQuery::validateQueryArgs(args);
if (args.mode == QueryArgs::nearest)
{
if (args.nn == -1)
throw std::runtime_error("You must set nn in the query arguments.");
if (args.scale == -1)
if (args.scale == QueryArgs::DEFAULT_SCALE)
{
args.scale = float(1.1);
}
if (args.rmax == -1)
if (args.r_max == QueryArgs::DEFAULT_R_MAX)
{
// By default, we use 1/10 the smallest box dimension as the guessed query distance.
vec3<float> L = this->getBox().getL();
float rmax = std::min(L.x, L.y);
args.rmax = this->getBox().is2D() ? float(0.1) * rmax : float(0.1) * std::min(rmax, L.z);
float r_max = std::min(L.x, L.y);
r_max = this->getBox().is2D() ? r_max : std::min(r_max, L.z);
args.r_max = float(0.1) * r_max;
}
}
}
Expand Down Expand Up @@ -152,7 +152,7 @@ class AABBIterator : virtual public NeighborQueryIterator
virtual ~AABBIterator() {}

//! Computes the image vectors to query for
void updateImageVectors(float rmax, bool _check_rmax = true);
void updateImageVectors(float r_max, bool _check_r_max = true);

protected:
const AABBQuery* m_aabb_query; //!< Link to the AABBQuery object
Expand Down Expand Up @@ -204,12 +204,12 @@ class AABBQueryBallIterator : virtual public AABBIterator
public:
//! Constructor
AABBQueryBallIterator(const AABBQuery* neighbor_query, const vec3<float>* points, unsigned int N, float r,
bool exclude_ii, bool _check_rmax = true)
bool exclude_ii, bool _check_r_max = true)
: NeighborQueryIterator(neighbor_query, points, N, exclude_ii),
AABBIterator(neighbor_query, points, N, exclude_ii), m_r(r), cur_image(0), cur_node_idx(0),
cur_ref_p(0)
{
updateImageVectors(m_r, _check_rmax);
updateImageVectors(m_r, _check_r_max);
}

//! Empty Destructor
Expand Down
4 changes: 2 additions & 2 deletions cpp/locality/NeighborComputeFunctional.h
Expand Up @@ -116,7 +116,7 @@ class NeighborQueryNeighborIterator : public NeighborIterator
// the number of neighbors to look for.
if(qargs.exclude_ii && (qargs.mode == QueryArgs::QueryType::nearest))
{
++m_qargs.nn;
++m_qargs.num_neighbors;
}

// check if nq is a pointer to a RawPoints object
Expand Down Expand Up @@ -309,7 +309,7 @@ void loopOverNeighborQuery(const NeighborQuery* neighbor_query, const vec3<float
// the number of neighbors to look for.
if(qargs.exclude_ii && (qargs.mode == QueryArgs::QueryType::nearest))
{
++qargs.nn;
++qargs.num_neighbors;
}
// if nlist does not exist, check if neighbor_query is an actual NeighborQuery
std::shared_ptr<NeighborQueryIterator> iter;
Expand Down
5 changes: 5 additions & 0 deletions cpp/locality/NeighborQuery.cc
Expand Up @@ -7,4 +7,9 @@ namespace freud { namespace locality {

const NeighborBond NeighborQueryIterator::ITERATOR_TERMINATOR(-1, -1, 0);

const QueryArgs::QueryType QueryArgs::DEFAULT_MODE(QueryArgs::none);
const unsigned int QueryArgs::DEFAULT_NUM_NEIGHBORS(0xffffffff);
const float QueryArgs::DEFAULT_R_MAX(-1.0);
const float QueryArgs::DEFAULT_SCALE(-1.0);
const bool QueryArgs::DEFAULT_EXCLUDE_II(false);
}; }; // end namespace freud::locality
73 changes: 57 additions & 16 deletions cpp/locality/NeighborQuery.h
Expand Up @@ -20,32 +20,40 @@

namespace freud { namespace locality {

//! (Almost) POD class to hold information about generic queries.
//! POD class to hold information about generic queries.
/*! This class provides a standard method for specifying the type of query to
* perform with a NeighborQuery object. Rather than calling queryBall
* specifically, for example, the user can call a generic querying function and
* provide an instance of this class to specify the nature of the query.
*/
struct QueryArgs
{
//! Define constructor
/*! We must violate the strict POD nature of the class to support default
* values for parameters.
//! Default constructor.
/*! We set default values for all parameters here.
*/
QueryArgs() : nn(-1), rmax(-1), scale(-1), exclude_ii(false) {}
QueryArgs() : mode(DEFAULT_MODE), num_neighbors(DEFAULT_NUM_NEIGHBORS), r_max(DEFAULT_R_MAX),
scale(DEFAULT_SCALE), exclude_ii(DEFAULT_EXCLUDE_II) {}

//! Enumeration for types of queries.
enum QueryType
{
ball, //! Query based on distance cutoff.
nearest //! Query based on number of requested neighbors.
none, //! Default query type to avoid implicit default types.
ball, //! Query based on distance cutoff.
nearest, //! Query based on number of requested neighbors.
};

QueryType mode; //! Whether to perform a ball or k-nearest neighbor query.
int nn; //! The number of nearest neighbors to find.
float rmax; //! The cutoff distance within which to find neighbors
unsigned int num_neighbors; //! The number of nearest neighbors to find.
float r_max; //! The cutoff distance within which to find neighbors
float scale; //! The scale factor to use when performing repeated ball queries to find a specified number
//! of nearest neighbors.
bool exclude_ii; //! If true, exclude self-neighbors.

static const QueryType DEFAULT_MODE; //!< Default mode.
static const unsigned int DEFAULT_NUM_NEIGHBORS; //!< Default number of neighbors.
static const float DEFAULT_R_MAX; //!< Default query distance.
static const float DEFAULT_SCALE; //!< Default scaling parameter for AABB nearest neighbor queries.
static const bool DEFAULT_EXCLUDE_II; //!< Default for whether or not to include self-neighbors.
};

// Forward declare the iterator
Expand Down Expand Up @@ -90,11 +98,11 @@ class NeighborQuery
this->validateQueryArgs(args);
if (args.mode == QueryArgs::ball)
{
return this->queryBall(query_points, n_query_points, args.rmax, args.exclude_ii);
return this->queryBall(query_points, n_query_points, args.r_max, args.exclude_ii);
}
else if (args.mode == QueryArgs::nearest)
{
return this->query(query_points, n_query_points, args.nn, args.exclude_ii);
return this->query(query_points, n_query_points, args.num_neighbors, args.exclude_ii);
}
else
{
Expand Down Expand Up @@ -125,7 +133,7 @@ class NeighborQuery
}

//! Get the number of reference points
const unsigned int getNPoints() const
unsigned int getNPoints() const
{
return m_n_points;
}
Expand All @@ -141,17 +149,50 @@ class NeighborQuery
}

protected:
//! Validate the combination of specified arguments.
/*! Before checking if the combination of parameters currently set is
* valid, this function first attempts to infer a mode if one is not set in
* order to allow the user to specify certain simple minimal argument
* combinations (e.g. just an r_max) without having to specify the mode
* explicitly.
*/
virtual void validateQueryArgs(QueryArgs& args) const
{
inferMode(args);
// Validate remaining arguments.
if (args.mode == QueryArgs::ball)
{
if (args.rmax == -1)
throw std::runtime_error("You must set rmax in the query arguments.");
if (args.r_max == QueryArgs::DEFAULT_R_MAX)
throw std::runtime_error("You must set r_max in the query arguments when performing ball queries.");
if (args.num_neighbors != QueryArgs::DEFAULT_NUM_NEIGHBORS)
throw std::runtime_error("You cannot set num_neighbors in the query arguments when performing ball queries.");
}
else if (args.mode == QueryArgs::nearest)
{
if (args.nn == -1)
throw std::runtime_error("You must set nn in the query arguments.");
if (args.num_neighbors == QueryArgs::DEFAULT_NUM_NEIGHBORS)
throw std::runtime_error("You must set num_neighbors in the query arguments when performing number of neighbor queries.");
}
}

//! Try to determine the query mode if one is not specified.
/*! If no mode is specified and a number of neighbors is specified, the
* query mode must be a nearest neighbors query (all other arguments can
* reasonably modify that query). Otherwise, if a max distance is set we
* can assume a ball query is desired.
*/
virtual void inferMode(QueryArgs& args) const
{
// Infer mode if possible.
if (args.mode == QueryArgs::none)
{
if (args.num_neighbors != QueryArgs::DEFAULT_NUM_NEIGHBORS)
{
args.mode = QueryArgs::nearest;
}
else if (args.r_max != QueryArgs::DEFAULT_R_MAX)
{
args.mode = QueryArgs::ball;
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions freud/_locality.pxd
Expand Up @@ -19,13 +19,14 @@ cdef extern from "NeighborBond.h" namespace "freud::locality":
cdef extern from "NeighborQuery.h" namespace "freud::locality":

ctypedef enum QueryType "freud::locality::QueryArgs::QueryType":
none "freud::locality::QueryArgs::QueryType::none"
ball "freud::locality::QueryArgs::QueryType::ball"
nearest "freud::locality::QueryArgs::QueryType::nearest"

cdef cppclass QueryArgs:
QueryType mode
int nn
float rmax
int num_neighbors
float r_max
float scale
bool exclude_ii

Expand Down
3 changes: 3 additions & 0 deletions freud/common.pxd
@@ -1,2 +1,5 @@
cdef class Compute:
cdef public _called_compute

cdef class PairCompute(Compute):
pass
89 changes: 87 additions & 2 deletions freud/common.pyx
Expand Up @@ -8,9 +8,17 @@ import freud.box

from functools import wraps

cimport freud.box
cimport freud.locality

cdef class Compute:
R"""Parent class implementing functions to prevent access of
uncomputed values.
R"""Parent class for all compute classes in freud.

Currently, the primary purpose of this class is implementing functions to
prevent access of uncomputed values. This is accomplished by maintaining a
dictionary of compute functions in a class that have been called and
decorating class properties with the names of the compute function that
must be called to populate that property.

To use this class, one would do, for example,

Expand Down Expand Up @@ -125,6 +133,83 @@ cdef class Compute:
return wrapper


cdef class PairCompute(Compute):
R"""Parent class for all compute classes in freud that depend on finding
nearest neighbors.
bdice marked this conversation as resolved.
Show resolved Hide resolved

The purpose of this class is to consolidate some of the logic for parsing
the numerous possible inputs to the compute calls of such classes. In
particular, this class contains a helper function that calls the necessary
functions to create NeighborQuery and NeighborList classes as needed, as
well as dealing with boxes and query arguments.

.. moduleauthor:: Vyas Ramasubramani <vramasub@umich.edu>
"""

def preprocess_arguments(self, box, points, query_points=None, nlist=None,
query_args=None):
"""Process standard compute arguments into freud's internal types by
calling all the required internal functions.

This function handles the preprocessing of boxes and points into
:class:`freud.locality.NeighborQuery` objects, the determination of how
to handle the NeighborList object, the creation of default query
arguments as needed, deciding what `query_points` are, and setting the
appropriate `exclude_ii` flag.

Args:
box (:class:`freud.box.Box`):
Simulation box.
points ((:math:`N_{points}`, 3) :class:`numpy.ndarray`):
Reference points used to calculate the RDF.
query_points ((:math:`N_{query_points}`, 3) :class:`numpy.ndarray`, optional):
Points used to calculate the RDF. Uses :code:`points` if
not provided or :code:`None`.
nlist (:class:`freud.locality.NeighborList`, optional):
NeighborList to use to find bonds (Default value =
:code:`None`).
query_args (dict): A dictionary of query arguments (Default value =
:code:`None`).
""" # noqa E501
cdef freud.box.Box b = freud.common.convert_box(box)

cdef freud.locality.NeighborQuery nq = freud.locality.make_default_nq(
box, points)
vyasr marked this conversation as resolved.
Show resolved Hide resolved
cdef freud.locality.NlistptrWrapper nlistptr = \
freud.locality.NlistptrWrapper(nlist)

cdef freud.locality._QueryArgs qargs
if query_args is not None:
qargs = freud.locality._QueryArgs.from_dict(query_args)
else:
try:
qargs = freud.locality._QueryArgs.from_dict(
self.default_query_args)
qargs.update({'exclude_ii': query_points is None})
except ValueError:
# If a NeighborList was provided, then the user need not
# provide QueryArgs.
if nlist is None:
raise
else:
qargs = freud.locality._QueryArgs()

if query_points is None:
query_points = nq.points
query_points = freud.common.convert_array(
query_points, shape=(None, 3))
cdef const float[:, ::1] l_query_points = query_points
cdef unsigned int num_query_points = l_query_points.shape[0]
return (b, nq, nlistptr, qargs, l_query_points, num_query_points)

@property
def default_query_args(self):
raise ValueError(
"The {} class does not provide default query arguments. You must "
"either provide query arguments or a neighbor list to this "
"compute method.".format(type(self).__name__))


def convert_array(array, shape=None, dtype=np.float32):
"""Function which takes a given array, checks the dimensions and shape,
and converts to a supplied dtype.
Expand Down