Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Support] A cache for cyclical replacers/maps #98202

Merged
merged 4 commits into from
Jul 12, 2024

Conversation

zyx-billy
Copy link
Contributor

@zyx-billy zyx-billy commented Jul 9, 2024

This is a support data structure that acts as a cache for replacer-like functions that map values between two domains. The difference compared to just using a map to cache in-out pairs is that this class is able to handle replacer logic that is self-recursive (and thus may cause infinite recursion in the naive case).

This class provides a hook for the user to perform cycle pruning when a cycle is identified, and is able to perform context-sensitive caching so that the replacement result for an input that is part of a pruned cycle can be distinct from the replacement result for the same input when it is not part of a cycle.

In addition, this class allows deferring cycle pruning until specific inputs are repeated. This is useful for cases where not all elements in a cycle can perform pruning. The user still must guarantee that at least one element in any given cycle can perform pruning. Even if not, an assertion will eventually be tripped instead of infinite recursion (the run-time is linearly bounded by the maximum cycle length of its input).


Users of this class are pushed in stacked PRs:

@zyx-billy zyx-billy marked this pull request as ready for review July 9, 2024 21:04
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jul 9, 2024
@zyx-billy zyx-billy requested a review from Mogball July 9, 2024 21:05
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 9, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Billy Zhu (zyx-billy)

Changes

This is a support data structure that acts as a cache for replacer-like functions that map values between two domains. The difference compared to just using a map to cache in-out pairs is that this class is able to handle replacer logic that is self-recursive (and thus may cause infinite recursion in the naive case).

This class provides a hook for the user to perform cycle pruning when a cycle is identified, and is able to perform context-sensitive caching so that the replacement result for an input that is part of a pruned cycle can be distinct from the replacement result for the same input when it is not part of a cycle.

In addition, this class allows deferring cycle pruning until specific inputs are repeated. This is useful for cases where not all elements in a cycle can perform pruning. The user still must guarantee that at least one element in any given cycle can perform pruning. Even if not, an assertion will eventually be tripped instead of infinite recursion (the run-time is linearly bounded by the maximum cycle length of its input).


Users of this class are pushed in stacked PRs:


Patch is 25.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98202.diff

3 Files Affected:

  • (added) mlir/include/mlir/Support/CyclicReplacerCache.h (+277)
  • (modified) mlir/unittests/Support/CMakeLists.txt (+1)
  • (added) mlir/unittests/Support/CyclicReplacerCacheTest.cpp (+474)
