From 54fc80db22c38b9c3e950ce8443c64a7ecefcf02 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 12 Mar 2024 10:40:24 -0700 Subject: [PATCH] refactor get1DExtractionIndex into util file --- include/Dialect/BUILD | 16 ++++++++ lib/Dialect/BUILD | 12 ++++++ lib/Dialect/TensorExt/Transforms/BUILD | 1 + .../Transforms/CollapseInsertionChains.cpp | 20 +--------- lib/Dialect/Utils.h | 37 +++++++++++++++++++ 5 files changed, 67 insertions(+), 19 deletions(-) create mode 100644 lib/Dialect/Utils.h diff --git a/include/Dialect/BUILD b/include/Dialect/BUILD index bebb71ecf..3eb3ce231 100644 --- a/include/Dialect/BUILD +++ b/include/Dialect/BUILD @@ -10,6 +10,7 @@ package( exports_files( [ "HEIRInterfaces.h", + "Utils.h", ], ) @@ -43,3 +44,18 @@ gentbl_cc_library( ":td_files", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + hdrs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/BUILD b/lib/Dialect/BUILD index fb5356f75..9d031b48e 100644 --- a/lib/Dialect/BUILD +++ b/lib/Dialect/BUILD @@ -18,3 +18,15 @@ cc_library( "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "Utils", + srcs = [ + "Utils.h", + ], + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index f06b7becd..0e9f18828 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -44,6 +44,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect:Utils", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp index 63c2a9983..f64337780 100644 --- a/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp +++ b/lib/Dialect/TensorExt/Transforms/CollapseInsertionChains.cpp @@ -4,6 +4,7 @@ #include #include "include/Dialect/TensorExt/IR/TensorExtOps.h" +#include "lib/Dialect/Utils.h" #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project @@ -27,25 +28,6 @@ namespace tensor_ext { #define GEN_PASS_DEF_COLLAPSEINSERTIONCHAINS #include "include/Dialect/TensorExt/Transforms/Passes.h.inc" -template -FailureOr get1DExtractionIndex(Op op) { - auto insertIndices = op.getIndices(); - if (insertIndices.size() != 1) return failure(); - - // Each index must be constant; this may require running --canonicalize or - // -sccp before this pass to apply folding rules (use -sccp if you need to - // fold constants through control flow). - Value insertIndex = *insertIndices.begin(); - auto insertIndexConstOp = insertIndex.getDefiningOp(); - if (!insertIndexConstOp) return failure(); - - auto insertOffsetAttr = - llvm::dyn_cast(insertIndexConstOp.getValue()); - if (!insertOffsetAttr) return failure(); - - return insertOffsetAttr.getInt(); -} - /// A pattern that searches for sequences of extract + insert, where the /// indices extracted and inserted have the same offset, and replaced them with /// a single rotate operation. diff --git a/lib/Dialect/Utils.h b/lib/Dialect/Utils.h new file mode 100644 index 000000000..4c25a5e04 --- /dev/null +++ b/lib/Dialect/Utils.h @@ -0,0 +1,37 @@ +#ifndef INCLUDE_DIALECT_UTILS_H_ +#define INCLUDE_DIALECT_UTILS_H_ + +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// Given a tensor::InsertOp or tensor::ExtractOp, and assuming the shape +/// of the input tensor is 1-dimensional and the input index is constant, +/// return the constant index value. If any of these conditions are not +/// met, return a failure. +template +FailureOr get1DExtractionIndex(Op op) { + auto insertIndices = op.getIndices(); + if (insertIndices.size() != 1) return failure(); + + // Each index must be constant; this may require running --canonicalize or + // -sccp before this pass to apply folding rules (use -sccp if you need to + // fold constants through control flow). + Value insertIndex = *insertIndices.begin(); + auto insertIndexConstOp = insertIndex.getDefiningOp(); + if (!insertIndexConstOp) return failure(); + + auto insertOffsetAttr = + llvm::dyn_cast(insertIndexConstOp.getValue()); + if (!insertOffsetAttr) return failure(); + + return insertOffsetAttr.getInt(); +} + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_DIALECT_UTILS_H_