Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workgroup chipletgroup strategy to workgroup reordering pass #17811

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>
// GFX942-SAME: chip = <wgp_count = 304, chiplet_count = 8>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>]
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/GPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ createConvertVectorReductionToGPUPass(
bool expandSubgroupReduction = true,
std::function<int(mlir::FunctionOpInterface)> getWarpSize = nullptr);

enum class ReorderWorkgroupsStrategy { None, Swizzle, Transpose };
enum class ReorderWorkgroupsStrategy { None, ChipletGroup, Swizzle, Transpose };

/// Reorders workgroup IDs.
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
Expand Down
11 changes: 8 additions & 3 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,15 @@ def ReorderWorkgroupsPass :
let dependentDialects = ["::mlir::affine::AffineDialect"];
let options = [
Option<"strategy", "strategy", "std::string", /*default=*/"",
"Workgroup reordering strategy, one of: '' (none), 'transpose', 'swizzle'">,
Option<"logTile", "logTile", "unsigned",
"Workgroup reordering strategy, one of: '' (none), 'transpose', 'swizzle', 'chipletgroup'">,
Option<"logSwTile", "logSwTile", "unsigned",
/*default=*/"0",
"The log2 of the tile size used for swizzling. (0: disabled, non-0: swizzling enabled)">,
"The log2 of the tile size used for swizzling. "
"(0: swizzling disabled, non-0: swizzling enabled)">,
Option<"logCgTile", "logCgTile", "unsigned",
/*default=*/"0",
"The log2 of the tile size used for chipletgroup. "
"(0: chipletgroup disabled, non-0: chipletgroup enabled)">,
kuhar marked this conversation as resolved.
Show resolved Hide resolved
];
}

Expand Down
169 changes: 158 additions & 11 deletions compiler/src/iree/compiler/Codegen/Common/GPU/WorkgroupReordering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,117 @@ makeSwizzledIds(Location loc, OpBuilder b, Value workgroupIdX,
return {swizzledIdX, swizzledIdY};
}

// Reoredering to make workgroup ids move slowly between chiplet groups.
qedawkins marked this conversation as resolved.
Show resolved Hide resolved

// The following example illustrates the concept behind this function:
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// Currently, the GPU launches workgroups in a round-robin fashion across
// each XCD partition on the GPU.
// Assume we have 16 workgroups and XCDPartitionsOnGPU is 4.
// The default GPU schedule will launch workgroups {0, 1, 2, 3, ..., 15} in
// the following round-robin fashion:
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// Partition 0: {0, 4, 8, 12}
// Partition 1: {1, 5, 9, 13}
// Partition 2: {2, 6, 10, 14}
// Partition 3: {3, 7, 11, 15}

// After reordering, the workgroup IDs are {0, 4, 8, 12, 1, ..., 15},
// resulting in the round-robin launching fashion:
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// Partition 0: {0, 1, 2, 3}
// Partition 1: {4, 5, 6, 7}
// Partition 2: {8, 9, 10, 11}
// Partition 3: {12, 13, 14, 15}

