Skip to content

Commit

Permalink
[mlir][ArmSME] Add basic lowering of vector.transfer_write to zero
Browse files Browse the repository at this point in the history
This patch adds support for lowering a 'vector.transfer_write' of zeroes
and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1],
which zeroes the entire accumulator, and then writing it out to memory
with the 'str' instruction [2].

This contributes to supporting a path from 'linalg.fill' to SME.

[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
[2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array-

Reviewed By: awarzynski, dcaballe, WanderAway

Differential Revision: https://reviews.llvm.org/D152508
  • Loading branch information
c-rhodes committed Jul 3, 2023
1 parent 39b5a02 commit 564713c
Show file tree
Hide file tree
Showing 11 changed files with 379 additions and 1 deletion.
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_ARMSME_IR_ARMSME_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
Expand Up @@ -33,6 +33,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0616
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect"];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -119,6 +120,11 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;

def LLVM_aarch64_sme_str
: ArmSME_IntrOp<"str">,
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;

def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
Expand Up @@ -15,6 +15,11 @@ class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;

namespace arm_sme {
void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace arm_sme

/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
/// intrinsics.
void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
Expand Down
Expand Up @@ -109,6 +109,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
arm_sme::populateVectorTransferLoweringPatterns(converter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
Expand Up @@ -10,5 +10,6 @@ add_mlir_dialect_library(MLIRArmSMEDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRLLVMDialect
MLIRSCFDialect
MLIRSideEffectInterfaces
)
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
LegalizeForLLVMExport.cpp
LowerVectorOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
Expand All @@ -12,5 +13,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRSCFDialect
MLIRPass
)
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

using namespace mlir;
using namespace mlir::arm_sme;
Expand Down Expand Up @@ -51,7 +52,8 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(

void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
target.addLegalOp<arm_sme::aarch64_sme_za_enable,
target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();

// Mark 'func.func' ops as legal if either:
Expand Down
111 changes: 111 additions & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
@@ -0,0 +1,111 @@
//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrite patterns to lower vector dialect ops to ArmSME.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::arm_sme;

static constexpr unsigned kMinNumElts = 16;
static constexpr unsigned kZeroZAMask = 255;

/// Returns true if 'val' is a splat of zero, false otherwise.
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
if (llvm::isa<IntegerType>(elemType))
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
return false;
}

namespace {
/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only
/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B
/// SME virtual tile. This will be extended to support more element types.
struct TransferWriteToArmSMEZeroLowering
: public ConvertOpToLLVMPattern<vector::TransferWriteOp> {
using ConvertOpToLLVMPattern<vector::TransferWriteOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto vType = write.getVectorType();
if (vType.getRank() != 2)
return failure();
if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts}))
return failure();
if (vType.getElementType() != rewriter.getI8Type())
return failure();
if (vType.getScalableDims().size() != 2)
return failure();

auto memRefType = llvm::dyn_cast<MemRefType>(write.getSource().getType());
if (!memRefType)
return failure();

auto constant = write.getVector().getDefiningOp<arith::ConstantOp>();
if (!constant)
return failure();

auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr());
if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr))
return failure();

auto loc = write.getLoc();

// Create 'arm_sme.intr.zero' intrinsic to zero ZA.
auto tile = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask));
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tile);

// Create loop that iterates from 0 to SVLB-1 inclusive (the number of
// vectors in ZA) and stores each ZA vector to memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
rewriter.setInsertionPointToStart(forOp.getBody());

// Create 'arm_sme.intr.str' intrinsic to store ZA vector.
auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI64Type(), forOp.getInductionVar());
auto offset =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(),
ValueRange{vnumI64, offset}, rewriter);
auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), forOp.getInductionVar());
rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);

rewriter.eraseOp(write);

return success();
}
};
} // namespace

void mlir::arm_sme::populateVectorTransferLoweringPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<TransferWriteToArmSMEZeroLowering>(converter);
}
104 changes: 104 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-ops.mlir
@@ -0,0 +1,104 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s

// CHECK-LABEL: @transfer_write_2d_zero_i8
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64
// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
return
}

// -----

// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.

// CHECK-LABEL: @transfer_write_2d_zero__bad_type
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi4>
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4>
return
}

// -----

// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[8]x[8]xi8>
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8>
return
}

// -----

// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8>
vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8>
return
}

// -----

// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}

// -----

// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<1> : vector<[16]x[16]xi8>
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
return
}

// -----

// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.intr.zero
func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
return
}

0 comments on commit 564713c

Please sign in to comment.