Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][XeGPU] Add dpas and named barrier ops #88439

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ add_mlir_dialect(XeGPU xegpu)
add_mlir_doc(XeGPU XeGPU Dialects/ -gen-dialect-doc -dialect=xegpu)

set(LLVM_TARGET_DEFINITIONS XeGPU.td)
mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs)
mlir_tablegen(XeGPUAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=xegpu)
mlir_tablegen(XeGPUAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xegpu)
add_public_tablegen_target(MLIRXeGPUAttrsIncGen)
add_dependencies(mlir-headers MLIRXeGPUAttrsIncGen)

set(LLVM_TARGET_DEFINITIONS XeGPU.td)
set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_XEGPU_IR_XEGPU_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -19,7 +20,7 @@

namespace mlir {
namespace xegpu {
// placeholder
class TensorDescType;
} // namespace xegpu
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

class XeGPUAttr<string name, string attrMnemonic, list<Trait> traits = [],
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def XeGPU_Dialect : Dialect {
let summary = "The XeGPU dialect that models Intel GPU's ISA";
let description = [{
The XeGPU dialect models Intel Xe ISA semantics but works at vector and
TensorDesc data type. It provides 1:1 mappings to match Xe instructions
TensorDesc data type. It provides 1:1 mappings to match Xe instructions
like DPAS and 2D block load. The matrix size being processed at this level
exactly matches the hardware instructions or the intrinsic supported by
the lower-level GPU compiler.
}];

let dependentDialects = ["arith::ArithDialect"];

let useDefaultTypePrinterParser = true;
let useDefaultAttributePrinterParser = true;
}
Expand Down
154 changes: 151 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
#define MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/XeGPU/IR/XeGPUAttrs.td"
include "mlir/Dialect/XeGPU/IR/XeGPUDialect.td"
include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
Expand All @@ -35,7 +35,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:

static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
if (mlir::succeeded(parser.parseLess())) {
if (mlir::succeeded(parser.parseOptionalLess())) {
if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
return failure();
}
Expand Down Expand Up @@ -253,7 +253,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "Tensor
a block of data from memory to register. It takes a set of optional cache
hints for each level of cache, L1, L2 and L3. If hardware does not have a
correspoding cache, Corresponding cache hint attribute will be masked.
vnni transform is an hardware feature for Intel GPU, which is used to
VNNI transformation is an hardware feature for Intel GPU, which is used to
do data packing during the load for B operand of matrix operation, if
the bit width of the data type is less then 32 bits, e.g., fp16. And
transpose is another Intel hardware feature, which will do transpose
Expand Down Expand Up @@ -662,4 +662,152 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
}];
}

def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]> {
let summary = "It performs mma computation";

let description = [{DPAS performs matrix multiplication on matrix A of `mxk`
size, B of `kxn` size, and accumulate on matrix C of `mxn` to the same size
matrix , `m=8`, `n=16` and `k=8 * 32/bit_width_of_elem_type`. So for fp16
data type, the matrices are `A: vector<8x16xf16>`, `B: vector<16x16xf16>`,
and `C/D: vector<8x16xf32>`. Besides the matrix size requirements, DPAS
also requires A and B to be loaded with the required data layout. Specially,
VNNI layout is required for B operand. It is achieved via setting `vnni_axis = 0`
of the corresponding `load_nd` operator. To keep both operands as 3D vector,
operand A is loaded via setting `vnni_axis = 1` without impacting the
physical layouts change in register. Due to the VNNI transformation, A and B operands
are represented as 3D vector, with the last dimension representing the VNNI factor,
which is computed as `32/bit_width_of_elem_type`. Therefore, `A: vector<8x16xf16>`
is represented as `A: vector<4x8x2xf16>`, and `B:vector<16x16xf16>` is
chencha3 marked this conversation as resolved.
Show resolved Hide resolved
represented as `B: vector<8x16x2xf16>`.

Note: on PVC, the hardware can perform load with VNN transformation when data
chencha3 marked this conversation as resolved.
Show resolved Hide resolved
element type is 16-bit or lower precision, taking 2 or 4 elements from
the first dimension and inserted into the newly added innermost dimension.
}];

let arguments = (ins
XeGPU_DpasOpType : $lhs,
XeGPU_DpasOpType : $rhs,
Optional<XeGPU_Vector2DType>: $acc);
let results = (outs XeGPU_Vector2DType: $result);

let extraClassDeclaration = [{
VectorType getLhsType() {
return getLhs().getType();
}

VectorType getRhsType() {
return getRhs().getType();
}

VectorType getAccType() {
if (getAcc())
return getAcc().getType();
return {};
}

VectorType getResultType() {
return getResult().getType();
}
}];

let assemblyFormat = [{
$lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)? `->` type($result)
}];

