Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 191 additions & 1 deletion mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,195 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
}
};

/// This pattern handles the unrolling of LoadGatherOp with offsets (gathered
/// load).
/// It unrolls the offsets and mask operands accordingly, and creates multiple
/// LoadGatherOp with the unrolled operands.
struct UnrollLoadGatherOpWithOffset
: public UnrollPattern<xegpu::LoadGatherOp> {
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
Value offsets = op.getOffsets();
Value mask = op.getMask();

// Only handle the case where offsets are present (scattered load)
if (!offsets)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

SmallVector<int64_t> targetMaskShape(*targetShape);
int64_t chunkSize = 1;
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
chunkSize = intAttr.getInt();
}

// Unroll mask and offsets with correct shape
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
Type elemTy = valueTy.getElementType();
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);

SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsets;

if (chunkSize > 1) {
// For chunked loads, mask and offsets have one less dimension
targetMaskShape.pop_back();
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = chunkSize / blockedChunkSize;
chunkSize = blockedChunkSize;

convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);

SmallVector<Value> convertedMasksBase =
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
SmallVector<Value> convertedOffsetsBase =
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);

for (auto maskVal : convertedMasksBase)
convertedMasks.append(numNewChunks, maskVal);

for (auto [baseOffset, offsetType] :
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
i * blockedChunkSize);
Value incVec =
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
Value offsetVal =
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
convertedOffsets.push_back(offsetVal);
}
}
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks =
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);

convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
convertedOffsets =
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}

SmallVector<Value> newOps;
for (auto [o, m] : llvm::zip(convertedOffsets, convertedMasks)) {
auto newOp = xegpu::LoadGatherOp::create(
rewriter, loc, newValueTy, op.getSource(), o, m,
rewriter.getI64IntegerAttr(chunkSize), op.getL1HintAttr(),
op.getL2HintAttr(), op.getL3HintAttr());
newOps.push_back(newOp);
}

Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
rewriter.replaceOp(op, castOp);
return success();
}
};

/// This pattern handles the unrolling of StoreScatterOp with offsets (scattered
/// store).
/// It unrolls the offsets and mask operands accordingly, and creates multiple
/// StoreScatterOp with the unrolled operands.
struct UnrollStoreScatterOpWithOffsets
: public UnrollPattern<xegpu::StoreScatterOp> {
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
Value offsets = op.getOffsets();
Value mask = op.getMask();

// Only handle the case where offsets are present (scattered store)
if (!offsets)
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape)
return failure();

int64_t chunkSize = 1;
if (auto chunkSizeAttr = op->getAttr("chunk_size")) {
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(chunkSizeAttr))
chunkSize = intAttr.getInt();
}

SmallVector<int64_t> targetMaskShape(*targetShape);
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());

SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
SmallVector<Type> convertedOffsetTypes;
SmallVector<Value> convertedOffsets;

if (chunkSize > 1) {
targetMaskShape.pop_back();
int64_t blockedChunkSize = targetShape->back();
int64_t numNewChunks = chunkSize / blockedChunkSize;
chunkSize = blockedChunkSize;

convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedOffsetTypes = getUnrolledTypes(offsetsTy, targetMaskShape);

SmallVector<Value> convertedMasksBase =
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);
SmallVector<Value> convertedOffsetsBase =
pack(offsets, convertedOffsetTypes, targetMaskShape, loc, rewriter);

for (auto maskVal : convertedMasksBase)
convertedMasks.append(numNewChunks, maskVal);

for (auto [baseOffset, offsetType] :
llvm::zip(convertedOffsetsBase, convertedOffsetTypes)) {
for (int64_t i = 0; i < numNewChunks; ++i) {
Value inc = arith::ConstantIndexOp::create(rewriter, loc,
i * blockedChunkSize);
Value incVec =
vector::BroadcastOp::create(rewriter, loc, offsetType, inc);
Value offsetVal =
arith::AddIOp::create(rewriter, loc, baseOffset, incVec);
convertedOffsets.push_back(offsetVal);
}
}
} else {
convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
convertedMasks =
pack(mask, convertedMaskTypes, targetMaskShape, loc, rewriter);

convertedOffsetTypes = getUnrolledTypes(offsetsTy, *targetShape);
convertedOffsets =
pack(offsets, convertedOffsetTypes, *targetShape, loc, rewriter);
}

SmallVector<Type> convertedValTypes =
getUnrolledTypes(valueTy, *targetShape);
SmallVector<Value> convertedValues =
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);

for (auto [v, o, m] :
llvm::zip(convertedValues, convertedOffsets, convertedMasks)) {
xegpu::StoreScatterOp::create(rewriter, loc, v, op.getDest(), o, m,
rewriter.getI64IntegerAttr(chunkSize),
op.getL1HintAttr(), op.getL2HintAttr(),
op.getL3HintAttr());
}

rewriter.eraseOp(op);
return success();
}
};

struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
Expand Down Expand Up @@ -766,6 +955,7 @@ void mlir::xegpu::populateXeGPUUnrollPatterns(
.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, UnrollCreateDescOp,
UnrollLoadGatherOp, UnrollStoreScatterOp, UnrollPrefetchOp,
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp>(
UnrollUpdateOffsetOp, UnrollLoadMatrixOp, UnrollStoreMatrixOp,
UnrollLoadGatherOpWithOffset, UnrollStoreScatterOpWithOffsets>(
patterns.getContext(), options);
}
91 changes: 91 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,27 @@ gpu.module @test {
gpu.return %ld : vector<32xf32>
}

//-----


// CHECK-LABEL: load_with_offsets
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>

gpu.return %ld : vector<32xf32>
}

//-----

// CHECK-LABEL: prefetch
Expand Down Expand Up @@ -254,6 +275,28 @@ gpu.module @test {

gpu.return
}

//-----

// CHECK-LABEL: store_with_offsets
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK-COUNT-2: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
gpu.func @store_with_offsets(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%st_vec = arith.constant dense<1023.0>: vector<32xf32>
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>

gpu.return
}

//-----
// CHECK-LABEL: create_tdesc_step_chunk
Expand Down Expand Up @@ -319,6 +362,29 @@ gpu.module @test {
gpu.return %ld : vector<32x4xf32>
}

//-----
// CHECK-LABEL: load_with_offsets_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
// CHECK-COUNT-4: xegpu.load {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>
%ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
gpu.return %ld : vector<32x4xf32>
}

//-----
// CHECK-LABEL: store_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
Expand All @@ -342,6 +408,31 @@ gpu.module @test {
gpu.return
}

//-----
// CHECK-LABEL: store_with_offsets_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
// CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
// CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
// CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
// CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
// CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
// CHECK-COUNT-4: xegpu.store {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
gpu.func @store_with_offsets_chunk(%src: ui64) {
%cst = arith.constant dense<[
0, 8, 16, 24, 32, 40, 48, 56,
64, 72, 80, 88, 96, 104, 112, 120,
128, 136, 144, 152, 160, 168, 176, 184,
192, 200, 208, 216, 224, 232, 240, 248
]> : vector<32xindex>

%c17 = arith.constant 17: index
%mask = vector.create_mask %c17: vector<32xi1>

%st_vec = arith.constant dense<1023.>: vector<32x4xf32>
xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
gpu.return
}

//-----
// CHECK-LABEL: prefetch_chunk
// CHECK-SAME: [[arg0:%.+]]: ui64
Expand Down
Loading