diff --git a/mlir/include/mlir/Support/CyclicReplacerCache.h b/mlir/include/mlir/Support/CyclicReplacerCache.h
new file mode 100644
index 0000000000000..9a703676fff11
--- /dev/null
+++ b/mlir/include/mlir/Support/CyclicReplacerCache.h
@@ -0,0 +1,277 @@
+//===- CyclicReplacerCache.h ------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains helper classes for caching replacer-like functions that
+// map values between two domains. They are able to handle replacer logic that
+// contains self-recursion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_CACHINGREPLACER_H
+#define MLIR_SUPPORT_CACHINGREPLACER_H
+
+#include "mlir/IR/Visitors.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/MapVector.h"
+#include <set>
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// CyclicReplacerCache
+//===----------------------------------------------------------------------===//
+
+/// A cache for replacer-like functions that map values between two domains. The
+/// difference compared to just using a map to cache in-out pairs is that this
+/// class is able to handle replacer logic that is self-recursive (and thus may
+/// cause infinite recursion in the naive case).
+///
+/// This class provides a hook for the user to perform cycle pruning when a
+/// cycle is identified, and is able to perform context-sensitive caching so
+/// that the replacement result for an input that is part of a pruned cycle can
+/// be distinct from the replacement result for the same input when it is not
+/// part of a cycle.
+///
+/// In addition, this class allows deferring cycle pruning until specific inputs
+/// are repeated. This is useful for cases where not all elements in a cycle can
+/// perform pruning. The user still must guarantee that at least one element in
+/// any given cycle can perform pruning. Even if not, an assertion will
+/// eventually be tripped instead of infinite recursion (the run-time is
+/// linearly bounded by the maximum cycle length of its input).
+template <typename InT, typename OutT>
+class CyclicReplacerCache {
+public:
+  /// User-provided replacement function & cycle-breaking functions.
+  /// The cycle-breaking function must not make any more recursive invocations
+  /// to this cached replacer.
+  using CycleBreakerFn = std::function<std::optional<OutT>(const InT &)>;
+
+  CyclicReplacerCache() = delete;
+  CyclicReplacerCache(CycleBreakerFn cycleBreaker)
+      : cycleBreaker(std::move(cycleBreaker)) {}
+
+  /// A possibly unresolved cache entry.
+  /// If unresolved, the entry must be resolved before it goes out of scope.
+  struct CacheEntry {
+  public:
+    ~CacheEntry() { assert(result && "unresovled cache entry"); }
+
+    /// Check whether this node was repeated during recursive replacements.
+    /// This only makes sense to be called after all recursive replacements are
+    /// completed and the current element has resurfaced to the top of the
+    /// replacement stack.
+    bool wasRepeated() const {
+      // If the top frame includes itself as a dependency, then it must have
+      // been repeated.
+      ReplacementFrame &currFrame = cache.replacementStack.back();
+      size_t currFrameIndex = cache.replacementStack.size() - 1;
+      return currFrame.dependentFrames.count(currFrameIndex);
+    }
+
+    /// Resolve an unresolved cache entry by providing the result to be stored
+    /// in the cache.
+    void resolve(OutT result) {
+      assert(!this->result && "cache entry already resolved");
+      this->result = result;
+      cache.finalizeReplacement(element, result);
+    }
+
+    /// Get the resolved result if one exists.
+    std::optional<OutT> get() { return result; }
+
+  private:
+    friend class CyclicReplacerCache;
+    CacheEntry() = delete;
+    CacheEntry(CyclicReplacerCache<InT, OutT> &cache, InT element,
+               std::optional<OutT> result = std::nullopt)
+        : cache(cache), element(element), result(result) {}
+
+    CyclicReplacerCache<InT, OutT> &cache;
+    InT element;
+    std::optional<OutT> result;
+  };
+
+  /// Lookup the cache for a pre-calculated replacement for `element`.
+  /// If one exists, a resolved CacheEntry will be returned. Otherwise, an
+  /// unresolved CacheEntry will be returned, and the caller must resolve it
+  /// with the calculated replacement so it can be registered in the cache for
+  /// future use.
+  /// Multiple unresolved CacheEntries may be retrieved. However, any unresolved
+  /// CacheEntries that are returned must be resolved in reverse order of
+  /// retrieval, i.e. the last retrieved CacheEntry must be resolved first, and
+  /// the first retrieved CacheEntry must be resolved last. This should be
+  /// natural when used as a stack / inside recursion.
+  CacheEntry lookupOrInit(const InT &element);
+
+private:
+  /// Register the replacement in the cache and update the replacementStack.
+  void finalizeReplacement(const InT &element, const OutT &result);
+
+  CycleBreakerFn cycleBreaker;
+  DenseMap<InT, OutT> standaloneCache;
+
+  struct DependentReplacement {
+    OutT replacement;
+    /// The highest replacement frame index that this cache entry is dependent
+    /// on.
+    size_t highestDependentFrame;
+  };
+  DenseMap<InT, DependentReplacement> dependentCache;
+
+  struct ReplacementFrame {
+    /// The set of elements that is only legal while under this current frame.
+    /// They need to be removed from the cache when this frame is popped off the
+    /// replacement stack.
+    DenseSet<InT> dependingReplacements;
+    /// The set of frame indices that this current frame's replacement is
+    /// dependent on, ordered from highest to lowest.
+    std::set<size_t, std::greater<size_t>> dependentFrames;
+  };
+  /// Every element currently in the progress of being replaced pushes a frame
+  /// onto this stack.
+  SmallVector<ReplacementFrame> replacementStack;
+  /// Maps from each input element to its indices on the replacement stack.
+  DenseMap<InT, SmallVector<size_t, 2>> cyclicElementFrame;
+  /// If set to true, we are currently asking an element to break a cycle. No
+  /// more recursive invocations is allowed while this is true (the replacement
+  /// stack can no longer grow).
+  bool resolvingCycle = false;
+};
+
+template <typename InT, typename OutT>
+typename CyclicReplacerCache<InT, OutT>::CacheEntry
+CyclicReplacerCache<InT, OutT>::lookupOrInit(const InT &element) {
+  assert(!resolvingCycle &&
+         "illegal recursive invocation while breaking cycle");
+
+  if (auto it = standaloneCache.find(element); it != standaloneCache.end())
+    return CacheEntry(*this, element, it->second);
+
+  if (auto it = dependentCache.find(element); it != dependentCache.end()) {
+    // pdate the current top frame (the element that invoked this current
+    // replacement) to include any dependencies the cache entry had.
+    ReplacementFrame &currFrame = replacementStack.back();
+    currFrame.dependentFrames.insert(it->second.highestDependentFrame);
+    return CacheEntry(*this, element, it->second.replacement);
+  }
+
+  auto [it, inserted] = cyclicElementFrame.try_emplace(element);
+  if (!inserted) {
+    // This is a repeat of a known element. Try to break cycle here.
+    resolvingCycle = true;
+    std::optional<OutT> result = cycleBreaker(element);
+    resolvingCycle = false;
+    if (result) {
+      // Cycle was broken.
+      size_t dependentFrame = it->second.back();
+      dependentCache[element] = {*result, dependentFrame};
+      ReplacementFrame &currFrame = replacementStack.back();
+      // If this is a repeat, there is no replacement frame to pop. Mark the top
+      // frame as being dependent on this element.
+      currFrame.dependentFrames.insert(dependentFrame);
+
+      return CacheEntry(*this, element, *result);
+    }
+
+    // Cycle could not be broken.
+    // A legal setup must ensure at least one element of each cycle can break
+    // cycles. Under this setup, each element can be seen at most twice before
+    // the cycle is broken. If we see an element more than twice, we know this
+    // is an illegal setup.
+    assert(it->second.size() <= 2 && "illegal 3rd repeat of input");
+  }
+
+  // Otherwise, either this is the first time we see this element, or this
+  // element could not break this cycle.
+  it->second.push_back(replacementStack.size());
+  replacementStack.emplace_back();
+
+  return CacheEntry(*this, element);
+}
+
+template <typename InT, typename OutT>
+void CyclicReplacerCache<InT, OutT>::finalizeReplacement(const InT &element,
+                                                         const OutT &result) {
+  ReplacementFrame &currFrame = replacementStack.back();
+  // With the conclusion of this replacement frame, the current element is no
+  // longer a dependent element.
+  currFrame.dependentFrames.erase(replacementStack.size() - 1);
+
+  auto prevLayerIter = ++replacementStack.rbegin();
+  if (prevLayerIter == replacementStack.rend()) {
+    // If this is the last frame, there should be zero dependents.
+    assert(currFrame.dependentFrames.empty() &&
+           "internal error: top-level dependent replacement");
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else if (currFrame.dependentFrames.empty()) {
+    // Cache standalone result.
+    standaloneCache[element] = result;
+  } else {
+    // Cache dependent result.
+    size_t highestDependentFrame = *currFrame.dependentFrames.begin();
+    dependentCache[element] = {result, highestDependentFrame};
+
+    // Otherwise, the previous frame inherits the same dependent frames.
+    prevLayerIter->dependentFrames.insert(currFrame.dependentFrames.begin(),
+                                          currFrame.dependentFrames.end());
+
+    // Mark this current replacement as a depending replacement on the closest
+    // dependent frame.
+    replacementStack[highestDependentFrame].dependingReplacements.insert(
+        element);
+  }
+
+  // All depending replacements in the cache must be purged.
+  for (InT key : currFrame.dependingReplacements)
+    dependentCache.erase(key);
+
+  replacementStack.pop_back();
+  auto it = cyclicElementFrame.find(element);
+  it->second.pop_back();
+  if (it->second.empty())
+    cyclicElementFrame.erase(it);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer
+//===----------------------------------------------------------------------===//
+
+/// A helper class for cases where the input/output types of the replacer
+/// function is identical to the types stored in the cache. This class wraps
+/// the user-provided replacer function, and can be used in place of the user
+/// function.
+template <typename InT, typename OutT>
+class CachedCyclicReplacer {
+public:
+  using ReplacerFn = std::function<OutT(const InT &)>;
+  using CycleBreakerFn =
+      typename CyclicReplacerCache<InT, OutT>::CycleBreakerFn;
+
+  CachedCyclicReplacer() = delete;
+  CachedCyclicReplacer(ReplacerFn replacer, CycleBreakerFn cycleBreaker)
+      : replacer(std::move(replacer)), cache(std::move(cycleBreaker)) {}
+
+  OutT operator()(const InT &element) {
+    auto cacheEntry = cache.lookupOrInit(element);
+    if (std::optional<OutT> result = cacheEntry.get())
+      return *result;
+
+    OutT result = replacer(element);
+    cacheEntry.resolve(result);
+    return result;
+  }
+
+private:
+  ReplacerFn replacer;
+  CyclicReplacerCache<InT, OutT> cache;
+};
+
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_CACHINGREPLACER_H
diff --git a/mlir/unittests/Support/CMakeLists.txt b/mlir/unittests/Support/CMakeLists.txt
index 1dbf072bcbbfd..ec79a1c640909 100644
--- a/mlir/unittests/Support/CMakeLists.txt
+++ b/mlir/unittests/Support/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_unittest(MLIRSupportTests
+  CyclicReplacerCacheTest.cpp
   IndentedOstreamTest.cpp
   StorageUniquerTest.cpp
 )
diff --git a/mlir/unittests/Support/CyclicReplacerCacheTest.cpp b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
new file mode 100644
index 0000000000000..a4a92dbe147d4
--- /dev/null
+++ b/mlir/unittests/Support/CyclicReplacerCacheTest.cpp
@@ -0,0 +1,474 @@
+//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Support/CyclicReplacerCache.h"
+#include "llvm/ADT/SetVector.h"
+#include "gmock/gmock.h"
+#include <map>
+#include <set>
+
+using namespace mlir;
+
+TEST(CachedCyclicReplacerTest, testNoRecursion) {
+  CachedCyclicReplacer<int, bool> replacer(
+      /*replacer=*/[](int n) { return static_cast<bool>(n); },
+      /*cycleBreaker=*/[](int n) { return std::nullopt; });
+
+  EXPECT_EQ(replacer(3), true);
+  EXPECT_EQ(replacer(0), false);
+}
+
+TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
+  // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
+  CachedCyclicReplacer<int, int> replacer(
+      /*replacer=*/[&](int n) { return replacer((n + 1) % 3); },
+      /*cycleBreaker=*/[&](int n) { return -1; });
+
+  // Starting at 0.
+  EXPECT_EQ(replacer(0), -1);
+  // Starting at 2.
+  EXPECT_EQ(replacer(2), -1);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: ChainRecursion
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps ints into vectors of
+/// ints.
+///
+/// The replacement result for input `n` is the replacement result of `(n+1)%3`
+/// appended with an element `42`. Theoretically, this will produce an
+/// infinitely long vector. The cycle-breaker function prunes this infinite
+/// recursion in the replacer logic by returning an empty vector upon the first
+/// re-occurrence of an input value.
+class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
+public:
+  // N ==> (N+1) % 3
+  // This will create a chain of infinite length without recursion pruning.
+  CachedCyclicReplacerChainRecursionPruningTest()
+      : replacer(
+            [&](int n) {
+              ++invokeCount;
+              std::vector<int> result = replacer((n + 1) % 3);
+              result.push_back(42);
+              return result;
+            },
+            [&](int n) -> std::optional<std::vector<int>> {
+              return baseCase.value_or(n) == n
+                         ? std::make_optional(std::vector<int>{})
+                         : std::nullopt;
+            }) {}
+
+  std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
+
+  CachedCyclicReplacer<int, std::vector<int>> replacer;
+  int invokeCount = 0;
+  std::optional<int> baseCase = std::nullopt;
+};
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+
+  // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 2);
+
+  // Starting at 2. Cycle length is 4. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(2), getChain(4));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
+  // Starting at 1. Cycle length is 3.
+  EXPECT_EQ(replacer(1), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
+  baseCase = 0;
+
+  // Starting at 0. Cycle length is 3.
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 3);
+}
+
+TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
+  baseCase = 0;
+
+  // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
+  EXPECT_EQ(replacer(1), getChain(5));
+  EXPECT_EQ(invokeCount, 5);
+
+  // Starting at 0. Cycle length is 3. Entire result is cached.
+  invokeCount = 0;
+  EXPECT_EQ(replacer(0), getChain(3));
+  EXPECT_EQ(invokeCount, 0);
+}
+
+//===----------------------------------------------------------------------===//
+// CachedCyclicReplacer: GraphReplacement
+//===----------------------------------------------------------------------===//
+
+/// This set of tests uses a replacer function that maps from cyclic graphs to
+/// trees, pruning out cycles in the process.
+///
+/// It consists of two helper classes:
+/// - Graph
+///   - A directed graph where nodes are non-negative integers.
+/// - PrunedGraph
+///   - A Graph where edges that used to cause cycles are now represented with
+///     an indirection (a recursionId).
+class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
+public:
+  /// A directed graph where nodes are non-negative integers.
+  struct Graph {
+    using Node = int64_t;
+
+    /// Use ordered containers for deterministic output.
+    /// Nodes without outgoing edges are considered nonexistent.
+    std::map<Node, std::set<Node>> edges;
+
+    void addEdge(Node src, Node sink) { edges[src].insert(sink); }
+
+    bool isCyclic() const {
+      DenseSet<Node> visited;
+      for (Node root : llvm::make_first_range(edges)) {
+        if (visited.contains(root))
+          continue;
+
+        SetVector<Node> path;
+        SmallVector<Node> workstack;
+        workstack.push_back(root);
+        while (!workstack.empty()) {
+          Node curr = workstack.back();
+          workstack.pop_back();
+
+          if (curr < 0) {
+            // A negative node signals the end of processing all of this node's
+            // children. Remove self from path.
+            assert(path.back() == -curr && "internal inconsistency");
+            path.pop_back();
+            continue;
+          }
+
+          if (path.contains(curr))
+            return true;
+
+          visited.insert(curr);
+          auto edgesIter = edges.find(curr);
+          if (edgesIter == edges.end() || edgesIter->second.empty())
+            continue;
+
+          path.insert(curr);
+          // Push negative node to signify recursion return.
+          workstack.push_back(-curr);
+          workstack.insert(workstack.end(), edgesIter->second.begin(),
+                           edgesIter->second.end());
+        }
+      }
+      return false;
+    }
+
+    /// Deterministic output for testing.
+    std::string serialize() const {
+      std::ostringstream oss;
+      for (const auto &[src, neighbors] : edges) {
+        oss << src << ":";
+        for (Graph::Node neighbor : neighbors)
+          oss << " " << neighbor;
+        oss << "\n";
+      }
+      return oss.str();
+    }
+  };
+
+  /// A Graph where edges that used to cause cycles (back-edges) are now
+  /// represented with an indirection (a recursionId).
+  ///
+  /// In addition to each node having an integer ID, each node also tracks the
+  /// original integer ID it had in the original graph. This way for every
+  /// back-edge, we can represent it as pointing to a new instance of the
+  /// original node. Then we mark the original node and the new instance with
+  /// a new unique recursionId to indicate that they're supposed to be the same
+  /// node.
+  struct PrunedGraph {
+    using Node = Graph::Node;
+    struct NodeInfo {
+      Graph::Node originalId;
+      /// A negative recursive index means not recursive. Otherwise nodes with
+      /// the same originalId & recursionId are the same node in the original
+      /// graph.
+      int64_t recursionId;
+    };
+
+    /// Add a regular non-recursive-self node.
+    Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
+      Node id = nextConnectionId++;
+      info[id] = {originalId, recursionIndex};
+      return id;
+    }
+    /// Add a recursive-self...
[truncated]

