diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index aee269555bd34..67a880d2e5d61 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1091,7 +1091,12 @@ def Vector_ShapeCastOp : described above is applied to each source/result tuple element pair. It is currently assumed that this operation does not require moving data, - and that it will be canonicalized away before lowering vector operations. + and that it will be folded away before lowering vector operations. + + There is an exception to the folding expectation when targeting + llvm.intr.matrix operations. We need a type conversion back and forth from a + 2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM + is supported in that particular case, for now. Examples: @@ -1108,6 +1113,14 @@ def Vector_ShapeCastOp : tuple, vector<9x2xf32>> ``` }]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return source().getType().cast(); + } + VectorType getResultVectorType() { + return getResult().getType().cast(); + } + }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 8764d487dfb94..00089ebefd124 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -1171,6 +1171,75 @@ class ContractionOpLowering : public OpRewritePattern { } }; +/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D +/// vectors progressively on the way to target llvm.matrix intrinsics. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D +class ShapeCastOp2DDownCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = sourceVectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, resultVectorType, zero); + unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; + for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create(loc, op.source(), i); + desc = rewriter.create( + loc, vec, desc, + /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + +/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D +/// vectors progressively on the way from targeting llvm.matrix intrinsics. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.strided_slice from 1-D + vector.insert into 2-D +class ShapeCastOp2DUpCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) + return matchFailure(); + + auto loc = op.getLoc(); + auto elemType = sourceVectorType.getElementType(); + Value zero = rewriter.create(loc, elemType, + rewriter.getZeroAttr(elemType)); + Value desc = rewriter.create(loc, resultVectorType, zero); + unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; + for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create( + loc, op.source(), /*offsets=*/i * mostMinorVectorSize, + /*sizes=*/mostMinorVectorSize, + /*strides=*/1); + desc = rewriter.create(loc, vec, desc, i); + } + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + } // namespace // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). @@ -1188,5 +1257,9 @@ void mlir::vector::populateVectorSlicesLoweringPatterns( void mlir::vector::populateVectorContractLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns.insert(context); } diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir index 275fd0841a601..c5e40a7c18caf 100644 --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -250,3 +250,42 @@ func @full_contract2(%arg0: vector<2x3xf32>, : vector<2x3xf32>, vector<3x2xf32> into f32 return %0 : f32 } + +// Shape up and downcasts for 2-D vectors, for supporting conversion to +// llvm.matrix operations +// CHECK-LABEL: func @shape_casts +func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { + // CHECK: %[[cst:.*]] = constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[cst22:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> + // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> + // + // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] + // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // + // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> + // + // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] + // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + // + %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> + // CHECK: %[[add:.*]] = addf %[[in2]], %[[in2]] : vector<4xf32> + %r0 = addf %0, %0: vector<4xf32> + // + // CHECK: %[[ss0:.*]] = vector.strided_slice %[[add]] + // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + // CHECK: %[[s2:.*]] = vector.strided_slice %[[add]] + // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> + // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> + return %r0, %1 : vector<4xf32>, vector<2x2xf32> +}