Skip to content

Commit

Permalink
[mlir][BuiltinTypes] Return VectorType from VectorType::Builder conve…
Browse files Browse the repository at this point in the history
…rsion operator

0-D vectors are now supported, so the special case of returning the just
the element type can now be removed.

A few callers that relied on the old behaviour have been updated.

Reviewed By: awarzynski, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D159122
  • Loading branch information
MacDue committed Aug 30, 2023
1 parent 715cde0 commit 296d5cb
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 19 deletions.
7 changes: 1 addition & 6 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,7 @@ class VectorType::Builder {
return *this;
}

/// In the particular case where the vector has a single dimension that we
/// drop, return the scalar element type.
// TODO: unify once we have a VectorType that supports 0-D.
operator Type() {
if (shape.empty())
return elementType;
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2216,7 +2216,7 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
return failure();
if (mask.size() != 1)
return failure();
Type resType = VectorType::Builder(v1VectorType).setShape({1});
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
shuffleOp.getV1());
Expand Down
23 changes: 11 additions & 12 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,20 @@ static Value reshapeLoad(Location loc, Value val, VectorType type,
PatternRewriter &rewriter) {
if (index == -1)
return val;
Type lowType = VectorType::Builder(type).dropDim(0);
Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0)
: type.getElementType();
// At extraction dimension?
if (index == 0)
return rewriter.create<vector::ExtractOp>(loc, lowType, val, pos);
// Unroll leading dimensions.
VectorType vType = cast<VectorType>(lowType);
Type resType = VectorType::Builder(type).dropDim(index);
auto resVectorType = cast<VectorType>(resType);
VectorType resType = VectorType::Builder(type).dropDim(index);
Value result = rewriter.create<arith::ConstantOp>(
loc, resVectorType, rewriter.getZeroAttr(resVectorType));
for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) {
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, val, d);
Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
result =
rewriter.create<vector::InsertOp>(loc, resVectorType, load, result, d);
result = rewriter.create<vector::InsertOp>(loc, resType, load, result, d);
}
return result;
}
Expand All @@ -120,13 +119,13 @@ static Value reshapeStore(Location loc, Value val, Value result,
if (index == 0)
return rewriter.create<vector::InsertOp>(loc, type, val, result, pos);
// Unroll leading dimensions.
Type lowType = VectorType::Builder(type).dropDim(0);
VectorType vType = cast<VectorType>(lowType);
Type insType = VectorType::Builder(vType).dropDim(0);
VectorType lowType = VectorType::Builder(type).dropDim(0);
Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0)
: lowType.getElementType();
for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
Value ext = rewriter.create<vector::ExtractOp>(loc, vType, result, d);
Value ext = rewriter.create<vector::ExtractOp>(loc, lowType, result, d);
Value ins = rewriter.create<vector::ExtractOp>(loc, insType, val, d);
Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter);
result = rewriter.create<vector::InsertOp>(loc, type, sto, result, d);
}
return result;
Expand Down

0 comments on commit 296d5cb

Please sign in to comment.