-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rotation replacement in rotate-and-reduce
- Loading branch information
Showing
8 changed files
with
808 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# RotationAnalysis analysis pass | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files( | ||
["RotationAnalysis.h"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
#ifndef INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ | ||
#define INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ | ||
|
||
#include <unordered_set> | ||
|
||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "rotation-analysis" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace rotation_analysis { | ||
|
||
// A wrapper around a mapping from a single tensor SSA value to a set of its | ||
// access indices. | ||
class RotationSets { | ||
public: | ||
enum class Status { | ||
// The tensor value has not been set | ||
Uninitialized, | ||
|
||
// The rotation set is in a normal state. | ||
Normal, | ||
|
||
// The rotation set has a property that makes it invalid for later | ||
// optimizations: | ||
// | ||
// - It involves operations touch more than one source tensor (not | ||
// including value-semantic outputs) | ||
Overdetermined | ||
|
||
}; | ||
|
||
public: | ||
RotationSets() = default; | ||
~RotationSets() = default; | ||
|
||
// Clear the member data, i.e., set the value back to an uninitialized | ||
// state. | ||
void clear() { | ||
accessedIndices.clear(); | ||
status = Status::Uninitialized; | ||
} | ||
|
||
bool empty() const { return accessedIndices.empty(); } | ||
|
||
bool isOverdetermined() const { return status == Status::Overdetermined; } | ||
|
||
bool isUninitialized() const { return status == Status::Uninitialized; } | ||
|
||
void addRotation(int64_t index) { accessedIndices.insert(index); } | ||
|
||
bool operator==(const RotationSets &rhs) const { | ||
return tensor == rhs.tensor && status == rhs.status && | ||
accessedIndices == rhs.accessedIndices; | ||
} | ||
|
||
const std::unordered_set<int64_t> &getAccessedIndices() const { | ||
return accessedIndices; | ||
} | ||
|
||
Value getTensor() const { return tensor; } | ||
|
||
void print(raw_ostream &os) const { | ||
os << tensor << ": ["; | ||
for (auto index : accessedIndices) { | ||
os << index << ", "; | ||
} | ||
os << "]"; | ||
} | ||
|
||
static RotationSets overdetermined() { | ||
RotationSets sets; | ||
sets.status = Status::Overdetermined; | ||
return sets; | ||
} | ||
|
||
static RotationSets from(Value tensor) { | ||
RotationSets sets; | ||
if (!tensor.getType().isa<RankedTensorType>()) { | ||
sets.status = Status::Uninitialized; | ||
return sets; | ||
} | ||
|
||
sets.status = Status::Normal; | ||
sets.tensor = tensor; | ||
if (auto blockArg = dyn_cast<BlockArgument>(tensor)) { | ||
sets.addRotation(0); | ||
} | ||
return sets; | ||
} | ||
|
||
// Shift the rotation indices by the given amount. This helps in a situation | ||
// where an IR repeatedly rotates by 1, to ensure that rotations accumulate | ||
// like {1, 2, 3, ...} rather than {1, 1, 1, ...} | ||
static RotationSets rotate(const RotationSets &lhs, const int64_t shift) { | ||
if (lhs.status == Status::Overdetermined) { | ||
return overdetermined(); | ||
} | ||
|
||
RotationSets shifted; | ||
shifted.status = Status::Normal; | ||
shifted.tensor = lhs.tensor; | ||
int64_t size = | ||
llvm::cast<RankedTensorType>(lhs.tensor.getType()).getShape()[0]; | ||
for (auto index : lhs.accessedIndices) { | ||
shifted.addRotation((index + shift) % size); | ||
} | ||
return shifted; | ||
} | ||
|
||
static RotationSets join(const RotationSets &lhs, const RotationSets &rhs) { | ||
if (lhs.status == Status::Overdetermined || | ||
rhs.status == Status::Overdetermined) { | ||
return overdetermined(); | ||
} | ||
|
||
if (rhs.status == Status::Uninitialized || rhs.accessedIndices.empty()) | ||
return lhs; | ||
if (lhs.status == Status::Uninitialized || lhs.accessedIndices.empty()) | ||
return rhs; | ||
|
||
if (lhs.tensor != rhs.tensor) { | ||
LLVM_DEBUG({ | ||
llvm::dbgs() << "Joining rotations of different tensors: " << lhs.tensor | ||
<< " and " << rhs.tensor << "\n"; | ||
}); | ||
return overdetermined(); | ||
} | ||
|
||
LLVM_DEBUG({ | ||
llvm::dbgs() << "Joining :" << lhs.tensor << " and " << rhs.tensor | ||
<< "\n"; | ||
}); | ||
RotationSets merged; | ||
merged.status = Status::Normal; | ||
merged.tensor = lhs.tensor; | ||
for (auto index : lhs.accessedIndices) { | ||
merged.addRotation(index); | ||
} | ||
for (auto index : rhs.accessedIndices) { | ||
merged.addRotation(index); | ||
} | ||
return merged; | ||
} | ||
|
||
// Assuming two not-overdetermined rotation sets, compute the overlap in | ||
// their access indices. | ||
static RotationSets overlap(const RotationSets &lhs, | ||
const RotationSets &rhs) { | ||
assert(!lhs.isOverdetermined() && !rhs.isOverdetermined() && | ||
"Expected inputs to RotationSets::overlap to be not overdetermined"); | ||
if (lhs.status == Status::Uninitialized || lhs.empty()) { | ||
return lhs; | ||
} | ||
|
||
if (rhs.status == Status::Uninitialized || rhs.empty()) { | ||
return rhs; | ||
} | ||
|
||
RotationSets merged; | ||
merged.status = Status::Normal; | ||
merged.tensor = lhs.tensor; | ||
for (auto index : lhs.accessedIndices) { | ||
if (rhs.accessedIndices.count(index)) merged.addRotation(index); | ||
} | ||
return merged; | ||
} | ||
|
||
private: | ||
/// The accessed indices of a single SSA value of tensor type. | ||
Value tensor; | ||
|
||
// There is likely a data structure that can more efficiently represent a set | ||
// of intervals of integers, which properly merges adjacent intervals as | ||
// values are added. Java/Guava has RangeSet, and boost has interval_set. | ||
std::unordered_set<int64_t> accessedIndices; | ||
Status status = Status::Uninitialized; | ||
}; | ||
|
||
inline raw_ostream &operator<<(raw_ostream &os, const RotationSets &v) { | ||
v.print(os); | ||
return os; | ||
} | ||
|
||
class RotationLattice : public dataflow::Lattice<RotationSets> { | ||
public: | ||
using Lattice::Lattice; | ||
}; | ||
|
||
/// An analysis that identifies, for each SSA value, the set of underlying | ||
/// tensors and rotations of those tensors, provided constant rotation shifts | ||
/// can be determined. | ||
class RotationAnalysis | ||
: public dataflow::SparseForwardDataFlowAnalysis<RotationLattice> { | ||
public: | ||
explicit RotationAnalysis(DataFlowSolver &solver) | ||
: SparseForwardDataFlowAnalysis(solver) {} | ||
~RotationAnalysis() override = default; | ||
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; | ||
|
||
// Given the computed results of the operation, update its operand lattice | ||
// values. | ||
void visitOperation(Operation *op, ArrayRef<const RotationLattice *> operands, | ||
ArrayRef<RotationLattice *> results) override; | ||
|
||
void setToEntryState(RotationLattice *lattice) override; | ||
}; | ||
|
||
} // namespace rotation_analysis | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_ANALYSIS_ROTATIONANALYSIS_ROTATIONANALYSIS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# RotationAnalysis analysis pass | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "RotationAnalysis", | ||
srcs = ["RotationAnalysis.cpp"], | ||
hdrs = ["@heir//include/Analysis/RotationAnalysis:RotationAnalysis.h"], | ||
deps = [ | ||
"@heir//lib/Dialect:Utils", | ||
"@heir//lib/Dialect/TensorExt/IR:Dialect", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Analysis", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:TensorDialect", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#include "include/Analysis/RotationAnalysis/RotationAnalysis.h" | ||
|
||
#include "include/Dialect/TensorExt/IR/TensorExtOps.h" | ||
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace rotation_analysis { | ||
|
||
void RotationAnalysis::visitOperation( | ||
Operation *op, ArrayRef<const RotationLattice *> operands, | ||
ArrayRef<RotationLattice *> results) { | ||
llvm::TypeSwitch<Operation &>(*op) | ||
.Case<tensor_ext::RotateOp>([&](auto rotateOp) { | ||
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; }); | ||
auto shiftConstantOp = | ||
rotateOp.getShift().template getDefiningOp<arith::ConstantOp>(); | ||
// If the rotation shift can't be statically determined, we can't | ||
// propagate anything through the IR. | ||
if (!shiftConstantOp) return; | ||
|
||
int64_t shiftValue = | ||
dyn_cast<IntegerAttr>(shiftConstantOp.getValue()).getInt(); | ||
|
||
// The target slot propagates from the tensor argument to the result; | ||
// the tensor argument is first in the tablegen definition. | ||
const RotationLattice *lattice = operands[0]; | ||
RotationSets latticeRotations = lattice->getValue(); | ||
|
||
// If it's a block argument, then there is no initialized lattice value | ||
// and we can override it with a "zero rotation" | ||
auto blockArg = dyn_cast<BlockArgument>(rotateOp.getTensor()); | ||
if (blockArg) { | ||
latticeRotations = RotationSets::from(blockArg); | ||
} | ||
RotationSets rotated = | ||
RotationSets::rotate(latticeRotations, shiftValue); | ||
|
||
for (RotationLattice *r : results) { | ||
ChangeResult result = r->join(rotated); | ||
propagateIfChanged(r, result); | ||
} | ||
}) | ||
.Default([&](Operation &op) { | ||
// By default, an op propagates its result target slots to all its | ||
// operands. | ||
for (OpOperand &operand : op.getOpOperands()) { | ||
auto *latticeOperand = operands[operand.getOperandNumber()]; | ||
|
||
for (RotationLattice *r : results) { | ||
ChangeResult result = r->join(*latticeOperand); | ||
// If the operand is a block arg, this additionally treats this as | ||
// a zero rotation. If the underlying tensor differs across | ||
// operands, this will also cause a Status::TooManyTensors. | ||
// Otherwise, the join is a no-op. | ||
result |= r->join(RotationSets::from(operand.get())); | ||
propagateIfChanged(r, result); | ||
} | ||
} | ||
}); | ||
} | ||
|
||
void RotationAnalysis::setToEntryState(RotationLattice *lattice) { | ||
lattice->getValue().clear(); | ||
} | ||
|
||
} // namespace rotation_analysis | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.