Skip to content

Commit

Permalink
Merge pull request #575 from j2kun:heco-examples-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620849711
  • Loading branch information
Copybara-Service committed Apr 1, 2024
2 parents fe238e9 + 3c1bd1a commit 81e45d8
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 348 deletions.
328 changes: 189 additions & 139 deletions include/Analysis/RotationAnalysis/RotationAnalysis.h

Large diffs are not rendered by default.

150 changes: 95 additions & 55 deletions lib/Analysis/RotationAnalysis/RotationAnalysis.cpp
Original file line number Diff line number Diff line change
@@ -1,73 +1,113 @@
#include "include/Analysis/RotationAnalysis/RotationAnalysis.h"

#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace rotation_analysis {

void RotationAnalysis::visitOperation(
Operation *op, ArrayRef<const RotationLattice *> operands,
ArrayRef<RotationLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor_ext::RotateOp>([&](auto rotateOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto shiftConstantOp =
rotateOp.getShift().template getDefiningOp<arith::ConstantOp>();
// If the rotation shift can't be statically determined, we can't
// propagate anything through the IR.
if (!shiftConstantOp) return;
void RotationAnalysis::run(Operation *op) {
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
// If the op has no tensor results and no regions, then there's nothing to
// do. The operation may consume a tensor but cannot further reduce it.
if (op->getNumRegions() == 0 &&
llvm::none_of(op->getResultTypes(),
[](Type type) { return type.isa<RankedTensorType>(); })) {
return WalkResult::advance();
}

int64_t shiftValue =
dyn_cast<IntegerAttr>(shiftConstantOp.getValue()).getInt();
// Each tensor result can be the start of a new reduction.
for (Value result : op->getResults()) {
initializeFromValueIfTensor(result);
}

// The target slot propagates from the tensor argument to the result;
// the tensor argument is first in the tablegen definition.
const RotationLattice *lattice = operands[0];
RotationSets latticeRotations = lattice->getValue();

// If it's a block argument, then there is no initialized lattice value
// and we can override it with a "zero rotation"
auto blockArg = dyn_cast<BlockArgument>(rotateOp.getTensor());
if (blockArg) {
latticeRotations = RotationSets::from(blockArg);
// Block args within regions can be the start of a new reduction.
for (Region &region : op->getRegions()) {
for (Block &block : region) {
for (Value arg : block.getArguments()) {
initializeFromValueIfTensor(arg);
}
RotationSets rotated =
RotationSets::rotate(latticeRotations, shiftValue);
}
}

for (RotationLattice *r : results) {
ChangeResult result = r->join(rotated);
propagateIfChanged(r, result);
}
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (OpOperand &operand : op.getOpOperands()) {
auto *latticeOperand = operands[operand.getOperandNumber()];
// Each op now gets special treatment.
//
// - Rotate ops shift the accessIndices of their tensor operand's
// reductions if the shift is known to be constant.
// - Binary ops join partial reductions of operands and set the opName.
// - Everything else is ignored.
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor_ext::RotateOp>([&](auto rotateOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
const dataflow::Lattice<dataflow::ConstantValue> *shiftLattice =
solver.lookupState<dataflow::Lattice<dataflow::ConstantValue>>(
rotateOp.getShift());

for (RotationLattice *r : results) {
ChangeResult result = r->join(*latticeOperand);
// If the operand is a block arg, this additionally treats this as
// a zero rotation. If the underlying tensor differs across
// operands, this will also cause a Status::TooManyTensors.
// Otherwise, the join is a no-op.
result |= r->join(RotationSets::from(operand.get()));
propagateIfChanged(r, result);
if (shiftLattice) {
LLVM_DEBUG(llvm::dbgs() << "At " << rotateOp
<< " SCCP analysis gives lattice of "
<< *shiftLattice << "\n");
}
}
});
}

void RotationAnalysis::setToEntryState(RotationLattice *lattice) {
lattice->getValue().clear();
// If the rotation shift can't be statically determined, we can't
// propagate anything through the IR.
if (!shiftLattice || shiftLattice->getValue().isUninitialized() ||
!shiftLattice->getValue().getConstantValue()) {
LLVM_DEBUG(
llvm::dbgs()
<< "At " << rotateOp
<< " can't statically determine constant insertion index\n");
return;
}
auto shiftValue = shiftLattice->getValue()
.getConstantValue()
.dyn_cast<IntegerAttr>()
.getInt();

// For each partial reduction the tensor operand is a root of,
// rotate the accessed indices appropriately.
Value tensor = rotateOp.getTensor();
Value result = rotateOp.getResult();
for (const auto &reduction : rootToPartialReductions[tensor]) {
addPartialReduction(
PartialReduction::rotate(reduction, shiftValue, result));
}
})
.Case<arith::AddIOp, arith::MulIOp>([&](auto arithOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << arithOp << "\n"; });
Value lhs = arithOp.getLhs();
Value rhs = arithOp.getRhs();
Value newRoot = arithOp.getResult();
OperationName opName = arithOp.getOperation()->getName();

// This is inefficient, but what can we do better here? I suspect a
// better approach may be to identify cases in which only one of these
// reductions needs to be kept because it's "the best" according to
// some metric (e.g., it monotonically increases the number of indices
// and all else stays the same). But for now even on the
// box_blur_64x64 example this is far from the bottleneck.
for (const auto &lhsReduction : rootToPartialReductions[lhs]) {
for (const auto &rhsReduction : rootToPartialReductions[rhs]) {
if (PartialReduction::canJoin(lhsReduction, rhsReduction,
opName)) {
addPartialReduction(PartialReduction::join(
lhsReduction, rhsReduction, newRoot, opName));
}
}
}
});

return WalkResult::advance();
});
}

} // namespace rotation_analysis
Expand Down
2 changes: 0 additions & 2 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#include "include/Analysis/SelectVariableNames/SelectVariableNames.h"

