Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADT] Add TrieRawHashMap #69528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

cachemeifyoucan
Copy link
Collaborator

Implement TrieRawHashMap which stores objects into a Trie based on the hash of the object.

User needs to supply the hashing function and guarantees the uniqueness of the hash for the objects to be inserted. Hash collision is not supported.

This is part of LLVMCAS implementation which you can see the overall change here: #68448

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 18, 2023

@llvm/pr-subscribers-llvm-support

@llvm/pr-subscribers-llvm-adt

Author: Steven Wu (cachemeifyoucan)

Changes

Implement TrieRawHashMap which stores objects into a Trie based on the hash of the object.

User needs to supply the hashing function and guarantees the uniqueness of the hash for the objects to be inserted. Hash collision is not supported.

This is part of LLVMCAS implementation which you can see the overall change here: #68448


Patch is 46.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69528.diff

6 Files Affected:

  • (added) llvm/include/llvm/ADT/TrieRawHashMap.h (+398)
  • (modified) llvm/lib/Support/CMakeLists.txt (+1)
  • (added) llvm/lib/Support/TrieHashIndexGenerator.h (+89)
  • (added) llvm/lib/Support/TrieRawHashMap.cpp (+483)
  • (modified) llvm/unittests/ADT/CMakeLists.txt (+1)
  • (added) llvm/unittests/ADT/TrieRawHashMapTest.cpp (+342)
