diff --git a/cmake/Modules/FindMLIR.cmake b/cmake/Modules/FindMLIR.cmake index 14ad214b6eb..3713ebf7017 100644 --- a/cmake/Modules/FindMLIR.cmake +++ b/cmake/Modules/FindMLIR.cmake @@ -28,7 +28,11 @@ else() set(MLIR_LIB_DIR ${MLIR_ROOT_DIR}/lib) # To be done: add the required MLIR libraries. Hopefully we don't have to manually list all MLIR libs. - set(MLIR_LIBRARIES "") + if(EXISTS "${MLIR_LIB_DIR}/MLIRIR.lib") + set(MLIR_LIBRARIES ${MLIR_LIB_DIR}/MLIRIR.lib ${MLIR_LIB_DIR}/MLIRSupport.lib) + elseif(EXISTS "${MLIR_LIB_DIR}/libMLIRIR.a") + set(MLIR_LIBRARIES ${MLIR_LIB_DIR}/libMLIRIR.a ${MLIR_LIB_DIR}/libMLIRSupport.a) + endif() # XXX: This function is untested and will need adjustment. function(mlir_tablegen) diff --git a/dmd/dmodule.d b/dmd/dmodule.d index d41d926a267..5d0f313cd56 100644 --- a/dmd/dmodule.d +++ b/dmd/dmodule.d @@ -550,6 +550,8 @@ version (IN_LLVM) objExt = global.ll_ext; else if (global.params.output_s) objExt = global.s_ext; + else if (global.params.output_mlir) + objExt = global.mlir_ext; if (objExt) objfile = setOutfilename(global.params.objname, global.params.objdir, filename, objExt); diff --git a/dmd/globals.d b/dmd/globals.d index e7c819857d9..a9e6d967973 100644 --- a/dmd/globals.d +++ b/dmd/globals.d @@ -306,6 +306,7 @@ version (IN_LLVM) // LDC stuff OUTPUTFLAG output_ll; + OUTPUTFLAG output_mlir; OUTPUTFLAG output_bc; OUTPUTFLAG output_s; OUTPUTFLAG output_o; @@ -345,6 +346,7 @@ extern (C++) struct Global version (IN_LLVM) { const(char)[] ll_ext; + const(char)[] mlir_ext; const(char)[] bc_ext; const(char)[] s_ext; const(char)[] ldc_version; @@ -510,6 +512,7 @@ else vendor = "LDC"; obj_ext = "o"; ll_ext = "ll"; + mlir_ext = "mlir"; bc_ext = "bc"; s_ext = "s"; diff --git a/dmd/globals.h b/dmd/globals.h index edebcce38f8..5e13c57ebb7 100644 --- a/dmd/globals.h +++ b/dmd/globals.h @@ -268,6 +268,7 @@ struct Param // LDC stuff OUTPUTFLAG output_ll; + OUTPUTFLAG output_mlir; OUTPUTFLAG output_bc; OUTPUTFLAG output_s; OUTPUTFLAG output_o; @@ -304,6 +305,7 @@ struct Global DString obj_ext; #if IN_LLVM DString ll_ext; + DString mlir_ext; //MLIR code DString bc_ext; DString s_ext; DString ldc_version; diff --git a/driver/cl_options.cpp b/driver/cl_options.cpp index 8bc6850b692..64bdab3e358 100644 --- a/driver/cl_options.cpp +++ b/driver/cl_options.cpp @@ -204,6 +204,9 @@ cl::opt output_bc("output-bc", cl::desc("Write LLVM bitcode"), cl::opt output_ll("output-ll", cl::desc("Write LLVM IR"), cl::ZeroOrMore); +cl::opt output_mlir("output-mlir", cl::desc("Write MLIR"), + cl::ZeroOrMore); + cl::opt output_s("output-s", cl::desc("Write native assembly"), cl::ZeroOrMore); diff --git a/driver/cl_options.h b/driver/cl_options.h index ec1716d14a2..d23cabfa08b 100644 --- a/driver/cl_options.h +++ b/driver/cl_options.h @@ -52,6 +52,7 @@ extern cl::opt objectDir; extern cl::opt soname; extern cl::opt output_bc; extern cl::opt output_ll; +extern cl::opt output_mlir; extern cl::opt output_s; extern cl::opt output_o; extern cl::opt ddocDir; diff --git a/driver/codegenerator.cpp b/driver/codegenerator.cpp index 25366ae6b49..ec00cba4580 100644 --- a/driver/codegenerator.cpp +++ b/driver/codegenerator.cpp @@ -219,8 +219,16 @@ void inlineAsmDiagnosticHandler(const llvm::SMDiagnostic &d, void *context, } // anonymous namespace namespace ldc { -CodeGenerator::CodeGenerator(llvm::LLVMContext &context, bool singleObj) - : context_(context), moduleCount_(0), singleObj_(singleObj), ir_(nullptr) { +CodeGenerator::CodeGenerator(llvm::LLVMContext &context, +#if LDC_MLIR_ENABLED + mlir::MLIRContext &mlirContext, +#endif + bool singleObj) + : context_(context), +#if LDC_MLIR_ENABLED + mlirContext_(mlirContext), +#endif + moduleCount_(0), singleObj_(singleObj), ir_(nullptr) { // Set the context to discard value names when not generating textual IR. if (!global.params.output_ll) { context_.setDiscardValueNames(true); @@ -274,7 +282,6 @@ void CodeGenerator::finishLLModule(Module *m) { if (moduleCount_ == 1) { insertBitcodeFiles(ir_->module, ir_->context(), global.params.bitcodeFiles); } - writeAndFreeLLModule(m->objfile.toChars()); } @@ -341,4 +348,58 @@ void CodeGenerator::emit(Module *m) { Logger::disable(); } } + +#if LDC_MLIR_ENABLED +void CodeGenerator::emitMLIR(Module *m) { + bool const loggerWasEnabled = Logger::enabled(); + if (m->llvmForceLogging && !loggerWasEnabled) { + Logger::enable(); + } + + IF_LOG Logger::println("CodeGenerator::emitMLIR(%s)", m->toPrettyChars()); + LOG_SCOPE; + + if (global.params.verbose_cg) { + printf("codegen: %s (%s)\n", m->toPrettyChars(), m->srcfile.toChars()); + } + + if (global.errors) { + Logger::println("Aborting because of errors"); + fatal(); + } + + mlir::OwningModuleRef module; + /*module = mlirGen(mlirContext, m, irs); + if(!module){ + IF_LOG Logger::println("Error generating MLIR:'%s'", llpath.c_str()); + fatal(); + }*/ + + writeMLIRModule(&module, m->objfile.toChars()); + + if (m->llvmForceLogging && !loggerWasEnabled) { + Logger::disable(); + } +} + +void CodeGenerator::writeMLIRModule(mlir::OwningModuleRef *module, + const char *filename) { + // Write MLIR + if (global.params.output_mlir) { + const auto llpath = replaceExtensionWith(global.mlir_ext, filename); + Logger::println("Writting MLIR to %s\n", llpath.c_str()); + std::error_code errinfo; + llvm::raw_fd_ostream aos(llpath, errinfo, llvm::sys::fs::F_None); + + if (aos.has_error()) { + error(Loc(), "Cannot write MLIR file '%s': %s", llpath.c_str(), + errinfo.message().c_str()); + fatal(); + } + + // module->print(aos); + } +} + +#endif } diff --git a/driver/codegenerator.h b/driver/codegenerator.h index 5ced3e206b0..986047581bc 100644 --- a/driver/codegenerator.h +++ b/driver/codegenerator.h @@ -20,21 +20,40 @@ #pragma once #include "gen/irstate.h" +#if LDC_MLIR_ENABLED +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#endif namespace ldc { class CodeGenerator { public: - CodeGenerator(llvm::LLVMContext &context, bool singleObj); + CodeGenerator(llvm::LLVMContext &context, +#if LDC_MLIR_ENABLED + mlir::MLIRContext &mlirContext, +#endif + bool singleObj); + ~CodeGenerator(); void emit(Module *m); +#if LDC_MLIR_ENABLED + void emitMLIR(Module *m); +#endif + private: void prepareLLModule(Module *m); void finishLLModule(Module *m); void writeAndFreeLLModule(const char *filename); +#if LDC_MLIR_ENABLED + void writeMLIRModule(mlir::OwningModuleRef *module, const char *filename); +#endif llvm::LLVMContext &context_; +#if LDC_MLIR_ENABLED + mlir::MLIRContext &mlirContext_; +#endif int moduleCount_; bool const singleObj_; IRState *ir_; diff --git a/driver/main.cpp b/driver/main.cpp index bf146925b3f..7c815d6f47e 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -439,11 +439,13 @@ void parseCommandLine(Strings &sourceFiles) { global.params.output_o = (opts::output_o == cl::BOU_UNSET && - !(opts::output_bc || opts::output_ll || opts::output_s)) + !(opts::output_bc || opts::output_ll || opts::output_s || + opts::output_mlir)) ? OUTPUTFLAGdefault : opts::output_o == cl::BOU_TRUE ? OUTPUTFLAGset : OUTPUTFLAGno; global.params.output_bc = opts::output_bc ? OUTPUTFLAGset : OUTPUTFLAGno; global.params.output_ll = opts::output_ll ? OUTPUTFLAGset : OUTPUTFLAGno; + global.params.output_mlir = opts::output_mlir ? OUTPUTFLAGset : OUTPUTFLAGno; global.params.output_s = opts::output_s ? OUTPUTFLAGset : OUTPUTFLAGno; global.params.cov = (global.params.covPercent <= 100); @@ -509,9 +511,20 @@ void parseCommandLine(Strings &sourceFiles) { strcmp(ext, global.s_ext.ptr) == 0) { global.params.output_s = OUTPUTFLAGset; global.params.output_o = OUTPUTFLAGno; + } else if (opts::output_mlir.getNumOccurrences() == 0 && + strcmp(ext, global.mlir_ext.ptr) == 0) { + global.params.output_mlir = OUTPUTFLAGset; + global.params.output_o = OUTPUTFLAGno; } } +#ifndef LDC_MLIR_ENABLED + if (global.params.output_mlir == OUTPUTFLAGset) { + error(Loc(), "MLIR output requested but this LDC was built without MLIR support"); + fatal(); + } +#endif + if (soname.getNumOccurrences() > 0 && !global.params.dll) { error(Loc(), "-soname can be used only when building a shared library"); } @@ -1090,7 +1103,13 @@ int cppmain() { void codegenModules(Modules &modules) { // Generate one or more object/IR/bitcode files/dcompute kernels. if (global.params.obj && !modules.empty()) { +#if LDC_MLIR_ENABLED + mlir::MLIRContext mlircontext; + ldc::CodeGenerator cg(getGlobalContext(), mlircontext, + global.params.oneobj); +#else ldc::CodeGenerator cg(getGlobalContext(), global.params.oneobj); +#endif DComputeCodeGenManager dccg(getGlobalContext()); std::vector computeModules; // When inlining is enabled, we are calling semantic3 on function @@ -1113,7 +1132,12 @@ void codegenModules(Modules &modules) { const auto atCompute = hasComputeAttr(m); if (atCompute == DComputeCompileFor::hostOnly || atCompute == DComputeCompileFor::hostAndDevice) { - cg.emit(m); +#if LDC_MLIR_ENABLED + if (global.params.output_mlir == OUTPUTFLAGset) + cg.emitMLIR(m); + else +#endif + cg.emit(m); } if (atCompute != DComputeCompileFor::hostOnly) { computeModules.push_back(m); diff --git a/driver/toobj.cpp b/driver/toobj.cpp index a42398c95ae..d7d49d9c57c 100644 --- a/driver/toobj.cpp +++ b/driver/toobj.cpp @@ -309,6 +309,23 @@ bool shouldDoLTO(llvm::Module *m) { } } // end of anonymous namespace +std::string replaceExtensionWith(const DArray &ext, + const char *filename) { + const auto outputFlags = {global.params.output_o, global.params.output_bc, + global.params.output_ll, global.params.output_s, + global.params.output_mlir}; + const auto numOutputFiles = + std::count_if(outputFlags.begin(), outputFlags.end(), + [](OUTPUTFLAG flag) { return flag != 0; }); + + if (numOutputFiles == 1) + return filename; + llvm::SmallString<128> buffer(filename); + llvm::sys::path::replace_extension(buffer, + llvm::StringRef(ext.ptr, ext.length)); + return {buffer.data(), buffer.size()}; +} + void writeModule(llvm::Module *m, const char *filename) { const bool doLTO = shouldDoLTO(m); const bool outputObj = shouldOutputObjectFile(); @@ -349,29 +366,13 @@ void writeModule(llvm::Module *m, const char *filename) { } } - const auto outputFlags = {global.params.output_o, global.params.output_bc, - global.params.output_ll, global.params.output_s}; - const auto numOutputFiles = - std::count_if(outputFlags.begin(), outputFlags.end(), - [](OUTPUTFLAG flag) { return flag != 0; }); - - const auto replaceExtensionWith = - [=](const DArray &ext) -> std::string { - if (numOutputFiles == 1) - return filename; - llvm::SmallString<128> buffer(filename); - llvm::sys::path::replace_extension(buffer, - llvm::StringRef(ext.ptr, ext.length)); - return {buffer.data(), buffer.size()}; - }; - // write LLVM bitcode const bool emitBitcodeAsObjectFile = doLTO && outputObj && !global.params.output_bc; if (global.params.output_bc || emitBitcodeAsObjectFile) { std::string bcpath = emitBitcodeAsObjectFile ? filename - : replaceExtensionWith(global.bc_ext); + : replaceExtensionWith(global.bc_ext, filename); Logger::println("Writing LLVM bitcode to: %s\n", bcpath.c_str()); std::error_code errinfo; llvm::raw_fd_ostream bos(bcpath.c_str(), errinfo, llvm::sys::fs::F_None); @@ -413,7 +414,7 @@ void writeModule(llvm::Module *m, const char *filename) { // write LLVM IR if (global.params.output_ll) { - const auto llpath = replaceExtensionWith(global.ll_ext); + const auto llpath = replaceExtensionWith(global.ll_ext, filename); Logger::println("Writing LLVM IR to: %s\n", llpath.c_str()); std::error_code errinfo; llvm::raw_fd_ostream aos(llpath.c_str(), errinfo, llvm::sys::fs::F_None); @@ -435,7 +436,7 @@ void writeModule(llvm::Module *m, const char *filename) { llvm::sys::fs::createUniqueFile("ldc-%%%%%%%.s", buffer); spath = {buffer.data(), buffer.size()}; } else { - spath = replaceExtensionWith(global.s_ext); + spath = replaceExtensionWith(global.s_ext, filename); } Logger::println("Writing asm to: %s\n", spath.c_str()); diff --git a/driver/toobj.h b/driver/toobj.h index 9b13349ea32..5914313c720 100644 --- a/driver/toobj.h +++ b/driver/toobj.h @@ -12,9 +12,14 @@ //===----------------------------------------------------------------------===// #pragma once +#include +#include "dmd/root/dcompat.h" namespace llvm { class Module; } void writeModule(llvm::Module *m, const char *filename); + +std::string replaceExtensionWith(const DArray &ext, + const char *filename);