diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 5463a7bd8f4c84..6dbe36e605e9a7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator return v; Type promotedType = dstElementType; if (vecType) - promotedType = VectorType::get(vecType.getShape(), promotedType); + promotedType = vecType.clone(promotedType); if (isa(dstElementType)) return rewriter.create(loc, promotedType, v); return rewriter.create(loc, promotedType, v); diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir index 44fb23088cea93..6933b24a32a830 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>, return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> +// +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> +// +// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32> +// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32> +// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> +// +// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32> +// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> +// +// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32> +// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32> +// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// CHECK-SAME: : vector<2xf32>, vector<[3]xf32> +// +// CHECK: return %[[c3]] : vector<2x[3]xf32> +func.func @matmul_scalable(%arg0: vector<2x4xf32>, + %arg1: vector<4x[3]xf32>, + %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + // CHECK-LABEL: func @matmul_0 // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, @@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_0_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x[3]xf32> +func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>) +-> vector<2x[3]xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + // CHECK-LABEL: func @matmul_0_mixed // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, @@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_0_mixed_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16> +// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> +// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] +// CHECK: return %[[c0]] : vector<2x[3]xf32> +func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>) +-> vector<2x[3]xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + #matmat_accesses_1 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (n, k)>, @@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_1_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x[3]xf32> +func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>) +-> vector<2x[3]xf32> +{ + %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + #matmat_accesses_2 = [ affine_map<(m, n, k) -> (k, m)>, affine_map<(m, n, k) -> (k, n)>, @@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_2_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x[3]xf32> +func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>) +-> vector<2x[3]xf32> +{ + %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + #matmat_accesses_3 = [ affine_map<(m, n, k) -> (k, m)>, affine_map<(m, n, k) -> (n, k)>, @@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto return %0 : vector<2x3xf32> } +// CHECK-LABEL: func @matmul_3_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x[3]xf32> +func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>) +-> vector<2x[3]xf32> +{ + %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32> + return %0 : vector<2x[3]xf32> +} + #matmat_accesses_4 = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, @@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @matmul_4_scalable +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x[2]xf32> +func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>) +-> vector<3x[2]xf32> +{ + %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 + : vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32> + return %0 : vector<3x[2]xf32> +} + +#matmat_accesses_5 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_5 = { + indexing_maps = #matmat_accesses_5, + iterator_types = ["parallel", "parallel", "reduction"] +} + // CHECK-LABEL: @masked_matvec_mk_k_m // CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> // CHECK-SAME: %[[VEC:.+]]: vector<2xf32>