diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 1236fede4d88b..cace63d32fd80 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -149,10 +149,13 @@ def TileZeroOp : AMX_Op<"tile_zero", [ let summary = "tile zero operation"; let description = [{ Zeroes the destination tile, with the shape defined by the 2-dim - vector type of the result. This is eventually lowered into the - "tilezero" instruction with the corresponding tile configuration. - With memory-effects, each "tilezero" operation serves as a compilation - hint to use a separate tile register. + vector type of the result. + + The operation is eventually lowered into the "tilezero" instruction + with the corresponding tile configuration. + + With the write memory effect, each `amx.tile_zero` operation serves as + a compilation hint to use a separate tile register. Example: @@ -184,25 +187,53 @@ def TileZeroOp : AMX_Op<"tile_zero", [ def TileLoadOp : AMX_Op<"tile_load", [ AMXIntrinsicOpInterface, - MemoryEffects<[MemWrite]> + MemoryEffects<[MemWrite]>, + AttrSizedOperandSegments ]> { let summary = "tile load operation"; let description = [{ - Loads a tile from memory defined by a base and indices, with the - shape defined by the 2-dim vector type of the result. This is - eventually lowered into the "tileloadd" instruction with the - corresponding tile configuration. With memory-effects, each "tileload" - operation serves as a compilation hint to use a separate tile register. + Loads a tile from memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the result. + The tile's rows are populated by reading contiguous elements starting + at the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is loaded using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + tile[row, col] = mem_row[col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tileloadd" instruction + with the corresponding tile configuration. + + With the write memory effect, each `amx.tile_load` operation serves as + a compilation hint to use a separate tile register. Example: ```mlir + // Tile load from a 2-D memref with implicit stride. %0 = amx.tile_load %arg0[%c0, %c0] : memref into !amx.tile<16x64xi8> + + // Tile load from a 1-D memref with explicit stride. + %0 = amx.tile_load %arg0[%c0], %stride : memref into !amx.tile<16x64xi8> ``` }]; let arguments = (ins Arg:$base, - Variadic:$indices); + Variadic:$indices, + Optional:$stride); let results = (outs AnyAMXTile:$res); + let builders = [ + OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast(getBase().getType()); @@ -219,30 +250,56 @@ def TileLoadOp : AMX_Op<"tile_load", [ const ::mlir::LLVMTypeConverter &typeConverter, ::mlir::RewriterBase &rewriter); }]; - let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " - "type($base) `into` qualified(type($res))"; + let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict" + "`:` type($base) `into` qualified(type($res))"; let hasVerifier = 1; } def TileStoreOp : AMX_Op<"tile_store", [ - AMXIntrinsicOpInterface + AMXIntrinsicOpInterface, + AttrSizedOperandSegments ]> { let summary = "tile store operation"; let description = [{ - Stores a tile to memory defined by a base and indices, with the - shape defined by the 2-dim vector type of the value. This is - eventually lowered into the "tilestored" instruction with the - corresponding tile configuration. + Stores a tile to memory defined by a `base` and `indices`, with the + shape defined by the 2-dim vector type of the value. + The tile's rows are written contiguously to the buffer starting at + the `base`. For each tile row, the `base` is incremented by `stride` + number of elements. + + The tile is stored using the following indexing scheme: + + ``` + for row in enumerate(tile_rows): + mem_row = base[i0, i1, ..., iN + row * stride] + for col in enumerate(tile_cols): + mem_row[col] = tile[row, col] + ``` + + If the `stride` is not provided, then the `base` buffer must be at least + 2-dimensional, and the `stride` is automatically inferred and corresponds + to the stride of the buffer's second innermost dimension. + + The operation is eventually lowered into the "tilestored" instruction + with the corresponding tile configuration. Example: ```mlir + // Tile store to a 2-D memref with implicit stride. amx.tile_store %arg1[%c0, %c0], %0 : memref, !amx.tile<16x64xi8> + + // Tile store to a 1-D memref with explicit stride. + amx.tile_store %arg1[%c0], %0, %stride : memref, !amx.tile<16x64xi8> ``` }]; let arguments = (ins Arg:$base, Variadic:$indices, - AnyAMXTile:$val); + AnyAMXTile:$val, + Optional:$stride); + let builders = [ + OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)> + ]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return ::llvm::cast(getBase().getType()); @@ -259,8 +316,8 @@ def TileStoreOp : AMX_Op<"tile_store", [ const ::mlir::LLVMTypeConverter &typeConverter, ::mlir::RewriterBase &rewriter); }]; - let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " - "type($base) `,` qualified(type($val))"; + let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?" + "attr-dict `:` type($base) `,` qualified(type($val))"; let hasVerifier = 1; } @@ -276,8 +333,10 @@ def TileMulFOp : AMX_Op<"tile_mulf", [Pure, let description = [{ Multiplies a "m x k" tile with a "k x n" tile and accumulates the results into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with - pairs of "bf16"). The operation is eventually lowered into the - "tdpbf16ps" instruction with the corresponding tile configuration. + pairs of "bf16"). + + The operation is eventually lowered into the "tdpbf16ps" instruction with + the corresponding tile configuration. Example: @@ -330,9 +389,11 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure, into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with - the attributes and default to sign extended). The operation is eventually - lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" - instructions with the corresponding tile configuration. + the attributes and default to sign extended). + + The operation is eventually lowered into one of the "tdpbssd", + "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding + tile configuration. Example: diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 68990ef0dc0c3..d9c097c9a3c6f 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -80,10 +80,22 @@ static SmallVector getTileSizes(Location loc, amx::TileType tType, LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } +/// Returns stride expressed in number of bytes for the given `elementStride` +/// stride encoded in number of elements of the type `mType`. +static Value computeStrideInBytes(Location loc, MemRefType mType, + Value elementStride, RewriterBase &rewriter) { + Type llvmInt64Type = rewriter.getIntegerType(64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() / 8; + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, elementStride) + .getResult(); +} + /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer /// shape may "envelop" the actual tile shape, and may be dynamically sized. -static Value getStride(Location loc, MemRefType mType, Value base, - RewriterBase &rewriter) { +static Value inferStride(Location loc, MemRefType mType, Value base, + RewriterBase &rewriter) { assert(mType.getRank() >= 2 && "Invalid shape for AMX strides"); int64_t preLast = mType.getRank() - 2; Type llvmInt64Type = rewriter.getIntegerType(64); @@ -94,11 +106,8 @@ static Value getStride(Location loc, MemRefType mType, Value base, if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); - auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); - return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) - .getResult(); + return computeStrideInBytes( + loc, mType, memrefDescriptor.stride(rewriter, loc, preLast), rewriter); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); @@ -117,21 +126,39 @@ amx::TileZeroOp::getIntrinsicOperands(ArrayRef operands, return getTileSizes(getLoc(), getTileType(), rewriter); } -LogicalResult amx::TileLoadOp::verify() { - MemRefType memrefTy = getMemRefType(); +template || + std::is_same_v>> +static LogicalResult tileTransferVerifier(OpTy op) { + MemRefType memrefTy = op.getMemRefType(); unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); + if (op.getIndices().size() != rank) + return op.emitOpError("requires ") << rank << " indices"; + + if (failed(verifyTileSize(op, op.getTileType()))) + return failure(); + + // Validate basic buffer properties when the stride is implicit. + if (!op.getStride()) { + if (rank < 2) + return op.emitOpError("requires at least 2D memref"); + SmallVector strides; + int64_t offset; + if (failed(memrefTy.getStridesAndOffset(strides, offset)) || + strides.back() != 1) + return op.emitOpError("requires memref with unit innermost stride"); + } + + return success(); +} + +void amx::TileLoadOp::build(OpBuilder &builder, OperationState &state, Type res, + Value base, ValueRange indices) { + build(builder, state, res, base, indices, /*stride=*/nullptr); } +LogicalResult amx::TileLoadOp::verify() { return tileTransferVerifier(*this); } + SmallVector amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, const LLVMTypeConverter &typeConverter, @@ -144,27 +171,23 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); return intrinsicOperands; } -LogicalResult amx::TileStoreOp::verify() { - MemRefType memrefTy = getMemRefType(); - unsigned rank = memrefTy.getRank(); - if (rank < 2) - return emitOpError("requires at least 2D memref"); - if (getIndices().size() != rank) - return emitOpError("requires ") << rank << " indices"; - SmallVector strides; - int64_t offset; - if (failed(memrefTy.getStridesAndOffset(strides, offset)) || - strides.back() != 1) - return emitOpError("requires memref with unit innermost stride"); - return verifyTileSize(*this, getTileType()); +void amx::TileStoreOp::build(OpBuilder &builder, OperationState &state, + Value base, ValueRange indices, Value val) { + build(builder, state, base, indices, val, /*stride=*/nullptr); } +LogicalResult amx::TileStoreOp::verify() { return tileTransferVerifier(*this); } + SmallVector amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, const LLVMTypeConverter &typeConverter, @@ -177,8 +200,12 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef operands, intrinsicOperands.push_back( LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), adaptor.getBase(), adaptor.getIndices())); - intrinsicOperands.push_back( - getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); + if (Value stride = adaptor.getStride()) + intrinsicOperands.push_back( + computeStrideInBytes(loc, getMemRefType(), stride, rewriter)); + else + intrinsicOperands.push_back( + inferStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); intrinsicOperands.push_back(adaptor.getVal()); return intrinsicOperands; diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir index 7e562b00a46a9..a109f42e9dea3 100644 --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -60,30 +60,74 @@ func.func @mulfp16(%arg0: memref, %arg1: memref) { return } -// CHECK-LABEL: strides( -// CHECK: %[[CST_64_1:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_1]] -// CHECK: %[[CST_128_1:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_1]] -// CHECK: llvm.mlir.constant(2 : i64) : i64 +/// Intrinsics require stride in number of bytes. +// CHECK-LABEL: strides_implicit( +// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]] +// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]] // CHECK: llvm.extractvalue %{{.+}}[4, 0] -// CHECK: %[[STRIDE_1:.+]] = llvm.mul -// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_1]] -// CHECK: %[[CST_64_2:.+]] = llvm.mlir.constant(64 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_64_2]] -// CHECK: %[[CST_128_2:.+]] = llvm.mlir.constant(128 : i64) : i64 -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[CST_128_2]] -// CHECK: llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[LOAD_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[LOAD_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE]], %[[LOAD_BUF_STRIDE]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]] +// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mlir.constant(32 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]] +// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mlir.constant(128 : i64) : i64 +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]] // CHECK: llvm.extractvalue %{{.+}}[4, 0] -// CHECK: %[[STRIDE_2:.+]] = llvm.mul -// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STRIDE_2]] -func.func @strides(%arg0: memref<16x32xbf16>, %arg1: memref<16x32xbf16, strided<[64, 1]>>, %arg2: memref<16x32xbf16, strided<[?, 1]>>) { +// CHECK: %[[STORE_BUF_STRIDE:.+]] = llvm.extractvalue %{{.+}}[4, 0] +// CHECK: %[[STORE_STRIDE_SCALE:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE]], %[[STORE_BUF_STRIDE]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]] +func.func @strides_implicit(%arg0: memref<16x32xi8>, + %arg1: memref<32x32xbf16, strided<[64, 1]>>, + %arg2: memref<16x32xf32, strided<[?, 1]>>) { %0 = arith.constant 0 : index - %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xbf16> into !amx.tile<16x32xbf16> - %2 = amx.tile_load %arg1[%0, %0] : memref<16x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> - %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xbf16, strided<[?, 1]>> into !amx.tile<16x32xbf16> - amx.tile_store %arg0[%0, %0], %3 : memref<16x32xbf16>, !amx.tile<16x32xbf16> - amx.tile_store %arg1[%0, %0], %1 : memref<16x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> - amx.tile_store %arg2[%0, %0], %2 : memref<16x32xbf16, strided<[?, 1]>>, !amx.tile<16x32xbf16> + %1 = amx.tile_load %arg0[%0, %0] : memref<16x32xi8> into !amx.tile<16x32xi8> + %2 = amx.tile_load %arg1[%0, %0] : memref<32x32xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref<16x32xf32, strided<[?, 1]>> into !amx.tile<16x16xf32> + amx.tile_store %arg0[%0, %0], %1 : memref<16x32xi8>, !amx.tile<16x32xi8> + amx.tile_store %arg1[%0, %0], %2 : memref<32x32xbf16, strided<[64, 1]>>, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %3 : memref<16x32xf32, strided<[?, 1]>>, !amx.tile<16x16xf32> + return +} + +/// Intrinsics require stride in number of bytes. +// CHECK-LABEL: strides_explicit( +// CHECK-SAME: %[[STRIDE:.+]]: index +// CHECK-DAG: %[[STRIDE_I64:.+]] = builtin.unrealized_conversion_cast %[[STRIDE]] : index to i64 +// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[C64_I64:.+]] = builtin.unrealized_conversion_cast %[[C64]] : index to i64 +// CHECK: %[[LOAD_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_1:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_1]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_1]] +// CHECK: %[[LOAD_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_2:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_2]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_2]] +// CHECK: %[[LOAD_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[LOAD_STRIDE_3:.+]] = llvm.mul %[[LOAD_STRIDE_SCALE_3]], %[[C64_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tileloadd64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[LOAD_STRIDE_3]] +// CHECK: %[[STORE_STRIDE_SCALE_1:.+]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[STORE_STRIDE_1:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_1]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_1]] +// CHECK: %[[STORE_STRIDE_SCALE_2:.+]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[STORE_STRIDE_2:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_2]], %[[STRIDE_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_2]] +// CHECK: %[[STORE_STRIDE_SCALE_3:.+]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK: %[[STORE_STRIDE_3:.+]] = llvm.mul %[[STORE_STRIDE_SCALE_3]], %[[C64_I64]] +// CHECK: llvm.call_intrinsic "llvm.x86.tilestored64.internal"(%{{.+}}, %{{.+}}, %{{.+}}, %[[STORE_STRIDE_3]] +func.func @strides_explicit(%stride: index, + %arg0: memref, + %arg1: memref<16x32xbf16>, + %arg2: memref<32x32xf32, strided<[64, 1]>>) { + %0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = amx.tile_load %arg0[%0], %stride : memref into !amx.tile<16x32xi8> + %2 = amx.tile_load %arg1[%0, %0], %stride : memref<16x32xbf16> into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0], %c64 : memref<32x32xf32, strided<[64, 1]>> into !amx.tile<16x16xf32> + amx.tile_store %arg0[%0], %1, %stride : memref, !amx.tile<16x32xi8> + amx.tile_store %arg1[%0, %0], %2, %stride : memref<16x32xbf16>, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %3, %c64 : memref<32x32xf32, strided<[64, 1]>>, !amx.tile<16x16xf32> return } diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir index 1b7f781ae173d..3d0f276df6a26 100644 --- a/mlir/test/Dialect/AMX/roundtrip.mlir +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -1,5 +1,33 @@ // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s +// CHECK-LABEL: tloadstore +// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}], %{{.*}} : +// CHECK-SAME: memref into !amx.tile<16x32xbf16> +// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : +// CHECK-SAME: memref into !amx.tile<16x32xbf16> +// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref> into !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}], %[[z]], %{{.*}} : +// CHECK-SAME: memref, !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[x]], %{{.*}} : +// CHECK-SAME: memref, !amx.tile<16x32xbf16> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[y]] : +// CHECK-SAME: memref>, !amx.tile<16x32xbf16> +func.func @tloadstore(%stride: index, + %arg0: memref, + %arg1: memref, + %arg2: memref>) { + %0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %1 = amx.tile_load %arg0[%0], %stride : memref into !amx.tile<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0], %stride : memref into !amx.tile<16x32xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref> into !amx.tile<16x32xbf16> + amx.tile_store %arg0[%0], %3, %stride : memref, !amx.tile<16x32xbf16> + amx.tile_store %arg1[%0, %0], %1, %stride : memref, !amx.tile<16x32xbf16> + amx.tile_store %arg2[%0, %0], %2 : memref>, !amx.tile<16x32xbf16> + return +} + // CHECK-LABEL: tzero // CHECK: amx.tile_zero : !amx.tile<16x16xbf16> // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, !amx.tile<16x16xbf16> diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir index abdf2fe3bd534..160a9ced46e21 100644 --- a/mlir/test/Target/LLVMIR/amx.mlir +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -23,6 +23,19 @@ func.func @amx_tile_load_store(%base: memref, %out: memref, return } +// CHECK-LABEL: define void @amx_tile_load_store_strided +func.func @amx_tile_load_store_strided(%base: memref, %out: memref, + %idx: index, %stride: index) +{ + // CHECK: call x86_amx @llvm.x86.tileloadd64.internal + // CHECK: call void @llvm.x86.tilestored64.internal + %val = amx.tile_load %base[%idx], %stride + : memref into !amx.tile<16x64xi8> + amx.tile_store %out[%idx], %val, %stride + : memref, !amx.tile<16x64xi8> + return +} + // CHECK-LABEL: define void @amx_tile_mulf_bf16 func.func @amx_tile_mulf_bf16( %matA: memref, %matB: memref, %idx: index,