Skip to content

Commit

Permalink
[mlir][linalg] Fix the semantic use of a flag (#90081)
Browse files Browse the repository at this point in the history
`useInBoundsInsteadOfMasking` was doing the opposite i.e., when set to
true; was updating the mask instead of updating the inBounds.
  • Loading branch information
pashu123 committed Apr 26, 2024
1 parent 5350052 commit 8feedd5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ bool isLinearizableVector(VectorType type);
/// for each dimension of the passed in tensor.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> readShape, Value padValue,
bool useInBoundsInsteadOfMasking = true);
bool useInBoundsInsteadOfMasking);

/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1499,11 +1499,11 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
// If the input vector sizes are not provided, then the vector sizes are
// determined by the result tensor shape. In case the vector sizes aren't
// provided, we update the inBounds attribute instead of masking.
bool useInBoundsInsteadOfMasking = true;
bool useInBoundsInsteadOfMasking = false;
if (inputVectorSizes.empty()) {
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
useInBoundsInsteadOfMasking = false;
useInBoundsInsteadOfMasking = true;
}

// Create masked TransferReadOp.
Expand Down Expand Up @@ -1612,7 +1612,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
// to shape of source, then a mask is necessary.
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(),
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue);
ArrayRef<int64_t>(readMaskShape.begin(), readMaskShape.end()), padValue,
/*useInBoundsInsteadOfMasking=*/false);

PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
Expand Down Expand Up @@ -1669,7 +1670,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue);
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, maskedRead, reifiedReturnShapes[0], inputVectorSizes);
newResults.push_back(write->getResult(0));
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<bool> inBoundsVal(readRank, true);
if (!useInBoundsInsteadOfMasking) {
if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
for (unsigned i = 0; i < readRank; i++)
inBoundsVal[i] = (sourceShape[i] == readShape[i]) &&
Expand All @@ -359,7 +359,7 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
/*padding=*/padValue,
/*inBounds=*/inBoundsVal);

if (llvm::equal(readShape, sourceShape) || !useInBoundsInsteadOfMasking)
if (llvm::equal(readShape, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);
Expand Down

0 comments on commit 8feedd5

Please sign in to comment.