Skip to content

Commit ea00593

Browse files
authored
[MLIR][XeGPU][Quickfix] Disable block count in propagation (#170304)
One of the previous PRs #169267 has reintroduced block count to layout propagation that was removed in #168504. This PR patches the issue.
1 parent a8ef3c8 commit ea00593

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
517517
auto [bWidth, bHeight, bCount] = blockWHC.value();
518518
SmallVector<int> instData;
519519
int instWidth = xegpu::getLargestDivisor(
520-
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
521-
bCount);
520+
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth);
522521
if (instWidth == -1)
523522
prefetch.emitWarning(
524523
"No suitable instruction multiple found for the given shape.");
@@ -759,8 +758,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
759758
auto [bWidth, bHeight, bCount] = blockWHC.value();
760759
SmallVector<int> instData;
761760
int instWidth = xegpu::getLargestDivisor(
762-
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
763-
bCount);
761+
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth);
764762
if (instWidth == -1)
765763
store.emitWarning(
766764
"No suitable instruction multiple found for the given shape.");

mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
77
// CHECK: %[[TDESC_SRC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
88
// CHECK: %[[TDESC_DST:.*]] = xegpu.create_nd_tdesc %[[ARG1]] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
9+
// CHECK: xegpu.prefetch_nd %[[TDESC_SRC]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, layout = #xegpu.layout<inst_data = [8, 16]>}> :
10+
// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
911
// CHECK: %[[LOADED:.*]] = xegpu.load_nd %0 <{layout = #xegpu.layout<inst_data = [8, 16]>}> {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
1012
// CHECK-SAME: !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x32xf32>
1113
// CHECK: xegpu.store_nd %[[LOADED]], %[[TDESC_DST]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -16,6 +18,7 @@ func.func @load_store_no_array_len(%arg0: memref<8x32xf32>, %arg1: memref<8x32xf
1618
%cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
1719
%0 = xegpu.create_nd_tdesc %arg0 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
1820
%1 = xegpu.create_nd_tdesc %arg1 : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
21+
xegpu.prefetch_nd %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<8x32xf32>
1922
%2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xf32> -> vector<8x32xf32>
2023
xegpu.store_nd %2, %1 : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32>
2124
return

0 commit comments

Comments
 (0)