@zyx-billy zyx-billy requested a review from gysit July 9, 2024 21:05
Copy link
Contributor

@Mogball Mogball left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing. In general, CacledCyclicReplacer doesn't seem very friendly to non-trivial InT and OutT. Whether you want to improve the generics here or just add a disclaimed to the class is up to you

mlir/include/mlir/Support/CyclicReplacerCache.h Outdated Show resolved Hide resolved
void resolve(OutT result) {
assert(!this->result && "cache entry already resolved");
this->result = result;
cache.finalizeReplacement(element, result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call this before this->result is set? That would allow you to std::move(result) into this->result.

mlir/include/mlir/Support/CyclicReplacerCache.h Outdated Show resolved Hide resolved
mlir/include/mlir/Support/CyclicReplacerCache.h Outdated Show resolved Hide resolved
/// infinitely long vector. The cycle-breaker function prunes this infinite
/// recursion in the replacer logic by returning an empty vector upon the first
/// re-occurrence of an input value.
class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
namespace {
class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {

Make sure these test classes are in the anonymous namespace

@zyx-billy
Copy link
Contributor Author

Amazing. In general, CacledCyclicReplacer doesn't seem very friendly to non-trivial InT and OutT. Whether you want to improve the generics here or just add a disclaimed to the class is up to you

Yeah I was debating how general I should make it given it's currently only used internally for trivial types (raw pointers in the library, ints in tests). Until we have more uses for it, perhaps it's better if I just doubled down on only supporting trivial types and add a disclaimer. At least for most IR concepts in mlir a pointer should be enough.

@Mogball
Copy link
Contributor

Mogball commented Jul 11, 2024

SGTM

@zyx-billy zyx-billy merged commit ec50f58 into main Jul 12, 2024
7 checks passed
@zyx-billy zyx-billy deleted the users/zyx-billy/mlir/cyclic_replacer_cache branch July 12, 2024 05:08
zyx-billy added a commit that referenced this pull request Jul 12, 2024
The current `AttrTypeReplacer` does not allow for custom handling of
replacer functions that may cause self-recursion. For example, the
replacement of one attr/type may depend on the replacement of another
attr/type (by calling into the replacer manually again), which in turn
may depend on the replacement of the original attr/type.

To enable this functionality, this PR broke out the original
AttrTypeReplacer into two parts:
- An uncached base version (`detail::AttrTypeReplacerBase`) that allows
registering replacer functions and has logic for invoking it on
attr/types & their sub-elements
- A cached version (`AttrTypeReplacer`) that provides the same caching
as the original one. This is still the one used everywhere and behavior
is unchanged.

On top of the uncached base version, a `CyclicAttrTypeReplacer` is
introduced that provides caching & cycle-handling for replacer logic
that is cyclic. Cycle-breaking & caching is provided by the
`CyclicReplacerCache` from
#98202.

Both concrete implementations of the uncached base version use CRTP to
avoid dynamic dispatch. The base class merely provides replacer
registration & invocation, and is not meant to be used, or otherwise
extended elsewhere.
zyx-billy added a commit that referenced this pull request Jul 12, 2024
)

Use the new CyclicReplacerCache from
#98202 to support importing of
recursive DITypes in LLVM dialect's DebugImporter. This helps simplify
the implementation, allows for separate testing of the cache infra
itself, and as a result we even got more efficient translations.
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
This is a support data structure that acts as a cache for replacer-like
functions that map values between two domains. The difference compared
to just using a map to cache in-out pairs is that this class is able to
handle replacer logic that is self-recursive (and thus may cause
infinite recursion in the naive case).

This class provides a hook for the user to perform cycle pruning when a
cycle is identified, and is able to perform context-sensitive caching so
that the replacement result for an input that is part of a pruned cycle
can be distinct from the replacement result for the same input when it
is not part of a cycle.

In addition, this class allows deferring cycle pruning until specific
inputs are repeated. This is useful for cases where not all elements in
a cycle can perform pruning. The user still must guarantee that at least
one element in any given cycle can perform pruning. Even if not, an
assertion will eventually be tripped instead of infinite recursion (the
run-time is linearly bounded by the maximum cycle length of its input).
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
The current `AttrTypeReplacer` does not allow for custom handling of
replacer functions that may cause self-recursion. For example, the
replacement of one attr/type may depend on the replacement of another
attr/type (by calling into the replacer manually again), which in turn
may depend on the replacement of the original attr/type.

To enable this functionality, this PR broke out the original
AttrTypeReplacer into two parts:
- An uncached base version (`detail::AttrTypeReplacerBase`) that allows
registering replacer functions and has logic for invoking it on
attr/types & their sub-elements
- A cached version (`AttrTypeReplacer`) that provides the same caching
as the original one. This is still the one used everywhere and behavior
is unchanged.

On top of the uncached base version, a `CyclicAttrTypeReplacer` is
introduced that provides caching & cycle-handling for replacer logic
that is cyclic. Cycle-breaking & caching is provided by the
`CyclicReplacerCache` from
llvm#98202.

Both concrete implementations of the uncached base version use CRTP to
avoid dynamic dispatch. The base class merely provides replacer
registration & invocation, and is not meant to be used, or otherwise
extended elsewhere.
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…m#98203)

Use the new CyclicReplacerCache from
llvm#98202 to support importing of
recursive DITypes in LLVM dialect's DebugImporter. This helps simplify
the implementation, allows for separate testing of the cache infra
itself, and as a result we even got more efficient translations.
@Hardcode84
Copy link
Contributor

Unittest seems to be broken under the MSVC

FAILED: tools/mlir/unittests/Support/CMakeFiles/MLIRSupportTests.dir/CyclicReplacerCacheTest.cpp.obj
C:\PROGRA~2\MICROS~2\2022\BUILDT~1\VC\Tools\MSVC\1435~1.322\bin\Hostx64\x64\cl.exe  /nologo /TP -DBUILD_EXAMPLES -DGTEST_HAS_RTTI=0 -DGTEST_NO_LLVM_SUPPORT=0 -DMLIR_INCLUDE_TESTS -DUNICODE -D_CRT_NONSTDC_NO_DEPRECATE -D_CRT_NONSTDC_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE -D_CRT_SECURE_NO_WARNINGS -D_GLIBCXX_ASSERTIONS -D_HAS_EXCEPTIONS=0 -D_SCL_SECURE_NO_DEPRECATE -D_SCL_SECURE_NO_WARNINGS -D_UNICODE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -ID:\projs\llvm\llvm-build\tools\mlir\unittests\Support -ID:\projs\llvm\llvm-project\mlir\unittests\Support -ID:\projs\llvm\llvm-build\include -ID:\projs\llvm\llvm-project\llvm\include -ID:\projs\llvm\llvm-project\mlir\include -ID:\projs\llvm\llvm-build\tools\mlir\include -ID:\projs\llvm\llvm-project\third-party\unittest\googletest\include -ID:\projs\llvm\llvm-project\third-party\unittest\googlemock\include /DWIN32 /D_WINDOWS   /Zc:inline /Zc:preprocessor /Zc:__cplusplus /Oi /bigobj /permissive- /W4 -wd4141 -wd4146 -wd4244 -wd4267 -wd4291 -wd4351 -wd4456 -wd4457 -wd4458 -wd4459 -wd4503 -wd4624 -wd4722 -wd4100 -wd4127 -wd4512 -wd4505 -wd4610 -wd4510 -wd4702 -wd4245 -wd4706 -wd4310 -wd4701 -wd4703 -wd4389 -wd4611 -wd4805 -wd4204 -wd4577 -wd4091 -wd4592 -wd4319 -wd4709 -wd5105 -wd4324 -w14062 -we4238 /Gw /O2 /Ob2  -std:c++17 -MD  /EHs-c- /GR- -UNDEBUG /showIncludes /Fotools\mlir\unittests\Support\CMakeFiles\MLIRSupportTests.dir\CyclicReplacerCacheTest.cpp.obj /Fdtools\mlir\unittests\Support\CMakeFiles\MLIRSupportTests.dir\ /FS -c D:\projs\llvm\llvm-project\mlir\unittests\Support\CyclicReplacerCacheTest.cpp
D:\projs\llvm\llvm-project\mlir\unittests\Support\CyclicReplacerCacheTest.cpp(30): error C2326: 'auto CachedCyclicReplacerTest_testInPlaceRecursionPruneAnywhere_Test::TestBody::<lambda_1>::operator ()(int) const': function cannot access 'replacer'
Microsoft (R) C/C++ Optimizing Compiler Version 19.35.32216.1 for x64

CI passed, but looks like it uses older version

-- The C compiler identification is MSVC 19.29.30154.0
-- The CXX compiler identification is MSVC 19.29.30154.0

@zyx-billy
Copy link
Contributor Author

@Hardcode84 thanks for the heads up. Looks like this might've fixed it? #101021

@Hardcode84
Copy link
Contributor

Yes, it fixed it. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants