Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -379,29 +379,41 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
);

let builders = [
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
"llvm::ArrayRef<int32_t>": $lane_layout,
"llvm::ArrayRef<int32_t>": $lane_data),
[{
auto sg_layout = DenseI32ArrayAttr();
auto sg_data = DenseI32ArrayAttr();
auto inst_data = DenseI32ArrayAttr();
auto order = DenseI32ArrayAttr();
return $_get($_ctxt, sg_layout, sg_data, inst_data,
return $_get($_ctxt, sg_layout, sg_data,
DenseI32ArrayAttr::get($_ctxt, inst_data),
DenseI32ArrayAttr::get($_ctxt, lane_layout),
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
}]>,
AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,
"llvm::ArrayRef<int32_t>": $lane_data,
"llvm::ArrayRef<int32_t>": $order),
"llvm::ArrayRef<int32_t>": $lane_data),
[{
return $_get($_ctxt,
/*sg_layout =*/ nullptr,
/*sg_data =*/ nullptr,
/*inst_data =*/ nullptr,
auto sg_layout = DenseI32ArrayAttr();
auto sg_data = DenseI32ArrayAttr();
auto inst_data = DenseI32ArrayAttr();
auto order = DenseI32ArrayAttr();
return $_get($_ctxt, sg_layout, sg_data, inst_data,
DenseI32ArrayAttr::get($_ctxt, lane_layout),
DenseI32ArrayAttr::get($_ctxt, lane_data),
DenseI32ArrayAttr::get($_ctxt, order));
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
}]>,
// AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $lane_layout,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, the constructor was not used and its signature conflicted with the new inst_data one, so it can be removed altogether, until we need order somewhere.

// "llvm::ArrayRef<int32_t>": $lane_data,
// "llvm::ArrayRef<int32_t>": $order),
// [{
// return $_get($_ctxt,
// /*sg_layout =*/ nullptr,
// /*sg_data =*/ nullptr,
// /*inst_data =*/ nullptr,
// DenseI32ArrayAttr::get($_ctxt, lane_layout),
// DenseI32ArrayAttr::get($_ctxt, lane_data),
// DenseI32ArrayAttr::get($_ctxt, order));
// }]>,
AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
"DenseI32ArrayAttr": $lane_data,
"DenseI32ArrayAttr": $order),
Expand Down
30 changes: 0 additions & 30 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h

This file was deleted.

7 changes: 6 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
let options = [Option<
"printOnly", "print-analysis-only", "bool",
/*default=*/"false",
"Print the result of layout propagation analysis and exit.">];
"Print the result of layout propagation analysis and exit.">,
Option<
"assumeUnrolled", "assume-unrolled", "bool",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this option be an enumeration, so the propagation could be applied to "lane", "inst", and "subgroup" parameters? High-level implies lower level will be propagated, so "assumeUnrolled = true" can be replaced to "level = lane" here and the options are more extensible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an enum, but a string.
For subgroup, we must have a user layout on anchor ops to propagate? It's not like lane/inst fields, which are tightly coupled to hw subgroup size and/or instruction size.
Anyway, this is a topic for a different PR. For now, we can do lane and inst.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, yes we expect user to set sg_layout/sg_data.

/*default=*/"false",
"If the input IR has SG-sized tiles matching instruction sizes, omit `inst_data`.">
];
}

def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
Expand Down
82 changes: 76 additions & 6 deletions mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <map>
#include <string>

#define DEBUG_TYPE "xegpu-uarch"

using namespace mlir;
using namespace mlir::xegpu::uArch;

Expand All @@ -42,12 +40,61 @@ struct Xe2Plus : public uArch {
&instrs = {})
: uArch(archName, archDescription, regInfo, cacheInfo, instrs),
xeCore(xeCore) {}
int getSubgroupSize() const override { return 16; }
unsigned getPackedFormatBitSizeGatherScatter() const override { return 32; }
unsigned getPackedFormatBitSize() const override { return 16; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is getPackedFormatBitSize really getPackedFormatBitSizeDpasA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will be renamed. And I think it should be a member of dpas instruction per uarch instance. We might want to split this PR into two parts to have a substantial discussion in each: (1) uArch modification and (2) propagation option and uArch application in passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, is it the same as B (32)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C, the result is f32 so no packing needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a generic lane data calculation for dpas operands, wouldn't the following format be desired in the dpas propagation
packingFactor = dpasInst->getOperand*A/B/C*PackingBitSize() / dataElemBitwidth?

It is not so much about whether we actually consider "packing" C.

std::optional<unsigned> getPackedFormatBitSizeDpasB() const override {
return 32;
}
};

