From acee3ad542c27417ed2e7f8c7d528815067d6ab9 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 27 Mar 2024 19:44:21 -0700 Subject: [PATCH] Rewrite rotation analysis for new ports of HECO eval artifacts This rewrite still has some efficiency issues, but it removes the hacky issues from the previous lattice-based dataflow analysis, and supports the unsupported examples from this PR. Ports - dot_product - linear_polynomial - quadratic_polynomial --- .../RotationAnalysis/RotationAnalysis.h | 328 ++++++++++-------- .../RotationAnalysis/RotationAnalysis.cpp | 150 +++++--- .../SelectVariableNames.cpp | 2 - .../TensorExt/Transforms/RotateAndReduce.cpp | 172 ++------- tests/heir_simd_vectorizer/dot_product_8.mlir | 18 + .../hamming_distance.mlir | 6 - .../linear_polynomial_64.mlir | 22 ++ .../quadratic_polynomial.mlir | 25 ++ .../secret_to_bgv/hamming_distance_1024.mlir | 4 +- tests/tensor_ext/rotate_and_reduce.mlir | 24 ++ 10 files changed, 403 insertions(+), 348 deletions(-) create mode 100644 tests/heir_simd_vectorizer/dot_product_8.mlir create mode 100644 tests/heir_simd_vectorizer/linear_polynomial_64.mlir create mode 100644 tests/heir_simd_vectorizer/quadratic_polynomial.mlir diff --git a/include/Analysis/RotationAnalysis/RotationAnalysis.h b/include/Analysis/RotationAnalysis/RotationAnalysis.h index a1b2f2313..0ddc6348b 100644 --- a/include/Analysis/RotationAnalysis/RotationAnalysis.h +++ b/include/Analysis/RotationAnalysis/RotationAnalysis.h @@ -1,14 +1,15 @@ #ifndef INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ #define INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ -#include +#include -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project #define DEBUG_TYPE "rotation-analysis" @@ -16,200 +17,249 @@ namespace mlir { namespace heir { namespace rotation_analysis { -// A wrapper around a mapping from a single tensor SSA value to a set of its -// access indices. -class RotationSets { +// A PartialReduction represents a subset of an arithmetic op tree that reduces +// values within a tensor to a scalar (present in index zero of the result +// tensor). +// +// It is "partial" in the sense that it may not reduce across all elements of a +// tensor, and it is used in the analysis to accumulate reduced tensor indices +// across the IR. +// +// It also stores a reference to the SSA value that identifes the "end" of the +// computation (i.e., the SSA value that contains the result of the reduction). +class PartialReduction { public: - enum class Status { - // The tensor value has not been set - Uninitialized, - - // The rotation set is in a normal state. - Normal, - - // The rotation set has a property that makes it invalid for later - // optimizations: - // - // - It involves operations touch more than one source tensor (not - // including value-semantic outputs) - Overdetermined - - }; - - public: - RotationSets() = default; - ~RotationSets() = default; - - // Clear the member data, i.e., set the value back to an uninitialized - // state. - void clear() { - accessedIndices.clear(); - status = Status::Uninitialized; - } - bool empty() const { return accessedIndices.empty(); } - bool isOverdetermined() const { return status == Status::Overdetermined; } - - bool isUninitialized() const { return status == Status::Uninitialized; } - void addRotation(int64_t index) { accessedIndices.insert(index); } - bool operator==(const RotationSets &rhs) const { - return tensor == rhs.tensor && status == rhs.status && - accessedIndices == rhs.accessedIndices; + // Returns true if the accessed indices constitute all indices of the reduced + // tensor. + bool isComplete() const { + auto tensorType = tensor.getType().dyn_cast(); + assert(tensorType && + "Internal state of RotationAnalysis is broken; tensor must have a " + "ranked tensor type"); + + // std::set is ordered, so min/max is first/last element of the set + int64_t minIndex = *accessedIndices.begin(); + int64_t maxIndex = *accessedIndices.rbegin(); + return minIndex == 0 && maxIndex == tensorType.getShape()[0] - 1 && + accessedIndices.size() == tensorType.getShape()[0]; } - const std::unordered_set &getAccessedIndices() const { + const std::set &getAccessedIndices() const { return accessedIndices; } Value getTensor() const { return tensor; } + Value getRoot() const { return root; } + void print(raw_ostream &os) const { - os << tensor << ": ["; + os << "{ opName: " << (opName.has_value() ? opName->getStringRef() : "None") + << "; " << " tensor: " << tensor << "; " << "rotations: ["; for (auto index : accessedIndices) { os << index << ", "; } - os << "]"; + os << "]; root: " << root << "; }"; } - static RotationSets overdetermined() { - RotationSets sets; - sets.status = Status::Overdetermined; - return sets; - } - - static RotationSets from(Value tensor) { - RotationSets sets; - if (!tensor.getType().isa()) { - sets.status = Status::Uninitialized; - return sets; - } - - sets.status = Status::Normal; - sets.tensor = tensor; - if (auto blockArg = dyn_cast(tensor)) { - sets.addRotation(0); - } - return sets; + // Construct a "leaf" of a reduction, i.e., a PartialReduction that represents + // no operations applied to a starting tensor SSA value. + static PartialReduction initializeFromValue(Value tensor) { + PartialReduction reduction; + reduction.tensor = tensor; + reduction.root = tensor; + reduction.opName = std::nullopt; + // In the FHE world, the only extractible element (without a rotation) of a + // packed ciphertext is the constant term, i.e., the first element of the + // tensor. So a tensor by itself is always considered a reduction by that + // first element. + reduction.addRotation(0); + + LLVM_DEBUG(llvm::dbgs() + << "Initializing at " << tensor << " with rotations [0]\n"); + return reduction; } // Shift the rotation indices by the given amount. This helps in a situation // where an IR repeatedly rotates by 1, to ensure that rotations accumulate // like {1, 2, 3, ...} rather than {1, 1, 1, ...} - static RotationSets rotate(const RotationSets &lhs, const int64_t shift) { - if (lhs.status == Status::Overdetermined) { - return overdetermined(); - } - - RotationSets shifted; - shifted.status = Status::Normal; + static PartialReduction rotate(const PartialReduction &lhs, + const int64_t shift, Value result) { + LLVM_DEBUG({ + llvm::dbgs() << "Rotating\n\t"; + lhs.print(llvm::dbgs()); + llvm::dbgs() << " by " << shift; + }); + PartialReduction shifted; shifted.tensor = lhs.tensor; + shifted.opName = lhs.opName; + shifted.root = result; int64_t size = llvm::cast(lhs.tensor.getType()).getShape()[0]; + assert(!lhs.accessedIndices.empty() && + "Internal state of RotationAnalysis is broken; empty rotation sets " + "should be impossible"); for (auto index : lhs.accessedIndices) { shifted.addRotation((index + shift) % size); } + LLVM_DEBUG({ + llvm::dbgs() << " to\n\t"; + shifted.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); return shifted; } - static RotationSets join(const RotationSets &lhs, const RotationSets &rhs) { - if (lhs.status == Status::Overdetermined || - rhs.status == Status::Overdetermined) { - return overdetermined(); + // Determine if two PartialRotations are legal to join at an op whose + // OperationName is given. + static bool canJoin(const PartialReduction &lhs, const PartialReduction &rhs, + OperationName opName) { + if (lhs.tensor != rhs.tensor) { + return false; } - if (rhs.status == Status::Uninitialized || rhs.accessedIndices.empty()) - return lhs; - if (lhs.status == Status::Uninitialized || lhs.accessedIndices.empty()) - return rhs; - - if (lhs.tensor != rhs.tensor) { - LLVM_DEBUG({ - llvm::dbgs() << "Joining rotations of different tensors: " << lhs.tensor - << " and " << rhs.tensor << "\n"; - }); - return overdetermined(); + // If neither of the lhs and rhs ops are set, then any op is legal. + if (lhs.opName.has_value() || rhs.opName.has_value()) { + // Otherwise, if both ops are set, then they must agree with each other + // and the new op. + if (lhs.opName.has_value() && rhs.opName.has_value() && + (*lhs.opName != *rhs.opName || *lhs.opName != opName)) { + return false; + } + + // Otherwise, at least one of lhs and rhs must have a set op name, and it + // must agree with the new op. + auto materializedOpName = + lhs.opName.has_value() ? *lhs.opName : *rhs.opName; + if (materializedOpName != opName) { + return false; + } } - LLVM_DEBUG({ - llvm::dbgs() << "Joining :" << lhs.tensor << " and " << rhs.tensor - << "\n"; - }); - RotationSets merged; - merged.status = Status::Normal; + // If the two partial reductions have access indices in common, then they + // cannot be joined because some indices would be contributing multiple + // times to the overall reduction. Maybe we could improve this in the + // future so that we could handle a kind of reduction that sums the same + // index twice, but likely it is better to account for that in a different + // fashion. + auto smaller = + lhs.accessedIndices.size() < rhs.accessedIndices.size() ? lhs : rhs; + auto larger = + lhs.accessedIndices.size() >= rhs.accessedIndices.size() ? lhs : rhs; + return std::all_of(smaller.accessedIndices.begin(), + smaller.accessedIndices.end(), [&](int64_t index) { + return larger.accessedIndices.count(index) == 0; + }); + } + + // Join two partial reductions. This assumes the lhs and rhs have already + // been checked to have compatible tensors and opNames via canJoin. + static PartialReduction join(const PartialReduction &lhs, + const PartialReduction &rhs, Value newRoot, + OperationName opName) { + assert(!lhs.accessedIndices.empty() && + "Internal state of RotationAnalysis is broken; empty rotation sets " + "should be impossible"); + assert(!rhs.accessedIndices.empty() && + "Internal state of RotationAnalysis is broken; empty rotation sets " + "should be impossible"); + + PartialReduction merged; merged.tensor = lhs.tensor; + merged.root = newRoot; + merged.opName = opName; for (auto index : lhs.accessedIndices) { merged.addRotation(index); } for (auto index : rhs.accessedIndices) { merged.addRotation(index); } - return merged; - } - - // Assuming two not-overdetermined rotation sets, compute the overlap in - // their access indices. - static RotationSets overlap(const RotationSets &lhs, - const RotationSets &rhs) { - assert(!lhs.isOverdetermined() && !rhs.isOverdetermined() && - "Expected inputs to RotationSets::overlap to be not overdetermined"); - if (lhs.status == Status::Uninitialized || lhs.empty()) { - return lhs; - } - - if (rhs.status == Status::Uninitialized || rhs.empty()) { - return rhs; - } - - RotationSets merged; - merged.status = Status::Normal; - merged.tensor = lhs.tensor; - for (auto index : lhs.accessedIndices) { - if (rhs.accessedIndices.count(index)) merged.addRotation(index); - } + LLVM_DEBUG({ + llvm::dbgs() << "Joining\n\t"; + lhs.print(llvm::dbgs()); + llvm::dbgs() << " and\n\t"; + rhs.print(llvm::dbgs()); + llvm::dbgs() << " to get\n\t"; + merged.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); return merged; } private: - /// The accessed indices of a single SSA value of tensor type. + // The SSA value being reduced Value tensor; + // The root of the reduction tree constructed so far, e.g., the result of the + // last op in a linear chain of reduction operations. During + // rotate-and-reduce, this represents the final SSA value that is replaced by + // an optimized set of rotations. + Value root; + + // The operation performed in the reduction. + // + // Set to std::nullopt if no binary operation is applied (i.e., the reduction + // is a raw tensor at the leaf of a reduction tree). + std::optional opName; + + // The set of indices of `tensor` accumulated by the reduction so far. + // // There is likely a data structure that can more efficiently represent a set // of intervals of integers, which properly merges adjacent intervals as // values are added. Java/Guava has RangeSet, and boost has interval_set. - std::unordered_set accessedIndices; - Status status = Status::Uninitialized; + // For now we use std::set which is implemented as a binary tree and ordered + // by the index values. + std::set accessedIndices; }; -inline raw_ostream &operator<<(raw_ostream &os, const RotationSets &v) { +inline raw_ostream &operator<<(raw_ostream &os, const PartialReduction &v) { v.print(os); return os; } -class RotationLattice : public dataflow::Lattice { +/// An analysis that identifies, for each tensor-typed SSA value, the set of +/// partial reductions of associative, commutative binary arithmetic operations +/// that reduce it to a scalar via tensor_ext.rotate ops. +class RotationAnalysis { public: - using Lattice::Lattice; -}; + // The constructor requires a DataFlowSolver initialized with a sparse + // constant propagation analysis, which is used to determine the static + // values of rotation shifts. + RotationAnalysis(const DataFlowSolver &solver) : solver(solver){}; + ~RotationAnalysis() = default; -/// An analysis that identifies, for each SSA value, the set of underlying -/// tensors and rotations of those tensors, provided constant rotation shifts -/// can be determined. -class RotationAnalysis - : public dataflow::SparseForwardDataFlowAnalysis { - public: - explicit RotationAnalysis(DataFlowSolver &solver) - : SparseForwardDataFlowAnalysis(solver) {} - ~RotationAnalysis() override = default; - using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + void run(Operation *op); + + /// Add partial reduction + void addPartialReduction(PartialReduction reduction) { + rootToPartialReductions[reduction.getRoot()].emplace_back(reduction); + } - // Given the computed results of the operation, update its operand lattice - // values. - void visitOperation(Operation *op, ArrayRef operands, - ArrayRef results) override; + /// Add a tensor value as the start of a new reduction to the internal + /// reduction mappings. + void initializeFromValueIfTensor(Value value) { + if (RankedTensorType tensorType = + value.getType().dyn_cast()) { + addPartialReduction(PartialReduction::initializeFromValue(value)); + } + } + + const std::vector &getRootedReductionsAt( + Value value) const { + return rootToPartialReductions.at(value); + } + + private: + // The constant propagation analysis used to determine the static values of + // rotation shifts. + const DataFlowSolver &solver; - void setToEntryState(RotationLattice *lattice) override; + // A mapping from a root of a PartialReduction to its PartitalReduction. Note + // each tensor SSA value can be the root of many partial reductions. + llvm::DenseMap> rootToPartialReductions; }; } // namespace rotation_analysis diff --git a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp index 91377fb3d..66401e78b 100644 --- a/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp +++ b/lib/Analysis/RotationAnalysis/RotationAnalysis.cpp @@ -1,73 +1,113 @@ #include "include/Analysis/RotationAnalysis/RotationAnalysis.h" #include "include/Dialect/TensorExt/IR/TensorExtOps.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project namespace mlir { namespace heir { namespace rotation_analysis { -void RotationAnalysis::visitOperation( - Operation *op, ArrayRef operands, - ArrayRef results) { - llvm::TypeSwitch(*op) - .Case([&](auto rotateOp) { - LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); - auto shiftConstantOp = - rotateOp.getShift().template getDefiningOp(); - // If the rotation shift can't be statically determined, we can't - // propagate anything through the IR. - if (!shiftConstantOp) return; +void RotationAnalysis::run(Operation *op) { + op->walk([&](Operation *op) { + // If the op has no tensor results and no regions, then there's nothing to + // do. The operation may consume a tensor but cannot further reduce it. + if (op->getNumRegions() == 0 && + llvm::none_of(op->getResultTypes(), + [](Type type) { return type.isa(); })) { + return WalkResult::advance(); + } - int64_t shiftValue = - dyn_cast(shiftConstantOp.getValue()).getInt(); + // Each tensor result can be the start of a new reduction. + for (Value result : op->getResults()) { + initializeFromValueIfTensor(result); + } - // The target slot propagates from the tensor argument to the result; - // the tensor argument is first in the tablegen definition. - const RotationLattice *lattice = operands[0]; - RotationSets latticeRotations = lattice->getValue(); - - // If it's a block argument, then there is no initialized lattice value - // and we can override it with a "zero rotation" - auto blockArg = dyn_cast(rotateOp.getTensor()); - if (blockArg) { - latticeRotations = RotationSets::from(blockArg); + // Block args within regions can be the start of a new reduction. + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (Value arg : block.getArguments()) { + initializeFromValueIfTensor(arg); } - RotationSets rotated = - RotationSets::rotate(latticeRotations, shiftValue); + } + } - for (RotationLattice *r : results) { - ChangeResult result = r->join(rotated); - propagateIfChanged(r, result); - } - }) - .Default([&](Operation &op) { - // By default, an op propagates its result target slots to all its - // operands. - for (OpOperand &operand : op.getOpOperands()) { - auto *latticeOperand = operands[operand.getOperandNumber()]; + // Each op now gets special treatment. + // + // - Rotate ops shift the accessIndices of their tensor operand's + // reductions if the shift is known to be constant. + // - Binary ops join partial reductions of operands and set the opName. + // - Everything else is ignored. + llvm::TypeSwitch(*op) + .Case([&](auto rotateOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); + const dataflow::Lattice *shiftLattice = + solver.lookupState>( + rotateOp.getShift()); - for (RotationLattice *r : results) { - ChangeResult result = r->join(*latticeOperand); - // If the operand is a block arg, this additionally treats this as - // a zero rotation. If the underlying tensor differs across - // operands, this will also cause a Status::TooManyTensors. - // Otherwise, the join is a no-op. - result |= r->join(RotationSets::from(operand.get())); - propagateIfChanged(r, result); + if (shiftLattice) { + LLVM_DEBUG(llvm::dbgs() << "At " << rotateOp + << " SCCP analysis gives lattice of " + << *shiftLattice << "\n"); } - } - }); -} -void RotationAnalysis::setToEntryState(RotationLattice *lattice) { - lattice->getValue().clear(); + // If the rotation shift can't be statically determined, we can't + // propagate anything through the IR. + if (!shiftLattice || shiftLattice->getValue().isUninitialized() || + !shiftLattice->getValue().getConstantValue()) { + LLVM_DEBUG( + llvm::dbgs() + << "At " << rotateOp + << " can't statically determine constant insertion index\n"); + return; + } + auto shiftValue = shiftLattice->getValue() + .getConstantValue() + .dyn_cast() + .getInt(); + + // For each partial reduction the tensor operand is a root of, + // rotate the accessed indices appropriately. + Value tensor = rotateOp.getTensor(); + Value result = rotateOp.getResult(); + for (const auto &reduction : rootToPartialReductions[tensor]) { + addPartialReduction( + PartialReduction::rotate(reduction, shiftValue, result)); + } + }) + .Case([&](auto arithOp) { + LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << arithOp << "\n"; }); + Value lhs = arithOp.getLhs(); + Value rhs = arithOp.getRhs(); + Value newRoot = arithOp.getResult(); + OperationName opName = arithOp.getOperation()->getName(); + + // This is inefficient, but what can we do better here? I suspect a + // better approach may be to identify cases in which only one of these + // reductions needs to be kept because it's "the best" according to + // some metric (e.g., it monotonically increases the number of indices + // and all else stays the same). But for now even on the + // box_blur_64x64 example this is far from the bottleneck. + for (const auto &lhsReduction : rootToPartialReductions[lhs]) { + for (const auto &rhsReduction : rootToPartialReductions[rhs]) { + if (PartialReduction::canJoin(lhsReduction, rhsReduction, + opName)) { + addPartialReduction(PartialReduction::join( + lhsReduction, rhsReduction, newRoot, opName)); + } + } + } + }); + + return WalkResult::advance(); + }); } } // namespace rotation_analysis diff --git a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp index 13b2493c0..544dc5b01 100644 --- a/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp +++ b/lib/Analysis/SelectVariableNames/SelectVariableNames.cpp @@ -1,7 +1,5 @@ #include "include/Analysis/SelectVariableNames/SelectVariableNames.h" -#include - #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index 20d320b65..45e0a1129 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -23,6 +23,8 @@ #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#define DEBUG_NAME "rotate-and-reduce" + namespace mlir { namespace heir { namespace tensor_ext { @@ -37,95 +39,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase { using RotateAndReduceBase::RotateAndReduceBase; template - void tryReplaceRotations(ArithOp op, Value tensor, - DenseSet &visited, - DataFlowSolver &solver) { - // The dataflow analysis provides some guarantees, but not enough - // to prove that we can replace the op with the rotate-and-reduce trick - // while still maintaining program correctness. - // - // We need to do some more complicated checks to ensure that: the op tree - // all contains the same op type (all sum or all mul), and that the - // accessed rotations are included only once in the reduction. - // This cannot be done during the dataflow analysis itself due to the - // monotonicity requirements of the framework. + void tryReplaceRotations( + ArithOp op, const rotation_analysis::PartialReduction &reduction) { LLVM_DEBUG(llvm::dbgs() << "Trying to replace rotations ending in " << *op << "\n"); - SetVector backwardSlice; - BackwardSliceOptions options; - // asserts that the parent op has a single region with a single block. - options.omitBlockArguments = false; - - DenseSet visitedReductionOps; - DenseMap opCounts; - opCounts[op->getName().getStringRef()]++; - - getBackwardSlice(op.getOperation(), &backwardSlice, options); - - for (Operation *upstreamOpPtr : backwardSlice) { - auto result = - llvm::TypeSwitch(upstreamOpPtr) - .Case( - [&](auto upstreamOp) { return success(); }) - // Ignore generic ops - .template Case( - [&](auto upstreamOp) { return success(); }) - .template Case([&](auto - upstreamOp) { - opCounts[upstreamOp->getName().getStringRef()]++; - // More than one reduction op is mixed in the reduction. - if (opCounts.size() > 1) { - LLVM_DEBUG(llvm::dbgs() - << "Not replacing op because reduction " - "contains multiple incompatible ops " - << op->getName() << " and " - << upstreamOp->getName() << "\n"); - return failure(); - } - - // Inspect the lattice values at the join point, - // and fail if there is any overlap - auto *lhsLattice = - solver.lookupState( - upstreamOp.getLhs()); - auto *rhsLattice = - solver.lookupState( - upstreamOp.getRhs()); - LLVM_DEBUG(llvm::dbgs() - << "Computing overlap of " - << "lhs: " << lhsLattice->getValue() << "\n" - << "rhs: " << rhsLattice->getValue() << "\n"); - auto mergedLattice = rotation_analysis::RotationSets::overlap( - lhsLattice->getValue(), rhsLattice->getValue()); - LLVM_DEBUG(llvm::dbgs() - << "Overlap is: " << mergedLattice << "\n"); - if (!mergedLattice.empty()) { - LLVM_DEBUG( - llvm::dbgs() - << "Not replacing op because reduction " - "may not be a simple reduction of the input tensor\n" - << "lhs: " << lhsLattice->getValue() << "\n" - << "rhs: " << rhsLattice->getValue() << "\n"); - return failure(); - } - - visitedReductionOps.insert(upstreamOp); - return success(); - }) - .Default([&](Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch " - "encountered unsupported op " - << op->getName() << "\n"); - return failure(); - }); - - if (failed(result)) { - return; - } - } - - // From here we know we will succeed. auto b = ImplicitLocOpBuilder(op->getLoc(), op); + auto tensor = reduction.getTensor(); Operation *finalOp; auto tensorShape = tensor.getType().cast().getShape(); for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0; @@ -140,12 +59,6 @@ struct RotateAndReduce : impl::RotateAndReduceBase { [[maybe_unused]] auto *parentOp = op->getParentOp(); op->replaceAllUsesWith(finalOp); LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n"); - - // Mark all ops in the reduction as visited so we don't try to replace them - // twice. - for (Operation *visitedOp : visitedReductionOps) { - visited.insert(visitedOp); - } } template @@ -263,6 +176,11 @@ struct RotateAndReduce : impl::RotateAndReduceBase { // The test for a match is now: does the number of accessed indices exactly // match the size of the tensor? I.e., does each tensor element show up // exactly once in the reduction? + if (inputTensors.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "Not replacing op because it accesses no tensors\n"); + return; + } auto tensorShape = inputTensors.begin()->getType().cast().getShape(); if (tensorShape.size() != 1 || tensorShape[0] != accessIndices.size()) { @@ -308,7 +226,6 @@ struct RotateAndReduce : impl::RotateAndReduceBase { // https://github.com/llvm/llvm-project/issues/58922 solver.load(); solver.load(); - solver.load(); if (failed(solver.initializeAndRun(getOperation()))) { getOperation()->emitOpError() << "Failed to run dataflow analysis.\n"; @@ -316,60 +233,29 @@ struct RotateAndReduce : impl::RotateAndReduceBase { return; } - LLVM_DEBUG({ - getOperation()->walk([&](Operation *op) { - if (op->getNumResults() == 0) return; - auto *targetSlotLattice = - solver.lookupState( - op->getResult(0)); - if (targetSlotLattice->getValue().isOverdetermined()) { - llvm::dbgs() << "Rotation lattice for " << *op - << " is overdetermined\n"; - } else if (targetSlotLattice->getValue().empty()) { - llvm::dbgs() << "Rotation lattice for " << *op << " is empty\n"; - } else { - SmallVector sortedRotations( - targetSlotLattice->getValue().getAccessedIndices().begin(), - targetSlotLattice->getValue().getAccessedIndices().end()); - llvm::sort(sortedRotations); - std::string stringified = llvm::join( - llvm::map_range(sortedRotations, - [](int64_t i) { return std::to_string(i); }), - ","); - llvm::dbgs() << "Rotation lattice for " << *op << ": " << stringified - << "\n"; - } - }); - }); - + rotation_analysis::RotationAnalysis rotationAnalysis(solver); + rotationAnalysis.run(getOperation()); DenseSet visited; getOperation()->walk( [&](Operation *op) { - if (op->getNumResults() == 0) return; - auto *targetSlotLattice = - solver.lookupState( - op->getResult(0)); - if (targetSlotLattice->getValue().isUninitialized() || - targetSlotLattice->getValue().isOverdetermined()) { - return; - } - - auto tensor = targetSlotLattice->getValue().getTensor(); - auto accessIndices = - targetSlotLattice->getValue().getAccessedIndices(); - int64_t tensorSize = - tensor.getType().cast().getShape()[0]; - if (accessIndices.size() == tensorSize) { - llvm::TypeSwitch(*op) - .Case([&](auto arithOp) { - tryReplaceRotations(arithOp, tensor, visited, - solver); - }) - .Case([&](auto arithOp) { - tryReplaceRotations(arithOp, tensor, visited, - solver); - }); + for (Value result : op->getResults()) { + if (!result.getType().isa()) { + continue; + } + + for (const auto &reduction : + rotationAnalysis.getRootedReductionsAt(result)) { + if (reduction.isComplete()) { + llvm::TypeSwitch(*op) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, reduction); + }) + .Case([&](auto arithOp) { + tryReplaceRotations(arithOp, reduction); + }); + } + } } }); diff --git a/tests/heir_simd_vectorizer/dot_product_8.mlir b/tests/heir_simd_vectorizer/dot_product_8.mlir new file mode 100644 index 000000000..ebf39dd19 --- /dev/null +++ b/tests/heir_simd_vectorizer/dot_product_8.mlir @@ -0,0 +1,18 @@ +// RUN: heir-opt --secretize=entry-function=dot_product --wrap-generic --canonicalize --cse \ +// RUN: --heir-simd-vectorizer %s | FileCheck %s + +// CHECK-LABEL: func @dot_product +// CHECK-COUNT-3: tensor_ext.rotate +// CHECK-NOT: tensor_ext.rotate +func.func @dot_product(%arg0: tensor<8xi16>, %arg1: tensor<8xi16>) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) { + %1 = tensor.extract %arg0[%arg2] : tensor<8xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<8xi16> + %3 = arith.muli %1, %2 : i16 + %4 = arith.addi %iter, %3 : i16 + affine.yield %4 : i16 + } + return %0 : i16 +} diff --git a/tests/heir_simd_vectorizer/hamming_distance.mlir b/tests/heir_simd_vectorizer/hamming_distance.mlir index 84151900f..243004fbe 100644 --- a/tests/heir_simd_vectorizer/hamming_distance.mlir +++ b/tests/heir_simd_vectorizer/hamming_distance.mlir @@ -9,15 +9,9 @@ // CHECK-NEXT: arith.addi // CHECK-NEXT: tensor_ext.rotate // CHECK-NEXT: arith.addi -// CHECK-NEXT: tensor_ext.rotate -// CHECK-NEXT: arith.addi // CHECK-NEXT: tensor.extract // CHECK-NEXT: secret.yield -// TODO(#521): Fix rotate-and-reduce to work on this IR. -// The problem is that the lattice identifies the rotate-version of this IR as -// being overdetermined. - func.func @hamming(%arg0: tensor<4xi16>, %arg1: tensor<4xi16>) -> i16 { %c0 = arith.constant 0 : index %c0_si16 = arith.constant 0 : i16 diff --git a/tests/heir_simd_vectorizer/linear_polynomial_64.mlir b/tests/heir_simd_vectorizer/linear_polynomial_64.mlir new file mode 100644 index 000000000..8288adfd0 --- /dev/null +++ b/tests/heir_simd_vectorizer/linear_polynomial_64.mlir @@ -0,0 +1,22 @@ +// Ported from: https://github.com/MarbleHE/HECO/blob/3e13744233ab0c09030a41ef98b4e061b6fa2eac/evaluation/comparison/heco_input/linearpolynomial_64.mlir + +// RUN: heir-opt --secretize=entry-function=linear_polynomial --wrap-generic --canonicalize --cse \ +// RUN: --heir-simd-vectorizer %s | FileCheck %s + +// CHECK-LABEL: @linear_polynomial +// CHECK: secret.generic +// CHECK-NOT: tensor_ext.rotate +func.func @linear_polynomial(%a: tensor<64xi16>, %b: tensor<64xi16>, %x: tensor<64xi16>, %y: tensor<64xi16>) -> tensor<64xi16> { + %0 = affine.for %i = 0 to 64 iter_args(%iter = %y) -> (tensor<64xi16>) { + %ai = tensor.extract %a[%i] : tensor<64xi16> + %bi = tensor.extract %b[%i] : tensor<64xi16> + %xi = tensor.extract %x[%i] : tensor<64xi16> + %yi = tensor.extract %y[%i] : tensor<64xi16> + %axi = arith.muli %ai, %xi : i16 + %t1 = arith.subi %yi, %axi : i16 + %t2 = arith.subi %t1, %bi : i16 + %out = tensor.insert %t2 into %iter[%i] : tensor<64xi16> + affine.yield %out : tensor<64xi16> + } + return %0 : tensor<64xi16> +} diff --git a/tests/heir_simd_vectorizer/quadratic_polynomial.mlir b/tests/heir_simd_vectorizer/quadratic_polynomial.mlir new file mode 100644 index 000000000..c53f6f8ab --- /dev/null +++ b/tests/heir_simd_vectorizer/quadratic_polynomial.mlir @@ -0,0 +1,25 @@ +// Ported from: https://github.com/MarbleHE/HECO/blob/3e13744233ab0c09030a41ef98b4e061b6fa2eac/evaluation/comparison/heco_input/quadraticpolynomial_64.mlir + +// RUN: heir-opt --secretize=entry-function=quadratic_polynomial --wrap-generic --canonicalize --cse \ +// RUN: --heir-simd-vectorizer %s | FileCheck %s + +// CHECK-LABEL: @quadratic_polynomial +// CHECK: secret.generic +// CHECK-NOT: tensor_ext.rotate +func.func @quadratic_polynomial(%a: tensor<64xi16>, %b: tensor<64xi16>, %c: tensor<64xi16>, %x: tensor<64xi16>, %y: tensor<64xi16>) -> tensor<64xi16> { + %0 = affine.for %i = 0 to 64 iter_args(%iter = %y) -> (tensor<64xi16>) { + %ai = tensor.extract %a[%i] : tensor<64xi16> + %bi = tensor.extract %b[%i] : tensor<64xi16> + %ci = tensor.extract %c[%i] : tensor<64xi16> + %xi = tensor.extract %x[%i] : tensor<64xi16> + %yi = tensor.extract %y[%i] : tensor<64xi16> + %axi = arith.muli %ai, %xi : i16 + %t1 = arith.addi %axi, %bi : i16 + %t2 = arith.muli %xi, %t1 : i16 + %t3 = arith.addi %t2, %ci : i16 + %t4 = arith.subi %yi, %t3 : i16 + %out = tensor.insert %t4 into %iter[%i] : tensor<64xi16> + affine.yield %out : tensor<64xi16> + } + return %0 : tensor<64xi16> +} diff --git a/tests/secret_to_bgv/hamming_distance_1024.mlir b/tests/secret_to_bgv/hamming_distance_1024.mlir index 37a611b15..392a1053b 100644 --- a/tests/secret_to_bgv/hamming_distance_1024.mlir +++ b/tests/secret_to_bgv/hamming_distance_1024.mlir @@ -7,9 +7,7 @@ // CHECK: bgv.sub // CHECK-NEXT: bgv.mul // CHECK-NEXT: bgv.relinearize - -// TODO(#521): After rotate-and-reduce works, only check for 10 bg.rotate -// CHECK-COUNT-1023: bgv.rotate +// CHECK-COUNT-10: bgv.rotate // CHECK: bgv.extract // CHECK-NEXT: return diff --git a/tests/tensor_ext/rotate_and_reduce.mlir b/tests/tensor_ext/rotate_and_reduce.mlir index fa1c20596..6ace9307d 100644 --- a/tests/tensor_ext/rotate_and_reduce.mlir +++ b/tests/tensor_ext/rotate_and_reduce.mlir @@ -563,3 +563,27 @@ func.func @reduce_add_and_mul(%arg1: tensor<32xi16>) -> i16 { %out = arith.addi %extracted, %extracted_2 : i16 return %out : i16 } + + +// This test caused rotate-and-reduce to crash, so is here as a regression test +// without any particular assertion required. +// CHECK-LABEL: @test_dot_product_regression +func.func @test_dot_product_regression(%arg0: !secret.secret>, %arg1: !secret.secret>) -> !secret.secret { + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = secret.generic ins(%arg0, %arg1 : !secret.secret>, !secret.secret>) { + ^bb0(%arg2: tensor<8xi16>, %arg3: tensor<8xi16>): + %1 = arith.muli %arg2, %arg3 : tensor<8xi16> + %2 = tensor_ext.rotate %1, %c1 : tensor<8xi16>, index + %3 = arith.addi %2, %1 : tensor<8xi16> + %4 = tensor_ext.rotate %1, %c2 : tensor<8xi16>, index + %5 = arith.addi %4, %3 : tensor<8xi16> + %6 = tensor_ext.rotate %1, %c3 : tensor<8xi16>, index + %7 = arith.addi %6, %5 : tensor<8xi16> + %extracted = tensor.extract %7[%c0] : tensor<8xi16> + secret.yield %extracted : i16 + } -> !secret.secret + return %0 : !secret.secret +}