| 
 | 1 | +//===-- llvm/ADT/RadixTree.h - Radix Tree implementation --------*- C++ -*-===//  | 
 | 2 | +//  | 
 | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.  | 
 | 4 | +// See https://llvm.org/LICENSE.txt for license information.  | 
 | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  | 
 | 6 | +//===----------------------------------------------------------------------===//  | 
 | 7 | +//  | 
 | 8 | +// This file implements a Radix Tree.  | 
 | 9 | +//  | 
 | 10 | +//===----------------------------------------------------------------------===//  | 
 | 11 | + | 
 | 12 | +#ifndef LLVM_ADT_RADIXTREE_H  | 
 | 13 | +#define LLVM_ADT_RADIXTREE_H  | 
 | 14 | + | 
 | 15 | +#include "llvm/ADT/ADL.h"  | 
 | 16 | +#include "llvm/ADT/STLExtras.h"  | 
 | 17 | +#include "llvm/ADT/iterator.h"  | 
 | 18 | +#include "llvm/ADT/iterator_range.h"  | 
 | 19 | +#include <cassert>  | 
 | 20 | +#include <cstddef>  | 
 | 21 | +#include <iterator>  | 
 | 22 | +#include <limits>  | 
 | 23 | +#include <list>  | 
 | 24 | +#include <utility>  | 
 | 25 | + | 
 | 26 | +namespace llvm {  | 
 | 27 | + | 
 | 28 | +/// \brief A Radix Tree implementation.  | 
 | 29 | +///  | 
 | 30 | +/// A Radix Tree (also known as a compact prefix tree or radix trie) is a  | 
 | 31 | +/// data structure that stores a dynamic set or associative array where keys  | 
 | 32 | +/// are strings and values are associated with these keys. Unlike a regular  | 
 | 33 | +/// trie, the edges of a radix tree can be labeled with sequences of characters  | 
 | 34 | +/// as well as single characters. This makes radix trees more efficient for  | 
 | 35 | +/// storing sparse data sets, where many nodes in a regular trie would have  | 
 | 36 | +/// only one child.  | 
 | 37 | +///  | 
 | 38 | +/// This implementation supports arbitrary key types that can be iterated over  | 
 | 39 | +/// (e.g., `std::string`, `std::vector<char>`, `ArrayRef<char>`). The key type  | 
 | 40 | +/// must provide `begin()` and `end()` for iteration.  | 
 | 41 | +///  | 
 | 42 | +/// The tree stores `std::pair<const KeyType, T>` as its value type.  | 
 | 43 | +///  | 
 | 44 | +/// Example usage:  | 
 | 45 | +/// \code  | 
 | 46 | +///   llvm::RadixTree<StringRef, int> Tree;  | 
 | 47 | +///   Tree.emplace("apple", 1);  | 
 | 48 | +///   Tree.emplace("grapefruit", 2);  | 
 | 49 | +///   Tree.emplace("grape", 3);  | 
 | 50 | +///  | 
 | 51 | +///   // Find prefixes  | 
 | 52 | +///   for (const auto &[Key, Value] : Tree.find_prefixes("grapefruit juice")) {  | 
 | 53 | +///     // pair will be {"grape", 3}  | 
 | 54 | +///     // pair will be {"grapefruit", 2}  | 
 | 55 | +///     llvm::outs() << Key << ": " << Value << "\n";  | 
 | 56 | +///   }  | 
 | 57 | +///  | 
 | 58 | +///   // Iterate over all elements  | 
 | 59 | +///   for (const auto &[Key, Value] : Tree)  | 
 | 60 | +///     llvm::outs() << Key << ": " << Value << "\n";  | 
 | 61 | +/// \endcode  | 
 | 62 | +///  | 
 | 63 | +/// \note  | 
 | 64 | +/// The `RadixTree` takes ownership of the `KeyType` and `T` objects  | 
 | 65 | +/// inserted into it. When an element is removed or the tree is destroyed,  | 
 | 66 | +/// these objects will be destructed.  | 
 | 67 | +/// However, if `KeyType` is a reference-like type, e.g., StringRef or range,  | 
 | 68 | +/// the user must guarantee that the referenced data has a lifetime longer than  | 
 | 69 | +/// the tree.  | 
 | 70 | +template <typename KeyType, typename T> class RadixTree {  | 
 | 71 | +public:  | 
 | 72 | +  using key_type = KeyType;  | 
 | 73 | +  using mapped_type = T;  | 
 | 74 | +  using value_type = std::pair<const KeyType, mapped_type>;  | 
 | 75 | + | 
 | 76 | +private:  | 
 | 77 | +  using KeyConstIteratorType =  | 
 | 78 | +      decltype(adl_begin(std::declval<const key_type &>()));  | 
 | 79 | +  using KeyConstIteratorRangeType = iterator_range<KeyConstIteratorType>;  | 
 | 80 | +  using KeyValueType =  | 
 | 81 | +      remove_cvref_t<decltype(*adl_begin(std::declval<key_type &>()))>;  | 
 | 82 | +  using ContainerType = std::list<value_type>;  | 
 | 83 | + | 
 | 84 | +  /// Represents an internal node in the Radix Tree.  | 
 | 85 | +  struct Node {  | 
 | 86 | +    KeyConstIteratorRangeType Key{KeyConstIteratorType{},  | 
 | 87 | +                                  KeyConstIteratorType{}};  | 
 | 88 | +    std::vector<Node> Children;  | 
 | 89 | + | 
 | 90 | +    /// An iterator to the value associated with this node.  | 
 | 91 | +    ///  | 
 | 92 | +    /// If this node does not have a value (i.e., it's an internal node that  | 
 | 93 | +    /// only serves as a path to other values), this iterator will be equal  | 
 | 94 | +    /// to default constructed `ContainerType::iterator()`.  | 
 | 95 | +    typename ContainerType::iterator Value;  | 
 | 96 | + | 
 | 97 | +    /// The first character of the Key. Used for fast child lookup.  | 
 | 98 | +    KeyValueType KeyFront;  | 
 | 99 | + | 
 | 100 | +    Node() = default;  | 
 | 101 | +    Node(const KeyConstIteratorRangeType &Key)  | 
 | 102 | +        : Key(Key), KeyFront(*Key.begin()) {  | 
 | 103 | +      assert(!Key.empty());  | 
 | 104 | +    }  | 
 | 105 | + | 
 | 106 | +    Node(Node &&) = default;  | 
 | 107 | +    Node &operator=(Node &&) = default;  | 
 | 108 | + | 
 | 109 | +    Node(const Node &) = delete;  | 
 | 110 | +    Node &operator=(const Node &) = delete;  | 
 | 111 | + | 
 | 112 | +    const Node *findChild(const KeyConstIteratorRangeType &Key) const {  | 
 | 113 | +      if (Key.empty())  | 
 | 114 | +        return nullptr;  | 
 | 115 | +      for (const Node &Child : Children) {  | 
 | 116 | +        assert(!Child.Key.empty()); // Only root can be empty.  | 
 | 117 | +        if (Child.KeyFront == *Key.begin())  | 
 | 118 | +          return &Child;  | 
 | 119 | +      }  | 
 | 120 | +      return nullptr;  | 
 | 121 | +    }  | 
 | 122 | + | 
 | 123 | +    Node *findChild(const KeyConstIteratorRangeType &Query) {  | 
 | 124 | +      const Node *This = this;  | 
 | 125 | +      return const_cast<Node *>(This->findChild(Query));  | 
 | 126 | +    }  | 
 | 127 | + | 
 | 128 | +    size_t countNodes() const {  | 
 | 129 | +      size_t R = 1;  | 
 | 130 | +      for (const Node &C : Children)  | 
 | 131 | +        R += C.countNodes();  | 
 | 132 | +      return R;  | 
 | 133 | +    }  | 
 | 134 | + | 
 | 135 | +    ///  | 
 | 136 | +    /// Splits the current node into two.  | 
 | 137 | +    ///  | 
 | 138 | +    /// This function is used when a new key needs to be inserted that shares  | 
 | 139 | +    /// a common prefix with the current node's key, but then diverges.  | 
 | 140 | +    /// The current `Key` is truncated to the common prefix, and a new child  | 
 | 141 | +    /// node is created for the remainder of the original node's `Key`.  | 
 | 142 | +    ///  | 
 | 143 | +    /// \param SplitPoint An iterator pointing to the character in the current  | 
 | 144 | +    ///                   `Key` where the split should occur.  | 
 | 145 | +    void split(KeyConstIteratorType SplitPoint) {  | 
 | 146 | +      Node Child(make_range(SplitPoint, Key.end()));  | 
 | 147 | +      Key = make_range(Key.begin(), SplitPoint);  | 
 | 148 | + | 
 | 149 | +      Children.swap(Child.Children);  | 
 | 150 | +      std::swap(Value, Child.Value);  | 
 | 151 | + | 
 | 152 | +      Children.emplace_back(std::move(Child));  | 
 | 153 | +    }  | 
 | 154 | +  };  | 
 | 155 | + | 
 | 156 | +  /// Root always corresponds to the empty key, which is the shortest possible  | 
 | 157 | +  /// prefix for everything.  | 
 | 158 | +  Node Root;  | 
 | 159 | +  ContainerType KeyValuePairs;  | 
 | 160 | + | 
 | 161 | +  /// Finds or creates a new tail or leaf node corresponding to the `Key`.  | 
 | 162 | +  Node &findOrCreate(KeyConstIteratorRangeType Key) {  | 
 | 163 | +    Node *Curr = &Root;  | 
 | 164 | +    if (Key.empty())  | 
 | 165 | +      return *Curr;  | 
 | 166 | + | 
 | 167 | +    for (;;) {  | 
 | 168 | +      auto [I1, I2] = llvm::mismatch(Key, Curr->Key);  | 
 | 169 | +      Key = make_range(I1, Key.end());  | 
 | 170 | + | 
 | 171 | +      if (I2 != Curr->Key.end()) {  | 
 | 172 | +        // Match is partial. Either query is too short, or there is mismatching  | 
 | 173 | +        // character. Split either way, and put new node in between of the  | 
 | 174 | +        // current and its children.  | 
 | 175 | +        Curr->split(I2);  | 
 | 176 | + | 
 | 177 | +        // Split was caused by mismatch, so `findChild` would fail.  | 
 | 178 | +        break;  | 
 | 179 | +      }  | 
 | 180 | + | 
 | 181 | +      Node *Child = Curr->findChild(Key);  | 
 | 182 | +      if (!Child)  | 
 | 183 | +        break;  | 
 | 184 | + | 
 | 185 | +      // Move to child with the same first character.  | 
 | 186 | +      Curr = Child;  | 
 | 187 | +    }  | 
 | 188 | + | 
 | 189 | +    if (Key.empty()) {  | 
 | 190 | +      // The current node completely matches the key, return it.  | 
 | 191 | +      return *Curr;  | 
 | 192 | +    }  | 
 | 193 | + | 
 | 194 | +    // `Key` is a suffix of original `Key` unmatched by path from the `Root` to  | 
 | 195 | +    // the `Curr`, and we have no candidate in the children to match more.  | 
 | 196 | +    // Create a new one.  | 
 | 197 | +    return Curr->Children.emplace_back(Key);  | 
 | 198 | +  }  | 
 | 199 | + | 
 | 200 | +  ///  | 
 | 201 | +  /// An iterator for traversing prefixes search results.  | 
 | 202 | +  ///  | 
 | 203 | +  /// This iterator is used by `find_prefixes` to traverse the tree and find  | 
 | 204 | +  /// elements that are prefixes to the given key. It's a forward iterator.  | 
 | 205 | +  ///  | 
 | 206 | +  /// \tparam MappedType The type of the value pointed to by the iterator.  | 
 | 207 | +  ///                    This will be `value_type` for non-const iterators  | 
 | 208 | +  ///                    and `const value_type` for const iterators.  | 
 | 209 | +  template <typename MappedType>  | 
 | 210 | +  class IteratorImpl  | 
 | 211 | +      : public iterator_facade_base<IteratorImpl<MappedType>,  | 
 | 212 | +                                    std::forward_iterator_tag, MappedType> {  | 
 | 213 | +    const Node *Curr = nullptr;  | 
 | 214 | +    KeyConstIteratorRangeType Query{KeyConstIteratorType{},  | 
 | 215 | +                                    KeyConstIteratorType{}};  | 
 | 216 | + | 
 | 217 | +    void findNextValid() {  | 
 | 218 | +      while (Curr && Curr->Value == typename ContainerType::iterator())  | 
 | 219 | +        advance();  | 
 | 220 | +    }  | 
 | 221 | + | 
 | 222 | +    void advance() {  | 
 | 223 | +      assert(Curr);  | 
 | 224 | +      if (Query.empty()) {  | 
 | 225 | +        Curr = nullptr;  | 
 | 226 | +        return;  | 
 | 227 | +      }  | 
 | 228 | + | 
 | 229 | +      Curr = Curr->findChild(Query);  | 
 | 230 | +      if (!Curr) {  | 
 | 231 | +        Curr = nullptr;  | 
 | 232 | +        return;  | 
 | 233 | +      }  | 
 | 234 | + | 
 | 235 | +      auto [I1, I2] = llvm::mismatch(Query, Curr->Key);  | 
 | 236 | +      if (I2 != Curr->Key.end()) {  | 
 | 237 | +        Curr = nullptr;  | 
 | 238 | +        return;  | 
 | 239 | +      }  | 
 | 240 | +      Query = make_range(I1, Query.end());  | 
 | 241 | +    }  | 
 | 242 | + | 
 | 243 | +    friend class RadixTree;  | 
 | 244 | +    IteratorImpl(const Node *C, const KeyConstIteratorRangeType &Q)  | 
 | 245 | +        : Curr(C), Query(Q) {  | 
 | 246 | +      findNextValid();  | 
 | 247 | +    }  | 
 | 248 | + | 
 | 249 | +  public:  | 
 | 250 | +    IteratorImpl() = default;  | 
 | 251 | + | 
 | 252 | +    MappedType &operator*() const { return *Curr->Value; }  | 
 | 253 | + | 
 | 254 | +    IteratorImpl &operator++() {  | 
 | 255 | +      advance();  | 
 | 256 | +      findNextValid();  | 
 | 257 | +      return *this;  | 
 | 258 | +    }  | 
 | 259 | + | 
 | 260 | +    bool operator==(const IteratorImpl &Other) const {  | 
 | 261 | +      return Curr == Other.Curr;  | 
 | 262 | +    }  | 
 | 263 | +  };  | 
 | 264 | + | 
 | 265 | +public:  | 
 | 266 | +  RadixTree() = default;  | 
 | 267 | +  RadixTree(RadixTree &&) = default;  | 
 | 268 | +  RadixTree &operator=(RadixTree &&) = default;  | 
 | 269 | + | 
 | 270 | +  using prefix_iterator = IteratorImpl<value_type>;  | 
 | 271 | +  using const_prefix_iterator = IteratorImpl<const value_type>;  | 
 | 272 | + | 
 | 273 | +  using iterator = typename ContainerType::iterator;  | 
 | 274 | +  using const_iterator = typename ContainerType::const_iterator;  | 
 | 275 | + | 
 | 276 | +  /// Returns true if the tree is empty.  | 
 | 277 | +  bool empty() const { return KeyValuePairs.empty(); }  | 
 | 278 | + | 
 | 279 | +  /// Returns the number of elements in the tree.  | 
 | 280 | +  size_t size() const { return KeyValuePairs.size(); }  | 
 | 281 | + | 
 | 282 | +  /// Returns the number of nodes in the tree.  | 
 | 283 | +  ///  | 
 | 284 | +  /// This function counts all internal nodes in the tree. It can be useful for  | 
 | 285 | +  /// understanding the memory footprint or complexity of the tree structure.  | 
 | 286 | +  size_t countNodes() const { return Root.countNodes(); }  | 
 | 287 | + | 
 | 288 | +  /// Returns an iterator to the first element.  | 
 | 289 | +  iterator begin() { return KeyValuePairs.begin(); }  | 
 | 290 | +  const_iterator begin() const { return KeyValuePairs.begin(); }  | 
 | 291 | + | 
 | 292 | +  /// Returns an iterator to the end of the tree.  | 
 | 293 | +  iterator end() { return KeyValuePairs.end(); }  | 
 | 294 | +  const_iterator end() const { return KeyValuePairs.end(); }  | 
 | 295 | + | 
 | 296 | +  /// Constructs and inserts a new element into the tree.  | 
 | 297 | +  ///  | 
 | 298 | +  /// This function constructs an element in place within the tree. If an  | 
 | 299 | +  /// element with the same key already exists, the insertion fails and the  | 
 | 300 | +  /// function returns an iterator to the existing element along with `false`.  | 
 | 301 | +  /// Otherwise, the new element is inserted and the function returns an  | 
 | 302 | +  /// iterator to the new element along with `true`.  | 
 | 303 | +  ///  | 
 | 304 | +  /// \param Key The key of the element to construct.  | 
 | 305 | +  /// \param Args Arguments to forward to the constructor of the mapped_type.  | 
 | 306 | +  /// \return A pair consisting of an iterator to the inserted element (or to  | 
 | 307 | +  ///         the element that prevented insertion) and a boolean value  | 
 | 308 | +  ///         indicating whether the insertion took place.  | 
 | 309 | +  template <typename... Ts>  | 
 | 310 | +  std::pair<iterator, bool> emplace(key_type &&Key, Ts &&...Args) {  | 
 | 311 | +    // We want to make new `Node` to refer key in the container, not the one  | 
 | 312 | +    // from the argument.  | 
 | 313 | +    // FIXME: Determine that we need a new node, before expanding  | 
 | 314 | +    // `KeyValuePairs`.  | 
 | 315 | +    const value_type &NewValue = KeyValuePairs.emplace_front(  | 
 | 316 | +        std::move(Key), T(std::forward<Ts>(Args)...));  | 
 | 317 | +    Node &Node = findOrCreate(NewValue.first);  | 
 | 318 | +    bool HasValue = Node.Value != typename ContainerType::iterator();  | 
 | 319 | +    if (!HasValue)  | 
 | 320 | +      Node.Value = KeyValuePairs.begin();  | 
 | 321 | +    else  | 
 | 322 | +      KeyValuePairs.pop_front();  | 
 | 323 | +    return {Node.Value, !HasValue};  | 
 | 324 | +  }  | 
 | 325 | + | 
 | 326 | +  ///  | 
 | 327 | +  /// Finds all elements whose keys are prefixes of the given `Key`.  | 
 | 328 | +  ///  | 
 | 329 | +  /// This function returns an iterator range over all elements in the tree  | 
 | 330 | +  /// whose keys are prefixes of the provided `Key`. For example, if the tree  | 
 | 331 | +  /// contains "abcde", "abc", "abcdefgh", and `Key` is "abcde", this function  | 
 | 332 | +  /// would return iterators to "abcde" and "abc".  | 
 | 333 | +  ///  | 
 | 334 | +  /// \param Key The key to search for prefixes of.  | 
 | 335 | +  /// \return An `iterator_range` of `const_prefix_iterator`s, allowing  | 
 | 336 | +  ///         iteration over the found prefix elements.  | 
 | 337 | +  /// \note The returned iterators reference the `Key` provided by the caller.  | 
 | 338 | +  ///       The caller must ensure that `Key` remains valid for the lifetime  | 
 | 339 | +  ///       of the iterators.  | 
 | 340 | +  iterator_range<const_prefix_iterator>  | 
 | 341 | +  find_prefixes(const key_type &Key) const {  | 
 | 342 | +    return iterator_range<const_prefix_iterator>{  | 
 | 343 | +        const_prefix_iterator(&Root, KeyConstIteratorRangeType(Key)),  | 
 | 344 | +        const_prefix_iterator{}};  | 
 | 345 | +  }  | 
 | 346 | +};  | 
 | 347 | + | 
 | 348 | +} // namespace llvm  | 
 | 349 | + | 
 | 350 | +#endif // LLVM_ADT_RADIXTREE_H  | 
0 commit comments