Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Test for different distances

  • Loading branch information...
commit 7fffd82b1c29e25e2a472018acfd6044ccb1a889 1 parent c80c154
@mbrucher authored
Showing with 48 additions and 26 deletions.
  1. +10 −7 kdtree/kdtree.h
  2. +38 −19 kdtree/test_kdtree.cpp
View
17 kdtree/kdtree.h
@@ -119,9 +119,10 @@ namespace Search
boost::scoped_ptr<MyNode> root;
- bool test_distance_too_great(MyNode* node, const ContainerType& point, double max_dist) const
+ template<class Distance2>
+ bool test_distance_too_great(MyNode* node, const ContainerType& point, double max_dist, const Distance2& new_distance) const
{
- return (max_dist - distance(node->middle, node->maxpoint) < distance(point, node->middle));
+ return (max_dist - new_distance(node->middle, node->maxpoint) < new_distance(point, node->middle));
}
public:
@@ -150,7 +151,9 @@ namespace Search
typedef std::multimap<double, MyNode*> NodeContainer;
typedef std::multimap<double, ContainerType> MapContainer;
- std::vector<ContainerType> knn(const ContainerType& point, unsigned long k) const
+
+ template<class Distance2>
+ std::vector<ContainerType> knn(const ContainerType& point, unsigned long k, const Distance2& new_distance) const
{
std::vector<ContainerType> result;
@@ -160,7 +163,7 @@ namespace Search
}
NodeContainer nodes;
- nodes.insert(std::make_pair(distance(point, root->middle), root.get()));
+ nodes.insert(std::make_pair(new_distance(point, root->middle), root.get()));
MapContainer points;
while (!nodes.empty())
@@ -175,17 +178,17 @@ namespace Search
MyNode* current_node = nodes.begin()->second;
nodes.erase(nodes.begin());
- if (points.size() > k && test_distance_too_great(current_node, point, points.rbegin()->first))
+ if (points.size() > k && test_distance_too_great(current_node, point, points.rbegin()->first, new_distance))
{
continue;
}
if(current_node->children.empty())
{
- current_node->add_nodes(point, distance, nodes);
+ current_node->add_nodes(point, new_distance, nodes);
}
else
{
- current_node->add_children(point, distance, points);
+ current_node->add_children(point, new_distance, points);
}
}
View
57 kdtree/test_kdtree.cpp
@@ -29,6 +29,16 @@ float euclidian(const Point& p1, const Point& p2)
return std::sqrt(dist);
}
+float absolute(const Point& p1, const Point& p2)
+{
+ float dist = 0;
+ for(unsigned long i = 0; i < p1.size(); ++i)
+ {
+ dist += std::abs(p1[i] - p2[i]);
+ }
+ return dist;
+}
+
PointContainer generate(long size)
{
PointContainer data;
@@ -50,12 +60,13 @@ PointContainer generate(long size)
return data;
}
-PointContainer knn(const PointContainer& container, const Point& data, int k)
+template<class Distance>
+PointContainer knn(const PointContainer& container, const Point& data, int k, const Distance& distance)
{
std::multimap<float, Point> point_map;
for(PointContainer::const_iterator it = container.begin(); it != container.end(); ++it)
{
- point_map.insert(std::make_pair(euclidian(data, *it), *it));
+ point_map.insert(std::make_pair(distance(data, *it), *it));
}
PointContainer result;
@@ -78,6 +89,29 @@ std::ostream& operator<<(std::ostream& stream, const Point& data)
return stream;
}
+template<class Point, class Data, class Tree, class Distance>
+void compare_searches(const Point& point, const Data& data, const Tree& tree, const Distance& distance)
+{
+ boost::posix_time::ptime time = boost::posix_time::microsec_clock::local_time();
+ PointContainer result = knn(data, point, KNNSIZE, distance);
+ std::cout << "Out time (linear) " << (boost::posix_time::microsec_clock::local_time() - time) << std::endl;
+
+ time = boost::posix_time::microsec_clock::local_time();
+ PointContainer result2 = tree.knn(point, KNNSIZE, distance);
+ std::cout << "Out time (kdtree) " << (boost::posix_time::microsec_clock::local_time() - time) << std::endl;
+
+ bool boolean = true;
+ for(int i = 0; i < std::min(result.size(), result2.size()); ++i)
+ {
+ for(int j = 0; j < POINTSIZE; ++j)
+ {
+ boolean = boolean && (result[i][j] == result2[i][j]);
+ }
+ std::cout << i << std::endl << result[i] << result2[i];
+ }
+ std::cout << "Result " << (boolean && (result.size() == result2.size())) << std::endl;
+}
+
int main(int argc, char** argv)
{
PointContainer data = generate(VECTORLENGTH);
@@ -100,23 +134,8 @@ int main(int argc, char** argv)
// tree.dump(stream);
Point zero(2, 0.f);
- time = boost::posix_time::microsec_clock::local_time();
- PointContainer result = knn(data, zero, KNNSIZE);
- std::cout << "Out time (linear) " << (boost::posix_time::microsec_clock::local_time() - time) << std::endl;
-
- time = boost::posix_time::microsec_clock::local_time();
- PointContainer result2 = tree.knn(zero, KNNSIZE);
- std::cout << "Out time (kdtree) " << (boost::posix_time::microsec_clock::local_time() - time) << std::endl;
+ compare_searches(zero, data, tree, euclidian);
+ compare_searches(zero, data, tree, absolute);
- bool boolean = true;
- for(int i = 0; i < std::min(result.size(), result2.size()); ++i)
- {
- for(int j = 0; j < POINTSIZE; ++j)
- {
- boolean = boolean && (result[i][j] == result2[i][j]);
- }
- std::cout << i << std::endl << result[i] << result2[i];
- }
- std::cout << "Result " << (boolean && (result.size() == result2.size())) << std::endl;
return EXIT_SUCCESS;
}
Please sign in to comment.
Something went wrong with that request. Please try again.