Skip to content

[mlir] Expose linearize/delinearize lowering transforms #144156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ enum class BoundType;

namespace affine {
class AffineApplyOp;
class AffineDelinearizeIndexOp;
class AffineLinearizeIndexOp;

/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
AffineDelinearizeIndexOp op);

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions. Make a best effort to sort the input indices so that
/// the most loop-invariant terms are at the left of the additions
/// to enable loop-invariant code motion.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
AffineLinearizeIndexOp op);
Comment on lines +38 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow why we return a LogicalResult for these methods. It seems that it should always succeed. Should they just return void, or Value and SmallVector<Value> for further transformation?


/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
Expand Down
218 changes: 111 additions & 107 deletions mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
return result;
}

LogicalResult
affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
AffineDelinearizeIndexOp op) {
Location loc = op.getLoc();
Value linearIdx = op.getLinearIndex();
unsigned numResults = op.getNumResults();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numResults == staticBasis.size())
staticBasis = staticBasis.drop_front();

if (numResults == 1) {
rewriter.replaceOp(op, linearIdx);
return success();
}

SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/true);

Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);

Value initialPart =
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
results.push_back(initialPart);

auto emitModTerm = [&](Value stride) -> Value {
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
// If the correction is relevant, this term is <= stride, which is known
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
// this branch won't be taken, so the risk of `poison` is fine.
Value corrected = rewriter.create<arith::AddIOp>(
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
};

// Generate all the intermediate parts
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
Value thisStride = strides[i];
Value nextStride = strides[i + 1];
Value modulus = emitModTerm(thisStride);
// We know both inputs are positive, so floorDiv == div.
// This could potentially be a divui, but it's not clear if that would
// cause issues.
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
results.push_back(divided);
}

results.push_back(emitModTerm(strides.back()));

rewriter.replaceOp(op, results);
return success();
}

LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
AffineLinearizeIndexOp op) {
// Should be folded away, included here for safety.
if (op.getMultiIndex().empty()) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}

Location loc = op.getLoc();
ValueRange multiIndex = op.getMultiIndex();
size_t numIndexes = multiIndex.size();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numIndexes == staticBasis.size())
staticBasis = staticBasis.drop_front();

SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);

// Note: strides doesn't contain a value for the final element (stride 1)
// and everything else lines up. We use the "mutable" accessor so we can get
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
Value scaledIdx = rewriter.create<arith::MulIOp>(
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
scaledValues.emplace_back(
multiIndex.back(),
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));

// Sort by how many enclosing loops there are, ties implicitly broken by
// size of the stride.
llvm::stable_sort(scaledValues,
[&](auto l, auto r) { return l.second > r.second; });

Value result = scaledValues.front().first;
for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
}

namespace {
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value linearIdx = op.getLinearIndex();
unsigned numResults = op.getNumResults();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numResults == staticBasis.size())
staticBasis = staticBasis.drop_front();

if (numResults == 1) {
rewriter.replaceOp(op, linearIdx);
return success();
}

SmallVector<Value> results;
results.reserve(numResults);
SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/true);

Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);

Value initialPart =
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
results.push_back(initialPart);

auto emitModTerm = [&](Value stride) -> Value {
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
Value remainderNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, remainder, zero);
// If the correction is relevant, this term is <= stride, which is known
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
// this branch won't be taken, so the risk of `poison` is fine.
Value corrected = rewriter.create<arith::AddIOp>(
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
corrected, remainder);
return mod;
};

// Generate all the intermediate parts
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
Value thisStride = strides[i];
Value nextStride = strides[i + 1];
Value modulus = emitModTerm(thisStride);
// We know both inputs are positive, so floorDiv == div.
// This could potentially be a divui, but it's not clear if that would
// cause issues.
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
results.push_back(divided);
}

results.push_back(emitModTerm(strides.back()));

rewriter.replaceOp(op, results);
return success();
return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
}
};

/// Lowers `affine.linearize_index` into a sequence of multiplications and
/// additions. Make a best effort to sort the input indices so that
/// the most loop-invariant terms are at the left of the additions
/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
// Should be folded away, included here for safety.
if (op.getMultiIndex().empty()) {
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
return success();
}

Location loc = op.getLoc();
ValueRange multiIndex = op.getMultiIndex();
size_t numIndexes = multiIndex.size();
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
if (numIndexes == staticBasis.size())
staticBasis = staticBasis.drop_front();

SmallVector<Value> strides =
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
/*knownNonNegative=*/op.getDisjoint());
SmallVector<std::pair<Value, int64_t>> scaledValues;
scaledValues.reserve(numIndexes);

// Note: strides doesn't contain a value for the final element (stride 1)
// and everything else lines up. We use the "mutable" accessor so we can get
// our hands on an `OpOperand&` for the loop invariant counting function.
for (auto [stride, idxOp] :
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
Value scaledIdx = rewriter.create<arith::MulIOp>(
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
}
scaledValues.emplace_back(
multiIndex.back(),
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));

// Sort by how many enclosing loops there are, ties implicitly broken by
// size of the stride.
llvm::stable_sort(scaledValues,
[&](auto l, auto r) { return l.second > r.second; });

Value result = scaledValues.front().first;
for (auto [scaledValue, numHoistableLoops] :
llvm::drop_begin(scaledValues)) {
std::ignore = numHoistableLoops;
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
arith::IntegerOverflowFlags::nsw);
}
rewriter.replaceOp(op, result);
return success();
return affine::lowerAffineLinearizeIndexOp(rewriter, op);
}
};

Expand Down
Loading