// The return value is each workgroup's permuted Id
// In the above example:
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// linearedId 0's permuted Id is still 0
// linearedId 1's permiuted Id is 4
static Value chipletAwareWorkgroupReordering(Location loc, OpBuilder b,
Value linearizedId,
Value workgroupCountX,
Value workgroupCountY,
int64_t XCDParitionsOnGPU) {
Value numChipletsVal =
b.createOrFold<arith::ConstantIndexOp>(loc, XCDParitionsOnGPU);
Value workgroupCount =
b.create<arith::MulIOp>(loc, workgroupCountX, workgroupCountY);
Value workgroupCountPerChiplet =
b.create<arith::DivUIOp>(loc, workgroupCount, numChipletsVal);
Value chipletId = b.create<arith::RemUIOp>(loc, linearizedId, numChipletsVal);
Value wgIdWithinChiplet =
b.create<arith::DivUIOp>(loc, linearizedId, numChipletsVal);
Value reorderedId = b.create<arith::AddIOp>(
loc, wgIdWithinChiplet,
b.create<arith::MulIOp>(loc, chipletId, workgroupCountPerChiplet));

// Handle the remainder part.
Value constOne = b.createOrFold<arith::ConstantIndexOp>(loc, 1);
Value lastWorkgroupId =
b.create<arith::SubIOp>(loc, workgroupCount, constOne);
Value modulatedLastWorkgroupId = b.create<arith::SubIOp>(
loc, lastWorkgroupId,
b.create<arith::RemUIOp>(loc, workgroupCount, numChipletsVal));
Value isGreaterThanFinalWorkgroupId = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ugt, linearizedId, modulatedLastWorkgroupId);
Value finalId = b.create<arith::SelectOp>(loc, isGreaterThanFinalWorkgroupId,
linearizedId, reorderedId);
qedawkins marked this conversation as resolved.
Show resolved Hide resolved

return finalId;
}

// Chiplet-aware workgroup reordering strategy: reordering + super-grouping.
// Step 1: Reorder the workgroup grid to move slowly between
// chiplet groups (Function: chipletAwareWorkgroupReordering).
// Step 2: Implement 'super-grouping' of workgroups before switching to the next
// column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say what the return value is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not addressed.

static std::pair<Value, Value>
makeChipletGroupedIds(Location loc, OpBuilder b, Value workgroupIdX,
Value workgroupIdY, Value workgroupCountX,
Value workgroupCountY, unsigned chipletGroupTile,
unsigned numXCDs) {
// Create one dimension ID for workgroup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Create one dimension ID for workgroup
// Create one dimension ID for workgroup.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not addressed

Value linearized =
b.create<arith::MulIOp>(loc, workgroupIdY, workgroupCountX);
linearized = b.create<arith::AddIOp>(loc, linearized, workgroupIdX);

assert(numXCDs > 1);
qedawkins marked this conversation as resolved.
Show resolved Hide resolved
// Map chiplets to perform a spatially local tile operation.
// Reorder the linearized ID such that every consecutive group of chiplets
// is the slowest-changing dimension in the grid.
// Emphircally found that two chiplets as a group has better locality
// throughout.
linearized = chipletAwareWorkgroupReordering(
loc, b, linearized, workgroupCountX, workgroupCountY, numXCDs / 2);

// Detailed explaination about the idea behind the below implementation:
// the L2 Cache Optimizations subsection in
// https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// Emphircally, found rowGroupSize=16 for mi300x achieves good performance
unsigned rowGroupSize = chipletGroupTile;
Value rowGroupSizeVal =
b.createOrFold<arith::ConstantIndexOp>(loc, rowGroupSize);
// group every 16 workgroups along Y dimension
// Number of workgroups in the group
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
Value numWorkGroupsPerRowBlock =
b.create<arith::MulIOp>(loc, rowGroupSizeVal, workgroupCountX);

Value groupId =
b.create<arith::DivUIOp>(loc, linearized, numWorkGroupsPerRowBlock);
Value firstRowID = b.create<arith::MulIOp>(loc, groupId, rowGroupSizeVal);

Value currentRowGroupSize = b.create<arith::MinUIOp>(
loc, b.create<arith::SubIOp>(loc, workgroupCountY, firstRowID),
rowGroupSizeVal);

Value newY = b.create<arith::AddIOp>(
loc, firstRowID,
b.create<arith::RemUIOp>(loc, linearized, currentRowGroupSize));

Value newX = b.create<arith::DivUIOp>(
loc, b.create<arith::RemUIOp>(loc, linearized, numWorkGroupsPerRowBlock),
currentRowGroupSize);
return {newX, newY};
}

