Skip to content

Commit 04c4e9d

Browse files
authored
[mlir][XeGPU][Transform] Add vectorlinearize transform pass. (#158084)
Use upstream patterns to create a vectorlinearize pass needed for lowering to XeVM. Linearizes n-D vectors to 1-D vectors. This is needed because, `vector-to-llvm` does not linearize all the vectors.
1 parent cdd7898 commit 04c4e9d

File tree

4 files changed

+384
-0
lines changed

4 files changed

+384
-0
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,13 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> {
7575
"index::IndexDialect"];
7676
}
7777

78+
def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
79+
let summary = "Linearize n-D vectors to 1-D vectors";
80+
let description = [{
81+
This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM.
82+
}];
83+
let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect",
84+
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
85+
}
86+
7887
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
55
XeGPUUnroll.cpp
66
XeGPUWgToSgDistribute.cpp
77
XeGPUPropagateLayout.cpp
8+
XeGPUVectorLinearize.cpp
89

910
ADDITIONAL_HEADER_DIRS
1011
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
//===-- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SCF/IR/SCF.h"
10+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
13+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
14+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
15+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "llvm/Support/Debug.h"
20+
#include "llvm/Support/DebugLog.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
23+
#include <optional>
24+
25+
namespace mlir {
26+
namespace xegpu {
27+
#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
28+
#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
29+
} // namespace xegpu
30+
} // namespace mlir
31+
32+
#define DEBUG_TYPE "xegpu-vector-linearize"
33+
34+
using namespace mlir;
35+
36+
namespace {
37+
struct XeGPUVectorLinearizePass final
38+
: public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
39+
void runOnOperation() override {
40+
// vector.broadcast and vector.gather requires progressive lowering
41+
{
42+
RewritePatternSet patterns(&getContext());
43+
vector::populateVectorBroadcastLoweringPatterns(patterns);
44+
vector::populateVectorGatherLoweringPatterns(patterns);
45+
vector::populateVectorGatherToConditionalLoadPatterns(patterns);
46+
// vector.transpose lowering
47+
// Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
48+
vector::populateVectorTransposeLoweringPatterns(
49+
patterns, vector::VectorTransposeLowering::Shuffle16x16);
50+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
51+
return signalPassFailure();
52+
}
53+
54+
// Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
55+
// <1x1x...x1xdk>.
56+
{
57+
RewritePatternSet patterns(&getContext());
58+
vector::UnrollVectorOptions vectorOptions;
59+
vectorOptions.setNativeShapeFn(
60+
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
61+
auto extractVectorType = [](Operation *op) -> VectorType {
62+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
63+
return loadOp.getVectorType();
64+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
65+
return storeOp.getVectorType();
66+
return nullptr;
67+
};
68+
69+
VectorType vecType = extractVectorType(op);
70+
if (!vecType)
71+
return std::nullopt;
72+
73+
// Only handle rank >= 2 so we actually unroll something.
74+
int64_t rank = vecType.getRank();
75+
if (rank < 2)
76+
return std::nullopt;
77+
78+
ArrayRef<int64_t> shape = vecType.getShape();
79+
// Produce native shape: 1 x 1 x ... x (original last dim).
80+
SmallVector<int64_t> native(rank, 1);
81+
native.back() = shape.back();
82+
return native;
83+
});
84+
vector::populateVectorUnrollPatterns(patterns, vectorOptions);
85+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
86+
LDBG() << "Unroll failed.";
87+
return signalPassFailure();
88+
}
89+
}
90+
91+
// Use vector linearization patterns
92+
{
93+
MLIRContext &context = getContext();
94+
TypeConverter converter;
95+
RewritePatternSet patterns(&context);
96+
ConversionTarget target(context);
97+
vector::populateForVectorLinearize(converter, target);
98+
vector::populateVectorLinearizeBasePatterns(converter, target, patterns);
99+
vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target,
100+
patterns);
101+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
102+
target);
103+
if (failed(applyPartialConversion(getOperation(), target,
104+
std::move(patterns)))) {
105+
LDBG() << "Linearization failed.";
106+
return signalPassFailure();
107+
}
108+
}
109+
}
110+
};
111+
} // namespace
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
// RUN: mlir-opt %s -split-input-file -xegpu-vector-linearize -canonicalize | FileCheck %s
2+
3+
// CHECK-LABEL: test_vector_insert_2d_idx
4+
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<4xf32>) -> vector<2x8x4xf32>
5+
// CHECK: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
6+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[SRC]]
7+
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 64, 65, 66, 67, 16, 17, 18, 19, 20, 21,
8+
// CHECK-SAME: 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
9+
// CHECK-SAME: 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<4xf32>
10+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
11+
// CHECK: return %[[RES]] : vector<2x8x4xf32>
12+
func.func @test_vector_insert_2d_idx(%arg0: vector<2x8x4xf32>, %arg1: vector<4xf32>) -> vector<2x8x4xf32> {
13+
%0 = vector.insert %arg1, %arg0[0, 3]: vector<4xf32> into vector<2x8x4xf32>
14+
return %0 : vector<2x8x4xf32>
15+
}
16+
17+
// -----
18+
// CHECK-LABEL: test_vector_transpose
19+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32>
20+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32>
21+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
22+
// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32>
23+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
24+
// CHECK: return %[[RES]] : vector<8x2xf32>
25+
func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> {
26+
%0 = vector.transpose %arg, [1, 0] : vector<2x8xf32> to vector<8x2xf32>
27+
return %0 : vector<8x2xf32>
28+
}
29+
30+
// -----
31+
// CHECK-LABEL: test_vector_transpose_16x16
32+
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
33+
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
34+
// CHECK-62: vector.shuffle
35+
func.func @test_vector_transpose_16x16(%arg: vector<16x16xf32>) -> vector<16x16xf32> {
36+
%0 = vector.transpose %arg, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
37+
return %0 : vector<16x16xf32>
38+
}
39+
40+
// -----
41+
42+
// CHECK-LABEL: func.func @test_vector_store_load_4x4_f16
43+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf16>)
44+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
45+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
46+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
47+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
48+
// CHECK: %[[LOAD0:.*]] = vector.load %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
49+
// CHECK: %[[LOAD1:.*]] = vector.load %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
50+
// CHECK: %[[LOAD2:.*]] = vector.load %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
51+
// CHECK: %[[LOAD3:.*]] = vector.load %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
52+
// CHECK: vector.store %[[LOAD0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
53+
// CHECK: vector.store %[[LOAD1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
54+
// CHECK: vector.store %[[LOAD2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
55+
// CHECK: vector.store %[[LOAD3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
56+
func.func @test_vector_store_load_4x4_f16(%buffer: memref<4x4xf16>) {
57+
%c0 = arith.constant 0 : index
58+
%0 = vector.load %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
59+
vector.store %0, %buffer[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
60+
return
61+
}
62+
63+
// -----
64+
// CHECK-LABEL: func.func @test_vector_store_load_4x4x4
65+
// CHECK-SAME: (%[[BUF:.*]]: memref<4x4x4xf32>)
66+
// Constants (order not important)
67+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
68+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
69+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
70+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
71+
// All 16 scalar-slice (row/col plane) loads of 1D vectors
72+
// CHECK-COUNT-16: vector.load {{.*}} : memref<4x4x4xf32>, vector<4xf32>
73+
// No remaining 3D vector load
74+
// CHECK-NOT: vector.load {{.*}} : memref<4x4x4xf32>, vector<4x4x4xf32>
75+
// All 16 stores of 1D vectors
76+
// CHECK-COUNT-16: vector.store {{.*}} : memref<4x4x4xf32>, vector<4xf32>
77+
// CHECK: return
78+
func.func @test_vector_store_load_4x4x4(%buffer: memref<4x4x4xf32>) {
79+
%c0 = arith.constant 0 : index
80+
%0 = vector.load %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32>
81+
vector.store %0, %buffer[%c0, %c0, %c0] : memref<4x4x4xf32>, vector<4x4x4xf32>
82+
return
83+
}
84+
85+
// -----
86+
// CHECK-LABEL: func.func @test_linearize_index
87+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2xindex>, %[[ARG1:.*]]: vector<2x2xi32>) -> vector<2x2xindex>
88+
// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
89+
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[ARG1]] : vector<2x2xi32> to vector<4xi32>
90+
// CHECK: %[[CAST2:.*]] = vector.shape_cast %[[ARG0]] : vector<2x2xindex> to vector<4xindex>
91+
// CHECK: %[[ADDI:.*]] = arith.addi %[[CAST2]], %[[CST]] : vector<4xindex>
92+
// CHECK: %[[INDEX_CAST1:.*]] = arith.index_cast %[[ADDI]] : vector<4xindex> to vector<4xi32>
93+
// CHECK: %[[MULI:.*]] = arith.muli %[[INDEX_CAST1]], %[[CAST1]] : vector<4xi32>
94+
// CHECK: %[[INDEX_CAST2:.*]] = arith.index_cast %[[MULI]] : vector<4xi32> to vector<4xindex>
95+
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[INDEX_CAST2]] : vector<4xindex> to vector<2x2xindex>
96+
// CHECK: return %[[RESULT]] : vector<2x2xindex>
97+
func.func @test_linearize_index(%arg0: vector<2x2xindex>, %arg1: vector<2x2xi32>) -> vector<2x2xindex> {
98+
%0 = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xindex>
99+
// Arith and math ops are handled in generic way, check some of them
100+
%1 = arith.addi %arg0, %0 : vector<2x2xindex>
101+
%2 = arith.index_cast %1 : vector<2x2xindex> to vector<2x2xi32>
102+
%3 = arith.muli %2, %arg1 : vector<2x2xi32>
103+
%4 = arith.index_cast %3 : vector<2x2xi32> to vector<2x2xindex>
104+
return %4 : vector<2x2xindex>
105+
}
106+
107+
// -----
108+
// CHECK-LABEL: func.func @broadcast_stretch_at_start
109+
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x4xf32>) -> vector<3x4xf32>
110+
// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
111+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<1x4xf32> to vector<4xf32>
112+
// CHECK: %[[SHUFFLE1:.*]] = vector.shuffle %[[POISON]], %[[CAST]] [12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
113+
// CHECK: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[CAST]] [0, 1, 2, 3, 12, 13, 14, 15, 8, 9, 10, 11] : vector<12xf32>, vector<4xf32>
114+
// CHECK: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[CAST]] [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] : vector<12xf32>, vector<4xf32>
115+
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<12xf32> to vector<3x4xf32>
116+
func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> {
117+
%0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32>
118+
return %0 : vector<3x4xf32>
119+
}
120+
121+
// -----
122+
// CHECK-LABEL: func.func @broadcast_stretch_at_end
123+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1xf32>) -> vector<4x3xf32>
124+
// CHECK: %[[POISON:.*]] = ub.poison : vector<12xf32>
125+
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<4x1xf32>
126+
// CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[EXTRACT1]] : f32 to vector<3xf32>
127+
// CHECK: vector.shuffle
128+
// CHECK: %[[EXTRACT2:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<4x1xf32>
129+
// CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[EXTRACT2]] : f32 to vector<3xf32>
130+
// CHECK: vector.shuffle
131+
// CHECK: %[[EXTRACT3:.*]] = vector.extract %[[ARG0]][2, 0] : f32 from vector<4x1xf32>
132+
// CHECK: %[[BROADCAST3:.*]] = vector.broadcast %[[EXTRACT3]] : f32 to vector<3xf32>
133+
// CHECK: vector.shuffle
134+
// CHECK: %[[EXTRACT4:.*]] = vector.extract %[[ARG0]][3, 0] : f32 from vector<4x1xf32>
135+
// CHECK: %[[BROADCAST4:.*]] = vector.broadcast %[[EXTRACT4]] : f32 to vector<3xf32>
136+
// CHECK: vector.shuffle
137+
// CHECK: vector.shape_cast {{.*}} : vector<12xf32> to vector<4x3xf32>
138+
func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
139+
%0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32>
140+
return %0 : vector<4x3xf32>
141+
}
142+
143+
// -----
144+
// CHECK-LABEL: func.func @broadcast_stretch_in_middle
145+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32>
146+
// CHECK: ub.poison : vector<6xf32>
147+
// CHECK: ub.poison : vector<24xf32>
148+
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1x2xf32> to vector<8xf32>
149+
// CHECK-COUNT-20: vector.shuffle
150+
// CHECK: vector.shape_cast {{.*}} : vector<24xf32> to vector<4x3x2xf32>
151+
// CHECK-NOT: vector.broadcast
152+
func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
153+
%0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32>
154+
return %0 : vector<4x3x2xf32>
155+
}
156+
157+
// CHECK-LABEL: func.func @gather_memref_2d
158+
// CHECK-SAME: (%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
159+
160+
// CHECK: %0 = ub.poison : vector<6xf32>
161+
// CHECK: %c1 = arith.constant 1 : index
162+
// CHECK: %c0 = arith.constant 0 : index
163+
// CHECK: %1 = vector.shape_cast %arg3 : vector<2x3xf32> to vector<6xf32>
164+
165+
// First shuffle + if ladder for row 0
166+
// CHECK: %2 = vector.shuffle %1, %1 [0, 1, 2]
167+
// CHECK: %3 = vector.extract %arg2[0, 0]
168+
// CHECK: %4 = vector.extract %arg1[0, 0]
169+
// CHECK: %5 = arith.addi %4, %c1
170+
// CHECK: %6 = scf.if %3 -> (vector<3xf32>) {
171+
// CHECK: %{{.*}} = vector.load %arg0[%c0, %5] : memref<?x?xf32>, vector<1xf32>
172+
// CHECK: %{{.*}} = vector.extract {{.*}}[0] : f32
173+
// CHECK: %{{.*}} = vector.insert {{.*}}, %2 [0] : f32 into vector<3xf32>
174+
// CHECK: scf.yield {{.*}} : vector<3xf32>
175+
// CHECK: } else {
176+
// CHECK: scf.yield %2 : vector<3xf32>
177+
// CHECK: }
178+
179+
// CHECK: %7 = vector.extract %arg2[0, 1]
180+
// CHECK: %8 = vector.extract %arg1[0, 1]
181+
// CHECK: %9 = arith.addi %8, %c1
182+
// CHECK: %10 = scf.if %7 -> (vector<3xf32>)
183+
184+
// … (similar checks for the rest of row 0, then row 1)
185+
186+
// CHECK: %15 = vector.shuffle %0, %{{.*}} [6, 7, 8, 3, 4, 5]
187+
// CHECK: %16 = vector.shuffle %1, %1 [3, 4, 5]
188+
189+
// Row 1 if ladder checks
190+
// CHECK: %17 = vector.extract %arg2[1, 0]
191+
// CHECK: %18 = vector.extract %arg1[1, 0]
192+
// CHECK: %19 = arith.addi %18, %c1
193+
// CHECK: %20 = scf.if %17 -> (vector<3xf32>)
194+
195+
// … (similar checks for remaining row 1 inserts)
196+
197+
// Final reshuffle and cast
198+
// CHECK: %29 = vector.shuffle %15, %{{.*}} [0, 1, 2, 6, 7, 8]
199+
// CHECK: %30 = vector.shape_cast %29 : vector<6xf32> to vector<2x3xf32>
200+
// CHECK: return %30 : vector<2x3xf32>
201+
func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
202+
%c0 = arith.constant 0 : index
203+
%c1 = arith.constant 1 : index
204+
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
205+
return %0 : vector<2x3xf32>
206+
}
207+
208+
// -----
209+
// Check for vector linearization interoperability with XeGPU dialect ops.
210+
// The `xegpu-vector-linearize` pass does not itself affect the XeGPU ops.
211+
212+
// CHECK: gpu.func @test_kernel(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) kernel {
213+
// CHECK: %c0 = arith.constant 0 : index
214+
// CHECK: %cst = arith.constant dense<0.000000e+00> : vector<64xf16>
215+
// CHECK: %cst_0 = arith.constant dense<5.000000e+00> : vector<64xf32>
216+
217+
// CHECK: %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0]
218+
// CHECK: %1 = xegpu.load_nd %0
219+
// CHECK: %2 = vector.shape_cast %1 : vector<8x16xf16> to vector<128xf16>
220+
// CHECK: %3 = vector.shuffle %2, %cst {{.*}} : vector<128xf16>, vector<64xf16>
221+
// CHECK: %4 = vector.shape_cast %3 : vector<128xf16> to vector<8x16xf16>
222+
223+
// CHECK: %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0]
224+
// CHECK: %6 = xegpu.load_nd %5
225+
// CHECK: %7 = vector.shape_cast %6 : vector<16x16xf16> to vector<256xf16>
226+
// CHECK: %8 = vector.shuffle %7, %cst {{.*}} : vector<256xf16>, vector<64xf16>
227+
// CHECK: %9 = vector.shape_cast %8 : vector<256xf16> to vector<16x16xf16>
228+
229+
// CHECK: %10 = xegpu.dpas %4, %9 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
230+
// CHECK: %11 = vector.shape_cast %10 : vector<8x16xf32> to vector<128xf32>
231+
// CHECK: %12 = vector.shuffle %11, %11 {{.*}} : vector<128xf32>, vector<128xf32>
232+
// CHECK: %13 = arith.addf %12, %cst_0 : vector<64xf32>
233+
// CHECK: %14 = vector.shuffle %11, %13 {{.*}} : vector<128xf32>, vector<64xf32>
234+
// CHECK: %15 = vector.shape_cast %14 : vector<128xf32> to vector<8x16xf32>
235+
236+
// CHECK: %16 = xegpu.create_nd_tdesc %arg2[%c0, %c0]
237+
// CHECK: xegpu.store_nd %15, %16
238+
// CHECK: gpu.return
239+
240+
gpu.module @test_kernel {
241+
gpu.func @test_kernel(%A: memref<8x16xf16>, %B: memref<16x16xf16>, %C: memref<8x16xf32>) kernel {
242+
%c0 = arith.constant 0 : index
243+
%cst_vec_0 = arith.constant dense<0.000000e+00> : vector<8x8xf16>
244+
%cst_vec_1 = arith.constant dense<0.000000e+00> : vector<8x8xf16>
245+
%cst_vec_2 = arith.constant dense<5.000000e+00> : vector<8x8xf32>
246+
%a_tdesc = xegpu.create_nd_tdesc %A[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>>
247+
%a_val = xegpu.load_nd %a_tdesc : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<8x16xf16>
248+
%a_val_0 = vector.insert_strided_slice %cst_vec_0, %a_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<8x16xf16>
249+
%b_tdesc = xegpu.create_nd_tdesc %B[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>>
250+
251+
%b_val = xegpu.load_nd %b_tdesc : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 1>> -> vector<16x16xf16>
252+
%b_val_0 = vector.insert_strided_slice %cst_vec_1, %b_val{offsets = [0, 0], sizes = [8, 8], strides = [1, 1]}: vector<8x8xf16> into vector<16x16xf16>
253+
%c_val = xegpu.dpas %a_val_0, %b_val_0 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
254+
%c_val_0 = vector.extract_strided_slice %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x16xf32> to vector<8x8xf32>
255+
%c_addf = arith.addf %c_val_0, %cst_vec_2 : vector<8x8xf32>
256+
%c_result = vector.insert_strided_slice %c_addf, %c_val {offsets = [0, 0], sizes = [8, 8], strides = [1, 1]} : vector<8x8xf32> into vector<8x16xf32>
257+
%c_tdesc = xegpu.create_nd_tdesc %C[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<array_length = 1>>
258+
xegpu.store_nd %c_result, %c_tdesc : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
259+
gpu.return
260+
}
261+
}
262+
263+

0 commit comments

Comments
 (0)