#include <string>

#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
Expand Down
172 changes: 29 additions & 143 deletions lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

#define DEBUG_NAME "rotate-and-reduce"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -37,95 +39,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
using RotateAndReduceBase::RotateAndReduceBase;

template <typename ArithOp>
void tryReplaceRotations(ArithOp op, Value tensor,
DenseSet<Operation *> &visited,
DataFlowSolver &solver) {
// The dataflow analysis provides some guarantees, but not enough
// to prove that we can replace the op with the rotate-and-reduce trick
// while still maintaining program correctness.
//
// We need to do some more complicated checks to ensure that: the op tree
// all contains the same op type (all sum or all mul), and that the
// accessed rotations are included only once in the reduction.
// This cannot be done during the dataflow analysis itself due to the
// monotonicity requirements of the framework.
void tryReplaceRotations(
ArithOp op, const rotation_analysis::PartialReduction &reduction) {
LLVM_DEBUG(llvm::dbgs()
<< "Trying to replace rotations ending in " << *op << "\n");
SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
// asserts that the parent op has a single region with a single block.
options.omitBlockArguments = false;

DenseSet<Operation *> visitedReductionOps;
DenseMap<llvm::StringRef, int> opCounts;
opCounts[op->getName().getStringRef()]++;

getBackwardSlice(op.getOperation(), &backwardSlice, options);

for (Operation *upstreamOpPtr : backwardSlice) {
auto result =
llvm::TypeSwitch<Operation *, LogicalResult>(upstreamOpPtr)
.Case<arith::ConstantOp, tensor_ext::RotateOp>(
[&](auto upstreamOp) { return success(); })
// Ignore generic ops
.template Case<secret::GenericOp>(
[&](auto upstreamOp) { return success(); })
.template Case<arith::AddIOp, arith::MulIOp>([&](auto
upstreamOp) {
opCounts[upstreamOp->getName().getStringRef()]++;
// More than one reduction op is mixed in the reduction.
if (opCounts.size() > 1) {
LLVM_DEBUG(llvm::dbgs()
<< "Not replacing op because reduction "
"contains multiple incompatible ops "
<< op->getName() << " and "
<< upstreamOp->getName() << "\n");
return failure();
}

// Inspect the lattice values at the join point,
// and fail if there is any overlap
auto *lhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getLhs());
auto *rhsLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
upstreamOp.getRhs());
LLVM_DEBUG(llvm::dbgs()
<< "Computing overlap of "
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
auto mergedLattice = rotation_analysis::RotationSets::overlap(
lhsLattice->getValue(), rhsLattice->getValue());
LLVM_DEBUG(llvm::dbgs()
<< "Overlap is: " << mergedLattice << "\n");
if (!mergedLattice.empty()) {
LLVM_DEBUG(
llvm::dbgs()
<< "Not replacing op because reduction "
"may not be a simple reduction of the input tensor\n"
<< "lhs: " << lhsLattice->getValue() << "\n"
<< "rhs: " << rhsLattice->getValue() << "\n");
return failure();
}

visitedReductionOps.insert(upstreamOp);
return success();
})
.Default([&](Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch "
"encountered unsupported op "
<< op->getName() << "\n");
return failure();
});

if (failed(result)) {
return;
}
}

