[MLIR][XeGPU] Add support for Convert Layout from Wg to Sg#178922
Conversation
|
@llvm/pr-subscribers-mlir-gpu Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/178922.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}];
let description = [{
- This operation loads a 2D block of data from shared local memory (SLM) as specified
- by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation loads an nD block of data from shared local memory (SLM) as specified
+ by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)}];
let description = [{
- This operation stores a 2D `data` fragment into the shared local memory region
- specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation stores an nD `data` fragment into the shared local memory region
+ specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
return success();
}
- if (mdescTy.getRank() != 2)
- return emitError() << "mem_desc must be 2D.";
+ if (mdescTy.getRank() < 2)
+ return emitError() << "mem_desc must be 2D or greater.";
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..8dbca952cc8c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
struct WgToSgConvertLayoutOp
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
- auto input = op.getInputLayout();
- auto target = op.getTargetLayout();
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ auto inputLayout = op.getInputLayout();
+ auto targetLayout = op.getTargetLayout();
- if (!input || !target || !input.isForWorkgroup() ||
- !target.isForWorkgroup())
+ if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+ !targetLayout.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
- SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr inputOrder = input.getOrder();
- SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr targetOrder = target.getOrder();
-
- // TODO: currently we only support for optimal case, where input and
- // output has the same sg_layout and sg_data, so SLM is not involved.
- if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
- inputOrder != targetOrder)
+ SmallVector<int64_t> inputSgLayout =
+ inputLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> targetSgLayout =
+ targetLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+ auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+ if (shape.size() <= 2)
+ return true;
+ for (size_t i = 0; i + 2 < shape.size(); ++i)
+ if (shape[i] != 1)
+ return false;
+ return true;
+ };
+
+ if (wgShape.size() > 2) {
+ if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+ return rewriter.notifyMatchFailure(
+ op, "rank > 2 requires unit leading dims for sg_data");
+ }
+
+ // Fast path: if sg_layout and sg_data are identical, no SLM needed
+ if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+ inputLayout = inputLayout.dropSgLayoutAndData();
+ targetLayout = targetLayout.dropSgLayoutAndData();
+
+ SmallVector<Value> newOps(adaptor.getSource());
+ if (inputLayout && targetLayout) {
+ for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+ auto newOp = xegpu::ConvertLayoutOp::create(
+ rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+ newOps[i] = newOp;
+ }
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+
+ // SLM path: layouts differ, need cross-subgroup data redistribution
+ Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+ SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+ // Calculate SLM size requirements
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+ // Allocate SLM
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+ elemTy, nullptr);
+ auto memDesc =
+ xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+ auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+ rewriter.getIndexType(), nullptr);
+
+ // STORE PHASE: Each subgroup stores in SLM using input layout
+ auto storeCoords = inputLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(storeCoords))
return failure();
- input = input.dropSgLayoutAndData();
- target = target.dropSgLayoutAndData();
+ // Store to SLM
+ for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+ SmallVector<OpFoldResult> storeMatrixOffsets;
+ for (Value coord : coords) {
+ storeMatrixOffsets.push_back(coord);
+ }
+ xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+ storeMatrixOffsets, nullptr /*layout*/);
+ }
+
+ gpu::BarrierOp::create(rewriter, loc);
+
+ // LOAD PHASE: Each target subgroup loads from SLM using target layout
+ auto loadCoords = targetLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(loadCoords))
+ return failure();
+
+ VectorType loadType = VectorType::get(targetSgData, elemTy);
- SmallVector<Value> newOps(adaptor.getSource());
- if (input && target) {
- // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
- for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
- auto newOp = xegpu::ConvertLayoutOp::create(
- rewriter, op.getLoc(), src.getType(), src, input, target);
- newOps[i] = newOp;
+ // Load vectors from SLM
+ SmallVector<Value> finalResults;
+ for (auto coords : *loadCoords) {
+ SmallVector<OpFoldResult> loadMatrixOffsets;
+ for (Value coord : coords) {
+ loadMatrixOffsets.push_back(coord);
}
+ auto loadOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+ targetLayout.dropSgLayoutAndData());
+
+ finalResults.push_back(loadOp.getResult());
}
- rewriter.replaceOpWithMultiple(op, {newOps});
+
+ rewriter.replaceOpWithMultiple(op, {finalResults});
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
// -----
func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
- // expected-error@+1 {{mem_desc must be 2D}}
+ // expected-error@+1 {{mem_desc must be 2D or greater}}
%data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
return
}
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
// -----
func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
- // expected-error@+1 {{mem_desc must be 2D.}}
+ // expected-error@+1 {{mem_desc must be 2D or greater}}
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..d4b611c713674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -839,6 +839,84 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: convert_layout_slm
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+ gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+ // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+ %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+ gpu.return
+ }
+
+ gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+ %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+ gpu.return
+ }
+
// CHECK-LABEL: distribute_nested_slice
// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
// CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>
|
|
@llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/178922.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 2cbec50772b98..fefc8c9903497 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1593,9 +1593,8 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}];
let description = [{
- This operation loads a 2D block of data from shared local memory (SLM) as specified
- by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation loads an nD block of data from shared local memory (SLM) as specified
+ by the provided nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
@@ -1665,9 +1664,8 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)}];
let description = [{
- This operation stores a 2D `data` fragment into the shared local memory region
- specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
- subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+ This operation stores an nD `data` fragment into the shared local memory region
+ specified by an nD `mem_desc`. Memory descriptors of any rank are supported.
This operation serves as an anchor through which users assign a layout attribute
to govern computation distribution.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 91ba07a8e0256..c7226c7ebbd5d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -186,8 +186,8 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
return success();
}
- if (mdescTy.getRank() != 2)
- return emitError() << "mem_desc must be 2D.";
+ if (mdescTy.getRank() < 2)
+ return emitError() << "mem_desc must be 2D or greater.";
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..8dbca952cc8c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -604,44 +604,124 @@ struct WgToSgElementwiseOp : public ConversionPattern {
struct WgToSgConvertLayoutOp
: public OpConversionPattern<xegpu::ConvertLayoutOp> {
using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
LogicalResult
matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
- auto input = op.getInputLayout();
- auto target = op.getTargetLayout();
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ auto inputLayout = op.getInputLayout();
+ auto targetLayout = op.getTargetLayout();
- if (!input || !target || !input.isForWorkgroup() ||
- !target.isForWorkgroup())
+ if (!inputLayout || !targetLayout || !inputLayout.isForWorkgroup() ||
+ !targetLayout.isForWorkgroup())
return rewriter.notifyMatchFailure(
op, "Input and target layouts must have subgroup layout");
- SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr inputOrder = input.getOrder();
- SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
- DenseI32ArrayAttr targetOrder = target.getOrder();
-
- // TODO: currently we only support for optimal case, where input and
- // output has the same sg_layout and sg_data, so SLM is not involved.
- if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
- inputOrder != targetOrder)
+ SmallVector<int64_t> inputSgLayout =
+ inputLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> inputSgData = inputLayout.getEffectiveSgDataAsInt();
+ SmallVector<int64_t> targetSgLayout =
+ targetLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
+
+ auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) {
+ if (shape.size() <= 2)
+ return true;
+ for (size_t i = 0; i + 2 < shape.size(); ++i)
+ if (shape[i] != 1)
+ return false;
+ return true;
+ };
+
+ if (wgShape.size() > 2) {
+ if (!hasUnitLeadingDims(inputSgData) || !hasUnitLeadingDims(targetSgData))
+ return rewriter.notifyMatchFailure(
+ op, "rank > 2 requires unit leading dims for sg_data");
+ }
+
+ // Fast path: if sg_layout and sg_data are identical, no SLM needed
+ if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) {
+ inputLayout = inputLayout.dropSgLayoutAndData();
+ targetLayout = targetLayout.dropSgLayoutAndData();
+
+ SmallVector<Value> newOps(adaptor.getSource());
+ if (inputLayout && targetLayout) {
+ for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
+ auto newOp = xegpu::ConvertLayoutOp::create(
+ rewriter, loc, src.getType(), src, inputLayout, targetLayout);
+ newOps[i] = newOp;
+ }
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+
+ // SLM path: layouts differ, need cross-subgroup data redistribution
+ Type elemTy = cast<VectorType>(op.getSource().getType()).getElementType();
+
+ SmallVector<int64_t> slmShape = llvm::to_vector(wgShape);
+
+ // Calculate SLM size requirements
+ auto bitWidth = elemTy.getIntOrFloatBitWidth();
+ auto bytesPerElement = bitWidth / 8;
+ auto slmSize = computeProduct(slmShape) * bytesPerElement;
+
+ // Allocate SLM
+ auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
+ auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
+
+ auto memDescType = xegpu::MemDescType::get(rewriter.getContext(), slmShape,
+ elemTy, nullptr);
+ auto memDesc =
+ xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
+
+ auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
+ rewriter.getIndexType(), nullptr);
+
+ // STORE PHASE: Each subgroup stores in SLM using input layout
+ auto storeCoords = inputLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(storeCoords))
return failure();
- input = input.dropSgLayoutAndData();
- target = target.dropSgLayoutAndData();
+ // Store to SLM
+ for (auto [src, coords] : llvm::zip(adaptor.getSource(), *storeCoords)) {
+ SmallVector<OpFoldResult> storeMatrixOffsets;
+ for (Value coord : coords) {
+ storeMatrixOffsets.push_back(coord);
+ }
+ xegpu::StoreMatrixOp::create(rewriter, loc, src, memDesc.getResult(),
+ storeMatrixOffsets, nullptr /*layout*/);
+ }
+
+ gpu::BarrierOp::create(rewriter, loc);
+
+ // LOAD PHASE: Each target subgroup loads from SLM using target layout
+ auto loadCoords = targetLayout.computeDistributedCoords(
+ rewriter, loc, sgId.getResult(), wgShape);
+ if (failed(loadCoords))
+ return failure();
+
+ VectorType loadType = VectorType::get(targetSgData, elemTy);
- SmallVector<Value> newOps(adaptor.getSource());
- if (input && target) {
- // keep the ConvertLayoutOp for rest fields, e.g., inst_data.
- for (auto [i, src] : llvm::enumerate(adaptor.getSource())) {
- auto newOp = xegpu::ConvertLayoutOp::create(
- rewriter, op.getLoc(), src.getType(), src, input, target);
- newOps[i] = newOp;
+ // Load vectors from SLM
+ SmallVector<Value> finalResults;
+ for (auto coords : *loadCoords) {
+ SmallVector<OpFoldResult> loadMatrixOffsets;
+ for (Value coord : coords) {
+ loadMatrixOffsets.push_back(coord);
}
+ auto loadOp = xegpu::LoadMatrixOp::create(
+ rewriter, loc, loadType, memDesc.getResult(), loadMatrixOffsets,
+ targetLayout.dropSgLayoutAndData());
+
+ finalResults.push_back(loadOp.getResult());
}
- rewriter.replaceOpWithMultiple(op, {newOps});
+
+ rewriter.replaceOpWithMultiple(op, {finalResults});
return success();
}
};
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index f2011ab86e9e9..e6376e3ecb4cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -852,7 +852,7 @@ func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>)
// -----
func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
- // expected-error@+1 {{mem_desc must be 2D}}
+ // expected-error@+1 {{mem_desc must be 2D or greater}}
%data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
return
}
@@ -873,7 +873,7 @@ func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %
// -----
func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
- // expected-error@+1 {{mem_desc must be 2D.}}
+ // expected-error@+1 {{mem_desc must be 2D or greater}}
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..d4b611c713674 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -839,6 +839,84 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: convert_layout_slm
+ // CHECK-SAME: %[[ARG0:.*]]: memref<128x256xf32>
+ gpu.func @convert_layout_slm(%arg0: memref<128x256xf32>) {
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[SGIDX:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[SGIDY:.*]] = arith.remui %[[SGIDY_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[MUL_Y:.*]] = arith.muli %[[SGIDY]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL_X:.*]] = arith.muli %[[SGIDX]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[OFF_Y:.*]] = arith.remui %[[MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[OFF_X:.*]] = arith.remui %[[MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[OFF_Y]], %[[OFF_X]]] : memref<128x256xf32> -> !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> : !xegpu.tensor_desc<32x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<131072xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<131072xi8, 3> -> !xegpu.mem_desc<128x256xf32>
+ // CHECK-DAG: %[[SGID_STORE:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_Y_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<32x16xf32>, !xegpu.mem_desc<128x256xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y_TMP:.*]] = arith.divui %[[SGID_STORE]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_Y_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}>: !xegpu.mem_desc<128x256xf32>, index, index -> vector<16x32xf32>
+ %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>>
+ %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>} : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>> -> vector<128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [32, 16], inst_data = [16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32], inst_data = [16, 16]>}> : vector<128x256xf32>
+ gpu.return
+ }
+
+ gpu.func @convert_layout_3D(%arg0: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<0> : vector<1x32x16xindex>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} dense<true> : vector<1x32x16xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST]]], %[[CST_0]] <{chunk_size = 1 : i64, layout = #xegpu.layout<inst_data = [1, 16, 16]>}> : memref<?xf32>, vector<1x32x16xindex>, vector<1x32x16xi1> -> vector<1x32x16xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<1048576xi8, 3>
+ // CHECK-DAG: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<1048576xi8, 3> -> !xegpu.mem_desc<8x128x256xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[STORE_X:.*]] = arith.remui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_Y:.*]] = arith.remui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z_TMP:.*]] = arith.divui %[[STORE_YZ_TMP]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[STORE_Z:.*]] = arith.remui %[[STORE_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_Y:.*]] = arith.muli %[[STORE_Y]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[STORE_MUL_X:.*]] = arith.muli %[[STORE_X]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Z:.*]] = arith.remui %[[STORE_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_Y:.*]] = arith.remui %[[STORE_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[STORE_OFF_X:.*]] = arith.remui %[[STORE_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[LOAD]], %[[MDESC]][%[[STORE_OFF_Z]], %[[STORE_OFF_Y]], %[[STORE_OFF_X]]] : vector<1x32x16xf32>, !xegpu.mem_desc<8x128x256xf32>, index, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_X:.*]] = arith.remui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_YZ_TMP:.*]] = arith.divui %[[SGID]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Y:.*]] = arith.remui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z_TMP:.*]] = arith.divui %[[LOAD_YZ_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_Z:.*]] = arith.remui %[[LOAD_Z_TMP]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_Y:.*]] = arith.muli %[[LOAD_Y]], %[[C16:.*]] : index
+ // CHECK-DAG: %[[LOAD_MUL_X:.*]] = arith.muli %[[LOAD_X]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Z:.*]] = arith.remui %[[LOAD_Z]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_Y:.*]] = arith.remui %[[LOAD_MUL_Y]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[LOAD_OFF_X:.*]] = arith.remui %[[LOAD_MUL_X]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MDESC]][%[[LOAD_OFF_Z]], %[[LOAD_OFF_Y]], %[[LOAD_OFF_X]]] <{layout = #xegpu.layout<inst_data = [1, 16, 16]>}>: !xegpu.mem_desc<8x128x256xf32>, index, index, index -> vector<1x16x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<0> : vector<8x128x256xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} dense<true> : vector<8x128x256xi1>
+ %1 = xegpu.load %arg0[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>} : memref<?xf32>, vector<8x128x256xindex>, vector<8x128x256xi1> -> vector<8x128x256xf32>
+ %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [8, 4, 16], sg_data = [1, 32, 16], inst_data = [1, 16, 16]>,
+ target_layout = #xegpu.layout<sg_layout = [8, 8, 8], sg_data = [1, 16, 32], inst_data = [1, 16, 16]>}> : vector<8x128x256xf32>
+ gpu.return
+ }
+
// CHECK-LABEL: distribute_nested_slice
// CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
// CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>
|
| } | ||
|
|
||
| // Fast path: if sg_layout and sg_data are identical, no SLM needed | ||
| if (inputSgLayout == targetSgLayout && inputSgData == targetSgData) { |
There was a problem hiding this comment.
use isEqualTo()
or consider adding isCompatible() to xegpuattr.td and use it here.
| targetLayout.getEffectiveSgLayoutAsInt(); | ||
| SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt(); | ||
|
|
||
| auto hasUnitLeadingDims = [](ArrayRef<int64_t> shape) { |
There was a problem hiding this comment.
I don't think this check is necessary. We should enhance the inst_data layout and subgroup distribution to get it supported.
|
|
||
| // Allocate SLM | ||
| auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3); | ||
| auto slm = memref::AllocaOp::create(rewriter, loc, slmTy); |
There was a problem hiding this comment.
should we consider moving this (alloc) to begin of the function?
alloc'ing inside control-flow structures is less ideal.
There was a problem hiding this comment.
not sure I understand the question..where is the cf here?
There was a problem hiding this comment.
I mean allocating memory inside loops would be not great.
| } | ||
| rewriter.replaceOpWithMultiple(op, {newOps}); | ||
|
|
There was a problem hiding this comment.
shouldn't we insert another barrier here?
if not a thread may store to the same slm before consumer is loading it back.
|
|
||
| // Fast path: if sg_layout and sg_data are identical, no SLM needed | ||
| if (llvm::equal(inputSgLayout, targetSgLayout) && | ||
| llvm::equal(inputSgData, targetSgData)) { |
There was a problem hiding this comment.
I meant isEqualTo() in XeGPUAttrs.td
There was a problem hiding this comment.
but won't that check for inst_data and other fields as well?
No description provided.