let hasVerifier = 1;
}

def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
AllElementTypesMatch<["tensorDesc", "value", "result"]>,
AllShapesMatch<["tensorDesc", "mask", "value", "result"]>]> {
let summary = "A ready-modify-write operation. ";

let description = [{
`AtomicRMWOp` has same semantic to `memref.atomic_rmw`, except that
it work on a `TensorDescType` object while `memref.atomic_rmw` works
on a `MemRefType` object. It also has a `mask` variable, which has the
same shape with `TensorDesc`, to enable or disable some data points of
the `TensorDesc`.
}];

let arguments = (ins
AtomicRMWKindAttr:$kind,
XeGPU_TensorDesc:$tensorDesc,
XeGPU_MaskType:$mask,
XeGPU_ValueType:$value);

let results = (outs XeGPU_ValueType:$result);

let assemblyFormat = [{
$kind $tensorDesc `,` $mask `,` $value attr-dict `:`
type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)
}];
}

def XeGPU_AllocNbarrierOp: XeGPU_Op<"alloc_nbarrier", []> {
let summary = "It allocates a set of named barriers.";
let description = [{AllocNbarrier is to create a set of named barriers as
specified by `nbarrier_num`. Named barriers are workgroup level resources,
and are shared by all threads in the workgroup. For example, there are
up to 32 barriers (range 0-31) for each Xecore on PVC. A typical use case
chencha3 marked this conversation as resolved.
Show resolved Hide resolved
is that a workgroup is partitioned into N subgroups of threads (N <= 32),
and each subgroup coordinating their work with a separate barrier with id
range from 0 to N respectively.}];
let arguments = (ins I64Attr: $nbarrier_num);
let assemblyFormat = "$nbarrier_num attr-dict";
}

def XeGPU_InitNbarrierOp: XeGPU_Op<"init_nbarrier", []> {
let summary = "It assigns a named barrier to the current thread.";
let description = [{InitNbarrierOp assigns the named barrier with the specified
barrier ID (0~31) to the current thread. Multiple threads may bind to the
same named barrier, and the `participant_thread_num` specifies the total
number of threads associated with the nbarrier. It returns an object of
NbarrierType representing the barrier}];

let arguments = (ins I8: $nbarrier_id,
I8: $participant_thread_num);
let results = (outs XeGPU_Nbarrier: $result);
let assemblyFormat = [{
$nbarrier_id `,` $participant_thread_num attr-dict `:`
type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))
}];
}

def XeGPU_NbarrierArriveOp: XeGPU_Op<"nbarrier_arrive", []> {
let summary = "It signals the arrival at the named barrier.";
let description = [{NbarrierArriveOp signals the hardware (or other threads)
that the current thread has produced its data for the consumer threads. When
the hardware signalled by `participant_thread_num` threads for the named barrier,
it will notify the threads waiting for the named barrier to continue their work.}];

let arguments = (ins XeGPU_Nbarrier: $nbarrier);
let assemblyFormat = [{ $nbarrier attr-dict `:` qualified(type($nbarrier))}];
}

def XeGPU_NbarrierWaitOp: XeGPU_Op<"nbarrier_wait", []> {
let summary = "It waits for a named barrier.";
let description = [{NbarrierWaitOp signals the hardware which named barrier
the current thread is waiting for, such that it can get notified when the
named barrier is completed.}];
let arguments = (ins XeGPU_Nbarrier: $nbarrier);
let assemblyFormat = [{ $nbarrier attr-dict `:` qualified(type($nbarrier)) }];
}

def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
let summary = "It synchronizes memory accesses.";
let description = [{It synchronizes the memory access between
write and following read or write.
1. `Memory_kind` describes the memory kind. "global" means the global memory,
"slm" means the share local memory.
2. `Fence_scope` describes the scope of fence. "local" means that the scope would be
within each XeCore. "tile" means the scope would be across XeCore with one tile.
}];
let arguments = (ins XeGPU_MemoryScopeAttr: $memory_kind,
StrAttr: $fence_scope);
let assemblyFormat = [{`memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` $fence_scope attr-dict}];
let extraClassDeclaration = extraBaseClassDeclaration;
}

