diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake index 6589458ab7894..9b05b70231dba 100644 --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -762,3 +762,103 @@ function(mlir_target_link_libraries target type) target_link_libraries(${target} ${type} ${ARGN}) endif() endfunction() + +# Extracts LIT tests embedded in `Testable` records in `tblgen_file` +# and generates a file per test in `output_dir` +# +# Example usage: +# # Extract tests from MyPasses.td and generate them in test/Passes/ +# add_embedded_lit_tests(MyPassesEmbeddedTests +# ${CMAKE_CURRENT_SOURCE_DIR}/include/MyPasses.td +# ${CMAKE_CURRENT_SOURCE_DIR}/test/Passes/) +# +# # This will: +# # 1. Process MyPasses.td with mlir-tblgen --gen-lit-tests +# # 2. Extract individual test files to test/Passes/ +# # 3. Generate files like: test/Passes/generated_MyPass_test1.mlir +# +function(add_embedded_lit_tests target tblgen_file output_dir) + set(LLVM_TARGET_DEFINITIONS ${tblgen_file}) + + # Extraction script content + set(EXTRACT_SCRIPT_CONTENT [[ + # Generated extraction script + if(NOT CONSOLIDATED_FILE) + message(FATAL_ERROR "CONSOLIDATED_FILE variable is required") + endif() + + if(NOT OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR variable is required") + endif() + + if(NOT EXISTS ${CONSOLIDATED_FILE}) + message(FATAL_ERROR "Consolidated file does not exist: ${CONSOLIDATED_FILE}") + endif() + + # Read the consolidated file + file(READ ${CONSOLIDATED_FILE} file_content) + + # Split into lines for processing + string(REPLACE "\n" ";" lines "${file_content}") + + set(current_filename "") + set(current_content "") + set(in_test_block FALSE) + set(extracted_test_files) + + foreach(line IN LISTS lines) + # Check for filename line + if(line MATCHES "^// File: (.+)$") + set(current_filename "${CMAKE_MATCH_1}") + endif() + + # Check for BEGIN marker + if(line MATCHES "^// --- BEGIN .+ ---$") + set(in_test_block TRUE) + set(current_content "") + # Check for END marker + elseif(line MATCHES "^// --- END .+ ---$") + set(in_test_block FALSE) + + # Write the extracted content to file + if(current_filename AND current_content) + file(MAKE_DIRECTORY ${OUTPUT_DIR}) + file(WRITE ${OUTPUT_DIR}/${current_filename} "${current_content}") + message(STATUS "Extracted test file: ${current_filename}") + list(APPEND extracted_test_files ${current_filename}) + endif() + + set(current_filename "") + set(current_content "") + # Collect content within BEGIN/END block + elseif(in_test_block) + string(APPEND current_content "${line}\n") + endif() + endforeach() + + list(LENGTH extracted_test_files num_extracted_files) + message(STATUS "Extracted ${num_extracted_files} test files to ${OUTPUT_DIR}") + ]]) + + # Write extraction script to a file in the build directory + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake "${EXTRACT_SCRIPT_CONTENT}") + + # Process tblgen_file and generate a file with all embedded LIT + # tests in tblgen_file + get_filename_component(tblgen_name ${tblgen_file} NAME_WE) + set(consolidated_output_file ${tblgen_name}_extracted_lit_tests.txt) + mlir_tablegen(${consolidated_output_file} --gen-lit-tests) + + # Add public tablegen target to trigger builds on changes in tblgen_file + add_public_tablegen_target(${target}) + + # Call the extraction script to extract all LIT tests into individual + # `.mlir` test files + add_custom_command(TARGET ${target} POST_BUILD + COMMAND ${CMAKE_COMMAND} + -DCONSOLIDATED_FILE=${CMAKE_CURRENT_BINARY_DIR}/${consolidated_output_file} + -DOUTPUT_DIR=${output_dir} + -P ${CMAKE_CURRENT_BINARY_DIR}/extract_lit_tests.cmake + COMMENT "Extracting LIT tests to individual files" + ) +endfunction() \ No newline at end of file diff --git a/mlir/include/mlir/IR/Testable.td b/mlir/include/mlir/IR/Testable.td new file mode 100644 index 0000000000000..15814ed1bd939 --- /dev/null +++ b/mlir/include/mlir/IR/Testable.td @@ -0,0 +1,40 @@ +//===-- Testable.td - Testable type definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the `Testable` type. +// +// Any type whose records can have corresponding LIT tests (eg - Pass) can extend +// `Testable` in order to be able to embed LIT tests within record definitions. +// +//===----------------------------------------------------------------------===// + +#ifndef TESTABLE +#define TESTABLE + +// Represents a LIT test record in TableGen +class LitTest run = [], list check = []> { + // The name of the generated test file + string testFileName = name; + + // The IR snippet/code to be tested + code irSnippet = snippet; + + // The RUN commands for the test (e.g., "mlir-opt %s") + list runLines = run; + + // Expected output patterns (CHECK lines) + list checkLines = check; +} + +// Base class for elements that can have auto-generated LIT tests +class Testable { + // List of LIT tests associated with this element + list tests = []; +} + +#endif // TESTABLE \ No newline at end of file diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td index e37f9735e2241..50ea44419ca24 100644 --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -14,6 +14,8 @@ #ifndef MLIR_PASS_PASSBASE #define MLIR_PASS_PASSBASE +include "mlir/IR/Testable.td" + //===----------------------------------------------------------------------===// // Options //===----------------------------------------------------------------------===// @@ -62,7 +64,7 @@ class Statistic { // Pass //===----------------------------------------------------------------------===// -class PassBase { +class PassBase : Testable { // The command line argument of the pass. string argument = passArg; diff --git a/mlir/test/mlir-tblgen/gen-lit-tests.td b/mlir/test/mlir-tblgen/gen-lit-tests.td new file mode 100644 index 0000000000000..40a03fb2b2d60 --- /dev/null +++ b/mlir/test/mlir-tblgen/gen-lit-tests.td @@ -0,0 +1,65 @@ +// RUN: mlir-tblgen -gen-lit-tests -I %S/../../include -dialect=test %s | FileCheck %s + +include "mlir/Pass/PassBase.td" +include "mlir/IR/Testable.td" + +def TestPassWithEmbeddedLitTests : Pass<"test-pass-with-embedded-lit-tests"> { + let summary = "pass summary"; + let description = [{ + Pass description + }]; + + let tests = [ + LitTest< + "lit_test_file_1.mlir", + [{ + func.func @test1() { + return 42; + } + }], + [ + "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s", + ], + [ + "// RANDOM-CHECK-LABEL: func.func @test1", + ] + >, + LitTest< + "lit_test_file_2.mlir", + [{ + func.func @test2() { + return 42; + } + }], + [ + "// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s", + ], + [ + "// RANDOM-CHECK-LABEL: func.func @test2", + ] + >, + ]; +} + +// CHECK-LABEL: // Generated 2 LIT test files +// CHECK: // Use the following files for LIT testing: + +// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir +// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir --- +// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s +// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests +// CHECK: func.func @test1() { +// CHECK: return 42; +// CHECK: } +// CHECK: // RANDOM-CHECK-LABEL: func.func @test1 +// CHECK: --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_1.mlir --- + +// CHECK: // File: generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir +// CHECK: // --- BEGIN generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir --- +// CHECK: // RUN: mlir-opt %s --verify-roundtrip | FileCheck %s +// CHECK: // Generated from TableGen definition: TestPassWithEmbeddedLitTests +// CHECK: func.func @test2() { +// CHECK: return 42; +// CHECK: } +// CHECK: // RANDOM-CHECK-LABEL: func.func @test2 +// CHECK: // --- END generated_TestPassWithEmbeddedLitTests_lit_test_file_2.mlir --- \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt index 2a7ef7e0576c8..e721f1e26a2bd 100644 --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -16,6 +16,7 @@ add_tablegen(mlir-tblgen MLIR EnumsGen.cpp EnumPythonBindingGen.cpp FormatGen.cpp + LitTestGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp mlir-tblgen.cpp diff --git a/mlir/tools/mlir-tblgen/LitTestGen.cpp b/mlir/tools/mlir-tblgen/LitTestGen.cpp new file mode 100644 index 0000000000000..49a092fa9879f --- /dev/null +++ b/mlir/tools/mlir-tblgen/LitTestGen.cpp @@ -0,0 +1,170 @@ +//===- LitTestGen.cpp - LIT test generator ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// LitTestGen extracts `LitTest` records from `Testable` TableGen records and +// generates corresponding LIT test files. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Pass.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Path.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +#include + +using namespace mlir; +using namespace mlir::tblgen; +using llvm::formatv; +using llvm::RecordKeeper; + +static llvm::cl::OptionCategory litTestGenCategory("Options for -gen-lit-tests"); +static llvm::cl::opt + outputDir("output-dir", + llvm::cl::desc("Output directory for generated test files"), + llvm::cl::cat(litTestGenCategory), + llvm::cl::value_desc("directory")); + + +/// Cpp type corresponding to the `LitTest` record type in TableGen +struct LitTest { + std::string sourceDefName; + std::string testFileName; + std::string irSnippet; + llvm::SmallVector runLines; + llvm::SmallVector checkLines; +}; + +static llvm::SmallVector extractTestsFromRecord(const llvm::Record *record, + llvm::StringRef dialectName = "") { + llvm::SmallVector tests; + + // Check if the record has a tests field + const llvm::RecordVal *testsVal = record->getValue("tests"); + if (!testsVal) + return tests; + + const llvm::ListInit *testsList = + llvm::dyn_cast_or_null(testsVal->getValue()); + if (!testsList) + return tests; + + for (const llvm::Init *init : testsList->getElements()) { + const llvm::DefInit *defInit = llvm::dyn_cast(init); + if (!defInit) + continue; + + const llvm::Record *testRec = defInit->getDef(); + + // Extract fields from LitTest record + std::string name = testRec->getValueAsString("testFileName").str(); + std::string irSnippet = testRec->getValueAsString("irSnippet").str(); + + llvm::SmallVector runLines; + llvm::for_each(*testRec->getValueAsListInit("runLines"), [&](const llvm::Init *init) { + runLines.emplace_back(llvm::cast(init)->getValue()); + }); + + llvm::SmallVector checkLines; + llvm::for_each(*testRec->getValueAsListInit("checkLines"), [&](const llvm::Init *init) { + checkLines.emplace_back(llvm::cast(init)->getValue()); + }); + + tests.push_back(LitTest { + record->getName().str(), + name, + irSnippet, + runLines, + checkLines, + }); + } + + return tests; +} + +/// Extract tests from passes +static llvm::SmallVector extractPassTests(const RecordKeeper &records) { + llvm::SmallVector tests; + + // Check if PassBase class exists before trying to get derived definitions + if (records.getClass("PassBase")) { + for (const llvm::Record *def : records.getAllDerivedDefinitions("PassBase")) { + if (def->isAnonymous()) + continue; + + auto passTests = extractTestsFromRecord(def, "passes"); + tests.insert(tests.end(), passTests.begin(), passTests.end()); + } + } + + return tests; +} + +/// Generate a LIT test file for an IR test +static void generateTestFile(const LitTest &test, llvm::raw_ostream &os) { + // Add RUN lines + for (const auto& runLine : test.runLines) { + os << "\n" << runLine << "\n"; + } + + os << "// Generated from TableGen definition: " << test.sourceDefName << "\n\n"; + + // Add the test body + os << test.irSnippet << "\n"; + + // Add CHECK lines + for (const auto& checkLine : test.checkLines) { + os << "\n" << checkLine << "\n"; + } +} + +/// Main function to generate all IR test test files +static void generateLitTests(const RecordKeeper &records, raw_ostream &os) { + llvm::SmallVector allTests; + + // Extract tests from different definition types (only passes for now) + auto passTests = extractPassTests(records); + + allTests.insert(allTests.end(), passTests.begin(), passTests.end()); + + if (allTests.empty()) { + os << "// No LitTest record found in any TableGen definition\n"; + return; + } + + // Generate summary + os << "// Generated " << allTests.size() << " LIT test files\n"; + os << "// Use the following files for LIT testing:\n\n"; + + // Generate file list and content for each test + for (const auto& test : allTests) { + std::string testFileName = formatv("generated_{0}_{1}", test.sourceDefName, test.testFileName); + os << "// File: " << testFileName << "\n"; + + os << "// --- BEGIN " << testFileName << " ---\n"; + generateTestFile(test, os); + os << "// --- END " << testFileName << " ---\n\n"; + } +} + +//===----------------------------------------------------------------------===// +// Generator Registration +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genLitTests("gen-lit-tests", "Generate LIT test files for `Testable` TableGen records", + [](const RecordKeeper &records, raw_ostream &os) { + generateLitTests(records, os); + return false; + }); \ No newline at end of file