// From here we know we will succeed.
auto b = ImplicitLocOpBuilder(op->getLoc(), op);
auto tensor = reduction.getTensor();
Operation *finalOp;
auto tensorShape = tensor.getType().cast<RankedTensorType>().getShape();
for (int64_t shiftSize = tensorShape[0] / 2; shiftSize > 0;
Expand All @@ -140,12 +59,6 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
[[maybe_unused]] auto *parentOp = op->getParentOp();
op->replaceAllUsesWith(finalOp);
LLVM_DEBUG(llvm::dbgs() << "Post-replacement: " << *parentOp << "\n");

// Mark all ops in the reduction as visited so we don't try to replace them
// twice.
for (Operation *visitedOp : visitedReductionOps) {
visited.insert(visitedOp);
}
}

template <typename ArithOp>
Expand Down Expand Up @@ -263,6 +176,11 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
// The test for a match is now: does the number of accessed indices exactly
// match the size of the tensor? I.e., does each tensor element show up
// exactly once in the reduction?
if (inputTensors.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "Not replacing op because it accesses no tensors\n");
return;
}
auto tensorShape =
inputTensors.begin()->getType().cast<RankedTensorType>().getShape();
if (tensorShape.size() != 1 || tensorShape[0] != accessIndices.size()) {
Expand Down Expand Up @@ -308,68 +226,36 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<rotation_analysis::RotationAnalysis>();

if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run dataflow analysis.\n";
signalPassFailure();
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isOverdetermined()) {
llvm::dbgs() << "Rotation lattice for " << *op
<< " is overdetermined\n";
} else if (targetSlotLattice->getValue().empty()) {
llvm::dbgs() << "Rotation lattice for " << *op << " is empty\n";
} else {
SmallVector<int64_t> sortedRotations(
targetSlotLattice->getValue().getAccessedIndices().begin(),
targetSlotLattice->getValue().getAccessedIndices().end());
llvm::sort(sortedRotations);
std::string stringified = llvm::join(
llvm::map_range(sortedRotations,
[](int64_t i) { return std::to_string(i); }),
",");
llvm::dbgs() << "Rotation lattice for " << *op << ": " << stringified
<< "\n";
}
});
});

rotation_analysis::RotationAnalysis rotationAnalysis(solver);
rotationAnalysis.run(getOperation());
DenseSet<Operation *> visited;

getOperation()->walk<WalkOrder::PreOrder, ReverseIterator>(
[&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<rotation_analysis::RotationLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isUninitialized() ||
targetSlotLattice->getValue().isOverdetermined()) {
return;
}

auto tensor = targetSlotLattice->getValue().getTensor();
auto accessIndices =
targetSlotLattice->getValue().getAccessedIndices();
int64_t tensorSize =
tensor.getType().cast<RankedTensorType>().getShape()[0];
if (accessIndices.size() == tensorSize) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, tensor, visited,
solver);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, tensor, visited,
solver);
});
for (Value result : op->getResults()) {
if (!result.getType().isa<RankedTensorType>()) {
continue;
}

for (const auto &reduction :
rotationAnalysis.getRootedReductionsAt(result)) {
if (reduction.isComplete()) {
llvm::TypeSwitch<Operation &>(*op)
.Case<arith::AddIOp>([&](auto arithOp) {
tryReplaceRotations<arith::AddIOp>(arithOp, reduction);
})
.Case<arith::MulIOp>([&](auto arithOp) {
tryReplaceRotations<arith::MulIOp>(arithOp, reduction);
});
}
}
}
});

Expand Down
Loading

0 comments on commit 81e45d8

Please sign in to comment.