//===----------------------------------------------------------------------===//
// uArch instructions
//===----------------------------------------------------------------------===//
struct StoreNdInstruction : public Instruction {
StoreNdInstruction()
: Instruction(InstructionKind::STORE_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Reads 1, 2, 4, or 8 uints of data for each work item in the sub-group from
// the specified pointer
llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
};

struct LoadNdInstruction : public Instruction {
LoadNdInstruction()
: Instruction(InstructionKind::LOAD_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroups.html#_add_a_new_section_6_13_x_sub_group_read_and_write_functions
// Writes 1, 2, 4, or 8 uints of data for each work item in the sub-group to
// the specified pointer.
llvm::SmallVector<int> getSortedLaneVectorLengths() { return {1, 2, 4, 8}; }
};

struct PrefetchNdInstruction : public Instruction {
PrefetchNdInstruction()
: Instruction(InstructionKind::PREFETCH_ND, InstructionScope::Subgroup) {}

// Source :
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
llvm::SmallVector<int> getSortedLaneVectorLengths(int elementBitwidth) {
if (elementBitwidth == 8 || elementBitwidth == 16)
return {1, 2, 4, 8, 16};
else if (elementBitwidth == 32 || elementBitwidth == 64)
return {1, 2, 4, 8};
else
llvm_unreachable(
"Unsupported element bitwidth for PrefetchNdInstruction");
}
};

// struct to represent DPAS instruction
struct DPASInstruction : public Instruction, public MMAInstructionInterface {
DPASInstruction()
: Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
// Source:
// https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html

// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
Expand All @@ -72,6 +119,9 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
};

//===----------------------------------------------------------------------===//
// uArch instructions
//===----------------------------------------------------------------------===//
struct PVCuArch : public Xe2Plus {
// Maintaines ownership of the instructions owned by PVUarch
llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
Expand Down Expand Up @@ -101,9 +151,15 @@ struct PVCuArch : public Xe2Plus {
CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));

// Add the instructions-
auto dpas = std::make_shared<DPASInstruction>();
instructions.emplace(dpas->getInstructionKind(), dpas);
owned_instructions.push_back(dpas);
llvm::SmallVector<std::shared_ptr<Instruction>> instructionsToAdd{

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - formatting

std::make_shared<DPASInstruction>(),
std::make_shared<StoreNdInstruction>(),
std::make_shared<LoadNdInstruction>(),
std::make_shared<PrefetchNdInstruction>()};
for (auto &inst : instructionsToAdd) {
instructions.emplace(inst->getInstructionKind(), inst);
owned_instructions.push_back(inst);
}
}
};

Expand Down Expand Up @@ -139,10 +195,24 @@ struct BMGuArch : public Xe2Plus {
owned_instructions.push_back(dpas);
}
};

inline std::shared_ptr<uArch> getUArch(const std::string &archName) {
if (archName == "pvc")
return std::make_shared<PVCuArch>();
else if (archName == "bmg")
return std::make_shared<BMGuArch>();
else
return nullptr;
}

} // namespace uArch
} // namespace xegpu
} // namespace mlir

//===----------------------------------------------------------------------===//
// Instruction implementations
//===----------------------------------------------------------------------===//

inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
Expand Down
25 changes: 22 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
// multiply-add operation
DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
// multiply-add operation
STORE_ND, // Subgroup-level 2D block write instruction
LOAD_ND, // Subgroup-level 2D block load instruction
PREFETCH_ND // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};

Expand All @@ -54,6 +57,12 @@ struct Instruction {
switch (instKind) {
case InstructionKind::DPAS:
return "dpas";
case InstructionKind::STORE_ND:
return "store_nd";
case InstructionKind::LOAD_ND:
return "load_nd";
case InstructionKind::PREFETCH_ND:
return "prefetch_nd";
}
llvm_unreachable("Unknown InstructionKind");
}
Expand Down Expand Up @@ -142,12 +151,22 @@ struct uArch {
: name(name), description(description),
registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
instructions(instructions) {}

virtual ~uArch() = default;
// Get methods
const std::string &getName() const { return name; }

const std::string &getDescription() const { return description; }

virtual int getSubgroupSize() const = 0;
virtual unsigned getPackedFormatBitSizeGatherScatter() const = 0;
virtual unsigned getPackedFormatBitSize() const = 0;
virtual std::optional<unsigned> getPackedFormatBitSizeDpasB() const = 0;

std::shared_ptr<Instruction> getInstruction(InstructionKind instKind) const {
assert(instructions.find(instKind) != instructions.end());
return instructions.at(instKind);
}

const std::map<RegisterFileType, RegisterFileInfo> &
getRegisterFileInfo() const {
return registerFileInfo;
Expand Down
16 changes: 9 additions & 7 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -226,8 +226,10 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}

if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
return emitError()
<< "expected inst_data and lane_layout to have the same rank";
return emitError() << "expected inst_data and lane_layout to have the same "
"rank, got inst_data "
<< inst_data.size() << ", lane_layout "
<< lane_layout.size();
}

// sg_data is optional for Workgroup layout, but its presence requires
Expand Down Expand Up @@ -565,10 +567,10 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,

// for gather and scatter ops, Low-precision types are packed in 32-bit units.
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
int chunkAlignmentFactor =
bitWidth < targetinfo::packedSizeInBitsForGatherScatter
? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
: 1;
constexpr int packingBitSizeGatherScatter{32};
int chunkAlignmentFactor = bitWidth < packingBitSizeGatherScatter
? packingBitSizeGatherScatter / bitWidth
: 1;
auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
if (scatterAttr) {
int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
Expand Down
Loading