/// Transpose IDs, i.e., changes the traversal order from left -> right then
/// top -> bottom to top -> bottom then left -> right.
static std::pair<Value, Value> makeTransposedIds(Location loc, OpBuilder b,
Expand Down Expand Up @@ -112,11 +223,12 @@ getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {

static LogicalResult reorderWorkgroupsInFunc(FunctionOpInterface funcOp,
ReorderWorkgroupsStrategy strategy,
unsigned swizzleLogTile) {
unsigned logTile,
unsigned numXCDs = 2) {
assert(strategy != ReorderWorkgroupsStrategy::None &&
"Expected a concrete strategy");

unsigned swizzleTile = 1u << swizzleLogTile;
unsigned reorderWgTileSize = 1u << logTile;
IREE::HAL::InterfaceWorkgroupIDOp oldXId;
IREE::HAL::InterfaceWorkgroupIDOp oldYId;
unsigned numXIdOps = 0;
Expand Down Expand Up @@ -153,7 +265,11 @@ static LogicalResult reorderWorkgroupsInFunc(FunctionOpInterface funcOp,
if (strategy == ReorderWorkgroupsStrategy::Swizzle) {
std::tie(newWorkgroupIdX, newWorkgroupIdY) =
makeSwizzledIds(funcOp.getLoc(), builder, workgroupIdX, workgroupIdY,
workgroupCntX, workgroupCntY, swizzleTile);
workgroupCntX, workgroupCntY, reorderWgTileSize);
} else if (strategy == ReorderWorkgroupsStrategy::ChipletGroup) {
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
std::tie(newWorkgroupIdX, newWorkgroupIdY) = makeChipletGroupedIds(
funcOp.getLoc(), builder, workgroupIdX, workgroupIdY, workgroupCntX,
workgroupCntY, reorderWgTileSize, numXCDs);
} else {
assert(strategy == ReorderWorkgroupsStrategy::Transpose &&
"Unhandled strategy");
Expand Down Expand Up @@ -186,9 +302,9 @@ namespace {
struct ReorderWorkgroupsPass final
: impl::ReorderWorkgroupsPassBase<ReorderWorkgroupsPass> {
ReorderWorkgroupsPass(
ReorderWorkgroupsStrategy strategy, unsigned logSwizzleTile,
ReorderWorkgroupsStrategy strategy, unsigned logTile,
std::function<LogicalResult(mlir::FunctionOpInterface)> filterFn)
: reorderingStrategy(strategy), logSwizzleTile(logSwizzleTile),
: reorderingStrategy(strategy), reorderWgLogTileSize(logTile),
filterFn(std::move(filterFn)) {}

LogicalResult initializeOptions(
Expand All @@ -197,17 +313,25 @@ struct ReorderWorkgroupsPass final
if (failed(Pass::initializeOptions(options, errorHandler))) {
return failure();
}
logSwizzleTile = logTile;

auto selectedStrategy =
llvm::StringSwitch<FailureOr<ReorderWorkgroupsStrategy>>(strategy)
.Case("", ReorderWorkgroupsStrategy::None)
.Case("chipletgroup", ReorderWorkgroupsStrategy::ChipletGroup)
.Case("swizzle", ReorderWorkgroupsStrategy::Swizzle)
.Case("transpose", ReorderWorkgroupsStrategy::Transpose)
.Default(failure());
if (failed(selectedStrategy))
return failure();

reorderingStrategy = *selectedStrategy;
if (reorderingStrategy == ReorderWorkgroupsStrategy::Swizzle &&
reorderWgLogTileSize == 0)
reorderWgLogTileSize = logSwTile;
else if (reorderingStrategy == ReorderWorkgroupsStrategy::ChipletGroup &&
reorderWgLogTileSize == 0)
reorderWgLogTileSize = logCgTile;

return success();
}

Expand All @@ -216,7 +340,11 @@ struct ReorderWorkgroupsPass final
return;

if (reorderingStrategy == ReorderWorkgroupsStrategy::Swizzle &&
logSwizzleTile == 0)
reorderWgLogTileSize == 0)
return;

if (reorderingStrategy == ReorderWorkgroupsStrategy::ChipletGroup &&
reorderWgLogTileSize == 0)
return;

FunctionOpInterface funcOp = getOperation();
Expand All @@ -229,7 +357,26 @@ struct ReorderWorkgroupsPass final
llvm::dbgs() << "\n\n";
});

if (failed(reorderWorkgroupsInFunc(funcOp, reorderingStrategy, logTile))) {
uint32_t numXCDs = 1;
if (IREE::HAL::ExecutableTargetAttr targetAttr =
IREE::HAL::ExecutableTargetAttr::lookup(funcOp)) {
if (DictionaryAttr config = targetAttr.getConfiguration()) {
if (IREE::GPU::TargetAttr attr =
config.getAs<IREE::GPU::TargetAttr>("iree.gpu.target")) {
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
IREE::GPU::TargetChipAttr chipAttr = attr.getChip();
if (chipAttr)
numXCDs = chipAttr.getChipletCount();
}
}
}

LLVM_DEBUG(llvm::dbgs() << "Number of XCDs = " << numXCDs << "\n");
if (numXCDs == 1 &&
reorderingStrategy == ReorderWorkgroupsStrategy::ChipletGroup)
return;

if (failed(reorderWorkgroupsInFunc(funcOp, reorderingStrategy,
reorderWgLogTileSize, numXCDs))) {
LLVM_DEBUG(llvm::dbgs() << "Failed to reorder workgroups\n");
return;
}
Expand All @@ -244,16 +391,16 @@ struct ReorderWorkgroupsPass final
private:
ReorderWorkgroupsStrategy reorderingStrategy =
ReorderWorkgroupsStrategy::None;
unsigned logSwizzleTile = 0;
unsigned reorderWgLogTileSize = 0;
std::function<LogicalResult(mlir::FunctionOpInterface)> filterFn;
};
} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createReorderWorkgroups(
ReorderWorkgroupsStrategy strategy, unsigned swizzleLogTile,
ReorderWorkgroupsStrategy strategy, unsigned reorderWgLogTile,
std::function<LogicalResult(mlir::FunctionOpInterface)> filterFn) {
return std::make_unique<ReorderWorkgroupsPass>(strategy, swizzleLogTile,
return std::make_unique<ReorderWorkgroupsPass>(strategy, reorderWgLogTile,
filterFn);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reorder-workgroups{strategy=swizzle logTile=3}))" \
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reorder-workgroups{strategy=swizzle logSwTile=3}))" \
// RUN: --split-input-file %s | FileCheck --check-prefix=SWIZZLE %s

// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reorder-workgroups{strategy=transpose}))" \
// RUN: --split-input-file %s | FileCheck --check-prefix=TRANSPOSE %s

func.func @matmul() {
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-reorder-workgroups{strategy=chipletgroup logCgTile=3}))" \
// RUN: --split-input-file %s | FileCheck --check-prefix=CHIPLETGROUP %s
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8,
storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>],
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
chip = <wgp_count = 304, chiplet_count = 8>>, ukernels = "none"}>
func.func @matmul() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
%c0 = arith.constant 0 : index
%c128 = arith.constant 128 : index
%c96 = arith.constant 96 : index
Expand Down Expand Up @@ -55,6 +62,41 @@ func.func @matmul() {
// SWIZZLE: %[[S13:.*]] = arith.select %[[S12]], %[[WG_X]], %[[S6]] : index
// SWIZZLE: %[[S14:.*]] = arith.select %[[S12]], %[[WG_Y]], %[[S7]] : index

// CHIPLETGROUP-LABEL: func.func @matmul
// CHIPLETGROUP: %[[WG_X:.*]] = hal.interface.workgroup.id[0] : index
// CHIPLETGROUP: %[[WG_Y:.*]] = hal.interface.workgroup.id[1] : index
// CHIPLETGROUP: %[[WG_CNT_X:.*]] = hal.interface.workgroup.count[0] : index
// CHIPLETGROUP: %[[WG_CNT_Y:.*]] = hal.interface.workgroup.count[1] : index
// CHIPLETGROUP: %[[S0:.*]] = arith.muli %[[WG_Y]], %[[WG_CNT_X]] : index
// CHIPLETGROUP: %[[S1:.*]] = arith.addi %[[S0]], %[[WG_X]] : index
// CHIPLETGROUP: %[[CST4:.*]] = arith.constant 4 : index
// CHIPLETGROUP: %[[WG_CNT:.*]] = arith.muli %[[WG_CNT_X]], %[[WG_CNT_Y]] : index
// CHIPLETGROUP: %[[S3:.*]] = arith.divui %[[WG_CNT]], %[[CST4]] : index
// CHIPLETGROUP: %[[S4:.*]] = arith.remui %[[S1]], %[[CST4]] : index
// CHIPLETGROUP: %[[S5:.*]] = arith.divui %[[S1]], %[[CST4]] : index
// CHIPLETGROUP: %[[S6:.*]] = arith.muli %[[S4]], %[[S3]] : index
// CHIPLETGROUP: %[[S7:.*]] = arith.addi %[[S5]], %[[S6]] : index
// CHIPLETGROUP: %[[CST1:.*]] = arith.constant 1 : index
// CHIPLETGROUP: %[[S8:.*]] = arith.subi %[[WG_CNT]], %[[CST1]] : index
// CHIPLETGROUP: %[[S9:.*]] = arith.remui %[[WG_CNT]], %[[CST4]] : index
// CHIPLETGROUP: %[[S10:.*]] = arith.subi %[[S8]], %[[S9]] : index
// CHIPLETGROUP: %[[S11:.*]] = arith.cmpi ugt, %[[S1]], %[[S10]] : index
// CHIPLETGROUP: %[[S12:.*]] = arith.select %[[S11]], %[[S1]], %[[S7]] : index
// CHIPLETGROUP: %[[CST8:.*]] = arith.constant 8 : index
// CHIPLETGROUP: %[[S13:.*]] = arith.muli %[[CST8]], %[[WG_CNT_X]] : index
// CHIPLETGROUP: %[[S14:.*]] = arith.divui %[[S12]], %[[S13]] : index
// CHIPLETGROUP: %[[S15:.*]] = arith.muli %[[S14]], %[[CST8]] : index
// CHIPLETGROUP: %[[S16:.*]] = arith.subi %[[WG_CNT_Y]], %[[S15]] : index
// CHIPLETGROUP: %[[S17:.*]] = arith.minui %[[S16]], %[[CST8]] : index
// CHIPLETGROUP: %[[S18:.*]] = arith.remui %[[S12]], %[[S17]] : index
// CHIPLETGROUP: %[[S19:.*]] = arith.addi %[[S15]], %[[S18]] : index
// CHIPLETGROUP: %[[S20:.*]] = arith.remui %[[S12]], %[[S13]] : index
// CHIPLETGROUP: %[[S21:.*]] = arith.divui %[[S20]], %[[S17]] : index
// CHIPLETGROUP: %26 = affine.apply #map()[%[[S19]]]
// CHIPLETGROUP: %27 = affine.apply #map()[%workgroup_count_y_1]
// CHIPLETGROUP: %28 = affine.apply #map()[%[[S21]]]
// CHIPLETGROUP: %29 = affine.apply #map()[%workgroup_count_x_0]

// TRANSPOSE-LABEL: func.func @matmul
// TRANSPOSE: %[[WG_X:.*]] = hal.interface.workgroup.id[0] : index
// TRANSPOSE: %[[WG_Y:.*]] = hal.interface.workgroup.id[1] : index
Expand Down
Loading
Loading