diff --git a/src/mlpack/core/tree/ballbound.hpp b/src/mlpack/core/tree/ballbound.hpp index 68a16d759dd..14f289a14ee 100644 --- a/src/mlpack/core/tree/ballbound.hpp +++ b/src/mlpack/core/tree/ballbound.hpp @@ -16,14 +16,14 @@ namespace bound { /** * Ball bound encloses a set of points at a specific distance (radius) from a - * specific point (center). TMetricType is the custom metric type that defaults + * specific point (center). MetricType is the custom metric type that defaults * to the Euclidean (L2) distance. * + * @tparam MetricType metric type used in the distance measure. * @tparam VecType Type of vector (arma::vec or arma::sp_vec or similar). - * @tparam TMetricType metric type used in the distance measure. */ -template> +template, + typename VecType = arma::vec> class BallBound { public: @@ -31,8 +31,6 @@ class BallBound typedef typename VecType::elem_type ElemType; //! A public version of the vector type. typedef VecType Vec; - //! Needed for BinarySpaceTree. - typedef TMetricType MetricType; private: //! The radius of the ball bound. @@ -40,7 +38,7 @@ class BallBound //! The center of the ball bound. VecType center; //! The metric used in this bound. - TMetricType* metric; + MetricType* metric; /** * To know whether this object allocated memory to the metric member @@ -179,9 +177,9 @@ class BallBound ElemType Diameter() const { return 2 * radius; } //! Returns the distance metric used in this bound. - const TMetricType& Metric() const { return *metric; } + const MetricType& Metric() const { return *metric; } //! Modify the distance metric used in this bound. - TMetricType& Metric() { return *metric; } + MetricType& Metric() { return *metric; } //! Serialize the bound. template @@ -189,8 +187,8 @@ class BallBound }; //! A specialization of BoundTraits for this bound type. -template -struct BoundTraits> +template +struct BoundTraits> { //! These bounds are potentially loose in some dimensions. const static bool HasTightBounds = false; diff --git a/src/mlpack/core/tree/ballbound_impl.hpp b/src/mlpack/core/tree/ballbound_impl.hpp index 8e0e658fab5..885acb5a8e0 100644 --- a/src/mlpack/core/tree/ballbound_impl.hpp +++ b/src/mlpack/core/tree/ballbound_impl.hpp @@ -18,10 +18,10 @@ namespace mlpack { namespace bound { //! Empty Constructor. -template -BallBound::BallBound() : +template +BallBound::BallBound() : radius(std::numeric_limits::lowest()), - metric(new TMetricType()), + metric(new MetricType()), ownsMetric(true) { /* Nothing to do. */ } @@ -30,11 +30,11 @@ BallBound::BallBound() : * * @param dimension Dimensionality of ball bound. */ -template -BallBound::BallBound(const size_t dimension) : +template +BallBound::BallBound(const size_t dimension) : radius(std::numeric_limits::lowest()), center(dimension), - metric(new TMetricType()), + metric(new MetricType()), ownsMetric(true) { /* Nothing to do. */ } @@ -44,18 +44,18 @@ BallBound::BallBound(const size_t dimension) : * @param radius Radius of ball bound. * @param center Center of ball bound. */ -template -BallBound::BallBound(const ElemType radius, +template +BallBound::BallBound(const ElemType radius, const VecType& center) : radius(radius), center(center), - metric(new TMetricType()), + metric(new MetricType()), ownsMetric(true) { /* Nothing to do. */ } //! Copy Constructor. To prevent memory leaks. -template -BallBound::BallBound(const BallBound& other) : +template +BallBound::BallBound(const BallBound& other) : radius(other.radius), center(other.center), metric(other.metric), @@ -63,8 +63,8 @@ BallBound::BallBound(const BallBound& other) : { /* Nothing to do. */ } //! For the same reason as the copy constructor: to prevent memory leaks. -template -BallBound& BallBound::operator=( +template +BallBound& BallBound::operator=( const BallBound& other) { radius = other.radius; @@ -74,8 +74,8 @@ BallBound& BallBound::operator=( } //! Move constructor. -template -BallBound::BallBound(BallBound&& other) : +template +BallBound::BallBound(BallBound&& other) : radius(other.radius), center(other.center), metric(other.metric), @@ -89,17 +89,17 @@ BallBound::BallBound(BallBound&& other) : } //! Destructor to release allocated memory. -template -BallBound::~BallBound() +template +BallBound::~BallBound() { if (ownsMetric) delete metric; } //! Get the range in a certain dimension. -template -math::RangeType::ElemType> -BallBound::operator[](const size_t i) const +template +math::RangeType::ElemType> +BallBound::operator[](const size_t i) const { if (radius < 0) return math::Range(); @@ -110,8 +110,8 @@ BallBound::operator[](const size_t i) const /** * Determines if a point is within the bound. */ -template -bool BallBound::Contains(const VecType& point) const +template +bool BallBound::Contains(const VecType& point) const { if (radius < 0) return false; @@ -122,10 +122,10 @@ bool BallBound::Contains(const VecType& point) const /** * Calculates minimum bound-to-point squared distance. */ -template +template template -typename BallBound::ElemType -BallBound::MinDistance( +typename BallBound::ElemType +BallBound::MinDistance( const OtherVecType& point, typename boost::enable_if>* /* junk */) const { @@ -138,9 +138,9 @@ BallBound::MinDistance( /** * Calculates minimum bound-to-bound squared distance. */ -template -typename BallBound::ElemType -BallBound::MinDistance(const BallBound& other) +template +typename BallBound::ElemType +BallBound::MinDistance(const BallBound& other) const { if (radius < 0) @@ -156,10 +156,10 @@ BallBound::MinDistance(const BallBound& other) /** * Computes maximum distance. */ -template +template template -typename BallBound::ElemType -BallBound::MaxDistance( +typename BallBound::ElemType +BallBound::MaxDistance( const OtherVecType& point, typename boost::enable_if >* /* junk */) const { @@ -172,9 +172,9 @@ BallBound::MaxDistance( /** * Computes maximum distance. */ -template -typename BallBound::ElemType -BallBound::MaxDistance(const BallBound& other) +template +typename BallBound::ElemType +BallBound::MaxDistance(const BallBound& other) const { if (radius < 0) @@ -188,10 +188,10 @@ BallBound::MaxDistance(const BallBound& other) * * Example: bound1.MinDistanceSq(other) for minimum squared distance. */ -template +template template -math::RangeType::ElemType> -BallBound::RangeDistance( +math::RangeType::ElemType> +BallBound::RangeDistance( const OtherVecType& point, typename boost::enable_if >* /* junk */) const { @@ -206,9 +206,9 @@ BallBound::RangeDistance( } } -template -math::RangeType::ElemType> -BallBound::RangeDistance( +template +math::RangeType::ElemType> +BallBound::RangeDistance( const BallBound& other) const { if (radius < 0) @@ -226,9 +226,9 @@ BallBound::RangeDistance( /** * Expand the bound to include the given bound. * -template +template const BallBound& -BallBound::operator|=( +BallBound::operator|=( const BallBound& other) { double dist = metric->Evaluate(center, other); @@ -246,10 +246,10 @@ BallBound::operator|=( * The difference lies in the way we initialize the ball bound. The way we * expand the bound is same. */ -template +template template -const BallBound& -BallBound::operator|=(const MatType& data) +const BallBound& +BallBound::operator|=(const MatType& data) { if (radius < 0) { @@ -277,9 +277,9 @@ BallBound::operator|=(const MatType& data) } //! Serialize the BallBound. -template +template template -void BallBound::Serialize( +void BallBound::Serialize( Archive& ar, const unsigned int /* version */) { diff --git a/src/mlpack/core/tree/binary_space_tree/typedef.hpp b/src/mlpack/core/tree/binary_space_tree/typedef.hpp index 7d58f6750e7..28145d11bb2 100644 --- a/src/mlpack/core/tree/binary_space_tree/typedef.hpp +++ b/src/mlpack/core/tree/binary_space_tree/typedef.hpp @@ -103,7 +103,7 @@ template using BallTree = BinarySpaceTree; /** @@ -132,7 +132,7 @@ template using MeanSplitBallTree = BinarySpaceTree; } // namespace tree diff --git a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp index 28c5a0bf8f0..e4aa7c179fe 100644 --- a/src/mlpack/methods/neighbor_search/ns_model_impl.hpp +++ b/src/mlpack/methods/neighbor_search/ns_model_impl.hpp @@ -320,7 +320,7 @@ void NSModel::BuildModel(arma::mat&& referenceSet, { std::vector oldFromNewReferences; typename NSType::Tree* ballTree = - new typename NSType::Tree(std::move(referenceSet), + new typename NSType::Tree(std::move(referenceSet), oldFromNewReferences, leafSize); ballTreeNS = new NSType(ballTree, singleMode); diff --git a/src/mlpack/tests/serialization_test.cpp b/src/mlpack/tests/serialization_test.cpp index 9bddbc2c16d..73a960cf2fe 100644 --- a/src/mlpack/tests/serialization_test.cpp +++ b/src/mlpack/tests/serialization_test.cpp @@ -354,12 +354,12 @@ BOOST_AUTO_TEST_CASE(BallBoundTest) BOOST_AUTO_TEST_CASE(MahalanobisBallBoundTest) { - BallBound> b(100); + BallBound, arma::vec> b(100); b.Center().randu(); b.Radius() = 14.0; b.Metric().Covariance().randu(100, 100); - BallBound> xmlB, textB, binaryB; + BallBound, arma::vec> xmlB, textB, binaryB; SerializeObjectAll(b, xmlB, textB, binaryB);