Skip to content

Commit

Permalink
[mlir][ArmSME] Add tile load op and extend tile store tile size support
Browse files Browse the repository at this point in the history
This extends the existing 'arm_sme.tile_store' op to support all tile
sizes and adds a new op 'arm_sme.tile_load', as well as lowerings from
vector -> custom ops and custom ops -> intrinsics. Currently there's no
lowering for i128.

Depends on D154867

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D155306
  • Loading branch information
c-rhodes committed Jul 25, 2023
1 parent cee4494 commit ca9a335
Show file tree
Hide file tree
Showing 12 changed files with 1,069 additions and 66 deletions.
69 changes: 61 additions & 8 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,74 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}

def TileLoadOp : ArmSME_Op<"tile_load"> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the result tile.
The slice of memory must be contiguous. The memref must be either rank 1 or
rank 2 with dynamic dimensions, since the operation is scalable, and the
element type must be a scalar that matches the element type of the result.

Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```

Example 2: Load a FP 32-bit element ZA tile from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
```

Example 3: Load a 128-bit element ZA tile from memory.
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs SMETile:$result);

let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
}];

let assemblyFormat =
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
}

def TileStoreOp : ArmSME_Op<"tile_store"> {
let summary = "Tile store operation";
let description = [{
Store a 2D SME "virtual tile" to memory.

NOTE: At the moment it is assumed that the element type is `i8` and that
there's only one "virtual tile".
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
stored. The slice of memory must be contiguous. The memref must be either
rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
and the element type must be a scalar that matches the element type of the
result.

Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```

Example:
Example 2: Store a FP 32-bit element ZA tile to memory.
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```

Example 3: Store a 128-bit element ZA tile to memory.
```mlir
arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins nxnxv16i8:$valueToStore,
let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
Expand Down Expand Up @@ -304,7 +357,7 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
class ArmSME_IntrLoadOp<string mnemonic>
: ArmSME_IntrOp<mnemonic>,
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
Arg<LLVM_AnyPointer, "Load address">,
Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">)>;

Expand Down
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines prototypes for various utilities for the ArmSME
// dialect. These are not passes by themselves but are used either by passes,
// optimization sequences, or in turn by other transformation utilities.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"

namespace mlir {
namespace arm_sme {

/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);

/// Returns true if `type` is a valid element type for an SME tile or false
/// otherwise.
bool isValidSMETileElementType(Type type);

/// Returns true if `vType` is a valid vector type for an SME tile or false
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);

} // namespace arm_sme
} // namespace mlir

#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME

LINK_LIBS PUBLIC
MLIRArmSMEDialect
MLIRArmSMEUtils
MLIRLLVMCommonConversion
)
36 changes: 35 additions & 1 deletion mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"

#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"

Expand Down Expand Up @@ -76,9 +77,42 @@ struct TransferWriteToArmSMELowering
}
};

/// Conversion pattern for vector.load.
struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::LoadOp load,
PatternRewriter &rewriter) const override {
if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
return failure();

rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
load, load.getVectorType(), load.getBase(), load.getIndices());

return success();
}
};

/// Conversion pattern for vector.store.
struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::StoreOp store,
PatternRewriter &rewriter) const override {
if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
return failure();

rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
store, store.getValueToStore(), store.getBase(), store.getIndices());

return success();
}
};

} // namespace

void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<TransferWriteToArmSMELowering>(&ctx);
patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering>(&ctx);
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms

LINK_LIBS PUBLIC
MLIRArmSMEDialect
MLIRArmSMEUtils
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
Expand Down

0 comments on commit ca9a335

Please sign in to comment.