Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions mlir/include/mlir/Conversion/ComplexCommon/DivisionConverter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===- DivisionConverter.h - Complex division conversion ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
#define MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

namespace mlir {
namespace complex {
/// convert a complex division to the LLVM dialect using algebraic method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter,
Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm,
LLVM::FastmathFlagsAttr fmf,
Value *resultRe, Value *resultIm);

/// convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter,
Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm,
arith::FastMathFlagsAttr fmf,
Value *resultRe, Value *resultIm);

/// convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter,
Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm,
LLVM::FastmathFlagsAttr fmf,
Value *resultRe, Value *resultIm);

/// convert a complex division to the arith/math dialects using Smith's method
void convertDivToStandardUsingRangeReduction(
ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm,
Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe,
Value *resultIm);

} // namespace complex
} // namespace mlir

#endif // MLIR_CONVERSION_COMPLEXCOMMON_DIVISIONCONVERTER_H
8 changes: 6 additions & 2 deletions mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_

#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
class DialectRegistry;
Expand Down Expand Up @@ -39,8 +41,10 @@ class ComplexStructBuilder : public StructBuilder {
};

/// Populate the given list with patterns that convert from Complex to LLVM.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
void populateComplexToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
mlir::complex::ComplexRangeFlags complexRange =
mlir::complex::ComplexRangeFlags::basic);

void registerConvertComplexToLLVMInterface(DialectRegistry &registry);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_
#define MLIR_CONVERSION_COMPLEXTOSTANDARD_COMPLEXTOSTANDARD_H_

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
Expand All @@ -18,10 +20,15 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"

/// Populate the given list with patterns that convert from Complex to Standard.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns);
void populateComplexToStandardConversionPatterns(
RewritePatternSet &patterns,
mlir::complex::ComplexRangeFlags complexRange =
mlir::complex::ComplexRangeFlags::improved);

/// Create a pass to convert Complex operations to the Standard dialect.
std::unique_ptr<Pass> createConvertComplexToStandardPass();
std::unique_ptr<Pass>
createConvertComplexToStandardPass(ConvertComplexToStandardOptions options);

} // namespace mlir

Expand Down
22 changes: 22 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
def ConvertComplexToLLVMPass : Pass<"convert-complex-to-llvm"> {
let summary = "Convert Complex dialect to LLVM dialect";
let dependentDialects = ["LLVM::LLVMDialect"];

let options = [
Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
/*default=*/"::mlir::complex::ComplexRangeFlags::basic",
"Control the intermediate calculation of complex number division",
[{::llvm::cl::values(
clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved"),
clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic (default)"),
clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
)}]>,
];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -308,6 +319,17 @@ def ConvertComplexToStandard : Pass<"convert-complex-to-standard"> {
let summary = "Convert Complex dialect to standard dialect";
let constructor = "mlir::createConvertComplexToStandardPass()";
let dependentDialects = ["math::MathDialect"];

let options = [
Option<"complexRange", "complex-range", "::mlir::complex::ComplexRangeFlags",
/*default=*/"::mlir::complex::ComplexRangeFlags::improved",
"Control the intermediate calculation of complex number division",
[{::llvm::cl::values(
clEnumValN(::mlir::complex::ComplexRangeFlags::improved, "improved", "improved (default)"),
clEnumValN(::mlir::complex::ComplexRangeFlags::basic, "basic", "basic"),
clEnumValN(::mlir::complex::ComplexRangeFlags::none, "none", "none")
)}]>,
];
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ add_mlir_dialect(ComplexOps complex)
add_mlir_doc(ComplexOps ComplexOps Dialects/ -gen-dialect-doc -dialect=complex)

set(LLVM_TARGET_DEFINITIONS ComplexAttributes.td)
mlir_tablegen(ComplexEnums.h.inc -gen-enum-decls)
mlir_tablegen(ComplexEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRComplexAttributesIncGen)
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/Complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

#include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc"

//===----------------------------------------------------------------------===//
// Complex Dialect Enums
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Complex/IR/ComplexEnums.h.inc"

//===----------------------------------------------------------------------===//
// Complex Dialect Operations
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef COMPLEX_BASE
#define COMPLEX_BASE

include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"

def Complex_Dialect : Dialect {
Expand All @@ -24,4 +25,19 @@ def Complex_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
// Complex_ComplexRangeFlags
//===----------------------------------------------------------------------===//

def Complex_CRF_improved : I32BitEnumAttrCaseBit<"improved", 0>;
def Complex_CRF_basic : I32BitEnumAttrCaseBit<"basic", 1>;
def Complex_CRF_none : I32BitEnumAttrCaseBit<"none", 2>;

def Complex_ComplexRangeFlags : I32BitEnumAttr<
"ComplexRangeFlags",
"Complex range flags",
[Complex_CRF_improved, Complex_CRF_basic, Complex_CRF_none]> {
let cppNamespace = "::mlir::complex";
}

#endif // COMPLEX_BASE
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_subdirectory(ArmSMEToSCF)
add_subdirectory(ArmSMEToLLVM)
add_subdirectory(AsyncToLLVM)
add_subdirectory(BufferizationToMemRef)
add_subdirectory(ComplexCommon)
add_subdirectory(ComplexToLibm)
add_subdirectory(ComplexToLLVM)
add_subdirectory(ComplexToSPIRV)
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Conversion/ComplexCommon/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_conversion_library(MLIRComplexDivisionConversion
DivisionConverter.cpp

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRComplexDialect
MLIRLLVMDialect
MLIRMathDialect
)
Loading