Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][ArmSME] Add basic lowering of vector.transfer_write to zero
This patch adds support for lowering a 'vector.transfer_write' of zeroes and type 'vector<[16x16]xi8>' to the SME 'zero {za}' instruction [1], which zeroes the entire accumulator, and then writing it out to memory with the 'str' instruction [2]. This contributes to supporting a path from 'linalg.fill' to SME. [1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- [2] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/STR--Store-vector-from-ZA-array- Reviewed By: awarzynski, dcaballe, WanderAway Differential Revision: https://reviews.llvm.org/D152508
- Loading branch information
Showing
11 changed files
with
379 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements rewrite patterns to lower vector dialect ops to ArmSME. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" | ||
#include "mlir/Conversion/LLVMCommon/Pattern.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/ArmSME/IR/ArmSME.h" | ||
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" | ||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
using namespace mlir; | ||
using namespace mlir::arm_sme; | ||
|
||
static constexpr unsigned kMinNumElts = 16; | ||
static constexpr unsigned kZeroZAMask = 255; | ||
|
||
/// Returns true if 'val' is a splat of zero, false otherwise. | ||
static bool isSplatZero(Type elemType, DenseElementsAttr val) { | ||
if (llvm::isa<FloatType>(elemType)) | ||
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); | ||
if (llvm::isa<IntegerType>(elemType)) | ||
return val && val.isSplat() && val.getSplatValue<APInt>().isZero(); | ||
return false; | ||
} | ||
|
||
namespace { | ||
/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only | ||
/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B | ||
/// SME virtual tile. This will be extended to support more element types. | ||
struct TransferWriteToArmSMEZeroLowering | ||
: public ConvertOpToLLVMPattern<vector::TransferWriteOp> { | ||
using ConvertOpToLLVMPattern<vector::TransferWriteOp>::ConvertOpToLLVMPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto vType = write.getVectorType(); | ||
if (vType.getRank() != 2) | ||
return failure(); | ||
if (vType.getShape() != ArrayRef<int64_t>({kMinNumElts, kMinNumElts})) | ||
return failure(); | ||
if (vType.getElementType() != rewriter.getI8Type()) | ||
return failure(); | ||
if (vType.getScalableDims().size() != 2) | ||
return failure(); | ||
|
||
auto memRefType = llvm::dyn_cast<MemRefType>(write.getSource().getType()); | ||
if (!memRefType) | ||
return failure(); | ||
|
||
auto constant = write.getVector().getDefiningOp<arith::ConstantOp>(); | ||
if (!constant) | ||
return failure(); | ||
|
||
auto denseAttr = dyn_cast<DenseElementsAttr>(constant.getValueAttr()); | ||
if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) | ||
return failure(); | ||
|
||
auto loc = write.getLoc(); | ||
|
||
// Create 'arm_sme.intr.zero' intrinsic to zero ZA. | ||
auto tile = rewriter.create<arith::ConstantOp>( | ||
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); | ||
rewriter.create<arm_sme::aarch64_sme_zero>(loc, tile); | ||
|
||
// Create loop that iterates from 0 to SVLB-1 inclusive (the number of | ||
// vectors in ZA) and stores each ZA vector to memory. | ||
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); | ||
auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts); | ||
auto vscale = | ||
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); | ||
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||
auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale); | ||
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); | ||
rewriter.setInsertionPointToStart(forOp.getBody()); | ||
|
||
// Create 'arm_sme.intr.str' intrinsic to store ZA vector. | ||
auto vnumI64 = rewriter.create<arith::IndexCastUIOp>( | ||
loc, rewriter.getI64Type(), forOp.getInductionVar()); | ||
auto offset = | ||
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0); | ||
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(), | ||
ValueRange{vnumI64, offset}, rewriter); | ||
auto vnumI32 = rewriter.create<arith::IndexCastUIOp>( | ||
loc, rewriter.getI32Type(), forOp.getInductionVar()); | ||
rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr); | ||
|
||
rewriter.eraseOp(write); | ||
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void mlir::arm_sme::populateVectorTransferLoweringPatterns( | ||
LLVMTypeConverter &converter, RewritePatternSet &patterns) { | ||
patterns.add<TransferWriteToArmSMEZeroLowering>(converter); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero_i8 | ||
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>) | ||
// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
// CHECK: %[[C255:.*]] = arith.constant 255 : i32 | ||
// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () | ||
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index | ||
// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index | ||
// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 | ||
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index | ||
// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index | ||
// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index | ||
// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { | ||
// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 | ||
// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 | ||
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> | ||
// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 | ||
// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 | ||
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 | ||
// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 | ||
// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () | ||
func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<0> : vector<[16]x[16]xi8> | ||
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' | ||
// lowering only occurs for vector types of correct rank, shape, element size | ||
// and number of scalable dims. | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__bad_type | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__bad_type(%arg0 : memref<?x?xi4>) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<0> : vector<[16]x[16]xi4> | ||
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi4>, memref<?x?xi4> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__bad_shape | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref<?x?xi8>) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<0> : vector<[8]x[8]xi8> | ||
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi8>, memref<?x?xi8> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__bad_rank | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref<?x?x?xi8>) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8> | ||
vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16]x[16]x[16]xi8>, memref<?x?x?xi8> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor<?x?xi8>) -> tensor<?x?xi8> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<0> : vector<[16]x[16]xi8> | ||
%0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor<?x?xi8> | ||
return %0 : tensor<?x?xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref<?x?xi8>) { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant dense<1> : vector<[16]x[16]xi8> | ||
vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op | ||
// CHECK: vector.transfer_write | ||
// CHECK-NOT: arm_sme.intr.zero | ||
func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref<?x?xi8>, %arg1 : vector<[16]x[16]xi8>) { | ||
%c0 = arith.constant 0 : index | ||
vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
return | ||
} |
Oops, something went wrong.