Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 153 additions & 65 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,64 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}

static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
SmallVector<OpFoldResult> &mixedShapes,
SmallVector<OpFoldResult> &mixedStrides,
SmallVector<int64_t> &strides,
TypedValue<MemRefType> src) {
auto srcTy = src.getType();
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
unsigned srcRank = srcTy.getRank();
for (unsigned i = 0; i < srcRank; ++i)
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));

for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
if (shape == ShapedType::kDynamic)
mixedShapes.push_back(sourceDims[idx]);
else
mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
}

// Compute strides in reverse order.
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
mixedStrides.push_back(accStride);
else
mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
}
std::reverse(mixedStrides.begin(), mixedStrides.end());
}

static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
Location loc,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();

xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape())
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
else {
SmallVector<OpFoldResult> mixedShapes;
SmallVector<OpFoldResult> mixedStrides;
computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
src);

ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
mixedShapes, mixedStrides);
}

return ndDesc;
}

static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Expand All @@ -109,45 +167,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
getAsOpFoldResult(offsets));
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
unsigned srcRank = srcTy.getRank();
for (unsigned i = 0; i < srcRank; ++i)
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));

SmallVector<int64_t> constOffsets;
SmallVector<Value> dynOffsets;
SmallVector<OpFoldResult> mixedOffsets;
for (Value offset : offsets) {
std::optional<int64_t> staticVal = getConstantIntValue(offset);
if (!staticVal)
dynOffsets.push_back(offset);
constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
}

SmallVector<Value> dynShapes;
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
if (shape == ShapedType::kDynamic)
dynShapes.push_back(sourceDims[idx]);
if (staticVal)
mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
else
mixedOffsets.push_back(offset);
}

// Compute strides in reverse order.
SmallVector<Value> dynStrides;
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
dynStrides.push_back(accStride);
}
std::reverse(dynStrides.begin(), dynStrides.end());
SmallVector<OpFoldResult> mixedShapes;
SmallVector<OpFoldResult> mixedStrides;
computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
src);

ndDesc = xegpu::CreateNdDescOp::create(
rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
}

return ndDesc;
Expand Down Expand Up @@ -523,21 +558,35 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);

xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
readOp.getIndices());

DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::LoadNdOp loadOp;

if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));

loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
getAsOpFoldResult(readOp.getIndices()),
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
} else {
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
readOp.getIndices());

loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
}
rewriter.replaceOp(readOp, loadOp);

return success();
Expand Down Expand Up @@ -579,17 +628,30 @@ struct TransferWriteLowering
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
writeOp.getIndices());

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::StoreNdOp storeOp;
if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));

storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
getAsOpFoldResult(writeOp.getIndices()),
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
} else {
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
writeOp.getIndices());

storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
}
rewriter.replaceOp(writeOp, storeOp);

return success();
Expand Down Expand Up @@ -674,19 +736,32 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {

// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;

auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto loadNdOp = xegpu::LoadNdOp::create(
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::LoadNdOp loadNdOp;

if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
loadNdOp = xegpu::LoadNdOp::create(
rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
/*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
} else {
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
loadNdOp =
xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
/*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
}
rewriter.replaceOp(loadOp, loadNdOp);

return success();
Expand All @@ -711,15 +786,28 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::StoreNdOp storeNdOp;
if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType, storeOp.getBase());

storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
getAsOpFoldResult(storeOp.getIndices()),
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
} else {
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());

storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
}

rewriter.replaceOp(storeOp, storeNdOp);

return success();
Expand Down
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

// if shape and strides are from Memref, we don't need attributes for them
// to keep the IR print clean.
if (staticShape == memrefShape && staticStrides == memrefStrides) {
// to keep the IR print clean (only do so for full-static case, otherwise
// printer would fail trying to print empty array-attr).
if (staticShape == memrefShape && staticStrides == memrefStrides &&
dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
Expand Down Expand Up @@ -277,8 +279,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

// if shape and strides are from Memref, we don't need attributes for them
// to keep the IR print clean.
if (staticShape == memrefShape && staticStrides == memrefStrides) {
// to keep the IR print clean (only do so for full-static case, otherwise
// printer would fail trying to print empty array-attr).
if (staticShape == memrefShape && staticStrides == memrefStrides &&
dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

// -----
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// LOAD-ND-LABEL: @load_zero_pad_out_of_bounds(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]

// LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds(
Expand All @@ -109,9 +109,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET1:.+]]: index,
// LOAD-ND-SAME: %[[OFFSET2:.+]]: index
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND-SAME: -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// STORE-SCATTER-LABEL: @store_out_of_bounds(
// STORE-SCATTER: vector.transfer_write
Expand Down