Skip to content

Commit

Permalink
Rewrite rotation analysis for new ports of HECO eval artifacts
Browse files Browse the repository at this point in the history
This rewrite still has some efficiency issues, but it removes the hacky
issues from the previous lattice-based dataflow analysis, and supports
the unsupported examples from this PR.

Ports

 - dot_product
 - linear_polynomial
 - quadratic_polynomial
  • Loading branch information
j2kun committed Mar 31, 2024
1 parent fe238e9 commit 3c1bd1a
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 348 deletions.
328 changes: 189 additions & 139 deletions include/Analysis/RotationAnalysis/RotationAnalysis.h

Large diffs are not rendered by default.

150 changes: 95 additions & 55 deletions lib/Analysis/RotationAnalysis/RotationAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,73 +1,113 @@
#include "include/Analysis/RotationAnalysis/RotationAnalysis.h"

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace rotation_analysis {

void RotationAnalysis::visitOperation(
Operation *op, ArrayRef<const RotationLattice *> operands,
ArrayRef<RotationLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor_ext::RotateOp>([&](auto rotateOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto shiftConstantOp =
rotateOp.getShift().template getDefiningOp<arith::ConstantOp>();
// If the rotation shift can't be statically determined, we can't
// propagate anything through the IR.
if (!shiftConstantOp) return;
void RotationAnalysis::run(Operation *op) {
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
// If the op has no tensor results and no regions, then there's nothing to
// do. The operation may consume a tensor but cannot further reduce it.
if (op->getNumRegions() == 0 &&
llvm::none_of(op->getResultTypes(),
[](Type type) { return type.isa<RankedTensorType>(); })) {
return WalkResult::advance();
}

int64_t shiftValue =
dyn_cast<IntegerAttr>(shiftConstantOp.getValue()).getInt();
// Each tensor result can be the start of a new reduction.
for (Value result : op->getResults()) {
initializeFromValueIfTensor(result);
}

// The target slot propagates from the tensor argument to the result;
// the tensor argument is first in the tablegen definition.
const RotationLattice *lattice = operands[0];
RotationSets latticeRotations = lattice->getValue();

// If it's a block argument, then there is no initialized lattice value
// and we can override it with a "zero rotation"
auto blockArg = dyn_cast<BlockArgument>(rotateOp.getTensor());
if (blockArg) {
latticeRotations = RotationSets::from(blockArg);
// Block args within regions can be the start of a new reduction.
for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Value arg : block.getArguments()) {
initializeFromValueIfTensor(arg);
}
RotationSets rotated =
RotationSets::rotate(latticeRotations, shiftValue);
}
}

for (RotationLattice *r : results) {
ChangeResult result = r->join(rotated);
propagateIfChanged(r, result);
}
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (OpOperand &operand : op.getOpOperands()) {
auto *latticeOperand = operands[operand.getOperandNumber()];
// Each op now gets special treatment.
//
// - Rotate ops shift the accessIndices of their tensor operand's
// reductions if the shift is known to be constant.
// - Binary ops join partial reductions of operands and set the opName.
// - Everything else is ignored.
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor_ext::RotateOp>([&](auto rotateOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
const dataflow::Lattice<dataflow::ConstantValue> *shiftLattice =
solver.lookupState<dataflow::Lattice<dataflow::ConstantValue>>(
rotateOp.getShift());

for (RotationLattice *r : results) {
ChangeResult result = r->join(*latticeOperand);
// If the operand is a block arg, this additionally treats this as
// a zero rotation. If the underlying tensor differs across
// operands, this will also cause a Status::TooManyTensors.
// Otherwise, the join is a no-op.
result |= r->join(RotationSets::from(operand.get()));
propagateIfChanged(r, result);
if (shiftLattice) {
LLVM_DEBUG(llvm::dbgs() << "At " << rotateOp
<< " SCCP analysis gives lattice of "
<< *shiftLattice << "\n");
}
}
});
}

void RotationAnalysis::setToEntryState(RotationLattice *lattice) {
lattice->getValue().clear();
// If the rotation shift can't be statically determined, we can't
// propagate anything through the IR.
if (!shiftLattice || shiftLattice->getValue().isUninitialized() ||
!shiftLattice->getValue().getConstantValue()) {
LLVM_DEBUG(
llvm::dbgs()
<< "At " << rotateOp
<< " can't statically determine constant insertion index\n");
return;
}
auto shiftValue = shiftLattice->getValue()
.getConstantValue()
.dyn_cast<IntegerAttr>()
.getInt();

// For each partial reduction the tensor operand is a root of,
// rotate the accessed indices appropriately.
Value tensor = rotateOp.getTensor();
Value result = rotateOp.getResult();
for (const auto &reduction : rootToPartialReductions[tensor]) {
addPartialReduction(
PartialReduction::rotate(reduction, shiftValue, result));
}
})
.Case<arith::AddIOp, arith::MulIOp>([&](auto arithOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << arithOp << "\n"; });
Value lhs = arithOp.getLhs();
Value rhs = arithOp.getRhs();
Value newRoot = arithOp.getResult();
OperationName opName = arithOp.getOperation()->getName();

// This is inefficient, but what can we do better here? I suspect a
// better approach may be to identify cases in which only one of these
// reductions needs to be kept because it's "the best" according to
// some metric (e.g., it monotonically increases the number of indices
// and all else stays the same). But for now even on the
// box_blur_64x64 example this is far from the bottleneck.
for (const auto &lhsReduction : rootToPartialReductions[lhs]) {
for (const auto &rhsReduction : rootToPartialReductions[rhs]) {
if (PartialReduction::canJoin(lhsReduction, rhsReduction,
opName)) {
addPartialReduction(PartialReduction::join(
lhsReduction, rhsReduction, newRoot, opName));
}
}
}
});

