diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index c81bb4b455b98..2bece58a119b3 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -846,6 +847,7 @@ struct ConvertVectorToXeGPUPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); + populatePrepareVectorToMMAPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir index 38bda39d3aca2..292e4ff882000 100644 --- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir @@ -76,6 +76,56 @@ func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>, // ----- +#map = affine_map<(d0, d1, d2) -> (d2, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>, + %acc: vector<8x16xf32>) -> vector<8x16xf32> { + %3 = vector.contract + {indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %acc + : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32> + return %3 : vector<8x16xf32> +} + +// CHECK-LABEL: @gemm_transpose_a( +// CHECK-SAME: %[[LHS:.+]]: vector<16x8xf16>, +// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>, +// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32> +// CHECK: %[[LHS_TRANSPOSED:.+]] = vector.transpose %[[LHS]], [1, 0] : vector<16x8xf16> to vector<8x16xf16> +// CHECK: %[[DPAS:.+]] = xegpu.dpas +// CHECK-SAME: %[[LHS_TRANSPOSED]], %[[RHS]], %[[ACC]] +// CHECK-SAME: {{.*}}-> vector<8x16xf32> +// CHECK: return %[[DPAS]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>, + %acc: vector<8x16xf32>) -> vector<8x16xf32> { + %3 = vector.contract + {indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %acc + : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32> + return %3 : vector<8x16xf32> +} + +// CHECK-LABEL: @gemm_transpose_b( +// CHECK-SAME: %[[LHS:.+]]: vector<8x16xf16>, +// CHECK-SAME: %[[RHS:.+]]: vector<16x16xf16>, +// CHECK-SAME: %[[ACC:.+]]: vector<8x16xf32> +// CHECK: %[[RHS_TRANSPOSED:.+]] = vector.transpose %[[RHS]], [1, 0] : vector<16x16xf16> to vector<16x16xf16> +// CHECK: %[[DPAS:.+]] = xegpu.dpas +// CHECK-SAME: %[[LHS]], %[[RHS_TRANSPOSED]], %[[ACC]] +// CHECK-SAME: {{.*}}-> vector<8x16xf32> +// CHECK: return %[[DPAS]] + +// ----- + // For simplicity, only plain data layouts are currently supported. // VNNI packing is applied later as a separate lowering step. @@ -130,39 +180,3 @@ func.func @negative_accumulator_shape(%lhs: vector<8x16xf16>, %rhs: vector<16x16 // CHECK-LABEL: @negative_accumulator_shape( // CHECK: vector.contract - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d2, d0)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @negative_gemm_transpose_a(%lhs: vector<16x8xf16>, %rhs: vector<16x16xf16>, - %acc: vector<8x16xf32>) -> vector<8x16xf32> { - %3 = vector.contract - {indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"], - kind = #vector.kind} %lhs, %rhs, %acc - : vector<16x8xf16>, vector<16x16xf16> into vector<8x16xf32> - return %3 : vector<8x16xf32> -} - -// CHECK-LABEL: @negative_gemm_transpose_a( -// CHECK: vector.contract - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16xf16>, - %acc: vector<8x16xf32>) -> vector<8x16xf32> { - %3 = vector.contract - {indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"], - kind = #vector.kind} %lhs, %rhs, %acc - : vector<8x16xf16>, vector<16x16xf16> into vector<8x16xf32> - return %3 : vector<8x16xf32> -} - -// CHECK-LABEL: @negative_gemm_transpose_b( -// CHECK: vector.contract