Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 87 additions & 26 deletions mlir/include/mlir/Dialect/AMX/AMX.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
vector type of the result. This is eventually lowered into the
"tilezero" instruction with the corresponding tile configuration.
With memory-effects, each "tilezero" operation serves as a compilation
hint to use a separate tile register.
vector type of the result.

The operation is eventually lowered into the "tilezero" instruction
with the corresponding tile configuration.

With the write memory effect, each `amx.tile_zero` operation serves as
a compilation hint to use a separate tile register.

Example:

Expand Down Expand Up @@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [

def TileLoadOp : AMX_Op<"tile_load", [
AMXIntrinsicOpInterface,
MemoryEffects<[MemWrite]>
MemoryEffects<[MemWrite]>,
AttrSizedOperandSegments
]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the result. This is
eventually lowered into the "tileloadd" instruction with the
corresponding tile configuration. With memory-effects, each "tileload"
operation serves as a compilation hint to use a separate tile register.
Loads a tile from memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the result.
The tile's rows are populated by reading contiguous elements starting
at the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.

The tile is loaded using the following indexing scheme:

```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
tile[row, col] = mem_row[col]
```

If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.

The operation is eventually lowered into the "tileloadd" instruction
with the corresponding tile configuration.

With the write memory effect, each `amx.tile_load` operation serves as
a compilation hint to use a separate tile register.

Example:

```mlir
// Tile load from a 2-D memref with implicit stride.
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>

// Tile load from a 1-D memref with explicit stride.
%0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
Variadic<Index>:$indices,
Optional<Index>:$stride);
let results = (outs AnyAMXTile:$res);
let builders = [
OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
Expand All @@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` qualified(type($res))";
let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
"`:` type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}

def TileStoreOp : AMX_Op<"tile_store", [
AMXIntrinsicOpInterface
AMXIntrinsicOpInterface,
AttrSizedOperandSegments
]> {
let summary = "tile store operation";
let description = [{
Stores a tile to memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the value. This is
eventually lowered into the "tilestored" instruction with the
corresponding tile configuration.
Stores a tile to memory defined by a `base` and `indices`, with the
shape defined by the 2-dim vector type of the value.
The tile's rows are written contiguously to the buffer starting at
the `base`. For each tile row, the `base` is incremented by `stride`
number of elements.

The tile is stored using the following indexing scheme:

```
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
mem_row[col] = tile[row, col]
```

If the `stride` is not provided, then the `base` buffer must be at least
2-dimensional, and the `stride` is automatically inferred and corresponds
to the stride of the buffer's second innermost dimension.

The operation is eventually lowered into the "tilestored" instruction
with the corresponding tile configuration.

Example:

```mlir
// Tile store to a 2-D memref with implicit stride.
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>

// Tile store to a 1-D memref with explicit stride.
amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
AnyAMXTile:$val);
AnyAMXTile:$val,
Optional<Index>:$stride);
let builders = [
OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
];
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
Expand All @@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` qualified(type($val))";
let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
"attr-dict `:` type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}

Expand All @@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
pairs of "bf16"). The operation is eventually lowered into the
"tdpbf16ps" instruction with the corresponding tile configuration.
pairs of "bf16").

The operation is eventually lowered into the "tdpbf16ps" instruction with
the corresponding tile configuration.

Example:

Expand Down Expand Up @@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
combinations (4 bytes packed into dwords in the columns of both the
source operand tiles; the zero or sign extension is specified with
the attributes and default to sign extended). The operation is eventually
lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud"
instructions with the corresponding tile configuration.
the attributes and default to sign extended).

The operation is eventually lowered into one of the "tdpbssd",
"tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
tile configuration.

Example:

Expand Down
99 changes: 63 additions & 36 deletions mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,22 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
}

/// Returns stride expressed in number of bytes for the given `elementStride`
/// stride encoded in number of elements of the type `mType`.
static Value computeStrideInBytes(Location loc, MemRefType mType,
Value elementStride, RewriterBase &rewriter) {
Type llvmInt64Type = rewriter.getIntegerType(64);
unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8;
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride)
.getResult();
}

/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
static Value getStride(Location loc, MemRefType mType, Value base,
RewriterBase &rewriter) {
static Value inferStride(Location loc, MemRefType mType, Value base,
RewriterBase &rewriter) {
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
int64_t preLast = mType.getRank() - 2;
Type llvmInt64Type = rewriter.getIntegerType(64);
Expand All @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base,
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
MemRefDescriptor memrefDescriptor(base);
auto attr = rewriter.getI64IntegerAttr(bytes);
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale,
memrefDescriptor.stride(rewriter, loc, preLast))
.getResult();
return computeStrideInBytes(
loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter);
}
// Use direct constant for static stride.
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
Expand All @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
return getTileSizes(getLoc(), getTileType(), rewriter);
}

LogicalResult amx::TileLoadOp::verify() {
MemRefType memrefTy = getMemRefType();
template <typename OpTy,
typename = std::enable_if_t<std::is_same_v<OpTy, amx::TileLoadOp> ||
std::is_same_v<OpTy, amx::TileStoreOp>>>
static LogicalResult tileTransferVerifier(OpTy op) {
MemRefType memrefTy = op.getMemRefType();
unsigned rank = memrefTy.getRank();
if (rank < 2)
return emitOpError("requires at least 2D memref");
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return emitOpError("requires memref with unit innermost stride");
return verifyTileSize(*this, getTileType());
if (op.getIndices().size() != rank)
return op.emitOpError("requires ") << rank << " indices";

if (failed(verifyTileSize(op, op.getTileType())))
return failure();

// Validate basic buffer properties when the stride is implicit.
if (!op.getStride()) {
if (rank < 2)
return op.emitOpError("requires at least 2D memref");
SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return op.emitOpError("requires memref with unit innermost stride");
}

return success();
}

void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res,
Value base, ValueRange indices) {
build(builder, state, res, base, indices, /*stride=*/nullptr);
}

LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); }

SmallVector<Value>
amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
Expand All @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
intrinsicOperands.push_back(
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));

return intrinsicOperands;
}

LogicalResult amx::TileStoreOp::verify() {
MemRefType memrefTy = getMemRefType();
unsigned rank = memrefTy.getRank();
if (rank < 2)
return emitOpError("requires at least 2D memref");
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
SmallVector<int64_t> strides;
int64_t offset;
if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return emitOpError("requires memref with unit innermost stride");
return verifyTileSize(*this, getTileType());
void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state,
Value base, ValueRange indices, Value val) {
build(builder, state, base, indices, val, /*stride=*/nullptr);
}

LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); }

SmallVector<Value>
amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
const LLVMTypeConverter &typeConverter,
Expand All @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
intrinsicOperands.push_back(
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
adaptor.getBase(), adaptor.getIndices()));
intrinsicOperands.push_back(
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
if (Value stride = adaptor.getStride())
intrinsicOperands.push_back(
computeStrideInBytes(loc, getMemRefType(), stride, rewriter));
else
intrinsicOperands.push_back(
inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
intrinsicOperands.push_back(adaptor.getVal());

return intrinsicOperands;
Expand Down
Loading