diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h new file mode 100644 index 0000000000000..bd65970d5bf77 --- /dev/null +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -0,0 +1,29 @@ +//===- ConvertAVX512ToLLVM.h - Conversion Patterns from AVX512 to LLVM ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ +#define MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ + +#include + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template class OpPassBase; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the AVX512 dialect to LLVM. +void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert AVX512 operations to the LLVMIR dialect. +std::unique_ptr> createConvertAVX512ToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td new file mode 100644 index 0000000000000..917af2e1cc04b --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -0,0 +1,99 @@ +//===-- AVX512Ops.td - AVX512 dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the AVX512 dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef AVX512_OPS +#define AVX512_OPS + +include "mlir/Interfaces/SideEffects.td" + +//===----------------------------------------------------------------------===// +// AVX512 dialect definition +//===----------------------------------------------------------------------===// + +def AVX512_Dialect : Dialect { + let name = "avx512"; + let cppNamespace = "avx512"; +} + +//===----------------------------------------------------------------------===// +// AVX512 op definitions +//===----------------------------------------------------------------------===// + +class AVX512_Op traits = []> : + Op {} + +def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, + AllTypesMatch<["src", "a", "dst"]>, + TypesMatchWith<"imm has the same number of bits as elements in dst", + "dst", "imm", + "IntegerType::get(($_self.cast().getShape()[0])," + " $_self.getContext())">]> { + let summary = "Masked roundscale op"; + let description = [{ + The mask.rndscale op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.mask.rndscale.ps.512` or + `llvm.mask.rndscale.pd.512` instruction depending on the type of vectors it + is applied to. + + From the Intel Intrinsics Guide: + ================================ + Round packed floating-point elements in `a` to the number of fraction bits + specified by `imm`, and store the results in `dst` using writemask `k` + (elements are copied from src when the corresponding mask bit is not set). + }]; + // Supports vector<16xf32> and vector<8xf64>. + let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + I32:$k, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + AnyTypeOf<[I16, I8]>:$imm, + // TODO(ntv): figure rounding out (optional operand?). + I32:$rounding + ); + let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); + let assemblyFormat = + "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)"; +} + +def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect, + AllTypesMatch<["src", "a", "b", "dst"]>, + TypesMatchWith<"k has the same number of bits as elements in dst", + "dst", "k", + "IntegerType::get(($_self.cast().getShape()[0])," + " $_self.getContext())">]> { + let summary = "ScaleF op"; + let description = [{ + The `mask.scalef` op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.mask.scalef.ps.512` or + `llvm.mask.scalef.pd.512` depending on the type of MLIR vectors it is + applied to. + + From the Intel Intrinsics Guide: + ================================ + Scale the packed floating-point elements in `a` using values from `b`, and + store the results in `dst` using writemask `k` (elements are copied from src + when the corresponding mask bit is not set). + }]; + // Supports vector<16xf32> and vector<8xf64>. + let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$b, + AnyTypeOf<[I16, I8]>:$k, + // TODO(ntv): figure rounding out (optional operand?). + I32:$rounding + ); + let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); + // Fully specified by traits. + let assemblyFormat = + "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; +} + +#endif // AVX512_OPS diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h new file mode 100644 index 0000000000000..aeec2b728a113 --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h @@ -0,0 +1,31 @@ +//===- AVX512Dialect.h - MLIR Dialect for AVX512 ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for AVX512 in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AVX512_AVX512DIALECT_H_ +#define MLIR_DIALECT_AVX512_AVX512DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffects.h" + +namespace mlir { +namespace avx512 { + +#define GET_OP_CLASSES +#include "mlir/Dialect/AVX512/AVX512.h.inc" + +#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc" + +} // namespace avx512 +} // namespace mlir + +#endif // MLIR_DIALECT_AVX512_AVX512DIALECT_H_ diff --git a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt new file mode 100644 index 0000000000000..5868760077a68 --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(AVX512 avx512 AVX512Doc) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt index 27cbe93783469..32b24264ba69b 100644 --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineOps) +add_subdirectory(AVX512) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 796b4a68a2b14..99f55c2fb0ef2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -20,3 +20,9 @@ add_public_tablegen_target(MLIRNVVMConversionsIncGen) set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRROCDLConversionsIncGen) + +add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512Doc) + +set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) +mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td new file mode 100644 index 0000000000000..12668c4da41be --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td @@ -0,0 +1,52 @@ +//===-- LLVMAVX512.td - LLVMAVX512 dialect op definitions --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMAVX512 dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_AVX512_OPS +#define LLVMIR_AVX512_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMAVX512 dialect definition +//===----------------------------------------------------------------------===// + +def LLVMAVX512_Dialect : Dialect { + let name = "llvm_avx512"; + let cppNamespace = "LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMAVX512_IntrOp traits = []> : + LLVM_IntrOpBase; + +def LLVM_x86_avx512_mask_rndscale_ps_512 : + LLVMAVX512_IntrOp<"mask.rndscale.ps.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_rndscale_pd_512 : + LLVMAVX512_IntrOp<"mask.rndscale.pd.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_scalef_ps_512 : + LLVMAVX512_IntrOp<"mask.scalef.ps.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_scalef_pd_512 : + LLVMAVX512_IntrOp<"mask.scalef.pd.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +#endif // AVX512_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h new file mode 100644 index 0000000000000..27b98fd189107 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h @@ -0,0 +1,30 @@ +//===- LLVMAVX512Dialect.h - MLIR Dialect for LLVMAVX512 --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMAVX512 in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace LLVM { + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMAVX512.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc" + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index dbd7b0b0b9828..74b42243f2360 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -510,8 +510,9 @@ class OpAsmParser { ArrayRef(type), loc, result); } template - ParseResult resolveOperands(Operands &&operands, Types &&types, - llvm::SMLoc loc, SmallVectorImpl &result) { + std::enable_if_t::value, ParseResult> + resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc, + SmallVectorImpl &result) { size_t operandSize = std::distance(operands.begin(), operands.end()); size_t typeSize = std::distance(types.begin(), types.end()); if (operandSize != typeSize) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index c0a7ca04081f8..9a14a1586c7ff 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,9 +14,11 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ +#include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -38,8 +40,10 @@ namespace mlir { inline void registerAllDialects() { static bool init_once = []() { registerDialect(); + registerDialect(); registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h index b358cfa8802ed..c1cac45816df7 100644 --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -15,6 +15,7 @@ #define MLIR_INITALLPASSES_H_ #include "mlir/Analysis/Passes.h" +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" @@ -78,6 +79,9 @@ inline void registerAllPasses() { createSymbolDCEPass(); createLocationSnapshotPass({}); + // AVX512 + createConvertAVX512ToLLVMPass(); + // GPUtoRODCLPass createLowerGpuOpsToROCDLOpsPass(); diff --git a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..5573f6ca1618c --- /dev/null +++ b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRAVX512ToLLVM + ConvertAVX512ToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AVX512ToLLVM +) + +set(LIBS + MLIRAVX512 + MLIRLLVMAVX512 + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + LLVMCore + LLVMSupport + ) + +add_dependencies(MLIRAVX512ToLLVM ${LIBS}) +target_link_libraries(MLIRAVX512ToLLVM PUBLIC ${LIBS}) diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp new file mode 100644 index 0000000000000..7a8c1e81fcb8a --- /dev/null +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -0,0 +1,193 @@ +//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::vector; +using namespace mlir::avx512; + +template Type getSrcVectorElementType(OpTy op) { + return op.src().getType().template cast().getElementType(); +} + +// TODO(ntv, zinenko): Code is currently copy-pasted and adapted from the code +// 1-1 LLVM conversion. It would better if it were properly exposed in core and +// reusable. +/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +LogicalResult matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, + LLVMTypeConverter &typeConverter, + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} + +// TODO(ntv): Patterns are too verbose due to the fact that we have 1 op (e.g. +// MaskRndScaleOp) and different possible target ops. It would be better to take +// a Functor so that all these conversions become 1-liners. +struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern { + explicit MaskRndScaleOpPS512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!getSrcVectorElementType(cast(op)).isF32()) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern { + explicit MaskRndScaleOpPD512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!getSrcVectorElementType(cast(op)).isF64()) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern { + explicit ScaleFOpPS512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!getSrcVectorElementType(cast(op)).isF32()) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern { + explicit ScaleFOpPD512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!getSrcVectorElementType(cast(op)).isF64()) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +/// Populate the given list with patterns that convert from AVX512 to LLVM. +void mlir::populateAVX512ToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + // clang-format off + patterns.insert(ctx, converter); + // clang-format on +} + +namespace { +struct ConvertAVX512ToLLVMPass : public ModulePass { + void runOnModule() override; +}; +} // namespace + +void ConvertAVX512ToLLVMPass::runOnModule() { + // Convert to the LLVM IR dialect. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateAVX512ToLLVMConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed( + applyPartialConversion(getModule(), target, patterns, &converter))) { + signalPassFailure(); + } +} + +std::unique_ptr> mlir::createConvertAVX512ToLLVMPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "convert-avx512-to-llvm", + "Convert the operations from the avx512 dialect into the LLVM dialect"); diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 2f1826a1e2991..fbf3e12594935 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(AVX512ToLLVM) add_subdirectory(GPUToCUDA) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt new file mode 100644 index 0000000000000..0fc6da6240a93 --- /dev/null +++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRAVX512 + IR/AVX512Dialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512 + + DEPENDS + MLIRAVX512IncGen + ) +target_link_libraries(MLIRAVX512 + PUBLIC + MLIRIR + LLVMSupport + ) diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp new file mode 100644 index 0000000000000..aade931ee4e7e --- /dev/null +++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp @@ -0,0 +1,35 @@ +//===- AVX512Ops.cpp - MLIR AVX512 ops implementation ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AVX512 dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +avx512::AVX512Dialect::AVX512Dialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/AVX512/AVX512.cpp.inc" + >(); +} + +namespace mlir { +namespace avx512 { +#define GET_OP_CLASSES +#include "mlir/Dialect/AVX512/AVX512.cpp.inc" +} // namespace avx512 +} // namespace mlir + diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index fe99044a90e65..0bcc794894cc6 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(AVX512) add_subdirectory(AffineOps) add_subdirectory(FxpMathOps) add_subdirectory(GPU) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index 2e53d29f768da..148bc4bef3e88 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -24,6 +24,26 @@ target_link_libraries(MLIRLLVMIR MLIRSupport ) +add_mlir_dialect_library(MLIRLLVMAVX512 + IR/LLVMAVX512Dialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMAVX512IncGen + MLIRLLVMAVX512ConversionsIncGen + ) +target_link_libraries(MLIRLLVMAVX512 + PUBLIC + LLVMAsmParser + MLIRIR + MLIRLLVMIR + MLIRSideEffects + LLVMSupport + LLVMCore + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp new file mode 100644 index 0000000000000..bde81144fb54b --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp @@ -0,0 +1,36 @@ +//===- LLVMAVX512Dialect.cpp - MLIR LLVMAVX512 ops implementation ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMAVX512 dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsX86.h" + +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +LLVM::LLVMAVX512Dialect::LLVMAVX512Dialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc" + >(); +} + +namespace mlir { +namespace LLVM { +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc" +} // namespace LLVM +} // namespace mlir diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index b68bfa8d3cf20..9bc37cab1093e 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -18,6 +18,22 @@ target_link_libraries(MLIRTargetLLVMIRModuleTranslation MLIRTranslation ) +add_mlir_library(MLIRTargetAVX512 + LLVMIR/LLVMAVX512Intr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + DEPENDS + MLIRLLVMAVX512ConversionsIncGen + ) +target_link_libraries(MLIRTargetAVX512 + PUBLIC + MLIRIR + MLIRLLVMAVX512 + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) + add_mlir_library(MLIRTargetLLVMIR LLVMIR/ConvertFromLLVMIR.cpp LLVMIR/ConvertToLLVMIR.cpp diff --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp new file mode 100644 index 0000000000000..216ae862d4b2c --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp @@ -0,0 +1,51 @@ +//===- AVX512Intr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and AVX512 dialects +// and LLVM IR with AVX intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsX86.h" + +using namespace mlir; + +namespace { +class LLVMAVX512ModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr translateLLVMAVX512ModuleToLLVMIR(Operation *m) { + return LLVM::ModuleTranslation::translateModule( + m); +} +} // end namespace + +static TranslateFromMLIRRegistration + reg("avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(module); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }); diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir new file mode 100644 index 0000000000000..936819e27eb90 --- /dev/null +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s + +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: llvm_avx512.mask.rndscale.ps.512 + %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> + // CHECK: llvm_avx512.mask.rndscale.pd.512 + %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> + + // CHECK: llvm_avx512.mask.scalef.ps.512 + %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + // CHECK: llvm_avx512.mask.scalef.pd.512 + %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> + + return %a0, %a1: vector<16xf32>, vector<8xf64> +} diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir new file mode 100644 index 0000000000000..bd23103fa432a --- /dev/null +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: avx512.mask.rndscale {{.*}}: vector<16xf32> + %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32> + // CHECK: avx512.mask.rndscale {{.*}}: vector<8xf64> + %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64> + return %0, %1: vector<16xf32>, vector<8xf64> +} + +func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: avx512.mask.scalef {{.*}}: vector<16xf32> + %0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + // CHECK: avx512.mask.scalef {{.*}}: vector<8xf64> + %1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64> + return %0, %1: vector<16xf32>, vector<8xf64> +} diff --git a/mlir/test/Target/avx512.mlir b/mlir/test/Target/avx512.mlir new file mode 100644 index 0000000000000..5e75a98dc4ef8 --- /dev/null +++ b/mlir/test/Target/avx512.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --avx512-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512 +llvm.func @LLVM_x86_avx512_mask_ps_512(%a: !llvm<"<16 x float>">, + %b: !llvm.i32, + %c: !llvm.i16) + -> (!llvm<"<16 x float>">) +{ + // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> + %0 = "llvm_avx512.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : + (!llvm<"<16 x float>">, !llvm.i32, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>"> + // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> + %1 = "llvm_avx512.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : + (!llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>"> + llvm.return %1: !llvm<"<16 x float>"> +} + +// CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512 +llvm.func @LLVM_x86_avx512_mask_pd_512(%a: !llvm<"<8 x double>">, + %b: !llvm.i32, + %c: !llvm.i8) + -> (!llvm<"<8 x double>">) +{ + // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> + %0 = "llvm_avx512.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : + (!llvm<"<8 x double>">, !llvm.i32, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>"> + // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> + %1 = "llvm_avx512.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : + (!llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>"> + llvm.return %1: !llvm<"<8 x double>"> +} diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt index d665789e5bd0b..bf7a92509912a 100644 --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS MLIRPass MLIRSPIRV MLIRSPIRVSerialization + MLIRTargetAVX512 MLIRTargetLLVMIR MLIRTargetNVVMIR MLIRTargetROCDLIR @@ -13,6 +14,7 @@ set(LIBS ) set(FULL_LIBS MLIRSPIRVSerialization + MLIRTargetAVX512 MLIRTargetLLVMIR MLIRTargetNVVMIR MLIRTargetROCDLIR