diff --git a/src/mlpack/core/metrics/lmetric.hpp b/src/mlpack/core/metrics/lmetric.hpp index a84b2c610e1..6f350bf9bed 100644 --- a/src/mlpack/core/metrics/lmetric.hpp +++ b/src/mlpack/core/metrics/lmetric.hpp @@ -83,6 +83,12 @@ class LMetric static typename VecTypeA::elem_type Evaluate(const VecTypeA& a, const VecTypeB& b); + + template + static typename VecTypeA::elem_type Evaluate(const VecTypeA& a, + const VecTypeB& b, + typename VecTypeA::elem_type bound); + //! Serialize the metric (nothing to do). template void Serialize(Archive& /* ar */, const unsigned int /* version */) { } diff --git a/src/mlpack/core/metrics/lmetric_impl.hpp b/src/mlpack/core/metrics/lmetric_impl.hpp index 4128ae9bdfd..3398cbf2650 100644 --- a/src/mlpack/core/metrics/lmetric_impl.hpp +++ b/src/mlpack/core/metrics/lmetric_impl.hpp @@ -14,7 +14,6 @@ // In case it hasn't been included. #include "lmetric.hpp" - namespace mlpack { namespace metric { @@ -73,6 +72,46 @@ typename VecTypeA::elem_type LMetric<2, false>::Evaluate( return accu(arma::square(a - b)); } +template +void RotateVector(const VecTypeA& a, const VecTypeB& b, VecTypeA& vecAout, VecTypeB& vecBout) +{ + VecTypeA c = arma::abs(a - b); //calculate absolute difference instead of variance between two vectors just for simplicity + VecTypeA indices = arma::sort_index(c, 1); //1 for the 2nd parameter is for decreasing order + + //rearrange two vectors according to stored indices + //therefore, we get two modified vectors in decreasing order of difference(or variance) of the corresponding elements + arma::umat X = arma::join_rows(a, b); + arma::uvec rows = arma::linspace(0, X.n_cols - 1, X.n_cols); + X = X.submat(indices, rows); + vecAout = X.col(0); + vecBout = X.col(1); +} + +// L2-metric my own specializations. +template<> +template +typename VecTypeA::elem_type LMetric<2, true>::Evaluate( + const VecTypeA& a, + const VecTypeB& b, + typename VecTypeA::elem_type bound) +{ + VecTypeA out1; + VecTypeB out2; + RotateVector(a, b, out1, out2); + //rotate vectors a and b + a = out1; + b = out2; + typename VecTypeA::elem_type sum = 0; + typename VecTypeA::elem_type bound_square = std::pow(bound, 2.0); + for (size_t i = 0; i < a.n_elem; i++) + { + sum += std::pow(fabs(a[i] - b[i]), 2.0); + if (sum >= bound_square) + return sum; + } + return sum; +} + // L3-metric specialization (not very likely to be used, but just in case). template<> template diff --git a/src/mlpack/core/tree/cover_tree/cover_tree.hpp b/src/mlpack/core/tree/cover_tree/cover_tree.hpp index 7065c1313f5..790bb578c51 100644 --- a/src/mlpack/core/tree/cover_tree/cover_tree.hpp +++ b/src/mlpack/core/tree/cover_tree/cover_tree.hpp @@ -356,6 +356,9 @@ class CoverTree //! Return the maximum distance to another node. ElemType MaxDistance(const CoverTree& other) const; + //!Overloadded and improved distance function + ElemType MaxDistanceNew(const CoverTree& other, const ElemType bestDistance); + //! Return the maximum distance to another node given that the point-to-point //! distance has already been calculated. ElemType MaxDistance(const CoverTree& other, const ElemType distance) const; @@ -363,6 +366,9 @@ class CoverTree //! Return the maximum distance to another point. ElemType MaxDistance(const arma::vec& other) const; + //Overloadded and improved distance function + ElemType MaxDistanceNew(const arma::vec& other, const ElemType bestDistance); + //! Return the maximum distance to another point given that the distance from //! the center to the point has already been calculated. ElemType MaxDistance(const arma::vec& other, const ElemType distance) const; diff --git a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp index fbcdcbe1b30..87de5e0720d 100644 --- a/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp +++ b/src/mlpack/core/tree/cover_tree/cover_tree_impl.hpp @@ -701,7 +701,7 @@ size_t CoverTree:: size_t bestIndex = 0; for (size_t i = 0; i < children.size(); ++i) { - ElemType distance = children[i]->MaxDistance(point); + ElemType distance = children[i]->MaxDistanceNew(point, bestDistance); if (distance >= bestDistance) { bestDistance = distance; @@ -757,7 +757,7 @@ size_t CoverTree:: size_t bestIndex = 0; for (size_t i = 0; i < children.size(); ++i) { - ElemType distance = children[i]->MaxDistance(queryNode); + ElemType distance = children[i]->MaxDistanceNew(queryNode, bestDistance); if (distance >= bestDistance) { bestDistance = distance; @@ -845,6 +845,23 @@ CoverTree:: furthestDescendantDistance + other.FurthestDescendantDistance(); } + +template< + typename MetricType, + typename StatisticType, + typename MatType, + typename RootPointPolicy +> +typename CoverTree::ElemType +CoverTree:: + MaxDistanceNew(const CoverTree& other, const ElemType bestDistance) +{ + return metric->Evaluate(dataset->col(point), + other.Dataset().col(other.Point()), + bestDistance - furthestDescendantDistance - other.FurthestDescendantDistance); +} + template< typename MetricType, typename StatisticType, @@ -876,6 +893,23 @@ CoverTree:: furthestDescendantDistance; } + +template< + typename MetricType, + typename StatisticType, + typename MatType, + typename RootPointPolicy +> +typename CoverTree::ElemType +CoverTree:: + MaxDistanceNew(const arma::vec& other, const ElemType bestDistance) +{ + return metric->Evaluate(dataset->col(point), + other, + bestDistance - furthestDescendantDistance); +} + template< typename MetricType, typename StatisticType,