diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h new file mode 100644 index 0000000000000..2faf19b788b3a --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h @@ -0,0 +1,51 @@ +//===- DIExpressionLegalization.h - DIExpression Legalization Patterns ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Declarations for known legalization patterns for DIExpressions that should +// be performed before translation into llvm. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H + +#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h" + +namespace mlir { +namespace LLVM { + +//===----------------------------------------------------------------------===// +// Rewrite Patterns +//===----------------------------------------------------------------------===// + +/// Adjacent DW_OP_LLVM_fragment should be merged into one. +/// +/// E.g. +/// #llvm.di_expression<[ +/// DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64) +/// ]> +/// => +/// #llvm.di_expression<[DW_OP_LLVM_fragment(64, 32)]> +class MergeFragments : public DIExpressionRewriter::ExprRewritePattern { +public: + OpIterT match(OpIterRange operators) const override; + SmallVector replace(OpIterRange operators) const override; +}; + +//===----------------------------------------------------------------------===// +// Runner +//===----------------------------------------------------------------------===// + +/// Register all known legalization patterns declared here and apply them to +/// all ops in `op`. +void legalizeDIExpressionsRecursively(Operation *op); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONLEGALIZATION_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h new file mode 100644 index 0000000000000..2d9841518a633 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h @@ -0,0 +1,67 @@ +//===- DIExpressionRewriter.h - Rewriter for DIExpression operators -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A driver for running rewrite patterns on DIExpression operators. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include + +namespace mlir { +namespace LLVM { + +/// Rewriter for DIExpressionAttr. +/// +/// Users of this rewriter register their own rewrite patterns. Each pattern +/// matches on a contiguous range of LLVM DIExpressionElemAttrs, and can be +/// used to rewrite it into a new range of DIExpressionElemAttrs of any length. +class DIExpressionRewriter { +public: + using OperatorT = LLVM::DIExpressionElemAttr; + + class ExprRewritePattern { + public: + using OperatorT = DIExpressionRewriter::OperatorT; + using OpIterT = std::deque::const_iterator; + using OpIterRange = llvm::iterator_range; + + virtual ~ExprRewritePattern() = default; + /// Checks whether a particular prefix of operators matches this pattern. + /// The provided argument is guaranteed non-empty. + /// Return the iterator after the last matched element. + virtual OpIterT match(OpIterRange) const = 0; + /// Replace the operators with a new list of operators. + /// The provided argument is guaranteed to be the same length as returned + /// by the `match` function. + virtual SmallVector replace(OpIterRange) const = 0; + }; + + /// Register a rewrite pattern with the rewriter. + /// Rewrite patterns are attempted in the order of registration. + void addPattern(std::unique_ptr pattern); + + /// Simplify a DIExpression according to all the patterns registered. + /// An optional `maxNumRewrites` can be passed to limit the number of rewrites + /// that gets applied. + LLVM::DIExpressionAttr + simplify(LLVM::DIExpressionAttr expr, + std::optional maxNumRewrites = {}) const; + +private: + /// The registered patterns. + SmallVector> patterns; +}; + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_DIEXPRESSIONREWRITER_H diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index 47a2a251bf3e8..c80494a440116 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms AddComdats.cpp + DIExpressionLegalization.cpp + DIExpressionRewriter.cpp DIScopeForLLVMFuncOp.cpp LegalizeForExport.cpp OptimizeForNVVM.cpp diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp new file mode 100644 index 0000000000000..7d3170bb96821 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionLegalization.cpp @@ -0,0 +1,61 @@ +//===- DIExpressionLegalization.cpp - DIExpression Legalization Patterns --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h" + +#include "llvm/BinaryFormat/Dwarf.h" + +using namespace mlir; +using namespace LLVM; + +//===----------------------------------------------------------------------===// +// MergeFragments +//===----------------------------------------------------------------------===// + +MergeFragments::OpIterT MergeFragments::match(OpIterRange operators) const { + OpIterT it = operators.begin(); + if (it == operators.end() || + it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment) + return operators.begin(); + + ++it; + if (it == operators.end() || + it->getOpcode() != llvm::dwarf::DW_OP_LLVM_fragment) + return operators.begin(); + + return ++it; +} + +SmallVector +MergeFragments::replace(OpIterRange operators) const { + OpIterT it = operators.begin(); + OperatorT first = *(it++); + OperatorT second = *it; + // Add offsets & select the size of the earlier operator (the one closer to + // the IR value). + uint64_t offset = first.getArguments()[0] + second.getArguments()[0]; + uint64_t size = first.getArguments()[1]; + OperatorT newOp = OperatorT::get( + first.getContext(), llvm::dwarf::DW_OP_LLVM_fragment, {offset, size}); + return SmallVector{newOp}; +} + +//===----------------------------------------------------------------------===// +// Runner +//===----------------------------------------------------------------------===// + +void mlir::LLVM::legalizeDIExpressionsRecursively(Operation *op) { + LLVM::DIExpressionRewriter rewriter; + rewriter.addPattern(std::make_unique()); + + AttrTypeReplacer replacer; + replacer.addReplacement([&rewriter](LLVM::DIExpressionAttr expr) { + return rewriter.simplify(expr); + }); + replacer.recursivelyReplaceElementsIn(op); +} diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp new file mode 100644 index 0000000000000..6fdb2f8c19647 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp @@ -0,0 +1,75 @@ +//===- DIExpressionRewriter.cpp - Rewriter for DIExpression operators -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace LLVM; + +#define DEBUG_TYPE "llvm-di-expression-simplifier" + +//===----------------------------------------------------------------------===// +// DIExpressionRewriter +//===----------------------------------------------------------------------===// + +void DIExpressionRewriter::addPattern( + std::unique_ptr pattern) { + patterns.emplace_back(std::move(pattern)); +} + +DIExpressionAttr +DIExpressionRewriter::simplify(DIExpressionAttr expr, + std::optional maxNumRewrites) const { + ArrayRef operators = expr.getOperations(); + + // `inputs` contains the unprocessed postfix of operators. + // `result` contains the already finalized prefix of operators. + // Invariant: concat(result, inputs) is equivalent to `operators` after some + // application of the rewrite patterns. + // Using a deque for inputs so that we have efficient front insertion and + // removal. Random access is not necessary for patterns. + std::deque inputs(operators.begin(), operators.end()); + SmallVector result; + + uint64_t numRewrites = 0; + while (!inputs.empty() && + (!maxNumRewrites || numRewrites < *maxNumRewrites)) { + bool foundMatch = false; + for (const std::unique_ptr &pattern : patterns) { + ExprRewritePattern::OpIterT matchEnd = pattern->match(inputs); + if (matchEnd == inputs.begin()) + continue; + + foundMatch = true; + SmallVector replacement = + pattern->replace(llvm::make_range(inputs.cbegin(), matchEnd)); + inputs.erase(inputs.begin(), matchEnd); + inputs.insert(inputs.begin(), replacement.begin(), replacement.end()); + ++numRewrites; + break; + } + + if (!foundMatch) { + // If no match, pass along the current operator. + result.push_back(inputs.front()); + inputs.pop_front(); + } + } + + if (maxNumRewrites && numRewrites >= *maxNumRewrites) { + LLVM_DEBUG(llvm::dbgs() + << "LLVMDIExpressionSimplifier exceeded max num rewrites (" + << maxNumRewrites << ")\n"); + // Skip rewriting the rest. + result.append(inputs.begin(), inputs.end()); + } + + return LLVM::DIExpressionAttr::get(expr.getContext(), result); +} diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp index 61c1378d96121..1ac994fa5fb78 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/LegalizeForExport.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -79,6 +80,7 @@ struct LegalizeForExportPass : public LLVM::impl::LLVMLegalizeForExportBase { void runOnOperation() override { LLVM::ensureDistinctSuccessors(getOperation()); + LLVM::legalizeDIExpressionsRecursively(getOperation()); } }; } // namespace diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ce46a194ea7d9..fbbfb5b83eb60 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" +#include "mlir/Dialect/LLVMIR/Transforms/DIExpressionLegalization.h" #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" @@ -1568,6 +1569,7 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, return nullptr; LLVM::ensureDistinctSuccessors(module); + LLVM::legalizeDIExpressionsRecursively(module); ModuleTranslation translator(module, std::move(llvmModule)); llvm::IRBuilder<> llvmBuilder(llvmContext); diff --git a/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir new file mode 100644 index 0000000000000..60fbc8135be62 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/di-expression-legalization.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt -llvm-legalize-for-export --split-input-file %s | FileCheck %s -check-prefix=CHECK-OPT +// RUN: mlir-translate -mlir-to-llvmir --split-input-file %s | FileCheck %s -check-prefix=CHECK-TRANSLATE + +#di_file = #llvm.di_file<"foo.c" in "/mlir/"> +#di_compile_unit = #llvm.di_compile_unit, sourceLanguage = DW_LANG_C, file = #di_file, producer = "MLIR", isOptimized = true, emissionKind = Full> +#di_subprogram = #llvm.di_subprogram +#i32_type = #llvm.di_basic_type +#i8_type = #llvm.di_basic_type + +// struct0: {i8, i32} +#struct0_first = #llvm.di_derived_type +#struct0_second = #llvm.di_derived_type +#struct0 = #llvm.di_composite_type + +// struct1: {i8, struct0} +#struct1_first = #llvm.di_derived_type +#struct1_second = #llvm.di_derived_type +#struct1 = #llvm.di_composite_type + +// struct2: {i32, struct1} +#struct2_first = #llvm.di_derived_type +#struct2_second = #llvm.di_derived_type +#struct2 = #llvm.di_composite_type + +#var0 = #llvm.di_local_variable +#var1 = #llvm.di_local_variable +#var2 = #llvm.di_local_variable + +#loc = loc("test.mlir":0:0) + +llvm.func @merge_fragments(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr) { + // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]> + // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 32, 32)) + llvm.intr.dbg.value #var0 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32)]> = %arg0 : !llvm.ptr loc(fused<#di_subprogram>[#loc]) + // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(64, 32)]> + // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 64, 32)) + llvm.intr.dbg.value #var1 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64)]> = %arg1 : !llvm.ptr loc(fused<#di_subprogram>[#loc]) + // CHECK-OPT: #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(96, 32)]> + // CHECK-TRANSLATE: !DIExpression(DW_OP_deref, DW_OP_LLVM_fragment, 96, 32)) + llvm.intr.dbg.value #var2 #llvm.di_expression<[DW_OP_deref, DW_OP_LLVM_fragment(32, 32), DW_OP_LLVM_fragment(32, 64), DW_OP_LLVM_fragment(32, 96)]> = %arg2 : !llvm.ptr loc(fused<#di_subprogram>[#loc]) + llvm.return +}