diff --git a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h index cc7487a98..e659b9947 100644 --- a/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h +++ b/include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h @@ -111,9 +111,16 @@ class TargetSlotLattice : public dataflow::Lattice { class TargetSlotAnalysis : public dataflow::SparseBackwardDataFlowAnalysis { public: - explicit TargetSlotAnalysis(DataFlowSolver &solver, - SymbolTableCollection &symbolTable) - : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + explicit TargetSlotAnalysis( + DataFlowSolver &solver, SymbolTableCollection &symbolTable, + // The dataflow solver is a private member of the base analysis + // class, so if we want to access it we have to get it explicitly from + // the caller. It's required that this solver is pre-loaded with a + // SparseConstantPropagation analysis. I'd like a better way to do + // this: maybe pass a callback? + const DataFlowSolver *sccpAnalysis) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + sccpAnalysis(sccpAnalysis) {} ~TargetSlotAnalysis() override = default; using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; @@ -125,6 +132,9 @@ class TargetSlotAnalysis void visitBranchOperand(OpOperand &operand) override {}; void visitCallOperand(OpOperand &operand) override {}; void setToExitState(TargetSlotLattice *lattice) override {}; + + private: + const DataFlowSolver *sccpAnalysis; }; } // namespace target_slot_analysis diff --git a/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp index 3fe126af1..d0f44a833 100644 --- a/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp +++ b/lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp @@ -1,8 +1,10 @@ #include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h" #include "lib/Dialect/Utils.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project @@ -22,15 +24,47 @@ void TargetSlotAnalysis::visitOperation( llvm::TypeSwitch(*op) .Case([&](auto insertOp) { LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); - auto insertIndexRes = get1DExtractionIndex(insertOp); + auto insertIndices = insertOp.getIndices(); + if (insertIndices.size() != 1) { + LLVM_DEBUG(llvm::dbgs() << "At " << insertOp + << " can't handle >1D insertion index\n"); + return; + } + + Value insertIndexValue = insertOp.getIndices()[0]; + const dataflow::Lattice *insertIndexLattice = + sccpAnalysis + ->lookupState>( + insertIndexValue); + + if (insertIndexLattice) { + LLVM_DEBUG(llvm::dbgs() + << "At " << insertOp << " SCCP analysis gives lattice of " + << *insertIndexLattice << "\n"); + } + // If the target slot can't be statically determined, we can't // propagate anything through the IR. - if (failed(insertIndexRes)) return; + if (!insertIndexLattice || + insertIndexLattice->getValue().isUninitialized() || + !insertIndexLattice->getValue().getConstantValue()) { + LLVM_DEBUG( + llvm::dbgs() + << "At " << insertOp + << " can't statically determine constant insertion index\n"); + return; + } + Attribute insertIndexAttr = + insertIndexLattice->getValue().getConstantValue(); + auto insertIndexIntAttr = insertIndexAttr.dyn_cast(); + assert(insertIndexIntAttr && + "If 1D insertion index is constant, it must be integer"); + int64_t insertIndexConst = insertIndexIntAttr.getInt(); // The target slot propagates to the value inserted, which is the first // positional argument TargetSlotLattice *lattice = operands[0]; - TargetSlot newSlot = TargetSlot{insertIndexRes.value()}; + TargetSlot newSlot = TargetSlot{insertIndexConst}; LLVM_DEBUG({ llvm::dbgs() << "Joining " << lattice->getValue() << " and " << newSlot << " --> " diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp index 816d6b5b5..0768dc9ec 100644 --- a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -42,27 +42,47 @@ struct InsertRotate : impl::InsertRotateBase { SymbolTableCollection symbolTable; DataFlowSolver solver; - // These two upstream analyses are required dependencies for any sparse - // dataflow analysis, or else the analysis will be a no-op. Cf. + + // These two upstream analyses are required to be instantiated in any + // sparse dataflow analysis, or else the analysis will be a no-op. Cf. // https://github.com/llvm/llvm-project/issues/58922 solver.load(); solver.load(); - solver.load(symbolTable); if (failed(solver.initializeAndRun(getOperation()))) { getOperation()->emitOpError() << "Failed to run the analysis.\n"; signalPassFailure(); return; } + // We want to use the result of the sparse constant propagation from the + // first dataflow solver as an input to the target slot analysis. For some + // reason, actually running `--sccp` before this pass causes the IR to + // simplify away some operations that are needed to properly identify + // target slots. So the SparseConstantPropagation above is a simulated + // folding of arith operations, so as to identify when insertion indices + // are statically inferable. + // + // TODO(#572): find a better way to depend dataflow analyses on each other. + DataFlowSolver solver2; + solver2.load(); + solver2.load(); + solver2.load(symbolTable, + &solver); + if (failed(solver2.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + // Annotate all arith ops with their target slot attribute, so that it can // be matched in the DRR rules. OpBuilder builder(context); getOperation()->walk([&](Operation *op) { if (op->getNumResults() == 0) return; auto *targetSlotLattice = - solver.lookupState( + solver2.lookupState( op->getResult(0)); - if (targetSlotLattice->getValue().isInitialized()) { + if (targetSlotLattice && targetSlotLattice->getValue().isInitialized()) { op->setAttr( "target_slot", builder.getIndexAttr(targetSlotLattice->getValue().getValue())); diff --git a/tests/heir_simd_vectorizer/box_blur_64x64.mlir b/tests/heir_simd_vectorizer/box_blur_64x64.mlir index 8e6fb5d57..0ef7ddf8b 100644 --- a/tests/heir_simd_vectorizer/box_blur_64x64.mlir +++ b/tests/heir_simd_vectorizer/box_blur_64x64.mlir @@ -29,6 +29,7 @@ module { // CHECK-NEXT: secret.yield %[[v15]] // CHECK-NEXT: } -> !secret.secret> // CHECK-NEXT: return %[[v0]] + func.func @box_blur(%arg0: tensor<4096xi16>) -> tensor<4096xi16> { %c4096 = arith.constant 4096 : index %c64 = arith.constant 64 : index diff --git a/tests/heir_simd_vectorizer/roberts_cross_4x4.mlir b/tests/heir_simd_vectorizer/roberts_cross_4x4.mlir new file mode 100644 index 000000000..f2f855fab --- /dev/null +++ b/tests/heir_simd_vectorizer/roberts_cross_4x4.mlir @@ -0,0 +1,82 @@ +// Ported from https://github.com/MarbleHE/HECO/blob/3e13744233ab0c09030a41ef98b4e061b6fa2eac/evaluation/benchmark/heco_input/robertscross_4x4.mlir + +// RUN: heir-opt --secretize=entry-function=roberts_cross --wrap-generic --canonicalize --cse \ +// RUN: --heir-simd-vectorizer %s | FileCheck %s + +module{ + // CHECK-LABEL: @roberts_cross + // CHECK-SAME: (%[[arg0:.*]]: !secret.secret>) -> !secret.secret> { + // CHECK-NEXT: %[[c15:.*]] = arith.constant 15 : index + // CHECK-NEXT: %[[c11:.*]] = arith.constant 11 : index + // CHECK-NEXT: secret.generic ins(%[[arg0]] : !secret.secret>) { + // CHECK-NEXT: ^bb0(%[[arg1:.*]]: tensor<16xi16>): + // CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg1]], %[[c11]] + // CHECK-NEXT: %[[v2:.*]] = arith.subi %[[v1]], %[[arg1]] + // CHECK-NEXT: %[[v3:.*]] = tensor_ext.rotate %[[arg1]], %[[c15]] + // CHECK-NEXT: %[[v4:.*]] = arith.subi %[[v1]], %[[v3]] + // CHECK-NEXT: %[[v5:.*]] = arith.muli %[[v2]], %[[v2]] + // CHECK-NEXT: %[[v6:.*]] = arith.muli %[[v4]], %[[v4]] + // CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] + func.func @roberts_cross(%img: tensor<16xi16>) -> tensor<16xi16> { + %c16 = arith.constant 16 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c-1 = arith.constant -1 : index + + // Each point p = img[x][y], where x is row and y is column, in the new image will equal: + // (img[x-1][y-1] - img[x][y])^2 + (img[x-1][y] - img[x][y-1])^2 + %r = affine.for %x = 0 to 4 iter_args(%imgx = %img) -> tensor<16xi16> { + %1 = affine.for %y = 0 to 4 iter_args(%imgy = %imgx) -> tensor<16xi16> { + + // fetch img[x-1][y-1] + %4 = arith.addi %x, %c-1 : index + %5 = arith.muli %4, %c4 : index + %6 = arith.addi %y, %c-1 : index + %7 = arith.addi %5, %6 : index + %8 = arith.remui %7, %c16 : index + %9 = tensor.extract %img[%8] : tensor<16xi16> + + // fetch img[x][y] + %10 = arith.muli %x, %c4 : index + %11 = arith.addi %10, %y : index + %12 = arith.remui %11, %c16 : index + %13 = tensor.extract %img[%12] : tensor<16xi16> + + // subtract those two + %14 = arith.subi %9, %13 : i16 + + // fetch img[x-1][y] + %15 = arith.addi %x, %c-1 : index + %16 = arith.muli %15, %c4 : index + %17 = arith.addi %y, %c-1 : index + %18 = arith.addi %16, %17 : index + %19 = arith.remui %18, %c16 : index + %20 = tensor.extract %img[%19] : tensor<16xi16> + + // fetch img[x][y-1] + %21 = arith.muli %x, %c4 : index + %22 = arith.addi %y, %c-1 : index + %23 = arith.addi %21, %22 : index + %24 = arith.remui %23, %c16 : index + %25 = tensor.extract %img[%24] : tensor<16xi16> + + // subtract those two + %26 = arith.subi %20, %25 : i16 + + // square each difference + %27 = arith.muli %14, %14 : i16 + %28 = arith.muli %26, %26 : i16 + + // add the squares + %29 = arith.addi %27, %28 : i16 + + // save to result[x][y] + %30 = tensor.insert %29 into %imgy[%12] : tensor<16xi16> + affine.yield %30: tensor<16xi16> + } + affine.yield %1 : tensor<16xi16> + } + return %r : tensor<16xi16> + } +}