return WalkResult::advance();
});
}

} // namespace rotation_analysis
Expand Down
2 changes: 0 additions & 2 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "include/Analysis/SelectVariableNames/SelectVariableNames.h"

#include <string>

#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
Expand Down
172 changes: 29 additions & 143 deletions lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

#define DEBUG_NAME "rotate-and-reduce"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -37,95 +39,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
using RotateAndReduceBase::RotateAndReduceBase;

template <typename ArithOp>
void tryReplaceRotations(ArithOp op, Value tensor,
DenseSet<Operation *> &visited,
DataFlowSolver &solver) {
// The dataflow analysis provides some guarantees, but not enough
// to prove that we can replace the op with the rotate-and-reduce trick
// while still maintaining program correctness.
//
// We need to do some more complicated checks to ensure that: the op tree
// all contains the same op type (all sum or all mul), and that the
// accessed rotations are included only once in the reduction.
// This cannot be done during the dataflow analysis itself due to the
// monotonicity requirements of the framework.
void tryReplaceRotations(
ArithOp op, const rotation_analysis::PartialReduction &reduction) {
LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace rotations ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
options.omitBlockArguments = false;

DenseSet<Operation *> visitedReductionOps;
DenseMap<llvm::StringRef, int> opCounts;
opCounts[op->getName().getStringRef()]++;

getBackwardSlice(op.getOperation(), &backwardSlice, options);

for (Operation *upstreamOpPtr : backwardSlice) {
auto result =
llvm::TypeSwitch<Operation *, LogicalResult>(upstreamOpPtr)
.Case<arith::ConstantOp, tensor_ext::RotateOp>(
[&](auto upstreamOp) { return success(); })
// Ignore generic ops
.template Case<secret::GenericOp>(
[&](auto upstreamOp) { return success(); })
.template Case<arith::AddIOp, arith::MulIOp>([&](auto
upstreamOp) {
opCounts[upstreamOp->getName().getStringRef()]++;
// More than one reduction op is mixed in the reduction.
if (opCounts.size() > 1) {
LLVM_DEBUG(llvm::dbgs()
<< "Not replacing op because reduction "
"contains multiple incompatible ops "
<< op->getName() << " and "
<< upstreamOp->getName() << "\n");
return failure();
}

// Inspect the lattice values at the join point,
// and fail if there is any overlap
auto *lhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getLhs());
auto *rhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getRhs());
LLVM_DEBUG(llvm::dbgs()
<< "Computing overlap of "
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
auto mergedLattice = rotation_analysis::RotationSets::overlap(
lhsLattice->getValue(), rhsLattice->getValue());
LLVM_DEBUG(llvm::dbgs()
<< "Overlap is: " << mergedLattice << "\n");
if (!mergedLattice.empty()) {
LLVM_DEBUG(
llvm::dbgs()
<< "Not replacing op because reduction "
"may not be a simple reduction of the input tensor\n"
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
return failure();
}

visitedReductionOps.insert(upstreamOp);
return success();
})
.Default([&](Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch "
"encountered unsupported op "
<< op->getName() << "\n");
return failure();
});

if (failed(result)) {
return;
}
}

// From here we know we will succeed.
auto b = ImplicitLocOpBuilder(op->getLoc(), op);
auto tensor = reduction.getTensor();
Operation *finalOp;
auto tensorShape = tensor.getType().cast<RankedTensorType>().getShape();
for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0;
Expand All @@ -140,12 +59,6 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
[[maybe_unused]] auto *parentOp = op->getParentOp();
op->replaceAllUsesWith(finalOp);
LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n");

// Mark all ops in the reduction as visited so we don't try to replace them
// twice.
for (Operation *visitedOp : visitedReductionOps) {
visited.insert(visitedOp);
}
}