diff --git a/llvm/include/llvm/ADT/TrieRawHashMap.h b/llvm/include/llvm/ADT/TrieRawHashMap.h
new file mode 100644
index 000000000000000..baa08e214ce6fd7
--- /dev/null
+++ b/llvm/include/llvm/ADT/TrieRawHashMap.h
@@ -0,0 +1,398 @@
+//===- TrieRawHashMap.h -----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ADT_TRIERAWHASHMAP_H
+#define LLVM_ADT_TRIERAWHASHMAP_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <atomic>
+#include <optional>
+
+namespace llvm {
+
+class raw_ostream;
+
+/// TrieRawHashMap - is a lock-free thread-safe trie that is can be used to
+/// store/index data based on a hash value. It can be customized to work with
+/// any hash algorithm or store any data.
+///
+/// Data structure:
+/// Data node stored in the Trie contains both hash and data:
+/// struct {
+///    HashT Hash;
+///    DataT Data;
+/// };
+///
+/// Data is stored/indexed via a prefix tree, where each node in the tree can be
+/// either the root, a sub-trie or a data node. Assuming a 4-bit hash and two
+/// data objects {0001, A} and {0100, B}, it can be stored in a trie
+/// (assuming Root has 2 bits, SubTrie has 1 bit):
+///  +--------+
+///  |Root[00]| -> {0001, A}
+///  |    [01]| -> {0100, B}
+///  |    [10]| (empty)
+///  |    [11]| (empty)
+///  +--------+
+///
+/// Inserting a new object {0010, C} will result in:
+///  +--------+    +----------+
+///  |Root[00]| -> |SubTrie[0]| -> {0001, A}
+///  |        |    |       [1]| -> {0010, C}
+///  |        |    +----------+
+///  |    [01]| -> {0100, B}
+///  |    [10]| (empty)
+///  |    [11]| (empty)
+///  +--------+
+/// Note object A is sunk down to a sub-trie during the insertion. All the
+/// nodes are inserted through compare-exchange to ensure thread-safe and
+/// lock-free.
+///
+/// To find an object in the trie, walk the tree with prefix of the hash until
+/// the data node is found. Then the hash is compared with the hash stored in
+/// the data node to see if the is the same object.
+///
+/// Hash collision is not allowed so it is recommended to use trie with a
+/// "strong" hashing algorithm. A well-distributed hash can also result in
+/// better performance and memory usage.
+///
+/// It currently does not support iteration and deletion.
+
+/// Base class for a lock-free thread-safe hash-mapped trie.
+class ThreadSafeTrieRawHashMapBase {
+public:
+  static constexpr size_t TrieContentBaseSize = 4;
+  static constexpr size_t DefaultNumRootBits = 6;
+  static constexpr size_t DefaultNumSubtrieBits = 4;
+
+private:
+  template <class T> struct AllocValueType {
+    char Base[TrieContentBaseSize];
+    std::aligned_union_t<sizeof(T), T> Content;
+  };
+
+protected:
+  template <class T>
+  static constexpr size_t DefaultContentAllocSize = sizeof(AllocValueType<T>);
+
+  template <class T>
+  static constexpr size_t DefaultContentAllocAlign = alignof(AllocValueType<T>);
+
+  template <class T>
+  static constexpr size_t DefaultContentOffset =
+      offsetof(AllocValueType<T>, Content);
+
+public:
+  void operator delete(void *Ptr) { ::free(Ptr); }
+
+  LLVM_DUMP_METHOD void dump() const;
+  void print(raw_ostream &OS) const;
+
+protected:
+  /// Result of a lookup. Suitable for an insertion hint. Maybe could be
+  /// expanded into an iterator of sorts, but likely not useful (visiting
+  /// everything in the trie should probably be done some way other than
+  /// through an iterator pattern).
+  class PointerBase {
+  protected:
+    void *get() const { return I == -2u ? P : nullptr; }
+
+  public:
+    PointerBase() noexcept = default;
+    PointerBase(PointerBase &&) = default;
+    PointerBase(const PointerBase &) = default;
+    PointerBase &operator=(PointerBase &&) = default;
+    PointerBase &operator=(const PointerBase &) = default;
+
+  private:
+    friend class ThreadSafeTrieRawHashMapBase;
+    explicit PointerBase(void *Content) : P(Content), I(-2u) {}
+    PointerBase(void *P, unsigned I, unsigned B) : P(P), I(I), B(B) {}
+
+    bool isHint() const { return I != -1u && I != -2u; }
+
+    void *P = nullptr;
+    unsigned I = -1u;
+    unsigned B = 0;
+  };
+
+  /// Find the stored content with hash.
+  PointerBase find(ArrayRef<uint8_t> Hash) const;
+
+  /// Insert and return the stored content.
+  PointerBase
+  insert(PointerBase Hint, ArrayRef<uint8_t> Hash,
+         function_ref<const uint8_t *(void *Mem, ArrayRef<uint8_t> Hash)>
+             Constructor);
+
+  ThreadSafeTrieRawHashMapBase() = delete;
+
+  ThreadSafeTrieRawHashMapBase(
+      size_t ContentAllocSize, size_t ContentAllocAlign, size_t ContentOffset,
+      std::optional<size_t> NumRootBits = std::nullopt,
+      std::optional<size_t> NumSubtrieBits = std::nullopt);
+
+  /// Destructor, which asserts if there's anything to do. Subclasses should
+  /// call \a destroyImpl().
+  ///
+  /// \pre \a destroyImpl() was already called.
+  ~ThreadSafeTrieRawHashMapBase();
+  void destroyImpl(function_ref<void(void *ValueMem)> Destructor);
+
+  ThreadSafeTrieRawHashMapBase(ThreadSafeTrieRawHashMapBase &&RHS);
+
+  // Move assignment can be implemented in a thread-safe way if NumRootBits and
+  // NumSubtrieBits are stored inside the Root.
+  ThreadSafeTrieRawHashMapBase &
+  operator=(ThreadSafeTrieRawHashMapBase &&RHS) = delete;
+
+  // No copy.
+  ThreadSafeTrieRawHashMapBase(const ThreadSafeTrieRawHashMapBase &) = delete;
+  ThreadSafeTrieRawHashMapBase &
+  operator=(const ThreadSafeTrieRawHashMapBase &) = delete;
+
+  // Debug functions. Implementation details and not guaranteed to be
+  // thread-safe.
+  PointerBase getRoot() const;
+  unsigned getStartBit(PointerBase P) const;
+  unsigned getNumBits(PointerBase P) const;
+  unsigned getNumSlotUsed(PointerBase P) const;
+  std::string getTriePrefixAsString(PointerBase P) const;
+  unsigned getNumTries() const;
+  // Visit next trie in the allocation chain.
+  PointerBase getNextTrie(PointerBase P) const;
+
+private:
+  friend class TrieRawHashMapTestHelper;
+  const unsigned short ContentAllocSize;
+  const unsigned short ContentAllocAlign;
+  const unsigned short ContentOffset;
+  unsigned short NumRootBits;
+  unsigned short NumSubtrieBits;
+  struct ImplType;
+  // ImplPtr is owned by ThreadSafeTrieRawHashMapBase and needs to be freed in
+  // destoryImpl.
+  std::atomic<ImplType *> ImplPtr;
+  ImplType &getOrCreateImpl();
+  ImplType *getImpl() const;
+};
+
+/// Lock-free thread-safe hash-mapped trie.
+template <class T, size_t NumHashBytes>
+class ThreadSafeTrieRawHashMap : public ThreadSafeTrieRawHashMapBase {
+public:
+  using HashT = std::array<uint8_t, NumHashBytes>;
+
+  class LazyValueConstructor;
+  struct value_type {
+    const HashT Hash;
+    T Data;
+
+    value_type(value_type &&) = default;
+    value_type(const value_type &) = default;
+
+    value_type(ArrayRef<uint8_t> Hash, const T &Data)
+        : Hash(makeHash(Hash)), Data(Data) {}
+    value_type(ArrayRef<uint8_t> Hash, T &&Data)
+        : Hash(makeHash(Hash)), Data(std::move(Data)) {}
+
+  private:
+    friend class LazyValueConstructor;
+
+    struct EmplaceTag {};
+    template <class... ArgsT>
+    value_type(ArrayRef<uint8_t> Hash, EmplaceTag, ArgsT &&...Args)
+        : Hash(makeHash(Hash)), Data(std::forward<ArgsT>(Args)...) {}
+
+    static HashT makeHash(ArrayRef<uint8_t> HashRef) {
+      HashT Hash;
+      std::copy(HashRef.begin(), HashRef.end(), Hash.data());
+      return Hash;
+    }
+  };
+
+  using ThreadSafeTrieRawHashMapBase::operator delete;
+  using HashType = HashT;
+
+  using ThreadSafeTrieRawHashMapBase::dump;
+  using ThreadSafeTrieRawHashMapBase::print;
+
+private:
+  template <class ValueT> class PointerImpl : PointerBase {
+    friend class ThreadSafeTrieRawHashMap;
+
+    ValueT *get() const {
+      if (void *B = PointerBase::get())
+        return reinterpret_cast<ValueT *>(B);
+      return nullptr;
+    }
+
+  public:
+    ValueT &operator*() const {
+      assert(get());
+      return *get();
+    }
+    ValueT *operator->() const {
+      assert(get());
+      return get();
+    }
+    explicit operator bool() const { return get(); }
+
+    PointerImpl() = default;
+    PointerImpl(PointerImpl &&) = default;
+    PointerImpl(const PointerImpl &) = default;
+    PointerImpl &operator=(PointerImpl &&) = default;
+    PointerImpl &operator=(const PointerImpl &) = default;
+
+  protected:
+    PointerImpl(PointerBase Result) : PointerBase(Result) {}
+  };
+
+public:
+  class pointer;
+  class const_pointer;
+  class pointer : public PointerImpl<value_type> {
+    friend class ThreadSafeTrieRawHashMap;
+    friend class const_pointer;
+
+  public:
+    pointer() = default;
+    pointer(pointer &&) = default;
+    pointer(const pointer &) = default;
+    pointer &operator=(pointer &&) = default;
+    pointer &operator=(const pointer &) = default;
+
+  private:
+    pointer(PointerBase Result) : pointer::PointerImpl(Result) {}
+  };
+
+  class const_pointer : public PointerImpl<const value_type> {
+    friend class ThreadSafeTrieRawHashMap;
+
+  public:
+    const_pointer() = default;
+    const_pointer(const_pointer &&) = default;
+    const_pointer(const const_pointer &) = default;
+    const_pointer &operator=(const_pointer &&) = default;
+    const_pointer &operator=(const const_pointer &) = default;
+
+    const_pointer(const pointer &P) : const_pointer::PointerImpl(P) {}
+
+  private:
+    const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {}
+  };
+
+  class LazyValueConstructor {
+  public:
+    value_type &operator()(T &&RHS) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem) value_type(Hash, std::move(RHS)));
+    }
+    value_type &operator()(const T &RHS) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem) value_type(Hash, RHS));
+    }
+    template <class... ArgsT> value_type &emplace(ArgsT &&...Args) {
+      assert(Mem && "Constructor already called, or moved away");
+      return assign(::new (Mem)
+                        value_type(Hash, typename value_type::EmplaceTag{},
+                                   std::forward<ArgsT>(Args)...));
+    }
+
+    LazyValueConstructor(LazyValueConstructor &&RHS)
+        : Mem(RHS.Mem), Result(RHS.Result), Hash(RHS.Hash) {
+      RHS.Mem = nullptr; // Moved away, cannot call.
+    }
+    ~LazyValueConstructor() { assert(!Mem && "Constructor never called!"); }
+
+  private:
+    value_type &assign(value_type *V) {
+      Mem = nullptr;
+      Result = V;
+      return *V;
+    }
+    friend class ThreadSafeTrieRawHashMap;
+    LazyValueConstructor() = delete;
+    LazyValueConstructor(void *Mem, value_type *&Result, ArrayRef<uint8_t> Hash)
+        : Mem(Mem), Result(Result), Hash(Hash) {
+      assert(Hash.size() == sizeof(HashT) && "Invalid hash");
+      assert(Mem && "Invalid memory for construction");
+    }
+    void *Mem;
+    value_type *&Result;
+    ArrayRef<uint8_t> Hash;
+  };
+
+  /// Insert with a hint. Default-constructed hint will work, but it's
+  /// recommended to start with a lookup to avoid overhead in object creation
+  /// if it already exists.
+  pointer insertLazy(const_pointer Hint, ArrayRef<uint8_t> Hash,
+                     function_ref<void(LazyValueConstructor)> OnConstruct) {
+    return pointer(ThreadSafeTrieRawHashMapBase::insert(
+        Hint, Hash, [&](void *Mem, ArrayRef<uint8_t> Hash) {
+          value_type *Result = nullptr;
+          OnConstruct(LazyValueConstructor(Mem, Result, Hash));
+          return Result->Hash.data();
+        }));
+  }
+
+  pointer insertLazy(ArrayRef<uint8_t> Hash,
+                     function_ref<void(LazyValueConstructor)> OnConstruct) {
+    return insertLazy(const_pointer(), Hash, OnConstruct);
+  }
+
+  pointer insert(const_pointer Hint, value_type &&HashedData) {
+    return insertLazy(Hint, HashedData.Hash, [&](LazyValueConstructor C) {
+      C(std::move(HashedData.Data));
+    });
+  }
+
+  pointer insert(const_pointer Hint, const value_type &HashedData) {
+    return insertLazy(Hint, HashedData.Hash,
+                      [&](LazyValueConstructor C) { C(HashedData.Data); });
+  }
+
+  pointer find(ArrayRef<uint8_t> Hash) {
+    assert(Hash.size() == std::tuple_size<HashT>::value);
+    return ThreadSafeTrieRawHashMapBase::find(Hash);
+  }
+
+  const_pointer find(ArrayRef<uint8_t> Hash) const {
+    assert(Hash.size() == std::tuple_size<HashT>::value);
+    return ThreadSafeTrieRawHashMapBase::find(Hash);
+  }
+
+  ThreadSafeTrieRawHashMap(std::optional<size_t> NumRootBits = std::nullopt,
+                           std::optional<size_t> NumSubtrieBits = std::nullopt)
+      : ThreadSafeTrieRawHashMapBase(DefaultContentAllocSize<value_type>,
+                                     DefaultContentAllocAlign<value_type>,
+                                     DefaultContentOffset<value_type>,
+                                     NumRootBits, NumSubtrieBits) {}
+
+  ~ThreadSafeTrieRawHashMap() {
+    if constexpr (std::is_trivially_destructible<value_type>::value)
+      this->destroyImpl(nullptr);
+    else
+      this->destroyImpl(
+          [](void *P) { static_cast<value_type *>(P)->~value_type(); });
+  }
+
+  // Move constructor okay.
+  ThreadSafeTrieRawHashMap(ThreadSafeTrieRawHashMap &&) = default;
+
+  // No move assignment or any copy.
+  ThreadSafeTrieRawHashMap &operator=(ThreadSafeTrieRawHashMap &&) = delete;
+  ThreadSafeTrieRawHashMap(const ThreadSafeTrieRawHashMap &) = delete;
+  ThreadSafeTrieRawHashMap &
+  operator=(const ThreadSafeTrieRawHashMap &) = delete;
+};
+
+} // namespace llvm
+
+#endif // LLVM_ADT_TRIERAWHASHMAP_H
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index b96d62c7a6224d6..677f52677a27f8a 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -239,6 +239,7 @@ add_llvm_component_library(LLVMSupport
   TimeProfiler.cpp
   Timer.cpp
   ToolOutputFile.cpp
+  TrieRawHashMap.cpp
   Twine.cpp
   TypeSize.cpp
   Unicode.cpp
diff --git a/llvm/lib/Support/TrieHashIndexGenerator.h b/llvm/lib/Support/TrieHashIndexGenerator.h
new file mode 100644
index 000000000000000..c9e9b70e10d3c77
--- /dev/null
+++ b/llvm/lib/Support/TrieHashIndexGenerator.h
@@ -0,0 +1,89 @@
+//===- TrieHashIndexGenerator.h ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
+#define LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
+
+#include "llvm/ADT/ArrayRef.h"
+#include <optional>
+
+namespace llvm {
+
+struct IndexGenerator {
+  size_t NumRootBits;
+  size_t NumSubtrieBits;
+  ArrayRef<uint8_t> Bytes;
+  std::optional<size_t> StartBit = std::nullopt;
+
+  size_t getNumBits() const {
+    assert(StartBit);
+    size_t TotalNumBits = Bytes.size() * 8;
+    assert(*StartBit <= TotalNumBits);
+    return std::min(*StartBit ? NumSubtrieBits : NumRootBits,
+                    TotalNumBits - *StartBit);
+  }
+  size_t next() {
+    size_t Index;
+    if (!StartBit) {
+      StartBit = 0;
+      Index = getIndex(Bytes, *StartBit, NumRootBits);
+    } else {
+      *StartBit += *StartBit ? NumSubtrieBits : NumRootBits;
+      assert((*StartBit - NumRootBits) % NumSubtrieBits == 0);
+      Index = getIndex(Bytes, *StartBit, NumSubtrieBits);
+    }
+    return Index;
+  }
+
+  size_t hint(unsigned Index, unsigned Bit) {
+    assert(Index >= 0);
+    assert(Bit < Bytes.size() * 8);
+    assert(Bit == 0 || (Bit - NumRootBits) % NumSubtrieBits == 0);
+    StartBit = Bit;
+    return Index;
+  }
+
+  size_t getCollidingBits(ArrayRef<uint8_t> CollidingBits) const {
+    assert(StartBit);
+    return getIndex(CollidingBits, *StartBit, NumSubtrieBits);
+  }
+
+  static size_t getIndex(ArrayRef<uint8_t> Bytes, size_t StartBit,
+                         size_t NumBits) {
+    assert(StartBit < Bytes.size() * 8);
+
+    Bytes = Bytes.drop_front(StartBit / 8u);
+    StartBit %= 8u;
+    size_t Index = 0;
+    for (uint8_t Byte : Bytes) {
+      size_t ByteStart = 0, ByteEnd = 8;
+      if (StartBit) {
+        ByteStart = StartBit;
+        Byte &= (1u << (8 - StartBit)) - 1u;
+        StartBit = 0;
+      }
+      size_t CurrentNumBits = ByteEnd - ByteStart;
+      if (CurrentNumBits > NumBits) {
+        Byte >>= CurrentNumBits - NumBits;
+        CurrentNumBits = NumBits;
+      }
+      Index <<= CurrentNumBits;
+      Index |= Byte & ((1u << CurrentNumBits) - 1u);
+
+      assert(NumBits >= CurrentNumBits);
+      NumBits -= CurrentNumBits;
+      if (!NumBits)
+        break;
+    }
+    return Index;
+  }
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_SUPPORT_TRIEHASHINDEXGENERATOR_H
diff --git a/llvm/lib/Support/TrieRawHashMap.cpp b/llvm/lib/Support/TrieRawHashMap.cpp
new file mode 100644
index 000000000000000..af4cd8b57aed214
--- /dev/null
+++ b/llvm/lib/Support/TrieRawHashMap.cpp
@@ -0,0 +1,483 @@
+//===- TrieRawHashMap.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/ADT/TrieRawHashMap.h"
+#include "TrieHashIndexGenerator.h"
+#include "llvm/ADT/LazyAtomicPointer.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ThreadSafeAllocator.h"
+#include "llvm/Support/raw_ostream.h"
+#include <memory>
+
+using namespace llvm;
+
+namespace {
+struct TrieNode {
+  const bool IsSubtrie = false;
+
+  TrieNode(bool IsSubtrie) : IsSubtrie(IsSubtrie) {}
+
+  static void *operator new(size_t Size) { return ::malloc(Size); }
+  void operator delete(void *Ptr) { ::free(Ptr); }
+};
+
+struct TrieContent final : public TrieNode {
+  const uint8_t ContentOffset;
+  const uint8_t HashSize;
+  const uint8_t HashOffset;
+
+  void *getValuePointer() const {
+    auto Content = reinterpret_cast<const uint8_t *>(this) + ContentOffset;
+    return const_cast<uint8_t *>(Content);
+  }
+
+  ArrayRef<uint8_t> getHash() const {
+    auto *Begin = reinterpret_cast<const uint8_t *>(this) + HashOffset;
+    return ArrayRef(Begin, Begin + HashSize);
+  }
+
+  TrieContent(size_t ContentOffset, size_t HashSize, size_t HashOffset)
+      : TrieNode(/*IsSubtrie=*/false), ContentOffset(ContentOffset),
+        HashSize(HashSize), HashOffset(HashOffset) {}
+};
+static_assert(sizeof(TrieContent) ==
+                  ThreadSafeTrieRawHashMapBase::TrieContentBaseSize,
+              "Check header assumption!");
+
+class TrieSubtrie final : public TrieNode {
+public:
+  TrieNode *get(size_t I) const { return Slots[I].load(); }
+
+  TrieSubtrie *
+  sink(size_t I, TrieContent &Content, size_t NumSubtrieBits, size_t NewI,
+       function_ref<TrieSubtrie *(std::unique_ptr<TrieSubtrie>)> Saver);
+
+  static std::unique_ptr<TrieSubtrie> create(size_t StartBit, size_t NumBits);
+
+  explicit TrieSubtrie(size_t StartBit, size_t NumBits);
+
+private:
+  // FIXME: Use a bitset to speed up access:
+  //
+  //     std::array<std::atomic<uint64_t>, NumSlots/64> IsSet;
+  //
+  // This will avoid needing to visit sparsely filled slots in
+  // \a ThreadSafeTrieRawHashMapBase::destroyImpl() when there's a non-trivial
+  // destructor.
+  //
+  // It would also greatly speed up iteration, if we add that some day, and
+  // allow get() to return one level sooner.
+  //
+  // This would be the algorithm for updating IsSet (after updating Slots):
+  //
+  //     std::atomic<uint64_t> &Bits = IsSet[I.High];
+  //     const uint64_t NewBit = 1ULL << I.Low;
+  //     uint64_t Old = 0;
+  //     while (!Bits.compare_exchange_weak(Old, Old | NewBit))
+  //       ;...
[truncated]

@github-actions
Copy link

github-actions bot commented Oct 18, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@bzcheeseman bzcheeseman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't block on me, just a drive-by comment while skimming the change :) I'll take a more in-depth look when I get the chance (and plz ping if I am holding something up)

llvm/lib/Support/TrieRawHashMap.cpp Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/TrieRawHashMap.h Outdated Show resolved Hide resolved
llvm/include/llvm/ADT/TrieRawHashMap.h Outdated Show resolved Hide resolved
llvm/lib/Support/TrieHashIndexGenerator.h Outdated Show resolved Hide resolved
friend class llvm::ThreadSafeTrieRawHashMapBase;

public:
/// Linked list for ownership of tries. The pointer is owned by TrieSubtrie.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "This is an owning pointer"? ("The pointer is owned by TrieSubtrie" confused me a bit (it's not about what owns the pointer, but the memory it points to - and "which TrieSubtrie" - this or some other instance, etc))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not an owning pointer. Tries are actually also chained together as a single linked list as they get allocated. This is to get a well-defined destroying order. When destroying the entire TrieRawHashMap, it can't just have one owning pointer because it needs to destroy both the data stored in the trie and the trie itself. It needs to keep the trie structure (see comments in ThreadSafeTrieRawHashMapBase::destroyImpl) while destroying data. It achieves that by walk the allocation linked list using this pointer and does two passes:

  • Destroy all data
  • Destroy tries

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this is also to avoid using recursion to destroy all tries, since the number of tries can be very very big and overflow the stack.

size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
void *Memory = ::malloc(Size);
TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
return std::unique_ptr<TrieSubtrie>(S);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would produce mismatched new/delete. If the object is created with malloc+placement new, it should be destroyed with an explicit dtor call (if it's non-trivial) and a call to free - not a call to delete. I think?
(a unique_ptr with a custom deleter could be used to address this issue)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete is actually overridden with free(). I am not proud of this but it is tricky to get this placement new but also return unique_ptr.

llvm/lib/Support/TrieRawHashMap.cpp Show resolved Hide resolved
llvm/lib/Support/TrieRawHashMap.cpp Show resolved Hide resolved
llvm/unittests/ADT/TrieRawHashMapTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/TrieRawHashMapTest.cpp Outdated Show resolved Hide resolved
llvm/unittests/ADT/TrieRawHashMapTest.cpp Outdated Show resolved Hide resolved
@dwblaikie
Copy link
Collaborator

Per https://llvm.org/docs/GitHub.html#updating-pull-requests it'd be helpful to, "When updating a pull request, you should push additional “fix up” commits to your branch instead of force pushing. This makes it easier for GitHub to track the context of previous review comments. Consider using the built-in support for fixups in git."

(like, I can't readily tell what things changed in your most recent update - so I can see if/how different feedback was/wasn't addressed)

(I know we're all learning how to do pull requests in LLVM - I still haven't sent any of my own out... so I'm certainly no expert at making them, but I'm slowly learning what makes them easier/harder to review, at least)

@cachemeifyoucan
Copy link
Collaborator Author

Per https://llvm.org/docs/GitHub.html#updating-pull-requests it'd be helpful to, "When updating a pull request, you should push additional “fix up” commits to your branch instead of force pushing. This makes it easier for GitHub to track the context of previous review comments. Consider using the built-in support for fixups in git."

Oops, I didn't realize we push out an official guideline already. Definitely will do that in the future. At the meantime, let me see if the review history can snap back to the code if I restore the branch to old stage and add fixup commit.

@cachemeifyoucan
Copy link
Collaborator Author

Ping~


namespace llvm {

struct IndexGenerator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could you add comments to this struct and its functions? It also may be worth moving them out-of-line, they are maybe a little long to be fully inlined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments for the class and APIs. Let me know if it is still hard to read for inline implementation. This is currently a private header and I can definitely move some implementation out to its own cpp file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably move this to a .cpp if it was me. The comments help a lot though, thanks!

llvm/lib/Support/TrieRawHashMap.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/TrieRawHashMap.cpp Show resolved Hide resolved
size_t NumBits) {
size_t Size = sizeof(TrieSubtrie) + getTrieTailSize(StartBit, NumBits);
void *Memory = ::malloc(Size);
TrieSubtrie *S = ::new (Memory) TrieSubtrie(StartBit, NumBits);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't operator new overloaded above? Could you use that?

I would echo David's comment below that this is a bit fussy but if this is what you have to do I believe you :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is https://llvm.org/doxygen/classllvm_1_1TrailingObjects.html something that might be useful here? It looks like that's what you're going for, and this might be a little cleaner?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess using a TrailingObject is a cleaner solution but it does come with some downside (unless I missed how that can be used). Currently, the trailing object in TrieSubtrie is dynamically allocated based on parameter and I don't know how to get trailing object to work with a dynamically size object. So there are two ways to get it working:

  • instead of trailing object, just put a trailing pointer there.
  • for the currently implementation, the size is only affected by the configuration so technically we can template that, then we can have a static size to construct trailing object. The main problem here is that we haven't fine tuned to configuration yet and that will prevent creating a trie implementation that decides bit size of different level or you just have to create all variation of size node type ahead of time.

@dexonsmith any opinion on this?

llvm/lib/Support/TrieRawHashMap.cpp Show resolved Hide resolved
llvm/lib/Support/TrieRawHashMap.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/TrieRawHashMap.cpp Outdated Show resolved Hide resolved
NumRootBits(NumRootBits ? *NumRootBits : DefaultNumRootBits),
NumSubtrieBits(NumSubtrieBits ? *NumSubtrieBits : DefaultNumSubtrieBits),
ImplPtr(nullptr) {
assert((!NumRootBits || *NumRootBits < 20) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why do we have this restriction?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know~ @dexonsmith any insight.

I can see why it is not a reasonable choice but I don't know anything bad would happen if you force that (I don't think we will overflow, even on a 32 bit architecture). I removed it for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point is catch misuse, not logic errors, so it's okay for the assertion to be relaxed if there's another reasonable limit.

The limit of 20 bits ensures the root's allocation is under 8MB (1M entries). On some platforms, larger contiguous allocations might fail. An assertion could catch platform-dependent memory failures.

Certainly, I imagine a limit of 29 would be palatable? (500M entries, 4GB allocation for the root node)

In any case, it feels like someone is holding the trie "wrong" for root nodes this big. Going bigger is certainly valid in some sense, but I question the benefit of supporting it. If you really want to optimize for super large data sets, you probably want a file-backed trie instead.

I prefer the original limit of 20, unless someone has a specific argument for something else. (I don't feel strongly about what the limit is, just that there should be one.)

llvm/lib/Support/TrieRawHashMap.cpp Outdated Show resolved Hide resolved
llvm/lib/Support/TrieRawHashMap.cpp Show resolved Hide resolved
@cachemeifyoucan cachemeifyoucan force-pushed the eng/PR-trie-raw-hash-map branch 2 times, most recently from c96baad to 5088211 Compare October 30, 2023 18:28
Implement TrieRawHashMap which stores objects into a Trie based on the
hash of the object.

User needs to supply the hashing function and guarantees the uniqueness of
the hash for the objects to be inserted. Hash collision is not
supported
Copy link
Contributor

@bzcheeseman bzcheeseman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the trailing object open discussion, LGTM. I'll circle back once a direction has been decided-upon there :)


namespace llvm {

struct IndexGenerator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would probably move this to a .cpp if it was me. The comments help a lot though, thanks!

const_pointer(PointerBase Result) : const_pointer::PointerImpl(Result) {}
};

class LazyValueConstructor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably use a comment and/or usage example? Another "not super clear from inspection" case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants