Skip to content

Commit

Permalink
[mlir][ArmSME] Introduce custom TypeConverter for ArmSME
Browse files Browse the repository at this point in the history
At the moment, SME-to-LLVM lowerings rely entirely on
`LLVMTypeConverter`. This patch introduces a dedicated `TypeConverter`
that inherits from `LLVMTypeConverter` (it will also be used when
lowering ArmSME Ops to LLVM).

The new type converter merely disables lowerings for `VectorType` to
prevent 2-d scalable vectors (common in the context of ArmSME), e.g.

   `vector<[16]x[16]xi8>`,

entering the LLVM Type converter. LLVM does not support arrays of
scalable vectors and hence the need for specialisation. In the case of
SME such types are effectively eliminated when emitting LLVM IR
intrinsics for SME.

Differential Revision: https://reviews.llvm.org/D155365
  • Loading branch information
banach-space committed Jul 18, 2023
1 parent e65cabb commit 3fa5ee6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 1 deletion.
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Pass/Pass.h"

namespace mlir {

class RewritePatternSet;

namespace arm_sme {
//===----------------------------------------------------------------------===//
// The EnableArmStreaming pass.
//===----------------------------------------------------------------------===//
// Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
// the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
// In a locally streaming function PSTATE.SM is kept internal and the callee
Expand All @@ -33,6 +37,14 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
std::unique_ptr<Pass> createTileAllocationPass();

//===----------------------------------------------------------------------===//
// Type ArmSMETypeConverter pass.
//===----------------------------------------------------------------------===//
class ArmSMETypeConverter : public LLVMTypeConverter {
public:
ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
};

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
MLIRArmNeonDialect
MLIRArmSMEDialect
MLIRArmSMETransforms
MLIRVectorToArmSME
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
Expand Down Expand Up @@ -96,6 +97,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);

if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
// LLVM-compatible operations here. So far, all operations in the dialect
Expand All @@ -108,7 +111,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
}
if (armSME) {
configureArmSMELegalizeForExportTarget(target);
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
Expand Down
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/ArmSME/Transforms/Passes.h"

using namespace mlir;
arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
MLIRContext *ctx, const LowerToLLVMOptions &options)
: LLVMTypeConverter(ctx, options) {
// Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
// vectors (common in the context of ArmSME), e.g.
// `vector<[16]x[16]xi8>`,
// entering the LLVM Type converter. LLVM does not support arrays of scalable
// vectors, but in the case of SME such types are effectively eliminated when
// emitting ArmSME LLVM IR intrinsics.
addConversion([&](VectorType type) { return type; });
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
LegalizeForLLVMExport.cpp
TileAllocation.cpp
Expand Down

0 comments on commit 3fa5ee6

Please sign in to comment.