Skip to content

Commit

Permalink
[mlir][ArmSME] Add rudimentary support for tile spills to the stack (l…
Browse files Browse the repository at this point in the history
…lvm#76086)

This adds very basic (and inelegant) support for something like spilling
and reloading tiles, if you use more SME tiles than physically exist.

This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.

Currently, this works in two stages:

During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:

```
warning: failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance
```

Everything after this works like normal until `-convert-arm-sme-to-llvm`

Here the in-memory tile op:

```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```

Is lowered to:

```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>

// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
  %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
  "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx)  <{tile_id = 0 : i32}>
  vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```

This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.

Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.
  • Loading branch information
MacDue committed Jan 12, 2024
1 parent 8751bbe commit 5417a5f
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 61 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
Expand Up @@ -25,8 +25,9 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"

namespace mlir::arm_sme {
static constexpr unsigned kInMemoryTileIdBase = 16;
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
}
} // namespace mlir::arm_sme

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
Expand Down
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Expand Up @@ -97,6 +97,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
// This operation does not allocate a tile.
return std::nullopt;
}]
>,
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
/*methodName=*/"getTileType"
>
];

Expand All @@ -117,6 +122,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
rewriter.replaceOp($_op, newOp);
return newOp;
}

bool isInMemoryTile() {
auto tileId = getTileId();
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
}
}];

let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
Expand Down Expand Up @@ -331,6 +341,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
VectorType getTileType() {
return getVectorType();
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
}
Expand Down Expand Up @@ -407,6 +420,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
VectorType getTileType() {
return getVectorType();
}
}];

let builders = [
Expand Down Expand Up @@ -475,6 +491,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getTileType() {
return getVectorType();
}
}];

let builders = [
Expand Down Expand Up @@ -539,6 +558,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
VectorType getTileType() {
return getVectorType();
}
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -596,6 +618,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
VectorType getTileType() {
return getVectorType();
}
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -688,6 +713,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [

let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
VectorType getTileType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
}];

let assemblyFormat = [{
Expand Down Expand Up @@ -780,6 +808,9 @@ let arguments = (ins
return arm_sme::getSMETileType(getResultType());
return std::nullopt;
}
VectorType getTileType() {
return getResultType();
}
}];
}

Expand Down

0 comments on commit 5417a5f

Please sign in to comment.