From fc1656c2eafddf88cbf6312e777fa6b158350cc0 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 11 Sep 2025 14:05:03 +0000 Subject: [PATCH 1/6] [XeGPU][Transform] Add vectorlinearize transform pass. Use upstream patterns to create a vectorlinearize pass needed for lowering to xevm. Linearizes n-D vectors to 1-D vectors. --- .../mlir/Dialect/XeGPU/Transforms/Passes.td | 9 ++ .../Dialect/XeGPU/Transforms/CMakeLists.txt | 1 + .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 111 ++++++++++++++++++ 3 files changed, 121 insertions(+) create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 59dca9f0d852a..77c57ccb0746f 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -75,4 +75,13 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> { "index::IndexDialect"]; } +def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> { + let summary = "Linearize n-D vectors to 1-D vectors"; + let description = [{ + This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM. + }]; + let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect", + "scf::SCFDialect", "vector::VectorDialect"]; +} + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 9c178d1d85642..e6f76067094ce 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUUnroll.cpp XeGPUWgToSgDistribute.cpp XeGPUPropagateLayout.cpp + XeGPUVectorLinearize.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp new file mode 100644 index 0000000000000..a6a68716547c9 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -0,0 +1,111 @@ +//===- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors +//-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-vector-linearize" + +using namespace mlir; + +namespace { +struct XeGPUVectorLinearizePass final + : public xegpu::impl::XeGPUVectorLinearizeBase { + void runOnOperation() override { + auto *context = &getContext(); + + // vector.broadcast and vector.gather requires progressive lowering + { + mlir::RewritePatternSet patterns(&getContext()); + mlir::vector::populateVectorBroadcastLoweringPatterns(patterns); + mlir::vector::populateVectorGatherLoweringPatterns(patterns); + mlir::vector::populateVectorGatherToConditionalLoadPatterns(patterns); + // vector.transpose lowering + // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes. + mlir::vector::populateVectorTransposeLoweringPatterns( + patterns, mlir::vector::VectorTransposeLowering::Shuffle16x16); + (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); + } + + // Unroll load store from < to M <1xN> load/stores and then linearize + { + mlir::RewritePatternSet patterns(&getContext()); + mlir::vector::UnrollVectorOptions vectorOptions; + vectorOptions.setNativeShapeFn( + [](mlir::Operation *op) -> std::optional> { + // Only unroll for vector::LoadOp and vector::StoreOp + if (mlir::isa(op)) { + if (auto vecType = mlir::dyn_cast( + op->getResult(0).getType())) { + auto shape = vecType.getShape(); + if (shape.size() == 2) + return mlir::SmallVector{1, shape[1]}; + } + } + if (mlir::isa(op)) { + if (auto vecType = mlir::dyn_cast( + op->getOperand(0).getType())) { + auto shape = vecType.getShape(); + if (shape.size() == 2) + return mlir::SmallVector{1, shape[1]}; + } + } + return std::nullopt; + }); + mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions); + (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); + } + + // Use upstream linearization patterns + { + mlir::MLIRContext &context = getContext(); + mlir::TypeConverter converter; + mlir::RewritePatternSet patterns(&context); + mlir::ConversionTarget target(context); + mlir::vector::populateForVectorLinearize(converter, target); + mlir::vector::populateVectorLinearizeBasePatterns(converter, target, + patterns); + mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( + converter, target, patterns); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } + + mlir::TypeConverter typeConverter; + mlir::RewritePatternSet patterns(context); + mlir::ConversionTarget target(*context); + typeConverter.addConversion([](mlir::Type type) { return type; }); + + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalDialect(); + } +}; +} // namespace From 884a06924bde2ccd73633977bcceed9bc10579e0 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Mon, 15 Sep 2025 22:12:20 +0000 Subject: [PATCH 2/6] Address review comments. Add test case. --- .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 52 +-- .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 362 ++++++++++++++++++ 2 files changed, 383 insertions(+), 31 deletions(-) create mode 100644 mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp index a6a68716547c9..78648042ae127 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -35,8 +35,6 @@ namespace { struct XeGPUVectorLinearizePass final : public xegpu::impl::XeGPUVectorLinearizeBase { void runOnOperation() override { - auto *context = &getContext(); - // vector.broadcast and vector.gather requires progressive lowering { mlir::RewritePatternSet patterns(&getContext()); @@ -56,30 +54,32 @@ struct XeGPUVectorLinearizePass final mlir::vector::UnrollVectorOptions vectorOptions; vectorOptions.setNativeShapeFn( [](mlir::Operation *op) -> std::optional> { - // Only unroll for vector::LoadOp and vector::StoreOp - if (mlir::isa(op)) { - if (auto vecType = mlir::dyn_cast( - op->getResult(0).getType())) { - auto shape = vecType.getShape(); - if (shape.size() == 2) - return mlir::SmallVector{1, shape[1]}; - } - } - if (mlir::isa(op)) { - if (auto vecType = mlir::dyn_cast( - op->getOperand(0).getType())) { - auto shape = vecType.getShape(); - if (shape.size() == 2) - return mlir::SmallVector{1, shape[1]}; - } - } - return std::nullopt; + auto extractVectorType = + [](mlir::Operation *op) -> mlir::VectorType { + if (auto loadOp = mlir::dyn_cast(op)) + return mlir::dyn_cast( + loadOp.getResult().getType()); + if (auto storeOp = mlir::dyn_cast(op)) + return mlir::dyn_cast( + storeOp.getValueToStore().getType()); + return nullptr; + }; + + auto vecType = extractVectorType(op); + if (!vecType) + return std::nullopt; + + auto shape = vecType.getShape(); + if (shape.size() != 2) + return std::nullopt; + + return mlir::SmallVector{1, shape[1]}; }); mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions); (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); } - // Use upstream linearization patterns + // Use vector linearization patterns { mlir::MLIRContext &context = getContext(); mlir::TypeConverter converter; @@ -96,16 +96,6 @@ struct XeGPUVectorLinearizePass final std::move(patterns)))) return signalPassFailure(); } - - mlir::TypeConverter typeConverter; - mlir::RewritePatternSet patterns(context); - mlir::ConversionTarget target(*context); - typeConverter.addConversion([](mlir::Type type) { return type; }); - - target.addIllegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalDialect(); } }; } // namespace diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir new file mode 100644 index 0000000000000..61720884002c2 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -0,0 +1,362 @@ +// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize | FileCheck %s + +// CHECK-LABEL: @test_linearize +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xf32>) -> vector<2x2xf32> { +// CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xf32> to vector<4xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK: %[[T1:.*]] = math.sin %[[T0]] : vector<4xf32> +// CHECK: %[[T2:.*]] = arith.addf %[[T0]], %[[CST]] : vector<4xf32> +// CHECK: %[[T3:.*]] = arith.addf %[[T2]], %[[T1]] : vector<4xf32> +// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<4xf32> to vector<2x2xf32> +// CHECK: return %[[T4]] : vector<2x2xf32> +func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> +// Arith and math ops are handled in generic way, check some of them + %1 = math.sin %arg0 : vector<2x2xf32> + %2 = arith.addf %arg0, %0 : vector<2x2xf32> + %3 = arith.addf %2, %1 : vector<2x2xf32> + return %3 : vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: test_const_novector +// CHECK: %[[R:.*]] = arith.constant 42 : i32 +// CHECK: return %[[R]] : i32 +func.func @test_const_novector() -> i32 { + %0 = arith.constant 42 : i32 + return %0 : i32 +} + +// ----- +// CHECK-LABEL: test_create_mask +// CHECK: vector.create_mask {{.*}} : vector<16xi1> +func.func @test_create_mask() -> vector<1x16xi1> { + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> + return %0 : vector<1x16xi1> +} + +// ----- +// CHECK-LABEL: test_extract_strided_slice +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<8x16xf32>) -> vector<8x8xf32> +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<8x16xf32> to vector<128xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK: [8, 9, 10, 11, 12, 13, 14, 15, +// CHECK: 24, 25, 26, 27, 28, 29, 30, 31, +// CHECK: 40, 41, 42, 43, 44, 45, 46, 47, +// CHECK: 56, 57, 58, 59, 60, 61, 62, 63, +// CHECK: 72, 73, 74, 75, 76, 77, 78, 79, +// CHECK: 88, 89, 90, 91, 92, 93, 94, 95, +// CHECK: 104, 105, 106, 107, 108, 109, 110, 111, +// CHECK: 120, 121, 122, 123, 124, 125, 126, 127] : vector<128xf32>, vector<128xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<8x8xf32> +// CHECK: return %[[RES]] : vector<8x8xf32> +func.func @test_extract_strided_slice_1(%arg0 : vector<8x16xf32>) -> vector<8x8xf32> { + %0 = vector.extract_strided_slice %arg0 { sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} + : vector<8x16xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} + +// ----- +// CHECK-LABEL: test_extract_strided_slice_2 +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x32x8xf32>) -> vector<1x8x8xf32> +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x32x8xf32> to vector<512xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK: [448, 449, 450, 451, 452, 453, 454, 455, +// CHECK: 456, 457, 458, 459, 460, 461, 462, 463, +// CHECK: 464, 465, 466, 467, 468, 469, 470, 471, +// CHECK: 472, 473, 474, 475, 476, 477, 478, 479, +// CHECK: 480, 481, 482, 483, 484, 485, 486, 487, +// CHECK: 488, 489, 490, 491, 492, 493, 494, 495, +// CHECK: 496, 497, 498, 499, 500, 501, 502, 503, +// CHECK: 504, 505, 506, 507, 508, 509, 510, 511] : vector<512xf32>, vector<512xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<1x8x8xf32> +// CHECK: return %[[RES]] : vector<1x8x8xf32> +func.func @test_extract_strided_slice_2(%arg0 : vector<2x32x8xf32>) -> vector<1x8x8xf32> { + %0 = vector.extract_strided_slice %arg0 { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] } + : vector<2x32x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// ----- +// CHECK-LABEL: test_vector_shuffle +// CHECK-SAME: (%[[ORIG_ARG1:.*]]: vector<4x4xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -> vector<8x4xf32> { +// CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32> +// CHECK: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x4xf32> to vector<16xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG1]], %[[ARG2]] +// CHECK: [0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, +// CHECK: 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32> +// CHECK: return %[[RES]] : vector<8x4xf32> +func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -> vector<8x4xf32> { + %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x4xf32>, vector<4x4xf32> + return %0 : vector<8x4xf32> +} + +// ----- +// CHECK-LABEL: test_vector_extract +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x4xf32>) -> vector<8x4xf32> +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x4xf32> to vector<64xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK: [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, +// CHECK: 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<64xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32> +// CHECK: return %[[RES]] : vector<8x4xf32> +func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> { + %0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32> + return %0 : vector<8x4xf32> +} + +// ----- +// CHECK-LABEL: test_vector_insert +// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> +// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> +// CHECK: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] +// CHECK: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, +// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, +// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> +// CHECK: return %[[RES]] : vector<2x8x4xf32> +func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { + %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} + +// ----- +// CHECK-LABEL: test_vector_insert_2d_idx +// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32> +// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[SRC]] +// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 64, 65, 66, 67, 16, 17, 18, 19, 20, 21, +// CHECK-SAME: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, +// CHECK-SAME: 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<4xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> +// CHECK: return %[[RES]] : vector<2x8x4xf32> +func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf32>) -> vector<2x8x4xf32> { + %0 = vector.insert %arg1, %arg0[0, 3]: vector<4xf32> into vector<2x8x4xf32> + return %0 : vector<2x8x4xf32> +} + +// ----- +// CHECK-LABEL: test_vector_transpose +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32> +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32> +// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] +// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32> +// CHECK: return %[[RES]] : vector<8x2xf32> +func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> { + %0 = vector.transpose %arg, [1, 0] : vector<2x8xf32> to vector<8x2xf32> + return %0 : vector<8x2xf32> +} + +// ----- +// CHECK-LABEL: test_vector_transpose_16x16 +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> { + %0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32> + return %0 : vector<16x16xf32> +} + +// ----- +// CHECK-LABEL: func.func @test_vector_store_load_4x4 +// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf32>) +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> +func.func @test_vector_store_load_4x4(%buffer: memref<4x4xf32>) { + %c0 = arith.constant 0 : index + %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32> + vector.store %0, %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32> + return +} + +// ----- + +func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) { + %c0 = arith.constant 0 : index + %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> + return +} +// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16 +// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf16>) +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> + +// ----- +// CHECK-LABEL: @test_linearize_index +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> { +// CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[T1:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex> +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[CST]] : vector<4xindex> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : vector<4xindex> to vector<4xi32> +// CHECK: %[[T4:.*]] = arith.muli %[[T3]], %[[T0]] : vector<4xi32> +// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : vector<4xi32> to vector<4xindex> +// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<4xindex> to vector<2x2xindex> +// CHECK: return %[[T6]] : vector<2x2xindex> +func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> { + %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex> +// Arith and math ops are handled in generic way, check some of them + %1 = arith.addi %arg0, %0 : vector<2x2xindex> + %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32> + %3 = arith.muli %2, %arg1 : vector<2x2xi32> + %4 = arith.index_cast %3 : vector<2x2xi32> to vector<2x2xindex> + return %4 : vector<2x2xindex> +} + +// ----- +// CHECK-LABEL: @add_kernel_f32 +// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> +// CHECK: %[[CST1:.*]] = arith.constant dense<[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<16xindex> +// CHECK: %[[T0:.*]] = vector.splat %{{.*}} : vector<16xindex> +// CHECK: %[[T1:.*]] = arith.addi %[[T0]], %[[CST0]] : vector<16xindex> +// CHECK: %[[T2:.*]] = arith.addi %[[T0]], %[[CST1]] : vector<16xindex> +// CHECK: %[[T3:.*]] = arith.index_cast %[[T1]] : vector<16xindex> to vector<16xi32> +// CHECK: %[[T4:.*]] = arith.index_cast %[[T2]] : vector<16xindex> to vector<16xi32> +// CHECK: %[[T5:.*]] = vector.splat %{{.*}} : vector<16xi32> +// CHECK: %[[T6:.*]] = arith.addi %[[T5]], %[[T3]] : vector<16xi32> +// CHECK: %[[T7:.*]] = arith.addi %[[T5]], %[[T4]] : vector<16xi32> +// CHECK: %[[T8:.*]] = arith.index_cast %[[T6]] : vector<16xi32> to vector<16xindex> +// CHECK: %[[T9:.*]] = arith.index_cast %[[T7]] : vector<16xi32> to vector<16xindex> +gpu.module @add_kernel_f32 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @add_kernel_f32(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %cst = arith.constant dense : vector<16xi1> + %c32 = arith.constant 32 : index + %c1024_i32 = arith.constant 1024 : i32 + %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> + %cst_1 = arith.constant dense<[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x16xindex> + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %block_dim_y = gpu.block_dim y + %0 = arith.muli %thread_id_x, %block_dim_y : index + %1 = arith.addi %0, %thread_id_y : index + %cast = memref.cast %arg0 : memref<*xf32> to memref + %cast_2 = memref.cast %arg1 : memref<*xf32> to memref + %cast_3 = memref.cast %arg2 : memref<*xf32> to memref + %2 = arith.remsi %1, %c32 : index + %3 = arith.muli %2, %c32 : index + %4 = vector.splat %3 : vector<1x16xindex> + %5 = arith.addi %4, %cst_0 : vector<1x16xindex> + %6 = arith.addi %4, %cst_1 : vector<1x16xindex> + %7 = arith.index_cast %5 : vector<1x16xindex> to vector<1x16xi32> + %8 = arith.index_cast %6 : vector<1x16xindex> to vector<1x16xi32> + %block_id_x = gpu.block_id x + %9 = arith.index_cast %block_id_x : index to i32 + %10 = arith.muli %9, %c1024_i32 : i32 + %11 = vector.splat %10 : vector<1x16xi32> + %12 = arith.addi %11, %7 : vector<1x16xi32> + %13 = arith.addi %11, %8 : vector<1x16xi32> + %14 = arith.index_cast %12 : vector<1x16xi32> to vector<1x16xindex> + %15 = arith.index_cast %13 : vector<1x16xi32> to vector<1x16xindex> + %16 = vector.shape_cast %14 : vector<1x16xindex> to vector<16xindex> + %17 = xegpu.create_tdesc %cast, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %18 = vector.shape_cast %15 : vector<1x16xindex> to vector<16xindex> + %19 = xegpu.create_tdesc %cast, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %20 = xegpu.load %17, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %21 = vector.shape_cast %20 : vector<16xf32> to vector<1x16xf32> + %22 = xegpu.load %19, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %23 = vector.shape_cast %22 : vector<16xf32> to vector<1x16xf32> + %24 = xegpu.create_tdesc %cast_2, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %25 = xegpu.create_tdesc %cast_2, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %26 = xegpu.load %24, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %27 = vector.shape_cast %26 : vector<16xf32> to vector<1x16xf32> + %28 = xegpu.load %25, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> + %29 = vector.shape_cast %28 : vector<16xf32> to vector<1x16xf32> + %30 = arith.addf %21, %27 : vector<1x16xf32> + %31 = arith.addf %23, %29 : vector<1x16xf32> + %32 = xegpu.create_tdesc %cast_3, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %33 = xegpu.create_tdesc %cast_3, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> + %34 = vector.shape_cast %30 : vector<1x16xf32> to vector<16xf32> + xegpu.store %34, %32, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %35 = vector.shape_cast %31 : vector<1x16xf32> to vector<16xf32> + xegpu.store %35, %33, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + gpu.return + } +} From f12ddd03cdeed752b7b6784e81642ae6df46cc1c Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Wed, 17 Sep 2025 23:06:48 +0000 Subject: [PATCH 3/6] Address review comments. Update the test case to remove duplication with vector-linearize. Add new test cases for XeGPU, vector.broadcast, vector.gather. --- .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 67 ++- .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 529 ++++++++---------- 2 files changed, 258 insertions(+), 338 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp index 78648042ae127..2bb302f4287c4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -1,5 +1,4 @@ -//===- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors -//-------===// +//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -37,31 +36,29 @@ struct XeGPUVectorLinearizePass final void runOnOperation() override { // vector.broadcast and vector.gather requires progressive lowering { - mlir::RewritePatternSet patterns(&getContext()); - mlir::vector::populateVectorBroadcastLoweringPatterns(patterns); - mlir::vector::populateVectorGatherLoweringPatterns(patterns); - mlir::vector::populateVectorGatherToConditionalLoadPatterns(patterns); + RewritePatternSet patterns(&getContext()); + vector::populateVectorBroadcastLoweringPatterns(patterns); + vector::populateVectorGatherLoweringPatterns(patterns); + vector::populateVectorGatherToConditionalLoadPatterns(patterns); // vector.transpose lowering // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes. - mlir::vector::populateVectorTransposeLoweringPatterns( - patterns, mlir::vector::VectorTransposeLowering::Shuffle16x16); - (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransposeLowering::Shuffle16x16); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); } - // Unroll load store from < to M <1xN> load/stores and then linearize + // Unroll load store from to M <1xN> load/stores and then linearize { - mlir::RewritePatternSet patterns(&getContext()); - mlir::vector::UnrollVectorOptions vectorOptions; + RewritePatternSet patterns(&getContext()); + vector::UnrollVectorOptions vectorOptions; vectorOptions.setNativeShapeFn( - [](mlir::Operation *op) -> std::optional> { - auto extractVectorType = - [](mlir::Operation *op) -> mlir::VectorType { - if (auto loadOp = mlir::dyn_cast(op)) - return mlir::dyn_cast( - loadOp.getResult().getType()); - if (auto storeOp = mlir::dyn_cast(op)) - return mlir::dyn_cast( - storeOp.getValueToStore().getType()); + [](Operation *op) -> std::optional> { + auto extractVectorType = [](Operation *op) -> VectorType { + if (auto loadOp = dyn_cast(op)) + return loadOp.getVectorType(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getVectorType(); return nullptr; }; @@ -73,25 +70,25 @@ struct XeGPUVectorLinearizePass final if (shape.size() != 2) return std::nullopt; - return mlir::SmallVector{1, shape[1]}; + return SmallVector{1, shape[1]}; }); - mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions); - (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); + vector::populateVectorUnrollPatterns(patterns, vectorOptions); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); } // Use vector linearization patterns { - mlir::MLIRContext &context = getContext(); - mlir::TypeConverter converter; - mlir::RewritePatternSet patterns(&context); - mlir::ConversionTarget target(context); - mlir::vector::populateForVectorLinearize(converter, target); - mlir::vector::populateVectorLinearizeBasePatterns(converter, target, - patterns); - mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( - converter, target, patterns); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - converter, patterns, target); + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + vector::populateForVectorLinearize(converter, target); + vector::populateVectorLinearizeBasePatterns(converter, target, patterns); + vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, + patterns); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir index 61720884002c2..9985736e2cafb 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -1,131 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize -canonicalize | FileCheck %s -// CHECK-LABEL: @test_linearize -// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xf32>) -> vector<2x2xf32> { -// CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xf32> to vector<4xf32> -// CHECK: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> -// CHECK: %[[T1:.*]] = math.sin %[[T0]] : vector<4xf32> -// CHECK: %[[T2:.*]] = arith.addf %[[T0]], %[[CST]] : vector<4xf32> -// CHECK: %[[T3:.*]] = arith.addf %[[T2]], %[[T1]] : vector<4xf32> -// CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<4xf32> to vector<2x2xf32> -// CHECK: return %[[T4]] : vector<2x2xf32> -func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { - %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> -// Arith and math ops are handled in generic way, check some of them - %1 = math.sin %arg0 : vector<2x2xf32> - %2 = arith.addf %arg0, %0 : vector<2x2xf32> - %3 = arith.addf %2, %1 : vector<2x2xf32> - return %3 : vector<2x2xf32> -} - -// ----- - -// CHECK-LABEL: test_const_novector -// CHECK: %[[R:.*]] = arith.constant 42 : i32 -// CHECK: return %[[R]] : i32 -func.func @test_const_novector() -> i32 { - %0 = arith.constant 42 : i32 - return %0 : i32 -} - -// ----- -// CHECK-LABEL: test_create_mask -// CHECK: vector.create_mask {{.*}} : vector<16xi1> -func.func @test_create_mask() -> vector<1x16xi1> { - %c0 = arith.constant 0 : index - %c20 = arith.constant 20 : index - %0 = vector.create_mask %c0, %c20 : vector<1x16xi1> - return %0 : vector<1x16xi1> -} - -// ----- -// CHECK-LABEL: test_extract_strided_slice -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<8x16xf32>) -> vector<8x8xf32> -// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<8x16xf32> to vector<128xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] -// CHECK: [8, 9, 10, 11, 12, 13, 14, 15, -// CHECK: 24, 25, 26, 27, 28, 29, 30, 31, -// CHECK: 40, 41, 42, 43, 44, 45, 46, 47, -// CHECK: 56, 57, 58, 59, 60, 61, 62, 63, -// CHECK: 72, 73, 74, 75, 76, 77, 78, 79, -// CHECK: 88, 89, 90, 91, 92, 93, 94, 95, -// CHECK: 104, 105, 106, 107, 108, 109, 110, 111, -// CHECK: 120, 121, 122, 123, 124, 125, 126, 127] : vector<128xf32>, vector<128xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<8x8xf32> -// CHECK: return %[[RES]] : vector<8x8xf32> -func.func @test_extract_strided_slice_1(%arg0 : vector<8x16xf32>) -> vector<8x8xf32> { - %0 = vector.extract_strided_slice %arg0 { sizes = [8, 8], strides = [1, 1], offsets = [0, 8]} - : vector<8x16xf32> to vector<8x8xf32> - return %0 : vector<8x8xf32> -} - -// ----- -// CHECK-LABEL: test_extract_strided_slice_2 -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x32x8xf32>) -> vector<1x8x8xf32> -// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x32x8xf32> to vector<512xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] -// CHECK: [448, 449, 450, 451, 452, 453, 454, 455, -// CHECK: 456, 457, 458, 459, 460, 461, 462, 463, -// CHECK: 464, 465, 466, 467, 468, 469, 470, 471, -// CHECK: 472, 473, 474, 475, 476, 477, 478, 479, -// CHECK: 480, 481, 482, 483, 484, 485, 486, 487, -// CHECK: 488, 489, 490, 491, 492, 493, 494, 495, -// CHECK: 496, 497, 498, 499, 500, 501, 502, 503, -// CHECK: 504, 505, 506, 507, 508, 509, 510, 511] : vector<512xf32>, vector<512xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<1x8x8xf32> -// CHECK: return %[[RES]] : vector<1x8x8xf32> -func.func @test_extract_strided_slice_2(%arg0 : vector<2x32x8xf32>) -> vector<1x8x8xf32> { - %0 = vector.extract_strided_slice %arg0 { offsets = [1, 24], strides = [1, 1], sizes = [1, 8] } - : vector<2x32x8xf32> to vector<1x8x8xf32> - return %0 : vector<1x8x8xf32> -} - -// ----- -// CHECK-LABEL: test_vector_shuffle -// CHECK-SAME: (%[[ORIG_ARG1:.*]]: vector<4x4xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) -> vector<8x4xf32> { -// CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32> -// CHECK: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x4xf32> to vector<16xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG1]], %[[ARG2]] -// CHECK: [0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, -// CHECK: 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32> -// CHECK: return %[[RES]] : vector<8x4xf32> -func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -> vector<8x4xf32> { - %0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x4xf32>, vector<4x4xf32> - return %0 : vector<8x4xf32> -} - -// ----- -// CHECK-LABEL: test_vector_extract -// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x4xf32>) -> vector<8x4xf32> -// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x4xf32> to vector<64xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]] -// CHECK: [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, -// CHECK: 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<64xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<32xf32> to vector<8x4xf32> -// CHECK: return %[[RES]] : vector<8x4xf32> -func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> { - %0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32> - return %0 : vector<8x4xf32> -} - -// ----- -// CHECK-LABEL: test_vector_insert -// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> -// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> -// CHECK: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32> -// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]] -// CHECK: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, -// CHECK-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, -// CHECK-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32> -// CHECK: return %[[RES]] : vector<2x8x4xf32> -func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> { - %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32> - return %0 : vector<2x8x4xf32> -} - -// ----- // CHECK-LABEL: test_vector_insert_2d_idx // CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32> // CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32> @@ -157,133 +31,49 @@ func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> { // CHECK-LABEL: test_vector_transpose_16x16 // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> -// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> +// CHECK-62: vector.shuffle func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> { %0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32> return %0 : vector<16x16xf32> } -// ----- -// CHECK-LABEL: func.func @test_vector_store_load_4x4 -// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf32>) -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf32>, vector<4xf32> -func.func @test_vector_store_load_4x4(%buffer: memref<4x4xf32>) { - %c0 = arith.constant 0 : index - %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32> - vector.store %0, %buffer[%c0, %c0] : memref<4x4xf32>, vector<4x4xf32> - return -} - // ----- +// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16 +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf16>) +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[LOAD0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[LOAD2:.*]] = vector.load %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: %[[LOAD3:.*]] = vector.load %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[LOAD0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[LOAD1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[LOAD2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> +// CHECK: vector.store %[[LOAD3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) { %c0 = arith.constant 0 : index %0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> return } -// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16 -// CHECK-SAME: (%[[MEMREF:.*]]: memref<4x4xf16>) -// CHECK: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[V0:.*]] = vector.load %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: %[[V1:.*]] = vector.load %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: %[[V2:.*]] = vector.load %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: %[[V3:.*]] = vector.load %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: vector.store %[[V0]], %[[MEMREF]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: vector.store %[[V1]], %[[MEMREF]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: vector.store %[[V2]], %[[MEMREF]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> -// CHECK: vector.store %[[V3]], %[[MEMREF]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16> - // ----- -// CHECK-LABEL: @test_linearize_index -// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> { -// CHECK: %[[T0:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32> -// CHECK: %[[T1:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex> -// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[CST]] : vector<4xindex> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : vector<4xindex> to vector<4xi32> -// CHECK: %[[T4:.*]] = arith.muli %[[T3]], %[[T0]] : vector<4xi32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : vector<4xi32> to vector<4xindex> -// CHECK: %[[T6:.*]] = vector.shape_cast %[[T5]] : vector<4xindex> to vector<2x2xindex> -// CHECK: return %[[T6]] : vector<2x2xindex> +// CHECK-LABEL: func.func @test_linearize_index +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32> +// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST2]], %[[CST]] : vector<4xindex> +// CHECK: %[[INDEX_CAST1:.*]] = arith.index_cast %[[ADDI]] : vector<4xindex> to vector<4xi32> +// CHECK: %[[MULI:.*]] = arith.muli %[[INDEX_CAST1]], %[[CAST1]] : vector<4xi32> +// CHECK: %[[INDEX_CAST2:.*]] = arith.index_cast %[[MULI]] : vector<4xi32> to vector<4xindex> +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[INDEX_CAST2]] : vector<4xindex> to vector<2x2xindex> +// CHECK: return %[[RESULT]] : vector<2x2xindex> func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> { %0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex> -// Arith and math ops are handled in generic way, check some of them + // Arith and math ops are handled in generic way, check some of them %1 = arith.addi %arg0, %0 : vector<2x2xindex> %2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32> %3 = arith.muli %2, %arg1 : vector<2x2xi32> @@ -292,71 +82,204 @@ func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32> } // ----- -// CHECK-LABEL: @add_kernel_f32 -// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> -// CHECK: %[[CST1:.*]] = arith.constant dense<[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<16xindex> -// CHECK: %[[T0:.*]] = vector.splat %{{.*}} : vector<16xindex> -// CHECK: %[[T1:.*]] = arith.addi %[[T0]], %[[CST0]] : vector<16xindex> -// CHECK: %[[T2:.*]] = arith.addi %[[T0]], %[[CST1]] : vector<16xindex> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T1]] : vector<16xindex> to vector<16xi32> -// CHECK: %[[T4:.*]] = arith.index_cast %[[T2]] : vector<16xindex> to vector<16xi32> -// CHECK: %[[T5:.*]] = vector.splat %{{.*}} : vector<16xi32> -// CHECK: %[[T6:.*]] = arith.addi %[[T5]], %[[T3]] : vector<16xi32> -// CHECK: %[[T7:.*]] = arith.addi %[[T5]], %[[T4]] : vector<16xi32> -// CHECK: %[[T8:.*]] = arith.index_cast %[[T6]] : vector<16xi32> to vector<16xindex> -// CHECK: %[[T9:.*]] = arith.index_cast %[[T7]] : vector<16xi32> to vector<16xindex> -gpu.module @add_kernel_f32 attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { - gpu.func @add_kernel_f32(%arg0: memref<*xf32>, %arg1: memref<*xf32>, %arg2: memref<*xf32>) kernel attributes {VectorComputeFunctionINTEL, known_block_size = array, known_grid_size = array, spirv.entry_point_abi = #spirv.entry_point_abi<>} { - %cst = arith.constant dense : vector<16xi1> - %c32 = arith.constant 32 : index - %c1024_i32 = arith.constant 1024 : i32 - %cst_0 = arith.constant dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex> - %cst_1 = arith.constant dense<[[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]> : vector<1x16xindex> - %thread_id_x = gpu.thread_id x - %thread_id_y = gpu.thread_id y - %block_dim_y = gpu.block_dim y - %0 = arith.muli %thread_id_x, %block_dim_y : index - %1 = arith.addi %0, %thread_id_y : index - %cast = memref.cast %arg0 : memref<*xf32> to memref - %cast_2 = memref.cast %arg1 : memref<*xf32> to memref - %cast_3 = memref.cast %arg2 : memref<*xf32> to memref - %2 = arith.remsi %1, %c32 : index - %3 = arith.muli %2, %c32 : index - %4 = vector.splat %3 : vector<1x16xindex> - %5 = arith.addi %4, %cst_0 : vector<1x16xindex> - %6 = arith.addi %4, %cst_1 : vector<1x16xindex> - %7 = arith.index_cast %5 : vector<1x16xindex> to vector<1x16xi32> - %8 = arith.index_cast %6 : vector<1x16xindex> to vector<1x16xi32> - %block_id_x = gpu.block_id x - %9 = arith.index_cast %block_id_x : index to i32 - %10 = arith.muli %9, %c1024_i32 : i32 - %11 = vector.splat %10 : vector<1x16xi32> - %12 = arith.addi %11, %7 : vector<1x16xi32> - %13 = arith.addi %11, %8 : vector<1x16xi32> - %14 = arith.index_cast %12 : vector<1x16xi32> to vector<1x16xindex> - %15 = arith.index_cast %13 : vector<1x16xi32> to vector<1x16xindex> - %16 = vector.shape_cast %14 : vector<1x16xindex> to vector<16xindex> - %17 = xegpu.create_tdesc %cast, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %18 = vector.shape_cast %15 : vector<1x16xindex> to vector<16xindex> - %19 = xegpu.create_tdesc %cast, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %20 = xegpu.load %17, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> - %21 = vector.shape_cast %20 : vector<16xf32> to vector<1x16xf32> - %22 = xegpu.load %19, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> - %23 = vector.shape_cast %22 : vector<16xf32> to vector<1x16xf32> - %24 = xegpu.create_tdesc %cast_2, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %25 = xegpu.create_tdesc %cast_2, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %26 = xegpu.load %24, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> - %27 = vector.shape_cast %26 : vector<16xf32> to vector<1x16xf32> - %28 = xegpu.load %25, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16xf32> - %29 = vector.shape_cast %28 : vector<16xf32> to vector<1x16xf32> - %30 = arith.addf %21, %27 : vector<1x16xf32> - %31 = arith.addf %23, %29 : vector<1x16xf32> - %32 = xegpu.create_tdesc %cast_3, %16 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %33 = xegpu.create_tdesc %cast_3, %18 : memref, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr> - %34 = vector.shape_cast %30 : vector<1x16xf32> to vector<16xf32> - xegpu.store %34, %32, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> - %35 = vector.shape_cast %31 : vector<1x16xf32> to vector<16xf32> - xegpu.store %35, %33, %cst <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> +// CHECK-LABEL: func.func @broadcast_stretch_at_start +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32> +// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32> +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32> +// CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32> +// CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32> +// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32> +func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// ----- +// CHECK-LABEL: func.func @broadcast_stretch_at_end +// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1xf32>) -> vector<4x3xf32> +// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32> +// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<4x1xf32> +// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[EXTRACT1]] : f32 to vector<3xf32> +// CHECK: vector.shuffle +// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<4x1xf32> +// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[EXTRACT2]] : f32 to vector<3xf32> +// CHECK: vector.shuffle +// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG0]][2, 0] : f32 from vector<4x1xf32> +// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[EXTRACT3]] : f32 to vector<3xf32> +// CHECK: vector.shuffle +// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][3, 0] : f32 from vector<4x1xf32> +// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[EXTRACT4]] : f32 to vector<3xf32> +// CHECK: vector.shuffle +// CHECK: vector.shape_cast {{.*}} : vector<12xf32> to vector<4x3xf32> +func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} + +// ----- +// CHECK-LABEL: func.func @broadcast_stretch_in_middle +// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> +// CHECK: ub.poison : vector<6xf32> +// CHECK: ub.poison : vector<24xf32> +// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1x2xf32> to vector<8xf32> +// CHECK-COUNT-20: vector.shuffle +// CHECK: vector.shape_cast {{.*}} : vector<24xf32> to vector<4x3x2xf32> +// CHECK-NOT: vector.broadcast +func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func.func @gather_memref_2d +// CHECK-SAME: (%arg0: memref, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> { + +// CHECK: %0 = ub.poison : vector<6xf32> +// CHECK: %c1 = arith.constant 1 : index +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32> + +// First shuffle + if ladder for row 0 +// CHECK: %2 = vector.shuffle %1, %1 [0, 1, 2] +// CHECK: %3 = vector.extract %arg2[0, 0] +// CHECK: %4 = vector.extract %arg1[0, 0] +// CHECK: %5 = arith.addi %4, %c1 +// CHECK: %6 = scf.if %3 -> (vector<3xf32>) { +// CHECK: %{{.*}} = vector.load %arg0[%c0, %5] : memref, vector<1xf32> +// CHECK: %{{.*}} = vector.extract {{.*}}[0] : f32 +// CHECK: %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32> +// CHECK: scf.yield {{.*}} : vector<3xf32> +// CHECK: } else { +// CHECK: scf.yield %2 : vector<3xf32> +// CHECK: } + +// CHECK: %7 = vector.extract %arg2[0, 1] +// CHECK: %8 = vector.extract %arg1[0, 1] +// CHECK: %9 = arith.addi %8, %c1 +// CHECK: %10 = scf.if %7 -> (vector<3xf32>) + +// … (similar checks for the rest of row 0, then row 1) + +// CHECK: %15 = vector.shuffle %0, %{{.*}} [6, 7, 8, 3, 4, 5] +// CHECK: %16 = vector.shuffle %1, %1 [3, 4, 5] + +// Row 1 if ladder checks +// CHECK: %17 = vector.extract %arg2[1, 0] +// CHECK: %18 = vector.extract %arg1[1, 0] +// CHECK: %19 = arith.addi %18, %c1 +// CHECK: %20 = scf.if %17 -> (vector<3xf32>) + +// … (similar checks for remaining row 1 inserts) + +// Final reshuffle and cast +// CHECK: %29 = vector.shuffle %15, %{{.*}} [0, 1, 2, 6, 7, 8] +// CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32> +// CHECK: return %30 : vector<2x3xf32> +func.func @gather_memref_2d(%base: memref, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// ----- +// Check for vector linearization in XeGPU dialect. +// The vector<64xf16> loaded from memory is linearized into 4 vector<8xf16> using vector.shuffle ops. +// The pattern is similar to the one used in test_vector_transpose_16x16 above. +gpu.module @test_kernel { + // CHECK-LABEL: gpu.func @test_kernel + gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel { + %c24 = arith.constant 24 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<64xf16> + // CHECK: %[[V1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16> + %2 = vector.shape_cast %1 : vector<64xf16> to vector<2x32x1xf16> + %3 = vector.extract %2[0] : vector<32x1xf16> from vector<2x32x1xf16> + // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> + %4 = vector.extract_strided_slice %3 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK: %[[V3:.*]] = vector.shuffle %[[V1]], %[[V1]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + %5 = vector.extract_strided_slice %3 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK: %[[V4:.*]] = vector.shuffle %[[V1]], %[[V1]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> + %6 = vector.extract_strided_slice %3 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + %7 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> + %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<64xf16> + // CHECK: %[[V5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16> + %9 = vector.shape_cast %8 : vector<64xf16> to vector<2x32x1xf16> + %10 = vector.extract %9[0] : vector<32x1xf16> from vector<2x32x1xf16> + // CHECK: %[[V6:.*]] = vector.shuffle %[[V5]], %[[V5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + %11 = vector.extract_strided_slice %10 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> + %12 = vector.extract %9[1] : vector<32x1xf16> from vector<2x32x1xf16> + // CHECK: %[[V7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16> + // CHECK: %[[V8:.*]] = vector.shuffle %[[V7]], %[[V7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + %13 = vector.extract_strided_slice %12 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> + // CHECK: %[[V9:.*]] = vector.shuffle %[[V1]], %[[V1]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + %14 = vector.extract_strided_slice %3 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + %15 = vector.extract %2[1] : vector<32x1xf16> from vector<2x32x1xf16> + // CHECK: %[[V10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16> + // CHECK: %[[V11:.*]] = vector.shuffle %[[V10]], %[[V10]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> + %16 = vector.extract_strided_slice %15 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK: %[[V12:.*]] = vector.shuffle %[[V10]], %[[V10]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> + %17 = vector.extract_strided_slice %15 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK: %[[V13:.*]] = vector.shuffle %[[V10]], %[[V10]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> + %18 = vector.extract_strided_slice %15 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK: %[[V14:.*]] = vector.shuffle %[[V5]], %[[V5]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + %19 = vector.extract_strided_slice %10 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> + // CHECK: %[[V15:.*]] = vector.shuffle %[[V7]], %[[V7]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + %20 = vector.extract_strided_slice %12 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> + // CHECK: %[[V16:.*]] = vector.shuffle %[[V10]], %[[V10]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> + %21 = vector.extract_strided_slice %15 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> + // CHECK-NOT: vector.shape_cast + // CHECK-NOT: vector.extract + // CHECK-NOT: vector.extract_strided_slice + %22 = vector.shape_cast %4 : vector<8x1xf16> to vector<8xf16> + %23 = vector.shape_cast %11 : vector<16x1xf16> to vector<16xf16> + %24 = xegpu.dpas %22, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %25 = vector.shape_cast %13 : vector<16x1xf16> to vector<16xf16> + %26 = xegpu.dpas %22, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %27 = vector.shape_cast %5 : vector<8x1xf16> to vector<8xf16> + %28 = xegpu.dpas %27, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %29 = xegpu.dpas %27, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %30 = vector.shape_cast %6 : vector<8x1xf16> to vector<8xf16> + %31 = xegpu.dpas %30, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %32 = xegpu.dpas %30, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %33 = vector.shape_cast %14 : vector<8x1xf16> to vector<8xf16> + %34 = xegpu.dpas %33, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %35 = xegpu.dpas %33, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> + %36 = vector.shape_cast %16 : vector<8x1xf16> to vector<8xf16> + %37 = vector.shape_cast %19 : vector<16x1xf16> to vector<16xf16> + %38 = xegpu.dpas %36, %37, %24 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %39 = vector.shape_cast %20 : vector<16x1xf16> to vector<16xf16> + %40 = xegpu.dpas %36, %39, %26 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %41 = vector.shape_cast %17 : vector<8x1xf16> to vector<8xf16> + %42 = xegpu.dpas %41, %37, %28 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %43 = xegpu.dpas %41, %39, %29 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %44 = vector.shape_cast %18 : vector<8x1xf16> to vector<8xf16> + %45 = xegpu.dpas %44, %37, %31 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %46 = xegpu.dpas %44, %39, %32 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %47 = vector.shape_cast %21 : vector<8x1xf16> to vector<8xf16> + %48 = xegpu.dpas %47, %37, %34 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %49 = xegpu.dpas %47, %39, %35 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + %50 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %38, %50 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %51 = xegpu.create_nd_tdesc %arg2[%c0, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %40, %51 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %52 = xegpu.create_nd_tdesc %arg2[%c8, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %42, %52 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %53 = xegpu.create_nd_tdesc %arg2[%c8, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %43, %53 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %54 = xegpu.create_nd_tdesc %arg2[%c16, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %45, %54 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %55 = xegpu.create_nd_tdesc %arg2[%c16, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %46, %55 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %56 = xegpu.create_nd_tdesc %arg2[%c24, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %48, %56 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %57 = xegpu.create_nd_tdesc %arg2[%c24, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %49, %57 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } + + From 224d3beb6ff97bc7b289e6b54417958455df833b Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 18 Sep 2025 17:55:53 +0000 Subject: [PATCH 4/6] Address review comments. --- .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 141 ++++++------------ 1 file changed, 48 insertions(+), 93 deletions(-) diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir index 9985736e2cafb..ec98172c478ea 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -183,101 +183,56 @@ func.func @gather_memref_2d(%base: memref, %v: vector<2x3xindex>, %mask } // ----- -// Check for vector linearization in XeGPU dialect. -// The vector<64xf16> loaded from memory is linearized into 4 vector<8xf16> using vector.shuffle ops. -// The pattern is similar to the one used in test_vector_transpose_16x16 above. +// Check for vector linearization interoperability with XeGPU dialect ops. +// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops. + +// CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel { +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16> +// CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32> + +// CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] +// CHECK: %1 = xegpu.load_nd %0 +// CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16> +// CHECK: %3 = vector.shuffle %2, %cst {{.*}} : vector<128xf16>, vector<64xf16> +// CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16> + +// CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] +// CHECK: %6 = xegpu.load_nd %5 +// CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16> +// CHECK: %8 = vector.shuffle %7, %cst {{.*}} : vector<256xf16>, vector<64xf16> +// CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16> + +// CHECK: %10 = xegpu.dpas %4, %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32> +// CHECK: %12 = vector.shuffle %11, %11 {{.*}} : vector<128xf32>, vector<128xf32> +// CHECK: %13 = arith.addf %12, %cst_0 : vector<64xf32> +// CHECK: %14 = vector.shuffle %11, %13 {{.*}} : vector<128xf32>, vector<64xf32> +// CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32> + +// CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0, %c0] +// CHECK: xegpu.store_nd %15, %16 +// CHECK: gpu.return + gpu.module @test_kernel { - // CHECK-LABEL: gpu.func @test_kernel - gpu.func @test_kernel(%arg0: memref<32x32xf16>, %arg1: memref<32x32xf16>, %arg2: memref<32x32xf32>) kernel { - %c24 = arith.constant 24 : index - %c16 = arith.constant 16 : index - %c8 = arith.constant 8 : index + gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %C: memref<8x16xf32>) kernel { %c0 = arith.constant 0 : index - %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<64xf16> - // CHECK: %[[V1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16> - %2 = vector.shape_cast %1 : vector<64xf16> to vector<2x32x1xf16> - %3 = vector.extract %2[0] : vector<32x1xf16> from vector<2x32x1xf16> - // CHECK: %[[V2:.*]] = vector.shuffle %[[V1]], %[[V1]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> - %4 = vector.extract_strided_slice %3 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK: %[[V3:.*]] = vector.shuffle %[[V1]], %[[V1]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> - %5 = vector.extract_strided_slice %3 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK: %[[V4:.*]] = vector.shuffle %[[V1]], %[[V1]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> - %6 = vector.extract_strided_slice %3 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - %7 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x32xf16> -> !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> - %8 = xegpu.load_nd %7 <{packed}> : !xegpu.tensor_desc<32x16xf16, #xegpu.block_tdesc_attr> -> vector<64xf16> - // CHECK: %[[V5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<64xf16>, vector<64xf16> - %9 = vector.shape_cast %8 : vector<64xf16> to vector<2x32x1xf16> - %10 = vector.extract %9[0] : vector<32x1xf16> from vector<2x32x1xf16> - // CHECK: %[[V6:.*]] = vector.shuffle %[[V5]], %[[V5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> - %11 = vector.extract_strided_slice %10 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> - %12 = vector.extract %9[1] : vector<32x1xf16> from vector<2x32x1xf16> - // CHECK: %[[V7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16> - // CHECK: %[[V8:.*]] = vector.shuffle %[[V7]], %[[V7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> - %13 = vector.extract_strided_slice %12 {offsets = [0], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> - // CHECK: %[[V9:.*]] = vector.shuffle %[[V1]], %[[V1]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> - %14 = vector.extract_strided_slice %3 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - %15 = vector.extract %2[1] : vector<32x1xf16> from vector<2x32x1xf16> - // CHECK: %[[V10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf16>, vector<64xf16> - // CHECK: %[[V11:.*]] = vector.shuffle %[[V10]], %[[V10]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xf16>, vector<32xf16> - %16 = vector.extract_strided_slice %15 {offsets = [0], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK: %[[V12:.*]] = vector.shuffle %[[V10]], %[[V10]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xf16>, vector<32xf16> - %17 = vector.extract_strided_slice %15 {offsets = [8], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK: %[[V13:.*]] = vector.shuffle %[[V10]], %[[V10]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xf16>, vector<32xf16> - %18 = vector.extract_strided_slice %15 {offsets = [16], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK: %[[V14:.*]] = vector.shuffle %[[V5]], %[[V5]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> - %19 = vector.extract_strided_slice %10 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> - // CHECK: %[[V15:.*]] = vector.shuffle %[[V7]], %[[V7]] [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> - %20 = vector.extract_strided_slice %12 {offsets = [16], sizes = [16], strides = [1]} : vector<32x1xf16> to vector<16x1xf16> - // CHECK: %[[V16:.*]] = vector.shuffle %[[V10]], %[[V10]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf16>, vector<32xf16> - %21 = vector.extract_strided_slice %15 {offsets = [24], sizes = [8], strides = [1]} : vector<32x1xf16> to vector<8x1xf16> - // CHECK-NOT: vector.shape_cast - // CHECK-NOT: vector.extract - // CHECK-NOT: vector.extract_strided_slice - %22 = vector.shape_cast %4 : vector<8x1xf16> to vector<8xf16> - %23 = vector.shape_cast %11 : vector<16x1xf16> to vector<16xf16> - %24 = xegpu.dpas %22, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %25 = vector.shape_cast %13 : vector<16x1xf16> to vector<16xf16> - %26 = xegpu.dpas %22, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %27 = vector.shape_cast %5 : vector<8x1xf16> to vector<8xf16> - %28 = xegpu.dpas %27, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %29 = xegpu.dpas %27, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %30 = vector.shape_cast %6 : vector<8x1xf16> to vector<8xf16> - %31 = xegpu.dpas %30, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %32 = xegpu.dpas %30, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %33 = vector.shape_cast %14 : vector<8x1xf16> to vector<8xf16> - %34 = xegpu.dpas %33, %23 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %35 = xegpu.dpas %33, %25 : vector<8xf16>, vector<16xf16> -> vector<8xf32> - %36 = vector.shape_cast %16 : vector<8x1xf16> to vector<8xf16> - %37 = vector.shape_cast %19 : vector<16x1xf16> to vector<16xf16> - %38 = xegpu.dpas %36, %37, %24 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %39 = vector.shape_cast %20 : vector<16x1xf16> to vector<16xf16> - %40 = xegpu.dpas %36, %39, %26 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %41 = vector.shape_cast %17 : vector<8x1xf16> to vector<8xf16> - %42 = xegpu.dpas %41, %37, %28 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %43 = xegpu.dpas %41, %39, %29 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %44 = vector.shape_cast %18 : vector<8x1xf16> to vector<8xf16> - %45 = xegpu.dpas %44, %37, %31 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %46 = xegpu.dpas %44, %39, %32 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %47 = vector.shape_cast %21 : vector<8x1xf16> to vector<8xf16> - %48 = xegpu.dpas %47, %37, %34 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %49 = xegpu.dpas %47, %39, %35 : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> - %50 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %38, %50 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %51 = xegpu.create_nd_tdesc %arg2[%c0, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %40, %51 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %52 = xegpu.create_nd_tdesc %arg2[%c8, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %42, %52 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %53 = xegpu.create_nd_tdesc %arg2[%c8, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %43, %53 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %54 = xegpu.create_nd_tdesc %arg2[%c16, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %45, %54 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %55 = xegpu.create_nd_tdesc %arg2[%c16, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %46, %55 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %56 = xegpu.create_nd_tdesc %arg2[%c24, %c0] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %48, %56 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> - %57 = xegpu.create_nd_tdesc %arg2[%c24, %c16] : memref<32x32xf32> -> !xegpu.tensor_desc<8x16xf32> - xegpu.store_nd %49, %57 : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + %cst_vec_0 = arith.constant dense<0.000000e+00> : vector<8x8xf16> + %cst_vec_1 = arith.constant dense<0.000000e+00> : vector<8x8xf16> + %cst_vec_2 = arith.constant dense<5.000000e+00> : vector<8x8xf32> + %a_tdesc = xegpu.create_nd_tdesc %A[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> + %a_val = xegpu.load_nd %a_tdesc : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr> -> vector<8x16xf16> + %a_val_0 = vector.insert_strided_slice %cst_vec_0, %a_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<8x16xf16> + %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> + + %b_val = xegpu.load_nd %b_tdesc : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr> -> vector<16x16xf16> + %b_val_0 = vector.insert_strided_slice %cst_vec_1, %b_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<16x16xf16> + %c_val = xegpu.dpas %a_val_0, %b_val_0 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %c_val_0 = vector.extract_strided_slice %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32> + %c_addf = arith.addf %c_val_0, %cst_vec_2 : vector<8x8xf32> + %c_result = vector.insert_strided_slice %c_addf, %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x8xf32> into vector<8x16xf32> + %c_tdesc = xegpu.create_nd_tdesc %C[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + xegpu.store_nd %c_result, %c_tdesc : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> gpu.return } } From b4f8cbfb33041ac0231b452d516ce6525b9ff34b Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 18 Sep 2025 21:25:16 +0000 Subject: [PATCH 5/6] Address review comments. Add vector unroll support for n-D laod/store. --- .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 36 +++++++++++++++---- .../Dialect/XeGPU/xegpu-vector-linearize.mlir | 23 ++++++++++++ 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp index 2bb302f4287c4..24da724bf6d81 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -16,6 +16,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include @@ -48,7 +50,8 @@ struct XeGPUVectorLinearizePass final return signalPassFailure(); } - // Unroll load store from to M <1xN> load/stores and then linearize + // Unroll load/store from to (d1*d2*...*d(k-1)) slices of + // <1x1x...x1xdk>. { RewritePatternSet patterns(&getContext()); vector::UnrollVectorOptions vectorOptions; @@ -62,19 +65,36 @@ struct XeGPUVectorLinearizePass final return nullptr; }; - auto vecType = extractVectorType(op); + VectorType vecType = extractVectorType(op); if (!vecType) return std::nullopt; - auto shape = vecType.getShape(); - if (shape.size() != 2) + // Only handle rank >= 2 so we actually unroll something. + int64_t rank = vecType.getRank(); + if (rank < 2) return std::nullopt; - return SmallVector{1, shape[1]}; + ArrayRef shape = vecType.getShape(); + // Bail if any of the (rank-1) leading dims are dynamic (can't fully + // unroll). + for (int64_t i = 0; i < rank - 1; ++i) + if (shape[i] == ShapedType::kDynamic) { + LLVM_DEBUG(llvm::dbgs() + << "Dynamic leading dim " << i << " in " << vecType + << " prevents full unroll.\n"); + return std::nullopt; + } + + // Produce native shape: 1 x 1 x ... x (original last dim). + SmallVector native(rank, 1); + native.back() = shape.back(); + return native; }); vector::populateVectorUnrollPatterns(patterns, vectorOptions); - if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "Unroll failed.\n"); return signalPassFailure(); + } } // Use vector linearization patterns @@ -90,8 +110,10 @@ struct XeGPUVectorLinearizePass final scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + std::move(patterns)))) { + LLVM_DEBUG(llvm::dbgs() << "Linearization failed.\n"); return signalPassFailure(); + } } } }; diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir index ec98172c478ea..0bb7d7d3d8b1b 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -59,6 +59,29 @@ func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) { vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16> return } + +// ----- +// CHECK-LABEL: func.func @test_vector_store_load_4x4x4 +// CHECK-SAME: (%[[BUF:.*]]: memref<4x4x4xf32>) +// Constants (order not important) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// All 16 scalar-slice (row/col plane) loads of 1D vectors +// CHECK-COUNT-16: vector.load {{.*}} : memref<4x4x4xf32>, vector<4xf32> +// No remaining 3D vector load +// CHECK-NOT: vector.load {{.*}} : memref<4x4x4xf32>, vector<4x4x4xf32> +// All 16 stores of 1D vectors +// CHECK-COUNT-16: vector.store {{.*}} : memref<4x4x4xf32>, vector<4xf32> +// CHECK: return +func.func @test_vector_store_load_4x4x4(%buffer: memref<4x4x4xf32>) { + %c0 = arith.constant 0 : index + %0 = vector.load %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32> + vector.store %0, %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32> + return +} + // ----- // CHECK-LABEL: func.func @test_linearize_index // CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex> From 6b22d6d3936d5de3b640ae2f65d372c6cf2f3798 Mon Sep 17 00:00:00 2001 From: "Shahneous Bari, Md Abdullah" Date: Thu, 18 Sep 2025 22:05:25 +0000 Subject: [PATCH 6/6] Address review comments. --- .../mlir/Dialect/XeGPU/Transforms/Passes.td | 2 +- .../XeGPU/Transforms/XeGPUVectorLinearize.cpp | 15 +++------------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 77c57ccb0746f..83b128e2c7cbf 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -81,7 +81,7 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> { This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM. }]; let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect", - "scf::SCFDialect", "vector::VectorDialect"]; + "scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"]; } #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp index 24da724bf6d81..e31c37a2459ad 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include @@ -75,16 +76,6 @@ struct XeGPUVectorLinearizePass final return std::nullopt; ArrayRef shape = vecType.getShape(); - // Bail if any of the (rank-1) leading dims are dynamic (can't fully - // unroll). - for (int64_t i = 0; i < rank - 1; ++i) - if (shape[i] == ShapedType::kDynamic) { - LLVM_DEBUG(llvm::dbgs() - << "Dynamic leading dim " << i << " in " << vecType - << " prevents full unroll.\n"); - return std::nullopt; - } - // Produce native shape: 1 x 1 x ... x (original last dim). SmallVector native(rank, 1); native.back() = shape.back(); @@ -92,7 +83,7 @@ struct XeGPUVectorLinearizePass final }); vector::populateVectorUnrollPatterns(patterns, vectorOptions); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { - LLVM_DEBUG(llvm::dbgs() << "Unroll failed.\n"); + LDBG() << "Unroll failed."; return signalPassFailure(); } } @@ -111,7 +102,7 @@ struct XeGPUVectorLinearizePass final target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { - LLVM_DEBUG(llvm::dbgs() << "Linearization failed.\n"); + LDBG() << "Linearization failed."; return signalPassFailure(); } }