#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,15 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",

}


def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

let extraClassDeclaration = [{
static NbarrierType get(mlir::MLIRContext *context) {
return Base::get(context);
};
}];
}

#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ LogicalResult StoreScatterOp::verify() {

return success();
}
//===----------------------------------------------------------------------===//
// XeGPU_DpasOp
//===----------------------------------------------------------------------===//
LogicalResult DpasOp::verify() {
int64_t lhsRank = getLhsType().getRank();
int64_t rhsRank = getRhsType().getRank();

if (lhsRank != rhsRank || lhsRank != 3)
return emitOpError(
"lhs and rhs rank does not match for dpas op, or their rank is not 3.");

if (getAcc() && getAccType() != getResultType())
return emitOpError("Accumulator and Result for dpas op should have the "
"same type (both shape and element type).");

auto lhsShape = getLhsType().getShape();
auto rhsShape = getRhsType().getShape();
if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
return emitOpError("K-dimension or vnni-factor mismatch.");

return success();
}


} // namespace xegpu
} // namespace mlir
Expand Down
57 changes: 56 additions & 1 deletion mlir/test/Dialect/XeGPU/XeGPUOps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ gpu.func @test_prefetch_vc(%src: ui64) {
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
%1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
// CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
gpu.return
}

Expand Down Expand Up @@ -121,4 +121,59 @@ gpu.func @test_create_update_tdesc_vc(%src: ui64) {
gpu.return
}

// CHECK: gpu.func @test_dpas_vc(%[[arg0:.*]]: vector<8x8x2xf16>, %[[arg1:.*]]: vector<8x16x2xf16>)
gpu.func @test_dpas_vc(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
// CHECK: %0 = xegpu.dpas %[[arg0]], %[[arg1]] : vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%1 = xegpu.dpas %a, %b: vector<8x8x2xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
gpu.return
}

// CHECK: gpu.func @test_atomic_rmw(%[[arg0:.*]]: ui64, %[[arg1:.*]]: vector<16xf32>, %[[arg2:.*]]: vector<16xi1>)
gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
//CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
%1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
//CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : <16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
gpu.return
}

// CHECK: gpu.func @alloc_nbarrier({{.*}}) {
gpu.func @alloc_nbarrier() {
// CHECK: xegpu.alloc_nbarrier
xegpu.alloc_nbarrier 8
gpu.return
}

// CHECK: gpu.func @init_nbarrier({{.*}}) {
gpu.func @init_nbarrier() {
//CHECK: %[[c1:.*]] = arith.constant 1 : i8
//CHECK: %[[c16:.*]] = arith.constant 16 : i8
%nbarrier_id = arith.constant 1 : i8
%threads_count = arith.constant 16 : i8
//CHECK: xegpu.init_nbarrier %[[c1]], %[[c16]] : i8, i8 -> !xegpu.nbarrier
%nbarrier = xegpu.init_nbarrier %nbarrier_id, %threads_count : i8, i8 -> !xegpu.nbarrier
gpu.return
}

// CHECK: gpu.func @nbarrier_arrive(%[[arg0:.*]]: !xegpu.nbarrier) {
gpu.func @nbarrier_arrive(%nbarrier : !xegpu.nbarrier) {
//CHECK: xegpu.nbarrier_arrive %[[arg0]] : !xegpu.nbarrier
xegpu.nbarrier_arrive %nbarrier : !xegpu.nbarrier
gpu.return
}

// CHECK: gpu.func @nbarrier_wait(%[[arg0:.*]]: !xegpu.nbarrier) {
gpu.func @nbarrier_wait(%nbarrier : !xegpu.nbarrier) {
//CHECK: xegpu.nbarrier_wait %[[arg0]] : !xegpu.nbarrier
xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier
gpu.return
}

// CHECK-LABEL: gpu.func @fence({{.*}}) {
gpu.func @fence() {
//CHECK: xegpu.fence memory_kind = global, fence_scope = "local"
xegpu.fence memory_kind = global, fence_scope = "local"
gpu.return
}

}
Loading