Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -597,6 +598,137 @@ struct SgToWiMultiDimReduction
}
};

/// Helper to compute distributed coordinates for matrix ops.
/// When not using subgroup_block_io, each workitem computes its own
/// coordinates based on the layout and lane ID.
static SmallVector<Value> computeDistributedCoordsForMatrixOp(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just a copy & paste of computeDistributedCoordinatesForMatrixOp(), please move it to XeGPULayoutImpl

Copy link
Copy Markdown
Contributor Author

@nbpatel nbpatel Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but the other one (in the old pass) will be removed eventually once we move to this pass

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jianhui-Li are you ok with this?

ConversionPatternRewriter &rewriter, Location loc,
xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
ValueRange origOffsets) {
Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
/*upperBound=*/mlir::IntegerAttr());
auto maybeCoords =
layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
if (failed(maybeCoords))
return {};
assert(maybeCoords.value().size() == 1 &&
"Expected one set of distributed offsets");
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
getAsOpFoldResult(origOffsets));
return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
}

/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto layout = op.getLayoutAttr();
// If no layout, nothing to do.
if (!layout)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could generalize it to a legality check for AnchorLayoutInterface ops and omit this check in patterns.

return failure();

VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
op, "the matrix op payload must be a vector type");

auto loc = op.getLoc();
auto offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "the load op must have offsets");

FailureOr<VectorType> distPayloadTyOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadTyOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute matrix op payload based on layout.");

SmallVector<Value> offsetsAsValues =
vector::getAsValues(rewriter, loc, offsets);

SmallVector<Value> newCoords = offsetsAsValues;
if (!op.getSubgroupBlockIoAttr()) {
newCoords = computeDistributedCoordsForMatrixOp(
rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
if (newCoords.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute distributed coordinates.");
}

SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);

auto newOp = xegpu::LoadMatrixOp::create(
rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
xegpu::DistributeLayoutAttr{});
rewriter.replaceOp(op, newOp.getResult());
return success();
}
};

/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto layout = op.getLayoutAttr();
// If no layout, nothing to do.
if (!layout)
return failure();

VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
op, "the matrix op payload must be a vector type");

auto loc = op.getLoc();
auto offsets = op.getMixedOffsets();
if (offsets.empty())
return rewriter.notifyMatchFailure(op, "the store op must have offsets");

FailureOr<VectorType> distPayloadTyOrFailure =
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadTyOrFailure))
return rewriter.notifyMatchFailure(
op, "Failed to distribute matrix op payload based on layout.");

SmallVector<Value> offsetsAsValues =
vector::getAsValues(rewriter, loc, offsets);

SmallVector<Value> newCoords = offsetsAsValues;
if (!op.getSubgroupBlockIoAttr()) {
newCoords = computeDistributedCoordsForMatrixOp(
rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
if (newCoords.empty())
return rewriter.notifyMatchFailure(
op, "Failed to compute distributed coordinates.");
}

SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);

xegpu::StoreMatrixOp::create(
rewriter, loc, TypeRange{},
castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
distPayloadTyOrFailure.value()),
adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
rewriter.eraseOp(op);
return success();
}
};

/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
/// workitem-level.
///
Expand Down Expand Up @@ -901,5 +1033,6 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
SgToWiMultiDimReduction>(typeConverter, patterns.getContext());
SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
typeConverter, patterns.getContext());
}
64 changes: 64 additions & 0 deletions mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,67 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
}

// -----
// load_matrix and store_matrix with coordinate computation (offsets [0,0])
gpu.module @xevm_module {
// CHECK-LABEL: gpu.func @load_store_matrix_1
// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[R2]], %{{.*}} : index
// CHECK-DAG: %[[COL:.*]] = arith.remui %[[R1]], %{{.*}} : index
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[COL]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
%1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.return
}
}

// -----
// load_matrix and store_matrix with non-zero offsets [0,1]
gpu.module @xevm_module {
// CHECK-LABEL: gpu.func @load_store_matrix_2
// CHECK-DAG: %[[LANE_ID1:.*]] = gpu.lane_id
// CHECK-DAG: %[[R1:.*]] = arith.remui %[[LANE_ID1]], %{{.*}} : index
// CHECK-DAG: %[[D1:.*]] = arith.divui %[[LANE_ID1]], %{{.*}} : index
// CHECK-DAG: %[[R2:.*]] = arith.remui %[[D1]], %{{.*}} : index
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[R2]], %{{.*}} : index
// CHECK-DAG: %[[ROW:.*]] = arith.remui %[[MUL]], %{{.*}} : index
// CHECK-DAG: %[[R3:.*]] = arith.remui %[[R1]], %{{.*}} : index
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[R3]], %{{.*}} : index
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[ROW]], %[[ADD]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
// CHECK: %[[LANE_ID2:.*]] = gpu.lane_id
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.return
}
}

// -----
// load_matrix and store_matrix with subgroup_block_io (no coordinate computation)
gpu.module @xevm_module {
// CHECK-LABEL: gpu.func @load_store_matrix_3
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.load_matrix %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
xegpu.store_matrix %1, %arg0[%c0, %c1] <{subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>}> :
vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
gpu.return
}
}