From 9ffbed8e03b5a758ac747722e5d9dc1b1dc7dd53 Mon Sep 17 00:00:00 2001 From: Cosmin Boaca Date: Sun, 20 Dec 2015 00:40:46 +0200 Subject: [PATCH] Reworked lower_bound implementation witout stack Implemented upper_bound in terms of lower_bound Changed lower_bound an upper_bound return type to iterator instead of node_ptr --- .gitignore | 1 + include/boost/trie/trie.hpp | 114 +++++++++++++++--------------------- test/set.cpp | 35 +++++++++++ 3 files changed, 84 insertions(+), 66 deletions(-) diff --git a/.gitignore b/.gitignore index b99099b..3b73bad 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ tags *.obj *.exe Jamroot +*a.out* libs/ build/ diff --git a/include/boost/trie/trie.hpp b/include/boost/trie/trie.hpp index 7c21128..8fcb83e 100644 --- a/include/boost/trie/trie.hpp +++ b/include/boost/trie/trie.hpp @@ -526,93 +526,75 @@ class trie { // upper_bound() to find the first node that greater than the key template - node_ptr upper_bound(Iter first, Iter last) + iterator upper_bound(Iter first, Iter last) { - node_ptr cur = const_cast(&root); - // use a stack to store iterator in order to avoid the iterator cannot go backward - std::stack< Iter > si; - for (; first != last; ++first) - { - si.push(first); - const key_type& cur_key = *first; - typename node_type::children_iter ci = cur->children.find(cur_key, node_comparator); - // using upper_bound needs comparison in every step, so using find until ci == NULL - if (ci == cur->children.end()) - { - // find a node that - ci = cur->children.upper_bound(cur_key, node_comparator); - si.pop(); - while (ci == cur->children.end()) - { - if (cur->parent == NULL) - return &root; - cur = cur->parent; - ci = cur->children.upper_bound(*si.top()); - } - cur = ci->second; - while (cur->no_value()) - { - cur = &(*(cur->children.begin())); - } - return cur; - } - cur = &(*ci); + std::pair lb_result = lower_bound(first, last); + // Full match + if (lb_result.second) { + ++lb_result.first; } - // if find a full match, then increment it - iterator tmp(cur); - tmp.trie_node_increment(); - cur = tmp.tnode; - return cur; + return lb_result.first; } template - node_ptr upper_bound(const Container &container) + iterator upper_bound(const Container &container) { return upper_bound(container.begin(), container.end()); } // lower_bound() template - node_ptr lower_bound(Iter first, Iter last) + std::pair lower_bound(Iter first, Iter last) { - node_ptr cur = const_cast(&root); - // use a stack to store iterator in order to avoid the iterator cannot go backward - std::stack< Iter > si; + typedef typename node_type::children_iter children_iterator; + node_ptr cur = &root; + node_ptr last_lb_candidate = NULL; + for (; first != last; ++first) { - si.push(first); const key_type& cur_key = *first; - typename node_type::children_iter ci = cur->children.find(cur_key); - // using upper_bound needs comparison in every step, so using find until ci == NULL - if (ci == cur->children.end()) - { - // find a node that - ci = cur->children.upper_bound(cur_key); - si.pop(); - while (ci == cur->children.end()) - { - if (cur->parent == NULL) - return &root; - cur = cur->parent; - ci = cur->children.upper_bound(*si.top()); - } - cur = ci->second; - while (cur->no_value()) - { - cur = &(*(cur->children.begin())); + children_iterator child_iter = + cur->children.find(cur_key, node_comparator); + if (child_iter == cur->children.end()) { + break; + } + children_iterator lb_candidate_iter = child_iter; + lb_candidate_iter++; + if (lb_candidate_iter != cur->children.end()) { + last_lb_candidate = &(*lb_candidate_iter); + } + cur = &(*child_iter); + } + + if (first != last) { + children_iterator lb_candidate_iter = + cur->children.upper_bound(*first, node_comparator); + if (lb_candidate_iter == cur->children.end()) { + if (last_lb_candidate != NULL) { + cur = last_lb_candidate; + } else { + return std::make_pair(&root, false); } - return cur; + } else { + return std::make_pair(&(*lb_candidate_iter), false); } - cur = &(*ci); } - // lower_bound() needn't increment here!!! - return cur; + + if (!cur->no_value() && first == last) { + return std::make_pair(cur, true); + } + + while (cur->no_value()) { + cur = &(*cur->children.begin()); + } + + return std::make_pair(cur, false); } template - node_ptr lower_bound(const Container &container) + iterator lower_bound(const Container &container) { - return lower_bound(container.begin(), container.end()); + return lower_bound(container.begin(), container.end()).first; } // equal_range() is the same as find_prefix? the meaning is different @@ -778,7 +760,7 @@ class trie { erase_node(cur); } - void swap(const trie_type& t) + void swap(trie_type& t) { // is it OK? std::swap(root, t.root); diff --git a/test/set.cpp b/test/set.cpp index 60ff30f..ea02263 100644 --- a/test/set.cpp +++ b/test/set.cpp @@ -184,11 +184,46 @@ void iterator_operator_minus() BOOST_TEST(riter == t.rbegin()); } +void lower_bound_test() { + tsci t; + std::string s1 = "aaa", s2 = "aab", s3 = "abc"; + t.insert(s1); + t.insert(s2); + t.insert(s3); + BOOST_TEST(t.lower_bound(s1) == t.find(s1)); + BOOST_TEST(t.lower_bound(std::string("abb")) == t.find(std::string("abc"))); + BOOST_TEST(t.lower_bound(std::string("b")) == t.end()); + t.insert(std::string("abcdef")); + BOOST_TEST(t.lower_bound(std::string("abcd")) == t.find(std::string("abcdef"))); + t.insert(std::string("bbcccc")); + t.insert(std::string("bbd")); + BOOST_TEST(t.lower_bound(std::string("bbcccd")) == t.find(std::string("bbd"))); +} + +void upper_bound_test() { + tsci t; + std::string s1 = "aaa", s2 = "aab", s3 = "abc"; + t.insert(s1); + t.insert(s2); + t.insert(s3); + BOOST_TEST(t.upper_bound(std::string("abb")) == t.find(std::string("abc"))); + BOOST_TEST(t.upper_bound(std::string("b")) == t.end()); + BOOST_TEST(t.upper_bound(s3) == t.end()); + BOOST_TEST(t.upper_bound(s1) == t.find(s2)); + t.insert(std::string("abcdef")); + BOOST_TEST(t.upper_bound(std::string("abcd")) == t.find(std::string("abcdef"))); + t.insert(std::string("bbcccc")); + t.insert(std::string("bbd")); + BOOST_TEST(t.upper_bound(std::string("bbcccd")) == t.find(std::string("bbd"))); +} + int main() { insert_erase_test(); insert_find_test(); copy_test(); iterator_operator_plus(); iterator_operator_minus(); + lower_bound_test(); + upper_bound_test(); return boost::report_errors(); }