template <typename ArithOp>
Expand Down Expand Up @@ -263,6 +176,11 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
// The test for a match is now: does the number of accessed indices exactly
// match the size of the tensor? I.e., does each tensor element show up
// exactly once in the reduction?
if (inputTensors.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "Not replacing op because it accesses no tensors\n");
return;
}
auto tensorShape =
inputTensors.begin()->getType().cast<RankedTensorType>().getShape();
if (tensorShape.size() != 1 || tensorShape[0] != accessIndices.size()) {
Expand Down Expand Up @@ -308,68 +226,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<rotation_analysis::RotationAnalysis>();

if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run dataflow analysis.\n";
signalPassFailure();
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isOverdetermined()) {
llvm::dbgs() << "Rotation lattice for " << *op
<< " is overdetermined\n";
} else if (targetSlotLattice->getValue().empty()) {
llvm::dbgs() << "Rotation lattice for " << *op << " is empty\n";
} else {
SmallVector<int64_t> sortedRotations(
targetSlotLattice->getValue().getAccessedIndices().begin(),
targetSlotLattice->getValue().getAccessedIndices().end());
llvm::sort(sortedRotations);
std::string stringified = llvm::join(
llvm::map_range(sortedRotations,
[](int64_t i) { return std::to_string(i); }),
",");
llvm::dbgs() << "Rotation lattice for " << *op << ": " << stringified
<< "\n";
}
});
});

rotation_analysis::RotationAnalysis rotationAnalysis(solver);
rotationAnalysis.run(getOperation());
DenseSet<Operation *> visited;

getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
[&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isUninitialized() ||
targetSlotLattice->getValue().isOverdetermined()) {
return;
}

auto tensor = targetSlotLattice->getValue().getTensor();
auto accessIndices =
targetSlotLattice->getValue().getAccessedIndices();
int64_t tensorSize =
tensor.getType().cast<RankedTensorType>().getShape()[0];
if (accessIndices.size() == tensorSize) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, tensor, visited,
solver);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, tensor, visited,
solver);
});
for (Value result : op->getResults()) {
if (!result.getType().isa<RankedTensorType>()) {
continue;
}

for (const auto &reduction :
rotationAnalysis.getRootedReductionsAt(result)) {
if (reduction.isComplete()) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, reduction);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, reduction);
});
}
}
}
});

Expand Down
Loading

0 comments on commit 3c1bd1a

Please sign in to comment.