diff --git a/.github/workflows/premerge.yml b/.github/workflows/premerge.yml index 2a54b440..720c51cf 100644 --- a/.github/workflows/premerge.yml +++ b/.github/workflows/premerge.yml @@ -22,8 +22,8 @@ jobs: run: | wget https://apt.llvm.org/llvm.sh chmod +x llvm.sh - sudo ./llvm.sh 20 - sudo apt install libmlir-20-dev mlir-20-tools + sudo ./llvm.sh 23 + sudo apt install libmlir-23-dev mlir-23-tools - name: ccache uses: hendrikmuhs/ccache-action@v1.2 diff --git a/CMakePresets.json b/CMakePresets.json index fce0cf29..500e516a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -20,6 +20,16 @@ "CMAKE_BUILD_TYPE": "Release", "CPM_SOURCE_CACHE": ".cache/CPM" } + }, + { + "name": "release-with-debug-info", + "displayName": "Release with Debug Info", + "generator": "Ninja", + "binaryDir": "build/release-with-debug-info", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CPM_SOURCE_CACHE": ".cache/CPM" + } } ], "buildPresets": [ @@ -34,6 +44,12 @@ "displayName": "Release Build", "configurePreset": "release", "configuration": "Release" + }, + { + "name": "release-with-debug-info", + "displayName": "Release with Debug Info Build", + "configurePreset": "release-with-debug-info", + "configuration": "RelWithDebInfo" } ], "testPresets": [ @@ -46,6 +62,11 @@ "name": "release", "displayName": "Test all in Release mode", "configurePreset": "release" + }, + { + "name": "release-with-debug-info", + "displayName": "Test all in Release with Debug Info mode", + "configurePreset": "release-with-debug-info" } ] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 889a9d75..c31ceb1c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -298,7 +298,7 @@ target_link_libraries(python-cpp ) # LLVM backend -find_package(LLVM CONFIG 20.1) +find_package(LLVM CONFIG 23) if(ENABLE_LLVM_BACKEND AND NOT LLVM_FOUND) message(FATAL_ERROR "Could not find LLVM in the local environment") elseif(ENABLE_LLVM_BACKEND AND LLVM_FOUND) diff --git a/src/executable/bytecode/Bytecode.cpp b/src/executable/bytecode/Bytecode.cpp index 21cc7379..dc07dfdc 100644 --- a/src/executable/bytecode/Bytecode.cpp +++ b/src/executable/bytecode/Bytecode.cpp @@ -132,7 +132,9 @@ py::PyResult Bytecode::eval_loop(VirtualMachine &vm, Interpreter &int ASSERT((*vm.instruction_pointer()).get()); const auto ¤t_ip = vm.instruction_pointer(); const auto &instruction = *current_ip; - spdlog::debug("{} {}", (void *)instruction.get(), instruction->to_string()); + // spdlog::debug("{} {}", (void *)instruction.get(), instruction->to_string()); + // std::cout << std::format("{} {}", (void *)instruction.get(), instruction->to_string()) + // << std::endl; auto result = instruction->execute(vm, vm.interpreter()); // we left the current stack frame in the previous instruction if (vm.stack().size() != stack_depth) { diff --git a/src/executable/bytecode/instructions/ListToTuple.cpp b/src/executable/bytecode/instructions/ListToTuple.cpp index 4d820b43..f1b8f3b0 100644 --- a/src/executable/bytecode/instructions/ListToTuple.cpp +++ b/src/executable/bytecode/instructions/ListToTuple.cpp @@ -1,8 +1,11 @@ #include "ListToTuple.hpp" #include "runtime/PyList.hpp" +#include "runtime/PyString.hpp" #include "runtime/PyTuple.hpp" #include "vm/VM.hpp" +#include + using namespace py; PyResult ListToTuple::execute(VirtualMachine &vm, Interpreter &) const @@ -12,6 +15,14 @@ PyResult ListToTuple::execute(VirtualMachine &vm, Interpreter &) const ASSERT(std::holds_alternative(list)); auto *pylist = std::get(list); + if (!as(pylist)) { + std::cout << to_string() << std::endl; + if (!pylist) { + std::cout << "(null)" << std::endl; + } else { + std::cout << pylist->str().unwrap()->to_string() << std::endl; + } + } ASSERT(as(pylist)); auto result = PyTuple::create(as(pylist)->elements()); diff --git a/src/executable/bytecode/instructions/YieldFrom.cpp b/src/executable/bytecode/instructions/YieldFrom.cpp index 3075939a..657ed4d5 100644 --- a/src/executable/bytecode/instructions/YieldFrom.cpp +++ b/src/executable/bytecode/instructions/YieldFrom.cpp @@ -49,7 +49,7 @@ PyResult YieldFrom::execute(VirtualMachine &vm, Interpreter &interpreter) vm.reg(m_dst) = result.unwrap(); vm.reg(0) = result.unwrap(); vm.set_instruction_pointer(vm.instruction_pointer() - 1); - vm.pop_frame(true); + vm.pop_frame(false); } return result; diff --git a/src/executable/bytecode/instructions/YieldValue.cpp b/src/executable/bytecode/instructions/YieldValue.cpp index 275f432b..eadb612d 100644 --- a/src/executable/bytecode/instructions/YieldValue.cpp +++ b/src/executable/bytecode/instructions/YieldValue.cpp @@ -23,7 +23,7 @@ PyResult YieldValue::execute(VirtualMachine &vm, Interpreter &interpreter vm.reg(0) = result; - vm.pop_frame(true); + vm.pop_frame(false); return Ok(result); } diff --git a/src/executable/bytecode/serialization/deserialize.hpp b/src/executable/bytecode/serialization/deserialize.hpp index 65e2858a..9bcd704c 100644 --- a/src/executable/bytecode/serialization/deserialize.hpp +++ b/src/executable/bytecode/serialization/deserialize.hpp @@ -5,6 +5,7 @@ #include "serialize.hpp" #include "utilities.hpp" +#include #include #include diff --git a/src/executable/mlir/CMakeLists.txt b/src/executable/mlir/CMakeLists.txt index d283614b..3888d355 100644 --- a/src/executable/mlir/CMakeLists.txt +++ b/src/executable/mlir/CMakeLists.txt @@ -1,4 +1,4 @@ -find_package(LLVM 20.1 REQUIRED CONFIG) +find_package(LLVM 23 REQUIRED CONFIG) message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") find_package(MLIR CONFIG REQUIRED @@ -20,6 +20,8 @@ set(PYTHON_MLIR_BINARY_DIR ${PROJECT_BINARY_DIR}/src/executable/mlir) add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(Target) +add_subdirectory(tools/python-mlir-opt) +add_subdirectory(test) add_library(python-mlir compile.cpp) target_link_libraries(python-mlir PRIVATE PythonMLIRDialect TargetPythonBytecode PythonConversionPasses) diff --git a/src/executable/mlir/Conversion/Passes.td b/src/executable/mlir/Conversion/Passes.td index 67c8bf45..d652f058 100644 --- a/src/executable/mlir/Conversion/Passes.td +++ b/src/executable/mlir/Conversion/Passes.td @@ -4,4 +4,30 @@ include "mlir/Pass/PassBase.td" def ConvertPythonToPythonBytecode : Pass<"convert-python-to-pythonbytecode"> { let summary = "Convert recognized Python ops to PythonCpp bytecode"; let constructor = "mlir::py::createPythonToPythonBytecodePass()"; +} + +def ConvertPyForLoop : Pass<"convert-py-forloop"> { + let summary = "Lower py.for_loop to emitpybytecode control flow"; + let constructor = "mlir::py::createConvertForLoopPass()"; +} + +def ConvertPyWhile : Pass<"convert-py-while"> { + let summary = "Lower py.while to emitpybytecode control flow"; + let constructor = "mlir::py::createConvertWhileLoopPass()"; +} + +def ConvertPyTry : Pass<"convert-py-try"> { + let summary = "Lower py.try (and its handler scopes) to emitpybytecode control flow"; + let constructor = "mlir::py::createConvertTryPass()"; +} + +def ConvertPyWith : Pass<"convert-py-with"> { + let summary = "Lower py.with to emitpybytecode control flow"; + let constructor = "mlir::py::createConvertWithPass()"; +} + +def MaterialiseReturnNone : Pass<"materialise-return-none", "::mlir::func::FuncOp"> { + let summary = "Insert emitpybytecode.LOAD_CONST(None) before zero-operand func.return ops"; + let constructor = "mlir::py::createMaterialiseReturnNonePass()"; + let dependentDialects = ["::mlir::emitpybytecode::EmitPythonBytecodeDialect"]; } \ No newline at end of file diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/ArithPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/ArithPatterns.cpp new file mode 100644 index 00000000..51c3a3ad --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/ArithPatterns.cpp @@ -0,0 +1,147 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "executable/bytecode/instructions/BinaryOperation.hpp" +#include "executable/bytecode/instructions/Unary.hpp" +#include "utilities.hpp" + +#include "mlir/IR/PatternMatch.h" + +namespace mlir::py { +namespace { + + // Translate py.{binary,inplace_op}'s ArithOpKind enum to the + // bytecode-level BinaryOperation::Operation enum. The two are + // deliberately decoupled: the dialect enum is part of the IR + // contract; the bytecode enum is the wire format consumed by the + // VM, so the mapping between the two must stay explicit. + BinaryOperation::Operation py_kind_to_binary_op(mlir::py::ArithOpKind kind) + { + switch (kind) { + case mlir::py::ArithOpKind::add: + return BinaryOperation::Operation::PLUS; + case mlir::py::ArithOpKind::sub: + return BinaryOperation::Operation::MINUS; + case mlir::py::ArithOpKind::mod: + return BinaryOperation::Operation::MODULO; + case mlir::py::ArithOpKind::mul: + return BinaryOperation::Operation::MULTIPLY; + case mlir::py::ArithOpKind::exp: + return BinaryOperation::Operation::EXP; + case mlir::py::ArithOpKind::div: + return BinaryOperation::Operation::SLASH; + case mlir::py::ArithOpKind::fldiv: + return BinaryOperation::Operation::FLOORDIV; + case mlir::py::ArithOpKind::mmul: + return BinaryOperation::Operation::MATMUL; + case mlir::py::ArithOpKind::lshift: + return BinaryOperation::Operation::LEFTSHIFT; + case mlir::py::ArithOpKind::rshift: + return BinaryOperation::Operation::RIGHTSHIFT; + case mlir::py::ArithOpKind::and_: + return BinaryOperation::Operation::AND; + case mlir::py::ArithOpKind::or_: + return BinaryOperation::Operation::OR; + case mlir::py::ArithOpKind::xor_: + return BinaryOperation::Operation::XOR; + } + ASSERT_NOT_REACHED(); + } + + struct InplaceOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(InplaceOp op, + mlir::PatternRewriter &rewriter) const final + { + auto op_type = mlir::IntegerAttr::get(rewriter.getIntegerType(8, false), + static_cast(py_kind_to_binary_op(op.getKind()))); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getDst(), op.getSrc(), op_type); + + return success(); + } + }; + + struct BinaryOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::BinaryOp op, + mlir::PatternRewriter &rewriter) const final + { + auto op_type = mlir::IntegerAttr::get(rewriter.getIntegerType(8, false), + static_cast(py_kind_to_binary_op(op.getKind()))); + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), op.getLhs(), op.getRhs(), op_type); + return success(); + } + }; + + // Trivial 1:1 lowering of a py.unary_* op to emitpybytecode.UNARY_OP + // with the corresponding Unary::Operation enum baked in. + template + struct UnaryOpLowering : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + rewriter.template replaceOpWithNewOp( + op, op.getOutput().getType(), op.getInput(), static_cast(Kind)); + return mlir::success(); + } + }; + + using PositiveOpLowering = UnaryOpLowering; + using NegativeOpLowering = UnaryOpLowering; + using InvertOpLowering = UnaryOpLowering; + using NotOpLowering = UnaryOpLowering; + + struct CompareOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::CompareOp op, + mlir::PatternRewriter &rewriter) const final + { + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + auto op_type = mlir::IntegerAttr::get( + rewriter.getIntegerType(8, false), static_cast(op.getPredicate())); + + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), lhs, rhs, op_type); + + return success(); + } + }; + + struct CastToBoolOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::CastToBoolOp op, + mlir::PatternRewriter &rewriter) const final + { + rewriter.replaceOpWithNewOp( + op, op.getValue().getType(), op.getValue()); + return success(); + } + }; + +}// namespace + +void populateArithPatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); + patterns.add(ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/AttributeSubscriptPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/AttributeSubscriptPatterns.cpp new file mode 100644 index 00000000..972ab37d --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/AttributeSubscriptPatterns.cpp @@ -0,0 +1,49 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "mlir/IR/PatternMatch.h" + +namespace mlir::py { +namespace { + + using LoadAttributeOpLowering = detail::DirectReplaceLowering; + + using DeleteAttributeOpLowering = detail::DirectReplaceLowering; + + using LoadMethodOpLowering = detail::DirectReplaceRegisterName; + + using BinarySubscriptOpLowering = detail::DirectReplaceLowering; + + using StoreSubscriptOpLowering = detail::DirectReplaceLowering; + + using DeleteSubscriptOpLowering = detail::DirectReplaceLowering; + + using StoreAttributeOpLowering = detail::DirectReplaceRegisterName; + +}// namespace + +void populateAttributeSubscriptPatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/CMakeLists.txt b/src/executable/mlir/Conversion/PythonToPythonBytecode/CMakeLists.txt index 533d33b8..75b9be68 100644 --- a/src/executable/mlir/Conversion/PythonToPythonBytecode/CMakeLists.txt +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/CMakeLists.txt @@ -1,5 +1,12 @@ add_mlir_conversion_library(PythonToPythonBytecode PythonToPythonBytecode.cpp + ArithPatterns.cpp + AttributeSubscriptPatterns.cpp + CollectionPatterns.cpp + ControlFlowPatterns.cpp + FunctionPatterns.cpp + ImportPatterns.cpp + LoadStorePatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/src/executable/mlir/Conversion/PythonToPythonBytecode diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/CollectionPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/CollectionPatterns.cpp new file mode 100644 index 00000000..5bd4cd11 --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/CollectionPatterns.cpp @@ -0,0 +1,254 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "utilities.hpp" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" + +#include +#include + +namespace mlir::py { +namespace { + + struct BuildDictOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::BuildDictOp op, + mlir::PatternRewriter &rewriter) const final + { + const auto &requires_expansion = op.getRequiresExpansion(); + if (std::any_of(requires_expansion.begin(), + requires_expansion.end(), + [](const auto &el) { return el; })) { + std::optional result; + std::vector keys; + std::vector values; + + for (auto [key, value, to_expand] : + llvm::zip(op.getKeys(), op.getValues(), op.getRequiresExpansion())) { + if (to_expand) { + if (!result.has_value()) { + result = rewriter.create( + op.getLoc(), op.getOutput().getType(), keys, values); + keys.clear(); + values.clear(); + } + rewriter.create( + op.getLoc(), *result, value); + } else { + if (!result.has_value()) { + keys.push_back(key); + values.push_back(value); + } else { + ASSERT(keys.empty()); + ASSERT(values.empty()); + rewriter.create( + op.getLoc(), *result, key, value); + } + } + } + + ASSERT(result.has_value()); + ASSERT(keys.empty()); + ASSERT(values.empty()); + + rewriter.replaceOp(op, { *result }); + } else { + // Plain dict literal: lower 1:1. The register-pressure + // optimisation for large literals lives as a + // canonicalize pattern on emitpybytecode.BuildDict + // (ExpandLargeBuildDict in EmitPythonBytecode.cpp) so + // any path that reaches that op benefits, not just the + // one through this lowering. + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), op.getKeys(), op.getValues()); + } + + return success(); + } + }; + + // Build a list incrementally by walking (element, requires_expansion) + // pairs: each "expand" entry produces a ListExtend (unpacking *args + // or **kwargs in literals), each non-expand entry produces a + // ListAppend. Returns the resulting BuildList value, which both + // BuildListOp and BuildTupleOp lowerings hand off to their final + // step (replaceOp / wrap in ListToTuple respectively). + mlir::emitpybytecode::BuildList build_list_with_expansion(mlir::PatternRewriter &rewriter, + mlir::Location loc, + mlir::Type list_type, + mlir::ValueRange elements, + llvm::ArrayRef requires_expansion) + { + auto list = + rewriter.create(loc, list_type, mlir::ValueRange{}); + for (auto [el, expand] : llvm::zip(elements, requires_expansion)) { + if (expand) { + rewriter.create(loc, list, el); + } else { + rewriter.create(loc, list, el); + } + } + return list; + } + + struct BuildListOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::BuildListOp op, + mlir::PatternRewriter &rewriter) const final + { + const auto &requires_expansion = op.getRequiresExpansion(); + if (std::any_of(requires_expansion.begin(), + requires_expansion.end(), + [](const auto &el) { return el == 1; })) { + auto list = build_list_with_expansion(rewriter, + op.getLoc(), + op.getOutput().getType(), + op.getElements(), + requires_expansion); + rewriter.replaceOp(op, list); + } else { + // Plain list literal: lower 1:1. The all-constants + // register-pressure rewrite lives as a canonicalize + // pattern on emitpybytecode.BuildList + // (FoldAllConstBuildListIntoExtend in + // EmitPythonBytecode.cpp), so any caller landing on + // the bytecode-level op benefits, not just this path. + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), op.getElements()); + } + + return success(); + } + }; + + using ListAppendOpLowering = + detail::DirectReplaceLowering; + + using DictAddOpLowering = + detail::DirectReplaceLowering; + + struct BuildTupleOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::BuildTupleOp op, + mlir::PatternRewriter &rewriter) const final + { + const auto &requires_expansion = op.getRequiresExpansion(); + if (std::any_of(requires_expansion.begin(), + requires_expansion.end(), + [](const auto &el) { return el == 1; })) { + auto list = build_list_with_expansion(rewriter, + op.getLoc(), + op.getOutput().getType(), + op.getElements(), + requires_expansion); + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), list); + } else { + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), op.getElements()); + } + + return success(); + } + }; + + struct BuildSetOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::BuildSetOp op, + mlir::PatternRewriter &rewriter) const final + { + const auto &requires_expansion = op.getRequiresExpansion(); + if (std::any_of(requires_expansion.begin(), + requires_expansion.end(), + [](const auto &el) { return el == 1; })) { + std::vector elements; + std::optional set; + for (auto [el, expand] : llvm::zip(op.getElements(), requires_expansion)) { + if (expand) { + if (!set.has_value()) { + set = rewriter.create( + op->getLoc(), op.getOutput().getType(), elements); + } else { + for (auto el : elements) { + rewriter.create( + op.getLoc(), *set, el); + } + } + elements.clear(); + rewriter.create(op.getLoc(), *set, el); + } else { + elements.push_back(el); + } + } + ASSERT(set.has_value()); + for (auto el : elements) { + rewriter.create(op.getLoc(), *set, el); + } + rewriter.replaceOp(op, *set); + } else { + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), op.getElements()); + } + + return success(); + } + }; + + using SetAddOpLowering = + detail::DirectReplaceLowering; + + using BuildStringOpLowering = + detail::DirectReplaceLowering; + + struct FormatValueOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::FormatValueOp op, + mlir::PatternRewriter &rewriter) const final + { + rewriter.replaceOpWithNewOp(op, + op.getOutput().getType(), + op.getValue(), + static_cast(op.getConversion())); + + return success(); + } + }; + + using BuildSliceOpLowering = + detail::DirectReplaceLowering; + +}// namespace + +void populateCollectionPatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/ControlFlowPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/ControlFlowPatterns.cpp new file mode 100644 index 00000000..46bdea79 --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/ControlFlowPatterns.cpp @@ -0,0 +1,174 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "utilities.hpp" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir::py { +namespace { + + struct ConditionalBranchOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::cf::CondBranchOp op, + mlir::PatternRewriter &rewriter) const final + { + auto cond = (*op.getODSOperands(0).begin()); + ASSERT(cond.getDefiningOp()); + rewriter.replaceOpWithNewOp(op, + cond, + op.getTrueDest(), + op.getTrueDestOperands(), + op.getFalseDest(), + op.getFalseDestOperands()); + return success(); + } + }; + + struct CondBranchSubclassOpLowering + : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::CondBranchSubclassOp op, + mlir::PatternRewriter &rewriter) const final + { + rewriter.replaceOpWithNewOp(op, + op.getObjectType(), + op.getTrueDestOperands(), + op.getFalseDestOperands(), + op.getTrueDest(), + op.getFalseDest()); + + return success(); + } + }; + + using LoadAssertionErrorOpLowering = detail::DirectReplaceLowering; + + using WithExceptStartOpLowering = detail::DirectReplaceLowering; + + using ClearExceptionStateOpLowering = + detail::DirectReplaceLowering; + + struct RaiseOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + /// Find the first parent operation of the given type, or nullptr if there is + /// no ancestor operation. + template static mlir::Operation *getParentOfType(mlir::Region *region) + { + do { + if ((... || mlir::isa(*region->getParentOp()))) + return region->getParentOp(); + } while ((region = region->getParentRegion())); + return nullptr; + } + + static mlir::Block *get_handler(mlir::Operation *op, mlir::PatternRewriter &rewriter) + { + // find possible catch block in order to not clobber an active result register + auto *handler_op = + getParentOfType( + op->getParentRegion()); + ASSERT(handler_op); + return llvm::TypeSwitch(handler_op) + .Case([](mlir::py::TryOp op) { + return op.getHandlers().empty() ? &op.getFinally().front() + : &op.getHandlers().front().front(); + }) + .Case([](mlir::py::WithOp op) { return op->getParentOp()->getBlock(); }) + .Case([&rewriter](mlir::func::FuncOp op) { + auto insertion_point = rewriter.getInsertionPoint(); + auto *return_block = rewriter.createBlock(&op.getRegion()); + auto value = + rewriter.create(op.getLoc(), rewriter.getNoneType()); + rewriter.create(op.getLoc(), mlir::ValueRange{ value }); + rewriter.setInsertionPoint(insertion_point->getBlock(), insertion_point); + return return_block; + }) + .Default([](mlir::Operation *) -> mlir::Block * { + // Structurally unreachable: getParentOfType only + // walks for TryOp/WithOp/FuncOp, and the preceding + // ASSERT rules out nullptr, so one of the three + // Cases above must match. + ASSERT_NOT_REACHED(); + }); + } + + mlir::LogicalResult matchAndRewrite(mlir::py::RaiseOp op, + mlir::PatternRewriter &rewriter) const final + { + if (auto exception = op.getException()) { + rewriter.replaceOpWithNewOp( + op, exception, op.getCause(), get_handler(op, rewriter)); + } else { + rewriter.replaceOpWithNewOp( + op, get_handler(op, rewriter)); + } + + return success(); + } + }; + + using YieldOpLowering = + detail::DirectReplaceLowering; + + struct YieldFromOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::YieldFromOp op, + mlir::PatternRewriter &rewriter) const final + { + auto iterator = rewriter.create( + op.getLoc(), op.getIterable().getType(), op.getIterable()); + auto value = rewriter.create(op.getLoc(), rewriter.getNoneType()); + + rewriter.replaceOpWithNewOp( + op, iterator.getType(), iterator, value); + + return success(); + } + }; + + using UnpackSequenceOpLowering = detail::DirectReplaceLowering; + + using UnpackExpandOpLowering = detail::DirectReplaceLowering; + + using GetAwaitableOpLowering = detail::DirectReplaceLowering; + +}// namespace + +void populateControlFlowPatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/FunctionPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/FunctionPatterns.cpp new file mode 100644 index 00000000..357426d5 --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/FunctionPatterns.cpp @@ -0,0 +1,251 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "utilities.hpp" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Visitors.h" + +#include +#include + +namespace mlir::py { +namespace { + + struct CallFunctionLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(py::FunctionCallOp op, + mlir::PatternRewriter &rewriter) const final + { + auto callee = op.getCallee(); + auto args = op.getArgs(); + + if (op.getRequiresArgsExpansion() || op.getRequiresKwargsExpansion()) { + ASSERT(args.size() <= 1); + ASSERT(op.getKwargs().size() <= 1); + rewriter.replaceOpWithNewOp(op, + op.getOutput().getType(), + callee, + op.getRequiresArgsExpansion() ? args.front() : nullptr, + op.getRequiresKwargsExpansion() ? op.getKwargs().front() : nullptr); + } else if (!op.getKeywords().empty()) { + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), callee, args, op.getKeywords(), op.getKwargs()); + } else { + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), callee, args); + } + + return success(); + } + }; + + // py.return -> func.return. py.return relaxes func.return's + // HasParent constraint so that return statements emitted + // inside py.try / py.with regions don't trip the verifier + // pre-lowering. By the time this pattern fires, the surrounding + // region ops have been flattened into the enclosing func.func's + // body, so func.return is well-formed. + struct ReturnOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::ReturnOp op, + mlir::PatternRewriter &rewriter) const final + { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return mlir::success(); + } + }; + + struct MakeFunctionOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + mlir::LogicalResult matchAndRewrite(mlir::py::MakeFunctionOp op, + mlir::PatternRewriter &rewriter) const final + { + auto module = op->getParentOfType(); + auto function_definition = module.lookupSymbol(op.getFunctionName()); + ASSERT(function_definition); + ASSERT(mlir::isa(*function_definition)); + + auto sym_name = rewriter.create(op.getLoc(), + mlir::py::PyObjectType::get(rewriter.getContext()), + rewriter.getStringAttr(op.getFunctionName())); + + auto captures_tuple = [&]() -> mlir::Value { + if (op.getCaptures().empty()) { return nullptr; } + std::vector captures_vec; + for (auto attr : op.getCaptures()) { + auto name = mlir::cast(attr).getValue(); + captures_vec.push_back(rewriter.create( + op.getLoc(), mlir::py::PyObjectType::get(getContext()), name)); + } + return rewriter.create( + op.getLoc(), mlir::py::PyObjectType::get(getContext()), captures_vec); + }(); + rewriter.replaceOpWithNewOp(op, + mlir::py::PyObjectType::get(rewriter.getContext()), + sym_name, + op.getDefaults(), + op.getKwDefaults(), + captures_tuple); + + return success(); + } + }; + + struct FuncOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp op, + mlir::PatternRewriter &rewriter) const final + { + if (op.isPrivate()) { return success(); } + populate_arguments(op, rewriter); + return success(); + } + + void populate_arguments(mlir::func::FuncOp &op, mlir::OpBuilder &builder) const + { + for (size_t i = 0; i < op.getNumArguments(); ++i) { + auto arg_name = op.getArgAttr(i, "llvm.name"); + ASSERT(arg_name); + detail::add_identifier_to( + op, "locals", mlir::cast(arg_name).getValue(), builder); + } + } + }; + + struct ClassDefinitionOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::ClassDefinitionOp op, + mlir::PatternRewriter &rewriter) const final + { + auto module = op->getParentOfType(); + rewriter.setInsertionPointToEnd(module.getBody()); + + auto func_type = rewriter.getFunctionType(mlir::TypeRange{}, + mlir::TypeRange{ mlir::py::PyObjectType::get(rewriter.getContext()) }); + auto class_fn_definition = rewriter.create(op.getLoc(), + op.getMangledName(), + func_type, + mlir::ArrayRef{}, + mlir::ArrayRef{}); + + class_fn_definition->setAttr("is_class", rewriter.getBoolAttr(true)); + + if (auto cellvars = op->getAttrOfType("cellvars")) { + auto cell_names = cellvars.getValue(); + if (std::find_if(cell_names.begin(), + cell_names.end(), + [](mlir::Attribute name) { + return mlir::cast(name) == "__class__"; + }) + != cell_names.end()) { + + mlir::py::ClassReturnOp return_op; + op.getBody().walk([&return_op](mlir::Operation *child_op) { + if (mlir::isa(child_op)) { + return WalkResult::skip(); + } + if (auto cr = mlir::dyn_cast(child_op)) { + return_op = cr; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + ASSERT(return_op); + ASSERT(return_op->getParentOp() == op.getOperation()); + ASSERT(return_op.getValue().getDefiningOp()); + rewriter.setInsertionPoint(return_op.getValue().getDefiningOp()); + rewriter.replaceOpWithNewOp( + return_op.getValue().getDefiningOp(), + mlir::py::PyObjectType::get(getContext()), + mlir::StringRef{ "__class__" }); + } + } + + auto attr = class_fn_definition->getAttrs().vec(); + attr.insert(attr.end(), op->getAttrs().begin(), op->getAttrs().end()); + class_fn_definition->setAttrs(attr); + + // Convert all py.class_return ops in the class body to + // func.return so that the body, once inlined into the + // synthesised func.func, has a valid terminator. + op.getBody().walk([&rewriter](mlir::py::ClassReturnOp cr) { + rewriter.setInsertionPoint(cr); + rewriter.replaceOpWithNewOp( + cr, mlir::ValueRange{ cr.getValue() }); + }); + + auto *end = class_fn_definition.addEntryBlock(); + rewriter.setInsertionPointToStart(end); + rewriter.inlineRegionBefore(op.getBody(), &class_fn_definition.getBody().front()); + rewriter.eraseBlock(end); + + rewriter.setInsertionPoint(op); + auto class_name = rewriter.create(op.getLoc(), + mlir::py::PyObjectType::get(rewriter.getContext()), + rewriter.getStringAttr(op.getMangledName())); + + auto captures_tuple = [&]() -> mlir::Value { + if (op.getCaptures().empty()) { return {}; } + std::vector captures_vec; + for (auto attr : op.getCaptures()) { + auto name = mlir::cast(attr).getValue(); + captures_vec.push_back(rewriter.create( + op.getLoc(), mlir::py::PyObjectType::get(getContext()), name)); + } + return rewriter.create( + op.getLoc(), mlir::py::PyObjectType::get(getContext()), captures_vec); + }(); + + auto class_fn = rewriter.create(op.getLoc(), + mlir::py::PyObjectType::get(rewriter.getContext()), + class_name, + mlir::ValueRange{}, + mlir::ValueRange{}, + captures_tuple); + + auto class_builder = rewriter.create( + op.getLoc(), mlir::py::PyObjectType::get(rewriter.getContext())); + std::vector args{ class_fn, class_name }; + args.insert(args.end(), op.getBases().begin(), op.getBases().end()); + rewriter.replaceOpWithNewOp(op, + op.getOutput().getType(), + class_builder, + args, + op.getKeywords(), + op.getKwargs(), + false, + false); + + return success(); + } + }; + +}// namespace + +void populateFunctionPatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/ImportPatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/ImportPatterns.cpp new file mode 100644 index 00000000..05194168 --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/ImportPatterns.cpp @@ -0,0 +1,59 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" +#include "Dialect/Python/IR/PythonTypes.hpp" + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::py { +namespace { + + // py.import lowers to LOAD_CONST(level) + a BuildTuple over the + // from_list constants, then IMPORT_NAME. The bytecode form takes + // the level and from-list as actual operands, so the conversion + // materialises them as constants here rather than threading them + // through as attributes on the bytecode op. + struct ImportOpLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::py::ImportOp op, + mlir::PatternRewriter &rewriter) const final + { + auto name = op.getName(); + auto level = rewriter.create( + op.getLoc(), op.getModule().getType(), rewriter.getUI32IntegerAttr(op.getLevel())); + std::vector els; + for (auto attr : op.getFromList()) { + auto from = mlir::cast(attr); + els.push_back(rewriter.create( + op.getLoc(), op.getModule().getType(), from)); + } + auto from_list = rewriter.create( + op.getLoc(), op.getModule().getType(), els); + rewriter.replaceOpWithNewOp( + op, op.getModule().getType(), name, level, from_list); + + return success(); + } + }; + + using ImportAllOpLowering = + detail::DirectReplaceLowering; + + using ImportFromOpLowering = + detail::DirectReplaceLowering; + +}// namespace + +void populateImportPatterns(mlir::RewritePatternSet &patterns) +{ + patterns.add( + patterns.getContext()); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/LoadStorePatterns.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/LoadStorePatterns.cpp new file mode 100644 index 00000000..4426dbca --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/LoadStorePatterns.cpp @@ -0,0 +1,131 @@ +#include "Conversion/PythonToPythonBytecode/LoweringHelpers.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" + +#include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" +#include "Dialect/Python/IR/PythonAttributes.hpp" +#include "Dialect/Python/IR/PythonOps.hpp" + +#include "mlir/IR/PatternMatch.h" + +namespace mlir::py { +namespace { + + // py.constant -> emitpybytecode.LOAD_CONST, with a special case for + // the ellipsis attribute (which has its own bytecode-level op). + struct ConstantLoadLowering : public mlir::OpRewritePattern + { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(py::ConstantOp op, + mlir::PatternRewriter &rewriter) const final + { + auto constant_value = op.getValue(); + + auto ellipsis = + mlir::detail::AttributeUniquer::get(getContext()); + if (op.getValue() == ellipsis) { + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType()); + return success(); + } + + rewriter.replaceOpWithNewOp( + op, op.getOutput().getType(), constant_value); + + return success(); + } + }; + + // Trivial 1:1 lowering of py.load_* (and similar single-result name- + // referencing ops) to the corresponding emitpybytecode.LOAD_* op. + template + struct LoadNameLikeLowering : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + detail::register_name_with_parent(op, op.getNameAttr(), rewriter, Kind); + rewriter.template replaceOpWithNewOp( + op, op.getOutput().getType(), op.getNameAttr()); + return mlir::success(); + } + }; + + // Trivial 1:1 lowering of py.store_* to emitpybytecode.STORE_*. + template + struct StoreNameLikeLowering : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + detail::register_name_with_parent(op, op.getNameAttr(), rewriter, Kind); + rewriter.template replaceOpWithNewOp(op, op.getNameAttr(), op.getValue()); + return mlir::success(); + } + }; + + // Trivial 1:1 lowering of py.delete_* to emitpybytecode.DELETE_*. + template + struct DeleteNameLikeLowering : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + detail::register_name_with_parent(op, op.getNameAttr(), rewriter, Kind); + rewriter.template replaceOpWithNewOp(op, op.getNameAttr()); + return mlir::success(); + } + }; + + using LoadNameLowering = LoadNameLikeLowering; + using LoadDerefLowering = + LoadNameLikeLowering; + using LoadFastLowering = LoadNameLikeLowering; + using LoadGlobalLowering = LoadNameLikeLowering; + + using StoreNameLowering = + StoreNameLikeLowering; + using StoreDerefLowering = + StoreNameLikeLowering; + using StoreFastLowering = StoreNameLikeLowering; + using StoreGlobalLowering = StoreNameLikeLowering; + + using DeleteNameLowering = + DeleteNameLikeLowering; + using DeleteDerefLowering = + DeleteNameLikeLowering; + using DeleteFastLowering = DeleteNameLikeLowering; + using DeleteGlobalLowering = DeleteNameLikeLowering; + +}// namespace + +void populateLoadStorePatterns(mlir::RewritePatternSet &patterns) +{ + auto *ctx = patterns.getContext(); + patterns.add(ctx); + patterns.add( + ctx); + patterns.add( + ctx); +} + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/LoweringHelpers.hpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/LoweringHelpers.hpp new file mode 100644 index 00000000..57c81b56 --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/LoweringHelpers.hpp @@ -0,0 +1,110 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#include +#include + +namespace mlir::py::detail { + +// Adds `identifier` to the StrArrayAttr named `attr_name` on `fn` if +// not already present. Used by the conversion patterns to track the +// function-level "locals" / "names" sets that the bytecode emitter +// consumes. +inline void add_identifier_to(mlir::func::FuncOp fn, + mlir::StringRef attr_name, + mlir::StringRef identifier, + mlir::OpBuilder &builder) +{ + if (fn->hasAttr(attr_name)) { + auto arr = mlir::cast(fn->getAttr(attr_name)).getValue(); + if (std::find_if(arr.begin(), + arr.end(), + [identifier](mlir::Attribute attr) { + return mlir::cast(attr).getValue() == identifier; + }) + != arr.end()) { + return; + } + std::vector names_vec; + std::transform( + arr.begin(), arr.end(), std::back_inserter(names_vec), [](mlir::Attribute attr) { + return mlir::cast(attr).getValue(); + }); + names_vec.emplace_back(identifier); + fn->setAttr(attr_name, builder.getStrArrayAttr(names_vec)); + } else { + fn->setAttr(attr_name, builder.getStrArrayAttr({ identifier })); + } +} + +inline void + add_identifier(mlir::func::FuncOp fn, mlir::StringRef identifier, mlir::OpBuilder &builder) +{ + add_identifier_to(fn, "names", identifier, builder); +} + +// Categorizes how a Load/Store/Delete pattern should register its name +// with the parent FuncOp before lowering. +enum class NameKind { + None,// no registration (local-scope name / cell deref) + Local,// registers in the func's "locals" attribute + Global,// registers in the func's "names" attribute +}; + +inline void register_name_with_parent(mlir::Operation *op, + mlir::StringAttr name, + mlir::OpBuilder &builder, + NameKind kind) +{ + if (kind == NameKind::None) { return; } + auto fn = mlir::cast_or_null(op->getParentOp()); + assert(fn); + add_identifier_to(fn, kind == NameKind::Local ? "locals" : "names", name.getValue(), builder); +} + +// Generic 1:1 lowering for ops whose source and target dialect schemas +// match exactly: same operand types/order, same attribute names/types, +// same result types. Forwards everything from source op to target op. +template +struct DirectReplaceLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + rewriter.template replaceOpWithNewOp(op, + op.getOperation()->getResultTypes(), + op.getOperation()->getOperands(), + op.getOperation()->getAttrs()); + return mlir::success(); + } +}; + +// Like DirectReplaceLowering, but additionally registers the name +// returned by NameGetter in the parent FuncOp's "names" attribute. +// Used for ops whose names must appear in the function-level name +// table the bytecode emitter consumes (LoadMethod, StoreAttribute). +template +struct DirectReplaceRegisterName : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(From op, mlir::PatternRewriter &rewriter) const final + { + auto parent_fn = op->template getParentOfType(); + add_identifier(parent_fn, (op.*NameGetter)(), rewriter); + rewriter.template replaceOpWithNewOp(op, + op.getOperation()->getResultTypes(), + op.getOperation()->getOperands(), + op.getOperation()->getAttrs()); + return mlir::success(); + } +}; + +}// namespace mlir::py::detail diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/PatternPopulators.hpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/PatternPopulators.hpp new file mode 100644 index 00000000..78bb037e --- /dev/null +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/PatternPopulators.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace mlir::py { + +// Per-op-family pattern populators. PythonToPythonBytecodePass calls +// each populator to assemble its full RewritePatternSet. Splitting +// these out keeps the monolithic conversion file from ballooning past +// the point where it can be reviewed in one sitting. + +void populateArithPatterns(mlir::RewritePatternSet &patterns); +void populateAttributeSubscriptPatterns(mlir::RewritePatternSet &patterns); +void populateCollectionPatterns(mlir::RewritePatternSet &patterns); +void populateControlFlowPatterns(mlir::RewritePatternSet &patterns); +void populateFunctionPatterns(mlir::RewritePatternSet &patterns); +void populateImportPatterns(mlir::RewritePatternSet &patterns); +void populateLoadStorePatterns(mlir::RewritePatternSet &patterns); + +}// namespace mlir::py diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp index 46373a87..117a5742 100644 --- a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.cpp @@ -1,4 +1,5 @@ #include "Conversion/PythonToPythonBytecode/PythonToPythonBytecode.hpp" +#include "Conversion/PythonToPythonBytecode/PatternPopulators.hpp" #include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" #include "Dialect/Python/IR/Dialect.hpp" #include "Dialect/Python/IR/PythonAttributes.hpp" @@ -38,1329 +39,78 @@ namespace mlir { namespace py { namespace { - struct ConstantLoadLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::ConstantOp op, - mlir::PatternRewriter &rewriter) const final - { - auto constant_value = op.getValue(); - - - auto ellipsis = - mlir::detail::AttributeUniquer::get(getContext()); - if (op.getValue() == ellipsis) { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType()); - return success(); - } - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), constant_value); - - return success(); - } - }; - - struct StoreNameLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::StoreNameOp op, - mlir::PatternRewriter &rewriter) const final - { - auto object_name = op.getNameAttr(); - auto object_value = op.getValue(); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name, object_value); - - return success(); - } - }; - - struct StoreDerefLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::StoreDerefOp op, - mlir::PatternRewriter &rewriter) const final - { - auto object_name = op.getNameAttr(); - auto object_value = op.getValue(); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name, object_value); - - return success(); - } - }; - - template struct LocalDeclarationInterface : public T - { - template - LocalDeclarationInterface(Args &&...args) : T(std::forward(args)...) - {} - - void addLocalIdentifierToParentFunction(mlir::func::FuncOp &fn, - mlir::StringAttr identifier, - mlir::OpBuilder &builder) const - { - if (fn->hasAttr("locals")) { - auto names = fn->getAttr("locals"); - std::vector names_vec; - auto arr = names.cast().getValue(); - if (std::find_if(arr.begin(), - arr.end(), - [identifier](mlir::Attribute attr) { - return attr.cast().getValue() == identifier; - }) - != arr.end()) { - return; - } - std::transform(arr.begin(), - arr.end(), - std::back_inserter(names_vec), - [](mlir::Attribute attr) { - return attr.cast().getValue(); - }); - names_vec.emplace_back(identifier); - fn->setAttr("locals", builder.getStrArrayAttr(names_vec)); - } else { - fn->setAttr("locals", builder.getStrArrayAttr({ identifier })); - } - } - }; - - namespace { - void add_identifier(mlir::func::FuncOp &fn, - mlir::StringRef identifier, - mlir::OpBuilder &builder) - { - if (fn->hasAttr("names")) { - auto names = fn->getAttr("names"); - std::vector names_vec; - auto arr = names.cast().getValue(); - if (std::find_if(arr.begin(), - arr.end(), - [identifier](mlir::Attribute attr) { - return attr.cast().getValue() == identifier; - }) - != arr.end()) { - return; - } - std::transform(arr.begin(), - arr.end(), - std::back_inserter(names_vec), - [](mlir::Attribute attr) { - return attr.cast().getValue(); - }); - names_vec.emplace_back(identifier); - fn->setAttr("names", builder.getStrArrayAttr(names_vec)); - } else { - fn->setAttr("names", builder.getStrArrayAttr({ identifier })); - } - } - - void build_const(mlir::func::FuncOp &fn, std::vector values) {} - }// namespace - - template struct GlobalDeclarationInterface : public T - { - template - GlobalDeclarationInterface(Args &&...args) : T(std::forward(args)...) - {} - - void addGlobalIdentifierToParentFunction(mlir::func::FuncOp &fn, - mlir::StringAttr identifier, - mlir::OpBuilder &builder) const - { - add_identifier(fn, identifier, builder); - } - }; - - struct StoreFastLowering - : public LocalDeclarationInterface> - { - using LocalDeclarationInterface< - OpRewritePattern>::LocalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::StoreFastOp op, - mlir::PatternRewriter &rewriter) const final - { - auto object_name = op.getNameAttr(); - auto object_value = op.getValue(); - - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - addLocalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name, object_value); - - return success(); - } - }; - - struct StoreGlobalLowering - : public GlobalDeclarationInterface> - { - using GlobalDeclarationInterface< - OpRewritePattern>::GlobalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::StoreGlobalOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - addGlobalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name, op.getValue()); - - return success(); - } - }; - - struct LoadNameLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::LoadNameOp op, - mlir::PatternRewriter &rewriter) const final - { - auto object_name = op.getNameAttr(); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name); - - return success(); - } - }; - - - struct LoadGlobalLowering - : public GlobalDeclarationInterface> - { - using GlobalDeclarationInterface< - OpRewritePattern>::GlobalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::LoadGlobalOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - addGlobalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name); - - return success(); - } - }; - - - struct LoadFastLowering - : public LocalDeclarationInterface> - { - using LocalDeclarationInterface< - mlir::OpRewritePattern>::LocalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::LoadFastOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - addLocalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name); - - return success(); - } - }; - - struct LoadDerefLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::LoadDerefOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), object_name); - - return success(); - } - }; - - struct DeleteNameLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::DeleteNameOp op, - mlir::PatternRewriter &rewriter) const final - { - auto object_name = op.getNameAttr(); - - rewriter.replaceOpWithNewOp(op, object_name); - - return success(); - } - }; - - struct DeleteFastLowering - : public LocalDeclarationInterface> - { - using LocalDeclarationInterface< - mlir::OpRewritePattern>::LocalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::DeleteFastOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - addLocalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp(op, object_name); - - return success(); - } - }; - - struct DeleteGlobalLowering - : public GlobalDeclarationInterface> - { - using GlobalDeclarationInterface< - OpRewritePattern>::GlobalDeclarationInterface; - - mlir::LogicalResult matchAndRewrite(py::DeleteGlobalOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - addGlobalIdentifierToParentFunction(fn, object_name, rewriter); - - rewriter.replaceOpWithNewOp(op, object_name); - - return success(); - } - }; - - struct DeleteDerefLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::DeleteDerefOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent = op.getOperation()->getParentOp(); - auto fn = mlir::cast_or_null(parent); - ASSERT(fn); - auto object_name = op.getNameAttr(); - - rewriter.replaceOpWithNewOp(op, object_name); - - return success(); - } - }; - - struct CallFunctionLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(py::FunctionCallOp op, - mlir::PatternRewriter &rewriter) const final - { - auto callee = op.getCallee(); - auto args = op.getArgs(); - - if (op.getRequiresArgsExpansion() || op.getRequiresKwargsExpansion()) { - ASSERT(args.size() <= 1); - ASSERT(op.getKwargs().size() <= 1); - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - callee, - op.getRequiresArgsExpansion() ? args.front() : nullptr, - op.getRequiresKwargsExpansion() ? op.getKwargs().front() : nullptr); - } else if (!op.getKeywords().empty()) { - rewriter.replaceOpWithNewOp( - op, - op.getOutput().getType(), - callee, - args, - op.getKeywords(), - op.getKwargs()); - } else { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), callee, args); - } - - return success(); - } - }; - - struct ConditionalBranchOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::cf::CondBranchOp op, - mlir::PatternRewriter &rewriter) const final - { - auto cond = (*op.getODSOperands(0).begin()); - if (mlir::isa(cond.getDefiningOp())) { - auto bool_cast = mlir::cast(cond.getDefiningOp()); - auto pycond = rewriter.replaceOpWithNewOp( - bool_cast, bool_cast.getValue().getType(), bool_cast.getValue()); - rewriter.replaceOpWithNewOp(op, - pycond.getValue(), - op.getTrueDest(), - op.getTrueDestOperands(), - op.getFalseDest(), - op.getFalseDestOperands()); - } else { - ASSERT(mlir::isa(cond.getDefiningOp())); - rewriter.replaceOpWithNewOp(op, - mlir::cast(cond.getDefiningOp()) - .getValue(), - op.getTrueDest(), - op.getTrueDestOperands(), - op.getFalseDest(), - op.getFalseDestOperands()); - } - return success(); - } - }; - - struct CondBranchSubclassOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::CondBranchSubclassOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getObjectType(), - op.getTrueDestOperands(), - op.getFalseDestOperands(), - op.getTrueDest(), - op.getFalseDest()); - - return success(); - } - }; - - struct CompareOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::Compare op, - mlir::PatternRewriter &rewriter) const final - { - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto op_type = mlir::IntegerAttr::get( - rewriter.getIntegerType(8, false), static_cast(op.getPredicate())); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), lhs, rhs, op_type); - - return success(); - } - }; - - struct InplaceOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(InplaceOp op, - mlir::PatternRewriter &rewriter) const final - { - auto kind = [&op]() { - switch (op.getKind()) { - case py::InplaceOpKind::add: { - return BinaryOperation::Operation::PLUS; - } break; - case py::InplaceOpKind::sub: { - return BinaryOperation::Operation::MINUS; - } break; - case py::InplaceOpKind::mod: { - return BinaryOperation::Operation::MODULO; - } break; - case py::InplaceOpKind::mul: { - return BinaryOperation::Operation::MULTIPLY; - } break; - case py::InplaceOpKind::exp: { - return BinaryOperation::Operation::EXP; - } break; - case py::InplaceOpKind::div: { - return BinaryOperation::Operation::SLASH; - } break; - case py::InplaceOpKind::fldiv: { - return BinaryOperation::Operation::FLOORDIV; - } break; - case py::InplaceOpKind::lshift: { - return BinaryOperation::Operation::LEFTSHIFT; - } break; - case py::InplaceOpKind::rshift: { - return BinaryOperation::Operation::RIGHTSHIFT; - } break; - case py::InplaceOpKind::and_: { - return BinaryOperation::Operation::AND; - } break; - case py::InplaceOpKind::or_: { - return BinaryOperation::Operation::OR; - } break; - case py::InplaceOpKind::xor_: { - return BinaryOperation::Operation::XOR; - } break; - case py::InplaceOpKind::mmul: { - return BinaryOperation::Operation::MATMUL; - } break; - } - ASSERT_NOT_REACHED(); - }(); - auto dst = op.getDst(); - auto src = op.getSrc(); - auto op_type = mlir::IntegerAttr::get( - rewriter.getIntegerType(8, false), static_cast(kind)); - - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), dst, src, op_type); - - return success(); - } - }; - - template - struct BinaryOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(BinaryOpType op, - mlir::PatternRewriter &rewriter) const final - { - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto op_type = mlir::IntegerAttr::get( - rewriter.getIntegerType(8, false), static_cast(OperationEnumType)); - - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), lhs, rhs, op_type); - - return success(); - } - }; - -#define BINARY_OP_LOWERING(OPNAME, BINARYOP_ENUM) \ - using OPNAME##Lowering = BinaryOpLowering - - BINARY_OP_LOWERING(BinaryAddOp, PLUS); - BINARY_OP_LOWERING(BinarySubtractOp, MINUS); - BINARY_OP_LOWERING(BinaryModuloOp, MODULO); - BINARY_OP_LOWERING(BinaryMultiplyOp, MULTIPLY); - BINARY_OP_LOWERING(BinaryExpOp, EXP); - BINARY_OP_LOWERING(BinaryDivOp, SLASH); - BINARY_OP_LOWERING(BinaryFloorDivOp, FLOORDIV); - BINARY_OP_LOWERING(BinaryMatMulOp, MATMUL); - BINARY_OP_LOWERING(LeftShiftOp, LEFTSHIFT); - BINARY_OP_LOWERING(RightShiftOp, RIGHTSHIFT); - BINARY_OP_LOWERING(LogicalAndOp, AND); - BINARY_OP_LOWERING(LogicalOrOp, OR); - BINARY_OP_LOWERING(LogicalXorOp, XOR); - -#undef BINARY_OP_LOWERING - - struct LoadAssertionErrorOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::LoadAssertionError op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType()); - - return success(); - } - }; - - struct PositiveOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::PositiveOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - op.getInput(), - static_cast(Unary::Operation::POSITIVE)); - - return success(); - } - }; - - struct NegativeOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::NegativeOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - op.getInput(), - static_cast(Unary::Operation::NEGATIVE)); - - return success(); - } - }; - - struct InvertOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::InvertOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - op.getInput(), - static_cast(Unary::Operation::INVERT)); - - return success(); - } - }; - - struct NotOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::NotOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - op.getInput(), - static_cast(Unary::Operation::NOT)); - - return success(); - } - }; - - struct BuildDictOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildDictOp op, - mlir::PatternRewriter &rewriter) const final - { - const auto &requires_expansion = op.getRequiresExpansion(); - if (std::any_of(requires_expansion.begin(), - requires_expansion.end(), - [](const auto &el) { return el; })) { - std::optional result; - std::vector keys; - std::vector values; - - for (auto [key, value, to_expand] : - llvm::zip(op.getKeys(), op.getValues(), op.getRequiresExpansion())) { - if (to_expand) { - if (!result.has_value()) { - result = rewriter.create( - op.getLoc(), op.getOutput().getType(), keys, values); - keys.clear(); - values.clear(); - } - rewriter.create( - op.getLoc(), *result, value); - } else { - if (!result.has_value()) { - keys.push_back(key); - values.push_back(value); - } else { - ASSERT(keys.empty()); - ASSERT(values.empty()); - rewriter.create( - op.getLoc(), *result, key, value); - } - } - } - - ASSERT(result.has_value()); - ASSERT(keys.empty()); - ASSERT(values.empty()); - - rewriter.replaceOp(op, { *result }); - } else { - // TODO: this is a hack to avoid spilling to the stack when building large - // dictionaries from literals - if (op.getValues().size() > 10) { - auto keys = op.getKeys(); - auto values = op.getValues(); - rewriter.setInsertionPointAfterValue(keys.front()); - auto result = rewriter.create(op->getLoc(), - op.getOutput().getType(), - mlir::ValueRange{}, - mlir::ValueRange{}); - - for (auto [key, value] : llvm::zip(keys, values)) { - rewriter.setInsertionPointAfterValue(value); - rewriter.create( - op.getLoc(), result, key, value); - } - rewriter.replaceOp(op, result); - } else { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getKeys(), op.getValues()); - } - } - - return success(); - } - }; - - struct BuildListOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildListOp op, - mlir::PatternRewriter &rewriter) const final - { - const auto &requires_expansion = op.getRequiresExpansion(); - auto known_at_compiletime = [](mlir::Value element) -> bool { - ASSERT(element.getDefiningOp()); - return mlir::isa(element.getDefiningOp()) - || mlir::isa(element.getDefiningOp()); - }; - if (std::any_of(requires_expansion.begin(), - requires_expansion.end(), - [](const auto &el) { return el == 1; })) { - auto list = rewriter.create( - op.getLoc(), op.getOutput().getType(), ValueRange{}); - for (auto [el, expand] : llvm::zip(op.getElements(), requires_expansion)) { - if (expand) { - rewriter.create( - op.getLoc(), list, el); - } else { - rewriter.create( - op.getLoc(), list, el); - } - } - rewriter.replaceOp(op, list); - } else if (std::all_of(op.getElements().begin(), - op.getElements().end(), - known_at_compiletime)) { - std::vector elements; - elements.reserve(op.getElements().size()); - for (const auto &el : op.getElements()) { - if (el.getDefiningOp()) { - elements.push_back(el.getDefiningOp().getValue()); - } else { - ASSERT(el.getDefiningOp()); - elements.push_back( - el.getDefiningOp().getValue()); - } - } - auto loc = op.getLoc(); - auto output_type = op.getOutput().getType(); - auto list = rewriter.create( - loc, output_type, ::mlir::ValueRange{}); - auto tuple = rewriter.create( - loc, output_type, mlir::ArrayAttr::get(getContext(), elements)); - rewriter.create(loc, list, tuple); - rewriter.replaceOp(op, list); - } else { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getElements()); - } - - return success(); - } - }; - - struct ListAppendOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::ListAppendOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getList(), op.getValue()); - - return success(); - } - }; - - struct DictAddOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::DictAddOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getDict(), op.getKey(), op.getValue()); - - return success(); - } - }; - - struct BuildTupleOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildTupleOp op, - mlir::PatternRewriter &rewriter) const final - { - const auto &requires_expansion = op.getRequiresExpansion(); - if (std::any_of(requires_expansion.begin(), - requires_expansion.end(), - [](const auto &el) { return el == 1; })) { - auto list = rewriter.create( - op.getLoc(), op.getOutput().getType(), ValueRange{}); - for (auto [el, expand] : llvm::zip(op.getElements(), requires_expansion)) { - if (expand) { - rewriter.create( - op.getLoc(), list, el); - } else { - rewriter.create( - op.getLoc(), list, el); - } - } - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), list); - } else { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getElements()); - } - - return success(); - } - }; - - struct BuildSetOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildSetOp op, - mlir::PatternRewriter &rewriter) const final - { - const auto &requires_expansion = op.getRequiresExpansion(); - if (std::any_of(requires_expansion.begin(), - requires_expansion.end(), - [](const auto &el) { return el == 1; })) { - std::vector elements; - std::optional set; - for (auto [el, expand] : llvm::zip(op.getElements(), requires_expansion)) { - if (expand) { - if (!set.has_value()) { - set = rewriter.create( - op->getLoc(), op.getOutput().getType(), elements); - } else { - for (auto el : elements) { - rewriter.create( - op.getLoc(), *set, el); - } - } - elements.clear(); - rewriter.create(op.getLoc(), *set, el); - } else { - elements.push_back(el); - } - } - ASSERT(set.has_value()); - for (auto el : elements) { - rewriter.create(op.getLoc(), *set, el); - } - rewriter.replaceOp(op, *set); - } else { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getElements()); - } - - return success(); - } - }; - - struct SetAddOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::SetAddOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getSet(), op.getValue()); - - return success(); - } - }; - - struct BuildStringOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildStringOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getElements()); - - return success(); - } - }; - - struct FormatValueOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::FormatValueOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - op.getValue(), - static_cast(op.getConversion())); - - return success(); - } - }; - - struct LoadAttributeOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::LoadAttributeOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getSelf(), op.getAttr()); - - return success(); - } - }; - - struct DeleteAttributeOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::DeleteAttributeOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getSelf(), op.getAttr()); - - return success(); - } - }; - - struct LoadMethodOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::LoadMethodOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent_fn = op->getParentOfType(); - add_identifier(parent_fn, op.getMethodName(), rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getMethod().getType(), op.getSelf(), op.getMethodName()); - - return success(); - } - }; - - struct BinarySubscriptOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BinarySubscriptOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getSelf(), op.getSubscript()); - - return success(); - } - }; - - struct StoreSubscriptOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::StoreSubscriptOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getSelf(), op.getSubscript(), op.getValue()); - - return success(); - } - }; - - struct DeleteSubscriptOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::DeleteSubscriptOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getSelf(), op.getSubscript()); - - return success(); - } - }; - - struct StoreAttributeOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::StoreAttributeOp op, - mlir::PatternRewriter &rewriter) const final - { - auto parent_fn = op->getParentOfType(); - add_identifier(parent_fn, op.getAttribute(), rewriter); - - rewriter.replaceOpWithNewOp( - op, op.getSelf(), op.getAttribute(), op.getValue()); - - return success(); - } - }; - - struct BuildSliceOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::BuildSliceOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getSlice().getType(), op.getLower(), op.getUpper(), op.getStep()); - - return success(); - } - }; - - struct MakeFunctionOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite(mlir::py::MakeFunctionOp op, - mlir::PatternRewriter &rewriter) const final - { - auto module = op->getParentOfType(); - auto function_definition = module.lookupSymbol(op.getFunctionName()); - ASSERT(function_definition); - ASSERT(mlir::isa(*function_definition)); - - auto sym_name = rewriter.create(op.getLoc(), - mlir::py::PyObjectType::get(rewriter.getContext()), - rewriter.getStringAttr(op.getFunctionName())); - - auto captures_tuple = [&]() -> mlir::Value { - if (op.getCaptures().empty()) { return nullptr; } - std::vector captures_vec; - for (auto name : op.getCaptures().getValues()) { - captures_vec.push_back(rewriter.create( - op.getLoc(), mlir::py::PyObjectType::get(getContext()), name)); - } - return rewriter.create( - op.getLoc(), mlir::py::PyObjectType::get(getContext()), captures_vec); - }(); - rewriter.replaceOpWithNewOp(op, - mlir::py::PyObjectType::get(rewriter.getContext()), - sym_name, - op.getDefaults(), - op.getKwDefaults(), - captures_tuple); - - return success(); - } - }; - - struct FuncOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::func::FuncOp op, - mlir::PatternRewriter &rewriter) const final - { - if (op.isPrivate()) { return success(); } - // if (op.isPrivate() && op.getSymName() == "__hidden_init__") { return success(); } - populate_arguments(op, rewriter); - - return success(); - } - - void populate_arguments(mlir::func::FuncOp &op, mlir::OpBuilder &builder) const - { - for (size_t i = 0; i < op.getNumArguments(); ++i) { - auto arg_name = op.getArgAttr(i, "llvm.name"); - ASSERT(arg_name); - add_local(op, arg_name.cast().getValue(), builder); - } - } - - void add_local(mlir::func::FuncOp &fn, - mlir::StringRef identifier, - mlir::OpBuilder &builder) const - { - if (fn->hasAttr("locals")) { - auto names = fn->getAttr("locals"); - std::vector names_vec; - auto arr = names.cast().getValue(); - if (std::find_if(arr.begin(), - arr.end(), - [identifier](mlir::Attribute attr) { - return attr.cast().getValue() == identifier; - }) - != arr.end()) { - return; - } - std::transform(arr.begin(), - arr.end(), - std::back_inserter(names_vec), - [](mlir::Attribute attr) { - return attr.cast().getValue(); - }); - names_vec.emplace_back(identifier); - fn->setAttr("locals", builder.getStrArrayAttr(names_vec)); - } else { - fn->setAttr("locals", builder.getStrArrayAttr({ identifier })); - } - } - }; - - struct ClassDefinitionOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::ClassDefinitionOp op, - mlir::PatternRewriter &rewriter) const final - { - auto module = op->getParentOfType(); - rewriter.setInsertionPointToEnd(module.getBody()); - - auto func_type = rewriter.getFunctionType(mlir::TypeRange{}, - mlir::TypeRange{ mlir::py::PyObjectType::get(rewriter.getContext()) }); - auto class_fn_definition = rewriter.create(op.getLoc(), - op.getMangledName(), - func_type, - mlir::ArrayRef{}, - mlir::ArrayRef{}); - - class_fn_definition->setAttr("is_class", rewriter.getBoolAttr(true)); - - if (auto cellvars = op->getAttrOfType("cellvars")) { - auto cell_names = cellvars.getValue(); - if (std::find_if(cell_names.begin(), - cell_names.end(), - [](mlir::Attribute name) { - return name.cast() == "__class__"; - }) - != cell_names.end()) { - - mlir::Operation *return_op_; - op.getBody().walk( - [&return_op_](mlir::Operation *child_op) { - if (mlir::isa( - child_op)) { - return WalkResult::skip(); - } - if (mlir::isa(child_op)) { - return_op_ = child_op; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - ASSERT(return_op_); - auto return_op = mlir::cast(return_op_); - ASSERT(return_op.getOperands().size() == 1); - ASSERT(return_op->getParentOp() == op.getOperation()); - ASSERT(return_op.getOperand(0).getDefiningOp()); - rewriter.setInsertionPoint(return_op.getOperand(0).getDefiningOp()); - rewriter.replaceOpWithNewOp( - return_op.getOperand(0).getDefiningOp(), - mlir::py::PyObjectType::get(getContext()), - mlir::StringRef{ "__class__" }); - } - } - - auto attr = class_fn_definition->getAttrs().vec(); - attr.insert(attr.end(), op->getAttrs().begin(), op->getAttrs().end()); - class_fn_definition->setAttrs(attr); - - auto *end = class_fn_definition.addEntryBlock(); - rewriter.setInsertionPointToStart(end); - rewriter.inlineRegionBefore(op.getBody(), &class_fn_definition.getBody().front()); - rewriter.eraseBlock(end); - - rewriter.setInsertionPoint(op); - auto class_name = rewriter.create(op.getLoc(), - mlir::py::PyObjectType::get(rewriter.getContext()), - rewriter.getStringAttr(op.getMangledName())); - - auto captures_tuple = [&]() -> mlir::Value { - if (op.getCaptures().empty()) { return {}; } - std::vector captures_vec; - for (auto name : op.getCaptures().getValues()) { - captures_vec.push_back(rewriter.create( - op.getLoc(), mlir::py::PyObjectType::get(getContext()), name)); - } - return rewriter.create( - op.getLoc(), mlir::py::PyObjectType::get(getContext()), captures_vec); - }(); - - auto class_fn = rewriter.create(op.getLoc(), - mlir::py::PyObjectType::get(rewriter.getContext()), - class_name, - mlir::ValueRange{}, - mlir::ValueRange{}, - captures_tuple); - - auto class_builder = rewriter.create( - op.getLoc(), mlir::py::PyObjectType::get(rewriter.getContext())); - std::vector args{ class_fn, class_name }; - args.insert(args.end(), op.getBases().begin(), op.getBases().end()); - rewriter.replaceOpWithNewOp(op, - op.getOutput().getType(), - class_builder, - args, - op.getKeywords(), - op.getKwargs(), - false, - false); - - return success(); - } - }; - - struct ForLoopOpLowering : public mlir::OpRewritePattern - { - private: - std::function yield_op_callback( - mlir::PatternRewriter &rewriter, - mlir::Block *condition_start, - mlir::Block *end_block) const - { - return [this, &rewriter, condition_start, end_block](mlir::Operation *operation) { - auto parent_is_orelse = [](mlir::Operation *operation) { - auto forloop_op = operation->getParentOfType(); - if (!forloop_op) { return false; } - return &forloop_op.getOrelse() == operation->getParentRegion(); - }; - // llvm::outs() << "ForOpLowering 1:\n"; - // operation->print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - + // All non-region-bearing patterns now live in per-family files: + // {Arith, AttributeSubscript, Collection, ControlFlow, Function, + // Import, LoadStore}Patterns.cpp, registered via the + // populate*Patterns() entry points below. + // The shared DirectReplaceLowering / DirectReplaceRegisterName + // helpers and add_identifier* utilities moved to LoweringHelpers.hpp. + // What remains in this file are the four region-bearing structural + // patterns (ForLoop / While / Try / With), the + // PythonToPythonBytecodePass, and the dedicated single-pattern + // passes that wrap the structural ones. + + // Shared walker used by both ForLoopOpLowering and WhileOpLowering + // to lower py.br_yield ops nested inside a loop body to cf.br ops + // that target the right block (continue→condition / step, break→ + // end). Nested loops are walked into their orelse regions only; + // the loop body itself is skipped because the nested loop will be + // lowered by its own pattern. + // + // `skip_op` allows a caller to short-circuit on yield ops whose + // enclosing loop matches a specific predicate — ForLoopOpLowering + // uses this to ignore yields that bind to the *outer* for-loop's + // orelse (which shouldn't be lowered as part of the inner loop + // pass). + void replace_loop_branch_yields(mlir::PatternRewriter &rewriter, + mlir::Region ®ion, + mlir::Block *continue_target, + mlir::Block *break_target, + llvm::function_ref skip_op) + { + std::function callback = + [&rewriter, continue_target, break_target, skip_op, &callback]( + mlir::Operation *operation) { if (auto loop = mlir::dyn_cast(operation)) { if (loop.getOrelse().empty()) { return WalkResult::skip(); } - // llvm::outs() << "ForOpLowering - ForLoopOp or else\n"; - // loop.getOrelse().front().print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - loop.getOrelse().walk( - yield_op_callback(rewriter, condition_start, end_block)); + loop.getOrelse().walk(callback); return WalkResult::skip(); } if (auto loop = mlir::dyn_cast(operation)) { if (loop.getOrelse().empty()) { return WalkResult::skip(); } - // llvm::outs() << "ForOpLowering - WhileOp or else\n"; - loop.getOrelse().walk( - yield_op_callback(rewriter, condition_start, end_block)); + loop.getOrelse().walk(callback); return WalkResult::skip(); } - - // llvm::outs() << "ForOpLowering 2:\n"; - // operation->print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - - if (auto yield_op = mlir::dyn_cast(operation)) { - static_assert(mlir::py::ControlFlowYield::hasTrait:: - Impl>()); - if (!yield_op.getKind().has_value() - && mlir::isa(yield_op->getParentOp())) { - return WalkResult::advance(); - } - // is this hacky? maybe there is a better way of ignoring the else branch of - // a for loop - if (parent_is_orelse(operation)) { return WalkResult::advance(); } - rewriter.setInsertionPoint(yield_op); - if (!yield_op.getKind().has_value() - || yield_op.getKind().value() == py::LoopOpKind::continue_) { - rewriter.replaceOpWithNewOp( - yield_op, condition_start); - } else if (yield_op.getKind().value() == py::LoopOpKind::break_) { - rewriter.replaceOpWithNewOp(yield_op, end_block); - } + auto yield_op = mlir::dyn_cast(operation); + if (!yield_op) { return WalkResult::advance(); } + static_assert(mlir::py::BranchYieldOp::hasTrait::Impl>()); + // Kindless yields under try/with/try-handler don't + // participate in the loop's continue/break flow. + if (!yield_op.getKind().has_value() + && mlir::isa(yield_op->getParentOp())) { + return WalkResult::advance(); + } + if (skip_op && skip_op(yield_op)) { return WalkResult::advance(); } + rewriter.setInsertionPoint(yield_op); + if (!yield_op.getKind().has_value() + || yield_op.getKind().value() == py::LoopOpKind::continue_) { + rewriter.replaceOpWithNewOp(yield_op, continue_target); + } else if (yield_op.getKind().value() == py::LoopOpKind::break_) { + rewriter.replaceOpWithNewOp(yield_op, break_target); } return WalkResult::advance(); }; - } - - std::vector getIterators(mlir::py::ForLoopOp op, - mlir::emitpybytecode::GetIter current_iterator) const - { - std::vector iterators; - - iterators.push_back(current_iterator); - - auto parent = op->getParentOfType(); - while (parent) { - auto iterable = parent.getIterable(); - ASSERT(!iterable.getUsers().empty()); - auto iterator = *iterable.getUsers().begin(); - ASSERT(mlir::isa(*iterator)); - iterators.insert( - iterators.end() - 1, mlir::cast(*iterator)); - parent = parent->getParentOfType(); - } + region.walk(callback); + } - return iterators; - } - - public: + struct ForLoopOpLowering : public mlir::OpRewritePattern + { using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::py::ForLoopOp op, mlir::PatternRewriter &rewriter) const final { - // llvm::outs() << "ForLoopOp rewrite:\n"; - // op->print(llvm::outs()); - // llvm::outs() << "-------------------------\n"; - // llvm::outs().flush(); - auto *initBlock = rewriter.getInsertionBlock(); auto initPos = rewriter.getInsertionPoint(); @@ -1373,37 +123,28 @@ namespace py { // advance iterator auto iterator_next_block = rewriter.createBlock(endBlock); - // iterator_next_block->addArgument(iterator.getType(), op.getStep().getLoc()); rewriter.setInsertionPointToEnd(initBlock); - const auto &iterators = getIterators(op, iterator); rewriter.create(op.getStep().getLoc(), iterator_next_block); rewriter.setInsertionPointToStart(iterator_next_block); rewriter.create(op.getStep().getLoc(), - // iterator_next_block->getArgument(0), - iterators.front(), + iterator, &op.getStep().front(), op.getOrelse().empty() ? endBlock : &op.getOrelse().front()); ASSERT(!op.getStep().empty()); auto *iterator_exit_block = &op.getStep().back(); ASSERT(iterator_exit_block->getTerminator()); - // iterator_exit_block->print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - ASSERT(mlir::isa(iterator_exit_block->getTerminator())); + ASSERT(mlir::isa(iterator_exit_block->getTerminator())); rewriter.setInsertionPointToEnd(iterator_exit_block); rewriter.replaceOpWithNewOp( - iterator_exit_block->getTerminator(), &op.getBody().front() /*, iterators*/); + iterator_exit_block->getTerminator(), &op.getBody().front()); auto *for_iter_block = rewriter.createBlock(&op.getBody()); - // for (const auto &it : iterators) { - // for_iter_block->addArgument(it.getType(), op.getStep().getLoc()); - // } rewriter.create(op.getStep().getLoc(), - iterators.front(), + iterator, &op.getStep().front(), op.getOrelse().empty() ? endBlock : &op.getOrelse().front()); @@ -1411,12 +152,15 @@ namespace py { rewriter.inlineRegionBefore( op.getStep(), *op->getParentRegion(), endBlock->getIterator()); - // for (const auto &it : iterators) { - // op.getBody().addArgument(it.getType(), op.getStep().getLoc()); - // } - - op.getBody().walk( - yield_op_callback(rewriter, for_iter_block, endBlock)); + // Skip yields whose enclosing for-loop sits inside an + // outer for-loop's orelse — those belong to the outer + // pattern's rewrite, not this one. + auto skip_orelse_yields = [](mlir::py::BranchYieldOp y) { + auto forloop_op = y->getParentOfType(); + return forloop_op && &forloop_op.getOrelse() == y->getParentRegion(); + }; + replace_loop_branch_yields( + rewriter, op.getBody(), for_iter_block, endBlock, skip_orelse_yields); ASSERT(!op.getBody().empty()); auto *body_exit_block = &op.getBody().back(); @@ -1427,7 +171,7 @@ namespace py { if (!op.getOrelse().empty()) { auto *orelse_exit_block = &op.getOrelse().back(); ASSERT(orelse_exit_block->getTerminator()); - if (mlir::isa(orelse_exit_block->getTerminator())) { + if (mlir::isa(orelse_exit_block->getTerminator())) { rewriter.setInsertionPointToEnd(orelse_exit_block); rewriter.replaceOpWithNewOp( orelse_exit_block->getTerminator(), endBlock); @@ -1438,8 +182,6 @@ namespace py { rewriter.eraseOp(op); - // llvm::outs() << "ForLoopOp rewrite end\n"; - // llvm::outs().flush(); return success(); } @@ -1447,57 +189,6 @@ namespace py { struct WhileOpLowering : public mlir::OpRewritePattern { - private: - std::function yield_op_callback( - mlir::PatternRewriter &rewriter, - mlir::Block *condition_start, - mlir::Block *end_block) const - { - return [this, &rewriter, condition_start, end_block](mlir::Operation *operation) { - // llvm::outs() << "WhileOpLowering 1:\n"; - // operation->print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - if (auto loop = mlir::dyn_cast(operation)) { - if (loop.getOrelse().empty()) { return WalkResult::skip(); } - // llvm::outs() << "WhileOpLowering - ForLoopOp or else\n"; - loop.getOrelse().walk( - yield_op_callback(rewriter, condition_start, end_block)); - return WalkResult::skip(); - } - if (auto loop = mlir::dyn_cast(operation)) { - if (loop.getOrelse().empty()) { return WalkResult::skip(); } - // llvm::outs() << "WhileOpLowering - WhileOp or else\n"; - loop.getOrelse().walk( - yield_op_callback(rewriter, condition_start, end_block)); - return WalkResult::skip(); - } - // llvm::outs() << "WhileOpLowering 2:\n"; - // operation->print(llvm::outs()); - // llvm::outs() << '\n'; - // llvm::outs().flush(); - if (auto yield_op = mlir::dyn_cast(operation)) { - static_assert(mlir::py::ControlFlowYield::hasTrait:: - Impl>()); - if (!yield_op.getKind().has_value() - && mlir::isa(yield_op->getParentOp())) { - return WalkResult::advance(); - } - rewriter.setInsertionPoint(yield_op); - if (!yield_op.getKind().has_value() - || yield_op.getKind().value() == py::LoopOpKind::continue_) { - rewriter.replaceOpWithNewOp( - yield_op, condition_start); - } else if (yield_op.getKind().value() == py::LoopOpKind::break_) { - rewriter.replaceOpWithNewOp(yield_op, end_block); - } - } - return WalkResult::advance(); - }; - } - - public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(mlir::py::WhileOp op, @@ -1535,15 +226,18 @@ namespace py { rewriter.eraseOp(condition_op); rewriter.inlineRegionBefore(condition, endBlock); - op.getBody().walk( - yield_op_callback(rewriter, &condition_start, endBlock)); + replace_loop_branch_yields(rewriter, + op.getBody(), + &condition_start, + endBlock, + /*skip_op=*/{}); rewriter.inlineRegionBefore(op.getBody(), endBlock); // if (!op.getOrelse().empty()) { // auto *orelse_exit_block = &op.getOrelse().back(); // ASSERT(orelse_exit_block->getTerminator()); - // if (mlir::isa(orelse_exit_block->getTerminator())) { + // if (mlir::isa(orelse_exit_block->getTerminator())) { // rewriter.setInsertionPointToEnd(orelse_exit_block); // rewriter.replaceOpWithNewOp( // orelse_exit_block->getTerminator(), endBlock); @@ -1566,17 +260,17 @@ namespace py { { if (region.empty()) { return; } region.walk([callback](mlir::Operation *childOp) { - static_assert(mlir::py::ControlFlowYield::hasTrait::Impl>()); + static_assert(mlir::py::BranchYieldOp::hasTrait::Impl>()); if (mlir::isa(childOp) || mlir::isa(childOp) || mlir::isa(childOp) || mlir::isa(childOp) - || mlir::isa(childOp)) { + || mlir::isa(childOp)) { return WalkResult::skip(); } - if (mlir::isa(childOp) - && !mlir::cast(childOp).getKind().has_value()) { + if (mlir::isa(childOp) + && !mlir::cast(childOp).getKind().has_value()) { callback(childOp); return WalkResult::skip(); } @@ -1655,7 +349,7 @@ namespace py { auto &handler = op.getHandlers().front(); ASSERT(handler.getBlocks().size() == 1); auto handler_scope = - mlir::cast(handler.front().getTerminator()); + mlir::cast(handler.front().getTerminator()); ASSERT(handler_scope); rewriter.create(op.getLoc(), body_start, @@ -1674,7 +368,7 @@ namespace py { ASSERT(handler.getBlocks().size() == 1); auto handler_scope = - mlir::cast(handler.front().getTerminator()); + mlir::cast(handler.front().getTerminator()); ASSERT(handler_scope); if (!handler_scope.getCond().empty()) { @@ -1684,7 +378,7 @@ namespace py { rewriter.setInsertionPoint(cond); auto &next_handler = op.getHandlers()[idx + 1]; ASSERT(next_handler.getBlocks().size() == 1); - auto next_handler_scope = mlir::cast( + auto next_handler_scope = mlir::cast( next_handler.front().getTerminator()); ASSERT(next_handler_scope); @@ -1721,7 +415,7 @@ namespace py { auto &handler = op.getHandlers().back(); ASSERT(handler.getBlocks().size() == 1); auto handler_scope = - mlir::cast(handler.front().getTerminator()); + mlir::cast(handler.front().getTerminator()); ASSERT(handler_scope); if (!handler_scope.getCond().empty()) { auto cond = mlir::cast( @@ -1805,13 +499,13 @@ namespace py { op.getBody().walk([&rewriter, exit_block, cleanup_block]( mlir::Operation *childOp) { - static_assert(mlir::py::ControlFlowYield::hasTrait::Impl>()); + static_assert(mlir::py::BranchYieldOp::hasTrait::Impl>()); if (mlir::isa(childOp) || mlir::isa(childOp) || mlir::isa(childOp) || mlir::isa(childOp) - || mlir::isa(childOp)) { + || mlir::isa(childOp)) { return WalkResult::skip(); } if (auto op = mlir::dyn_cast(childOp)) { @@ -1826,7 +520,7 @@ namespace py { rewriter.replaceOpWithNewOp( op, BlockRange{ cleanup_block }); } - } else if (auto op = mlir::dyn_cast(childOp); + } else if (auto op = mlir::dyn_cast(childOp); op && !op.getKind().has_value()) { auto *current = op->getBlock(); auto *next = rewriter.splitBlock(current, op->getIterator()); @@ -1840,6 +534,15 @@ namespace py { rewriter.inlineRegionBefore(op.getBody(), endBlock); + // Multi-item with-statements (with a, b, c: ...) are not + // yet supported end-to-end: MLIRGenerator currently TODOs + // out for items().size() > 1, so the dialect op only ever + // arrives here with a single item. The loops below over + // op.getItems() exist for shape symmetry with the future + // multi-item version but bail explicitly until that work + // lands. + ASSERT(op.getItems().size() == 1 + && "WithOp lowering does not yet support multiple context managers"); rewriter.setInsertionPointToStart(cleanup_block); for (const auto &item : op.getItems()) { auto exit = rewriter.create(item.getLoc(), @@ -1862,12 +565,10 @@ namespace py { rewriter.setInsertionPointToStart(reraise_block); rewriter.create(item.getLoc(), endBlock); - // TODO: handle multiple handlers rewriter.setInsertionPointToStart(continue_block); rewriter.create(item.getLoc()); rewriter.create(op.getLoc(), endBlock); } - // rewriter.create(op.getLoc(), endBlock); rewriter.setInsertionPointToStart(exit_block); for (const auto &item : op.getItems()) { @@ -1906,331 +607,186 @@ namespace py { } }; - struct WithExceptStartOpLowering - : public mlir::OpRewritePattern + + struct PythonToPythonBytecodePass + : public PassWrapper> { - using OpRewritePattern::OpRewritePattern; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonToPythonBytecodePass) - mlir::LogicalResult matchAndRewrite(mlir::py::WithExceptStartOp op, - mlir::PatternRewriter &rewriter) const final + void getDependentDialects(DialectRegistry ®istry) const override { - rewriter.replaceOpWithNewOp( - op, op.getOutput().getType(), op.getExitMethod()); - return success(); + registry.insert(); } - }; - struct ClearExceptionStateOpLowering - : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; + StringRef getArgument() const final { return "python-to-pythonbytecode"; } - mlir::LogicalResult matchAndRewrite(mlir::py::ClearExceptionStateOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op); - return success(); - } + void runOnOperation() final; }; - struct RaiseOpLowering : public mlir::OpRewritePattern + // Pass scaffolds for the four region-bearing control-flow ops. + // Each pass applies its single lowering pattern greedily on the + // module. Dialect dependencies match PythonToPythonBytecodePass's + // (Python source dialect + EmitPythonBytecode target dialect); the + // patterns also create cf::BranchOp / func::FuncOp internally, but + // those dialects are already loaded by the time the pipeline runs. + template + struct SinglePatternConversionPass : public PassWrapper> { - using OpRewritePattern::OpRewritePattern; - - /// Find the first parent operation of the given type, or nullptr if there is - /// no ancestor operation. - template - static mlir::Operation *getParentOfType(mlir::Region *region) + void getDependentDialects(DialectRegistry ®istry) const override { - do { - if ((... || mlir::isa(*region->getParentOp()))) - return region->getParentOp(); - } while ((region = region->getParentRegion())); - return nullptr; + registry.insert(); } - static mlir::Block *get_handler(mlir::Operation *op, mlir::PatternRewriter &rewriter) - { - // find possible catch block in order to not clobber an active result register - auto *handler_op = - getParentOfType( - op->getParentRegion()); - ASSERT(handler_op); - return llvm::TypeSwitch(handler_op) - .Case([](mlir::py::TryOp op) { - return op.getHandlers().empty() ? &op.getFinally().front() - : &op.getHandlers().front().front(); - }) - .Case([](mlir::py::WithOp op) { return op->getParentOp()->getBlock(); }) - .Case([&rewriter](mlir::func::FuncOp op) { - auto insertion_point = rewriter.getInsertionPoint(); - auto *return_block = rewriter.createBlock(&op.getRegion()); - auto value = rewriter.create( - op.getLoc(), rewriter.getNoneType()); - rewriter.create( - op.getLoc(), mlir::ValueRange{ value }); - rewriter.setInsertionPoint(insertion_point->getBlock(), insertion_point); - return return_block; - }) - .Default([](mlir::Operation *op) { - TODO(); - return nullptr; - }); - } + StringRef getArgument() const final { return Argument; } - mlir::LogicalResult matchAndRewrite(mlir::py::RaiseOp op, - mlir::PatternRewriter &rewriter) const final + void runOnOperation() final { - if (auto exception = op.getException()) { - rewriter.replaceOpWithNewOp( - op, exception, op.getCause(), get_handler(op, rewriter)); - } else { - rewriter.replaceOpWithNewOp( - op, get_handler(op, rewriter)); - } - - return success(); - } - }; - - struct ImportOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; + mlir::RewritePatternSet patterns(&this->getContext()); + patterns.template add(&this->getContext()); - mlir::LogicalResult matchAndRewrite(mlir::py::ImportOp op, - mlir::PatternRewriter &rewriter) const final - { - auto name = op.getName(); - auto level = rewriter.create(op.getLoc(), - op.getModule().getType(), - rewriter.getUI32IntegerAttr(op.getLevel())); - std::vector els; - for (auto from : op.getFromList().getValues()) { - els.push_back(rewriter.create( - op.getLoc(), op.getModule().getType(), rewriter.getStringAttr(from))); - } - auto from_list = rewriter.create( - op.getLoc(), op.getModule().getType(), els); - rewriter.replaceOpWithNewOp( - op, op.getModule().getType(), name, level, from_list); + GreedyRewriteConfig config; + config.setStrictness(GreedyRewriteStrictness::AnyOp); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Normal); + config.setUseTopDownTraversal(true); + FrozenRewritePatternSet frozen{ std::move(patterns) }; - return success(); + (void)applyPatternsGreedily(this->getOperation(), frozen, config); } }; - struct ImportAllOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::ImportAllOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp(op, op.getModule()); - - return success(); - } - }; + inline constexpr char kConvertForLoopArg[] = "convert-py-forloop"; + inline constexpr char kConvertWhileLoopArg[] = "convert-py-while"; + inline constexpr char kConvertTryArg[] = "convert-py-try"; + inline constexpr char kConvertWithArg[] = "convert-py-with"; - struct ImportFromOpLowering : public mlir::OpRewritePattern + struct ConvertForLoopPass + : public SinglePatternConversionPass { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::ImportFromOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getModule().getType(), op.getModule(), op.getName()); - - return success(); - } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertForLoopPass) }; - struct CastToBoolOpLowering : public mlir::OpRewritePattern + struct ConvertWhileLoopPass + : public SinglePatternConversionPass { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::CastToBoolOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getValue().getType(), op.getValue()); - return success(); - } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertWhileLoopPass) }; - struct YieldOpLowering : public mlir::OpRewritePattern + struct ConvertTryPass + : public SinglePatternConversionPass { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::YieldOp op, - mlir::PatternRewriter &rewriter) const final - { - rewriter.replaceOpWithNewOp( - op, op.getValue().getType(), op.getValue()); - return success(); - } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertTryPass) }; - struct YieldFromOpLowering : public mlir::OpRewritePattern + struct ConvertWithPass + : public SinglePatternConversionPass { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite(mlir::py::YieldFromOp op, - mlir::PatternRewriter &rewriter) const final - { - auto iterator = rewriter.create( - op.getLoc(), op.getIterable().getType(), op.getIterable()); - auto value = - rewriter.create(op.getLoc(), rewriter.getNoneType()); - - rewriter.replaceOpWithNewOp( - op, iterator.getType(), iterator, value); - - return success(); - } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertWithPass) }; - struct UnpackSequenceOpLowering : public mlir::OpRewritePattern + // Pattern: rewrite a zero-operand func.return by inserting a None + // constant operand and, if RemoveDeadValues also rewrote the + // parent FuncOp's signature to return nothing, restoring its + // declared result type to PyObjectType. The bytecode emitter + // assumes every function returns a value (Python's "every + // function returns at minimum None") regardless of whether MLIR + // sees the result as used. + // + // Reaches for emitpybytecode::ConstantOp because the pass runs + // after PythonToPythonBytecodePass has already lowered + // py.constant; using py.constant here would re-introduce an + // illegal source-dialect op into the lowered IR. + struct MaterialiseReturnNonePattern : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite(mlir::py::UnpackSequenceOp op, + mlir::LogicalResult matchAndRewrite(mlir::func::ReturnOp op, mlir::PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp( - op, op.getUnpackedValues().getType(), op.getIterable()); - return success(); + if (op.getNumOperands() != 0) { return mlir::failure(); } + auto parent = op->getParentOfType(); + if (!parent) { return mlir::failure(); } + auto pyobject_ty = mlir::py::PyObjectType::get(rewriter.getContext()); + rewriter.setInsertionPoint(op); + auto none = rewriter.create( + op.getLoc(), pyobject_ty, rewriter.getUnitAttr()); + rewriter.replaceOpWithNewOp(op, mlir::ValueRange{ none }); + + // Restore the function signature if RemoveDeadValues stripped + // the result type. Plain assignment to the function-type + // attribute is fine here because the parent op's properties + // aren't tracked by the pattern rewriter's mutation tracking + // (we already produced a successful match-and-rewrite via + // replaceOpWithNewOp above). + if (parent.getFunctionType().getNumResults() == 0) { + auto fn_ty = parent.getFunctionType(); + parent.setFunctionType(rewriter.getFunctionType( + fn_ty.getInputs(), mlir::TypeRange{ pyobject_ty })); + } + return mlir::success(); } }; - struct UnpackExpandOpLowering : public mlir::OpRewritePattern + struct MaterialiseReturnNonePass + : public PassWrapper> { - using OpRewritePattern::OpRewritePattern; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterialiseReturnNonePass) - mlir::LogicalResult matchAndRewrite(mlir::py::UnpackExpandOp op, - mlir::PatternRewriter &rewriter) const final + void getDependentDialects(DialectRegistry ®istry) const override { - rewriter.replaceOpWithNewOp( - op, op.getUnpackedValues().getType(), op.getRest().getType(), op.getIterable()); - return success(); + registry.insert(); } - }; - struct GetAwaitableOpLowering : public mlir::OpRewritePattern - { - using OpRewritePattern::OpRewritePattern; + StringRef getArgument() const final { return "materialise-return-none"; } - mlir::LogicalResult matchAndRewrite(mlir::py::GetAwaitableOp op, - mlir::PatternRewriter &rewriter) const final + void runOnOperation() final { - rewriter.replaceOpWithNewOp( - op, op.getIterator().getType(), op.getIterable()); - return success(); - } - }; + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); - struct PythonToPythonBytecodePass - : public PassWrapper> - { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PythonToPythonBytecodePass) + GreedyRewriteConfig config; + config.setStrictness(GreedyRewriteStrictness::AnyOp); + FrozenRewritePatternSet frozen{ std::move(patterns) }; - void getDependentDialects(DialectRegistry ®istry) const override - { - registry.insert(); + (void)applyPatternsGreedily(getOperation(), frozen, config); } - - StringRef getArgument() const final { return "python-to-pythonbytecode"; } - - void runOnOperation() final; }; }// namespace void PythonToPythonBytecodePass::runOnOperation() { - ConversionTarget target(getContext()); - target.addLegalDialect(); - - target.addLegalOp(); - target.addLegalOp(); - target.addDynamicallyLegalOp([](mlir::func::FuncOp op) { - // don't convert this special function, which is the entry point of a module - return op.isPrivate() && op.getSymName() == "__hidden_init__"; - }); - target.addIllegalDialect(); - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add( - &getContext()); - patterns - .add( - &getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns - .add( - &getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add( - &getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); + populateArithPatterns(patterns); + populateAttributeSubscriptPatterns(patterns); + populateCollectionPatterns(patterns); + populateControlFlowPatterns(patterns); + populateFunctionPatterns(patterns); + populateImportPatterns(patterns); + populateLoadStorePatterns(patterns); + // ForLoop / While / Try / With lowerings remain in this file but + // run in dedicated passes (ConvertPyForLoop / While / Try / With) + // ahead of this monolithic conversion pass, so canonicalize / CSE + // can simplify between their structural rewrites. GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::AnyOp; - config.enableRegionSimplification = GreedySimplifyRegionLevel::Disabled; - config.useTopDownTraversal = true; + config.setStrictness(GreedyRewriteStrictness::AnyOp); + config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Normal); + config.setUseTopDownTraversal(true); FrozenRewritePatternSet frozen_patterns{ std::move(patterns) }; - // getOperation()->print(llvm::outs()); - // llvm::outs() << "-----------------------------------------------\n\n\n"; - // llvm::outs().flush(); - - // Currently ignoring the return value as it seems to always fail, even though the - // transformation seems to generate the expected output - (void)applyPatternsAndFoldGreedily(getOperation(), frozen_patterns, config); + // applyPatternsGreedily returns failure() when the driver hits + // its iteration limit without reaching a fixed point. The + // remaining work is to figure out which pattern keeps firing + // (likely one that always replaces-with-itself in some edge + // case) and either fix it or change the pass to use full + // dialect conversion. For now the IR is verified after the + // pass runs (PassManager's default), so a real failure would + // surface there; treating the rewriter's return as + // signalPassFailure() would be a false positive today. + (void)applyPatternsGreedily(getOperation(), frozen_patterns, config); } std::unique_ptr createPythonToPythonBytecodePass() @@ -2238,5 +794,24 @@ namespace py { return std::make_unique(); } + std::unique_ptr createConvertForLoopPass() + { + return std::make_unique(); + } + + std::unique_ptr createConvertWhileLoopPass() + { + return std::make_unique(); + } + + std::unique_ptr createConvertTryPass() { return std::make_unique(); } + + std::unique_ptr createConvertWithPass() { return std::make_unique(); } + + std::unique_ptr createMaterialiseReturnNonePass() + { + return std::make_unique(); + } + }// namespace py }// namespace mlir \ No newline at end of file diff --git a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.hpp b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.hpp index ee25d161..432c4de8 100644 --- a/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.hpp +++ b/src/executable/mlir/Conversion/PythonToPythonBytecode/PythonToPythonBytecode.hpp @@ -8,6 +8,23 @@ class Pass; namespace py { std::unique_ptr createPythonToPythonBytecodePass(); -} + + // Dedicated passes for the four region-bearing control-flow ops. + // Each runs only its own lowering pattern and is meant to slot into + // the pipeline ahead of the monolithic conversion pass, so that + // canonicalize / CSE can be interleaved between them. Plan step 18. + std::unique_ptr createConvertForLoopPass(); + std::unique_ptr createConvertWhileLoopPass(); + std::unique_ptr createConvertTryPass(); + std::unique_ptr createConvertWithPass(); + + // Materialise an emitpybytecode constant None as the operand of any + // zero-operand func.return inside a func.func whose result type is + // non-empty. Used as a follow-up to mlir::createRemoveDeadValuesPass, + // which can strip the return's operand when its producer becomes + // dead, leaving zero-operand returns that violate the bytecode + // emitter's exactly-one-operand invariant. + std::unique_ptr createMaterialiseReturnNonePass(); +}// namespace py }// namespace mlir diff --git a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.cpp b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.cpp index 6764139a..b4d3f8c1 100644 --- a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.cpp +++ b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.cpp @@ -5,8 +5,10 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "EmitPythonBytecode/IR/EmitPythonBytecodeDialect.cpp.inc" @@ -42,6 +44,108 @@ namespace emitpybytecode { } return SuccessorOperands(0, mlir::MutableOperandRange{ getOperation(), 0, 0 }); } + + namespace { + // Register-pressure relief for large dict literals. A single + // BUILD_DICT consuming N keys + N values forces 2N live values + // to coexist in registers immediately before the call; the + // linear-scan allocator handles that by spilling, which is + // expensive (no stack slots — only register-register moves + + // Push/Pop). For literals with more than ~10 entries, splitting + // into an empty BUILD_DICT followed by streamed DICT_ADD ops + // (each emitted next to its value's defining op) drops the + // live-value count to 3 (dict, key, value). The threshold is + // empirical; a future register-pressure-aware allocation pass + // (plan step 19) would let us pick this dynamically. + struct ExpandLargeBuildDict : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(BuildDict op, + mlir::PatternRewriter &rewriter) const final + { + if (op.getValues().size() <= 10) { return mlir::failure(); } + auto keys = op.getKeys(); + auto values = op.getValues(); + rewriter.setInsertionPointAfterValue(keys.front()); + auto result = rewriter.create( + op->getLoc(), op.getOutput().getType(), mlir::ValueRange{}, mlir::ValueRange{}); + + for (auto [key, value] : llvm::zip(keys, values)) { + rewriter.setInsertionPointAfterValue(value); + rewriter.create(op.getLoc(), result, key, value); + } + rewriter.replaceOp(op, result); + return mlir::success(); + } + }; + }// namespace + + void BuildDict::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) + { + patterns.add(context); + } + + namespace { + // Register-pressure relief for all-constants list literals + // (`[1, 2, "x"]`). A direct BUILD_LIST with N operands holds N + // live values until the op runs; rewriting to an empty + // BUILD_LIST + a single LOAD_CONST tuple + LIST_EXTEND keeps the + // live-value count at 2 (list, tuple). Same motivation and same + // "step 19 will subsume this" caveat as ExpandLargeBuildDict. + struct FoldAllConstBuildListIntoExtend : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(BuildList op, + mlir::PatternRewriter &rewriter) const final + { + auto elements_ops = op.getElements(); + if (elements_ops.empty()) { return mlir::failure(); } + std::vector elements; + elements.reserve(elements_ops.size()); + for (auto el : elements_ops) { + auto k = el.getDefiningOp(); + if (!k) { return mlir::failure(); } + elements.push_back(k.getValue()); + } + auto loc = op.getLoc(); + auto output_type = op.getOutput().getType(); + auto list = rewriter.create(loc, output_type, mlir::ValueRange{}); + auto tuple = rewriter.create( + loc, output_type, mlir::ArrayAttr::get(getContext(), elements)); + rewriter.create(loc, list, tuple); + rewriter.replaceOp(op, list); + return mlir::success(); + } + }; + }// namespace + + void BuildList::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) + { + patterns.add(context); + } + + mlir::LogicalResult ConstantOp::verify() + { + mlir::Attribute attr = getValue(); + // Accepted attribute kinds match the TypeSwitch in + // PythonBytecodeEmitter::emitOperation(emitpybytecode::ConstantOp). + // EllipsisAttr is not in the list - the conversion pass lowers it to + // LoadEllipsisOp instead. + if (mlir::isa(attr)) { + return mlir::success(); + } + return emitOpError() << "value attribute has unsupported kind: " << attr; + } }// namespace emitpybytecode }// namespace mlir diff --git a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.td b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.td index 12eab8bc..4f5c1e6a 100644 --- a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.td +++ b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.td @@ -1,98 +1,116 @@ include "EmitPythonBytecodeBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "Python/IR/PythonTypes.td" // Base class for EmitPythonBytecode dialect ops. class EmitPythonBytecode_Op traits = []> : Op; -def EmitPythonBytecode_ConstantOp: EmitPythonBytecode_Op<"LOAD_CONST"> { +def EmitPythonBytecode_ConstantOp: EmitPythonBytecode_Op<"LOAD_CONST", [Pure]> { let arguments = (ins AnyAttr:$value); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; } -def EmitPythonBytecode_LoadEllipsisOp: EmitPythonBytecode_Op<"LOAD_ELLIPSIS"> { +def EmitPythonBytecode_LoadEllipsisOp: EmitPythonBytecode_Op<"LOAD_ELLIPSIS", [Pure]> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadNameOp: EmitPythonBytecode_Op<"LOAD_NAME"> { +// LOAD_* ops resolve a name in the local / module / global / cell +// environment — they're *readers* of the binding state. Marking them +// MemRead on DefaultResource is what hooks them up to RDV's per- +// resource liveness analysis: a STORE_NAME's MemWrite to the same +// resource is then kept alive because a reachable reader exists. +// (The matching may-raise NameError / UnboundLocalError side effect +// is modelled separately on the source dialect via the +// PythonExceptionStateResource; this dialect runs post-conversion +// where the bytecode emitter handles exception state via its own +// resource-less instructions.) +def EmitPythonBytecode_LoadNameOp: EmitPythonBytecode_Op<"LOAD_NAME", + [MemoryEffects<[MemRead]>]> { let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadFastOp: EmitPythonBytecode_Op<"LOAD_FAST"> { +def EmitPythonBytecode_LoadFastOp: EmitPythonBytecode_Op<"LOAD_FAST", + [MemoryEffects<[MemRead]>]> { let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadGlobalOp: EmitPythonBytecode_Op<"LOAD_GLOBAL"> { +def EmitPythonBytecode_LoadGlobalOp: EmitPythonBytecode_Op<"LOAD_GLOBAL", + [MemoryEffects<[MemRead]>]> { let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadDerefOp: EmitPythonBytecode_Op<"LOAD_DEREF"> { +def EmitPythonBytecode_LoadDerefOp: EmitPythonBytecode_Op<"LOAD_DEREF", + [MemoryEffects<[MemRead]>]> { let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadClosureOp: EmitPythonBytecode_Op<"LOAD_CLOSURE"> { +def EmitPythonBytecode_LoadClosureOp: EmitPythonBytecode_Op<"LOAD_CLOSURE", + [MemoryEffects<[MemRead]>]> { let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_StoreNameOp: EmitPythonBytecode_Op<"STORE_NAME"> { +def EmitPythonBytecode_StoreNameOp: EmitPythonBytecode_Op<"STORE_NAME", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name, Python_PyObjectType:$object); - - let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_StoreGlobalOp: EmitPythonBytecode_Op<"STORE_GLOBAL"> { +def EmitPythonBytecode_StoreGlobalOp: EmitPythonBytecode_Op<"STORE_GLOBAL", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name, Python_PyObjectType:$object); - - let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_StoreFastOp: EmitPythonBytecode_Op<"STORE_FAST"> { +def EmitPythonBytecode_StoreFastOp: EmitPythonBytecode_Op<"STORE_FAST", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name, Python_PyObjectType:$object); - - let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_StoreDerefOp: EmitPythonBytecode_Op<"STORE_DEREF"> { +def EmitPythonBytecode_StoreDerefOp: EmitPythonBytecode_Op<"STORE_DEREF", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name, Python_PyObjectType:$object); - - let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_DeleteNameOp: EmitPythonBytecode_Op<"DELETE_NAME"> { +def EmitPythonBytecode_DeleteNameOp: EmitPythonBytecode_Op<"DELETE_NAME", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name); } -def EmitPythonBytecode_DeleteFastOp: EmitPythonBytecode_Op<"DELETE_FAST"> { +def EmitPythonBytecode_DeleteFastOp: EmitPythonBytecode_Op<"DELETE_FAST", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name); } -def EmitPythonBytecode_DeleteGlobalOp: EmitPythonBytecode_Op<"DELETE_GLOBAL"> { +def EmitPythonBytecode_DeleteGlobalOp: EmitPythonBytecode_Op<"DELETE_GLOBAL", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name); } -def EmitPythonBytecode_DeleteDerefOp: EmitPythonBytecode_Op<"DELETE_DEREF"> { +def EmitPythonBytecode_DeleteDerefOp: EmitPythonBytecode_Op<"DELETE_DEREF", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$name); } -def EmitPythonBytecode_FunctionCallOp: EmitPythonBytecode_Op<"CALL"> { +// CALL / CALL_EXPAND / CALL_KW / BINARY_OP / INPLACE_OP / UNARY / +// COMPARE / TO_BOOL / FORMAT_VALUE all dispatch to user-overridable +// dunders or run arbitrary user code. MemWrite is the conservative +// effect: it keeps RemoveDeadValues from stripping a call whose +// result happens to be unused (the user code's side effects are +// still observable). +def EmitPythonBytecode_FunctionCallOp: EmitPythonBytecode_Op<"CALL", + [MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$callee, Variadic:$args); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_FunctionCallExOp: EmitPythonBytecode_Op<"CALL_EXPAND", [AttrSizedOperandSegments]> { +def EmitPythonBytecode_FunctionCallExOp: EmitPythonBytecode_Op<"CALL_EXPAND", + [AttrSizedOperandSegments, MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$callee, Optional:$args, Optional:$kwargs); @@ -100,7 +118,8 @@ def EmitPythonBytecode_FunctionCallExOp: EmitPythonBytecode_Op<"CALL_EXPAND", [A let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_FunctionCallWithKeywordsOp: EmitPythonBytecode_Op<"CALL_KW", [AttrSizedOperandSegments]> { +def EmitPythonBytecode_FunctionCallWithKeywordsOp: EmitPythonBytecode_Op<"CALL_KW", + [AttrSizedOperandSegments, MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$callee, Variadic:$args, Builtin_DenseStringElementsAttr:$keywords, @@ -109,13 +128,15 @@ def EmitPythonBytecode_FunctionCallWithKeywordsOp: EmitPythonBytecode_Op<"CALL_K let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_BinaryOp: EmitPythonBytecode_Op<"BINARY_OP"> { +def EmitPythonBytecode_BinaryOp: EmitPythonBytecode_Op<"BINARY_OP", + [MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs, UI8Attr:$operation_type); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_InplaceOp: EmitPythonBytecode_Op<"INPLACE_OP"> { +def EmitPythonBytecode_InplaceOp: EmitPythonBytecode_Op<"INPLACE_OP", + [MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$dst, Python_PyObjectType:$src, UI8Attr:$operation_type); let results = (outs Python_PyObjectType:$output); @@ -163,13 +184,14 @@ def EmitPythonBytecode_JumpIfNotException: EmitPythonBytecode_Op<"JUMP_IF_NOT_EX }]; } -def EmitPythonBytecode_Compare: EmitPythonBytecode_Op<"COMPARE"> { +def EmitPythonBytecode_Compare: EmitPythonBytecode_Op<"COMPARE", + [MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs, UI8Attr:$predicate); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadAssertionError : EmitPythonBytecode_Op<"LOAD_ASSERTION_ERROR"> { +def EmitPythonBytecode_LoadAssertionError : EmitPythonBytecode_Op<"LOAD_ASSERTION_ERROR", [Pure]> { let summary = [{ Loads AssertionError type. This is guaranteed to be the builtin AssertionError. @@ -193,7 +215,8 @@ def EmitPythonBytecode_ReRaiseOp: EmitPythonBytecode_Op<"RERAISE", [Terminator]> let summary = "Re-raise the last exception"; } -def EmitPythonBytecode_UnaryOp : EmitPythonBytecode_Op<"UNARY"> { +def EmitPythonBytecode_UnaryOp : EmitPythonBytecode_Op<"UNARY", + [MemoryEffects<[MemWrite]>]> { let summary = "Unary operation"; let arguments = (ins Python_PyObjectType:$input, UI8Attr:$operation_type); @@ -201,23 +224,35 @@ def EmitPythonBytecode_UnaryOp : EmitPythonBytecode_Op<"UNARY"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_BuildDict : EmitPythonBytecode_Op<"BUILD_DICT", [SameVariadicOperandSize]> { +// Build / mutate / coerce collection ops: each constructs an object +// with distinct Python identity or mutates an existing object's +// state. MemWrite documents both the "different invocations produce +// different identities" property (so CSE can't merge two BuildList +// calls) and the explicit-mutation property (DICT_ADD etc.). +def EmitPythonBytecode_BuildDict : EmitPythonBytecode_Op<"BUILD_DICT", + [SameVariadicOperandSize, MemoryEffects<[MemWrite]>]> { let summary = "Build a dictionary object using keys and values"; let arguments = (ins Variadic:$keys, Variadic:$values); let results = (outs Python_PyObjectType:$output); + + // Register-pressure relief for large dict literals. See + // ExpandLargeBuildDict in EmitPythonBytecode.cpp. + let hasCanonicalizer = 1; } -def EmitPythonBytecode_DictUpdate : EmitPythonBytecode_Op<"DICT_UPDATE"> { +def EmitPythonBytecode_DictUpdate : EmitPythonBytecode_Op<"DICT_UPDATE", + [MemoryEffects<[MemWrite]>]> { let summary = "Updates a dictionary's entries using another mappable type"; let arguments = (ins Python_PyObjectType:$dict, Python_PyObjectType:$mappable); } -def EmitPythonBytecode_DictAdd : EmitPythonBytecode_Op<"DICT_ADD"> { +def EmitPythonBytecode_DictAdd : EmitPythonBytecode_Op<"DICT_ADD", + [MemoryEffects<[MemWrite]>]> { let summary = "Adds a key value pair to a dictionary"; let arguments = (ins Python_PyObjectType:$dict, @@ -225,27 +260,35 @@ def EmitPythonBytecode_DictAdd : EmitPythonBytecode_Op<"DICT_ADD"> { Python_PyObjectType:$value); } -def EmitPythonBytecode_BuildList : EmitPythonBytecode_Op<"BUILD_LIST"> { +def EmitPythonBytecode_BuildList : EmitPythonBytecode_Op<"BUILD_LIST", + [MemoryEffects<[MemWrite]>]> { let summary = "Builds a Python list"; let arguments = (ins Variadic:$elements); let results = (outs Python_PyObjectType:$output); + + // Register-pressure relief for all-constants list literals. See + // FoldAllConstBuildListIntoExtend in EmitPythonBytecode.cpp. + let hasCanonicalizer = 1; } -def EmitPythonBytecode_ListExtend : EmitPythonBytecode_Op<"LIST_EXTEND"> { +def EmitPythonBytecode_ListExtend : EmitPythonBytecode_Op<"LIST_EXTEND", + [MemoryEffects<[MemWrite]>]> { let summary = "Extends a Python list"; let arguments = (ins Python_PyObjectType:$list, Python_PyObjectType:$iterable); } -def EmitPythonBytecode_ListAppend : EmitPythonBytecode_Op<"LIST_APPEND"> { +def EmitPythonBytecode_ListAppend : EmitPythonBytecode_Op<"LIST_APPEND", + [MemoryEffects<[MemWrite]>]> { let summary = "Append to a Python list"; let arguments = (ins Python_PyObjectType:$list, Python_PyObjectType:$value); } -def EmitPythonBytecode_ListToTuple : EmitPythonBytecode_Op<"LIST_TO_TUPLE"> { +def EmitPythonBytecode_ListToTuple : EmitPythonBytecode_Op<"LIST_TO_TUPLE", + [MemoryEffects<[MemWrite]>]> { let summary = "Builds a Python tuple from a list"; let arguments = (ins Python_PyObjectType:$list); @@ -253,7 +296,8 @@ def EmitPythonBytecode_ListToTuple : EmitPythonBytecode_Op<"LIST_TO_TUPLE"> { let results = (outs Python_PyObjectType:$tuple); } -def EmitPythonBytecode_BuildTuple : EmitPythonBytecode_Op<"BUILD_TUPLE"> { +def EmitPythonBytecode_BuildTuple : EmitPythonBytecode_Op<"BUILD_TUPLE", + [MemoryEffects<[MemWrite]>]> { let summary = "Builds a Python tuple"; let arguments = (ins Variadic:$elements); @@ -261,7 +305,8 @@ def EmitPythonBytecode_BuildTuple : EmitPythonBytecode_Op<"BUILD_TUPLE"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_BuildSet : EmitPythonBytecode_Op<"BUILD_SET"> { +def EmitPythonBytecode_BuildSet : EmitPythonBytecode_Op<"BUILD_SET", + [MemoryEffects<[MemWrite]>]> { let summary = "Builds a Python set"; let arguments = (ins Variadic:$elements); @@ -269,19 +314,22 @@ def EmitPythonBytecode_BuildSet : EmitPythonBytecode_Op<"BUILD_SET"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_SetAdd : EmitPythonBytecode_Op<"SET_ADD"> { - let summary = "Adds elemets to a set"; +def EmitPythonBytecode_SetAdd : EmitPythonBytecode_Op<"SET_ADD", + [MemoryEffects<[MemWrite]>]> { + let summary = "Adds an element to a set"; - let arguments = (ins Python_PyObjectType:$set, Python_PyObjectType:$element); + let arguments = (ins Python_PyObjectType:$set, Python_PyObjectType:$value); } -def EmitPythonBytecode_SetUpdate : EmitPythonBytecode_Op<"SET_UPDATE"> { +def EmitPythonBytecode_SetUpdate : EmitPythonBytecode_Op<"SET_UPDATE", + [MemoryEffects<[MemWrite]>]> { let summary = "Updates a set with an iterable"; let arguments = (ins Python_PyObjectType:$set, Python_PyObjectType:$iterable); } -def EmitPythonBytecode_BuildString : EmitPythonBytecode_Op<"BUILD_STRING"> { +def EmitPythonBytecode_BuildString : EmitPythonBytecode_Op<"BUILD_STRING", + [MemoryEffects<[MemWrite]>]> { let summary = "Build a string"; let arguments = (ins Variadic:$elements); @@ -289,17 +337,22 @@ def EmitPythonBytecode_BuildString : EmitPythonBytecode_Op<"BUILD_STRING"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_BuildSlice: EmitPythonBytecode_Op<"BUILD_SLICE"> { +def EmitPythonBytecode_BuildSlice: EmitPythonBytecode_Op<"BUILD_SLICE", + [MemoryEffects<[MemWrite]>]> { let summary = "Build a slice object"; + // `step` is optional: matches py.build_slice in the source dialect, + // and the bytecode emitter selects the 3- vs 4-arg BuildSlice + // instruction constructor based on whether step is present. let arguments = (ins Python_PyObjectType:$lower, Python_PyObjectType:$upper, - Python_PyObjectType:$step); + Optional:$step); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_FormatValue : EmitPythonBytecode_Op<"FORMAT_VALUE"> { +def EmitPythonBytecode_FormatValue : EmitPythonBytecode_Op<"FORMAT_VALUE", + [MemoryEffects<[MemWrite]>]> { let summary = "Format value as a string using specified conversion"; let arguments = (ins Python_PyObjectType:$value, @@ -308,7 +361,11 @@ def EmitPythonBytecode_FormatValue : EmitPythonBytecode_Op<"FORMAT_VALUE"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadAttribute: EmitPythonBytecode_Op<"LOAD_ATTRIBUTE"> { +// Attribute / method / subscript access: each dispatches to a +// user-overridable __getattr__ / __getitem__ / __setitem__ / +// __setattr__ / __delitem__ / __delattr__. +def EmitPythonBytecode_LoadAttribute: EmitPythonBytecode_Op<"LOAD_ATTRIBUTE", + [MemoryEffects<[MemWrite]>]> { let summary = "Load an attribute"; let arguments = (ins Python_PyObjectType:$self, @@ -317,14 +374,16 @@ def EmitPythonBytecode_LoadAttribute: EmitPythonBytecode_Op<"LOAD_ATTRIBUTE"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_DeleteAttribute: EmitPythonBytecode_Op<"DELETE_ATTRIBUTE"> { +def EmitPythonBytecode_DeleteAttribute: EmitPythonBytecode_Op<"DELETE_ATTRIBUTE", + [MemoryEffects<[MemWrite]>]> { let summary = "Delete an attribute"; let arguments = (ins Python_PyObjectType:$self, - StrAttr:$attribute); + StrAttr:$attr); } -def EmitPythonBytecode_BinarySubscript: EmitPythonBytecode_Op<"BINARY_SUBSCRIPT"> { +def EmitPythonBytecode_BinarySubscript: EmitPythonBytecode_Op<"BINARY_SUBSCRIPT", + [MemoryEffects<[MemWrite]>]> { let summary = "Subscript an object"; let arguments = (ins Python_PyObjectType:$self, @@ -333,7 +392,8 @@ def EmitPythonBytecode_BinarySubscript: EmitPythonBytecode_Op<"BINARY_SUBSCRIPT" let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_StoreSubscript: EmitPythonBytecode_Op<"STORE_SUBSCRIPT"> { +def EmitPythonBytecode_StoreSubscript: EmitPythonBytecode_Op<"STORE_SUBSCRIPT", + [MemoryEffects<[MemWrite]>]> { let summary = "Store value in object using a subscript"; let arguments = (ins Python_PyObjectType:$self, @@ -341,22 +401,25 @@ def EmitPythonBytecode_StoreSubscript: EmitPythonBytecode_Op<"STORE_SUBSCRIPT"> Python_PyObjectType:$value); } -def EmitPythonBytecode_DeleteSubscript: EmitPythonBytecode_Op<"DELETE_SUBSCRIPT"> { +def EmitPythonBytecode_DeleteSubscript: EmitPythonBytecode_Op<"DELETE_SUBSCRIPT", + [MemoryEffects<[MemWrite]>]> { let summary = "Delete value in object using a subscript"; let arguments = (ins Python_PyObjectType:$self, Python_PyObjectType:$subscript); } -def EmitPythonBytecode_StoreAttribute: EmitPythonBytecode_Op<"STORE_ATTRIBUTE"> { +def EmitPythonBytecode_StoreAttribute: EmitPythonBytecode_Op<"STORE_ATTRIBUTE", + [MemoryEffects<[MemWrite]>]> { let summary = "Store value in object using an attribute"; let arguments = (ins Python_PyObjectType:$self, - StrAttr:$attribute, + StrAttr:$attr, Python_PyObjectType:$value); } -def EmitPythonBytecode_LoadMethod: EmitPythonBytecode_Op<"LOAD_METHOD"> { +def EmitPythonBytecode_LoadMethod: EmitPythonBytecode_Op<"LOAD_METHOD", + [MemoryEffects<[MemWrite]>]> { let summary = "Load a method"; let arguments = (ins Python_PyObjectType:$self, @@ -365,7 +428,8 @@ def EmitPythonBytecode_LoadMethod: EmitPythonBytecode_Op<"LOAD_METHOD"> { let results = (outs Python_PyObjectType:$method); } -def EmitPythonBytecode_MakeFunction : EmitPythonBytecode_Op<"MAKE_FUNCTION", [AttrSizedOperandSegments]> { +def EmitPythonBytecode_MakeFunction : EmitPythonBytecode_Op<"MAKE_FUNCTION", + [AttrSizedOperandSegments, MemoryEffects<[MemWrite]>]> { let summary = "Instantiates a new function object"; let arguments = (ins Python_PyObjectType:$sym_name, @@ -405,21 +469,25 @@ def EmitPythonBytecode_SetupWith : EmitPythonBytecode_Op<"SETUP_WITH", [Terminat let successors = (successor AnySuccessor:$body, AnySuccessor:$handler); } -def EmitPythonBytecode_WithExceptStart : EmitPythonBytecode_Op<"WITH_EXCEPT_START"> { +def EmitPythonBytecode_WithExceptStart : EmitPythonBytecode_Op<"WITH_EXCEPT_START", + [MemoryEffects<[MemWrite]>]> { let arguments = (ins Python_PyObjectType:$exit_method); let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_ClearExceptionState: EmitPythonBytecode_Op<"CLEAR_EXCEPTION_STATE"> { +def EmitPythonBytecode_ClearExceptionState: EmitPythonBytecode_Op<"CLEAR_EXCEPTION_STATE", + [MemoryEffects<[MemWrite]>]> { let summary = "Clears the interpreter exception state"; } -def EmitPythonBytecode_LeaveExceptionHandle : EmitPythonBytecode_Op<"LEAVE_EXCEPTION_HANDLE"> { +def EmitPythonBytecode_LeaveExceptionHandle : EmitPythonBytecode_Op<"LEAVE_EXCEPTION_HANDLE", + [MemoryEffects<[MemWrite]>]> { let summary = "Leave current exception handling"; } -def EmitPythonBytecode_ImportName : EmitPythonBytecode_Op<"IMPORT_NAME"> { +def EmitPythonBytecode_ImportName : EmitPythonBytecode_Op<"IMPORT_NAME", + [MemoryEffects<[MemWrite]>]> { let summary = "Import Python module"; let arguments = (ins StrAttr:$name, @@ -429,7 +497,8 @@ def EmitPythonBytecode_ImportName : EmitPythonBytecode_Op<"IMPORT_NAME"> { let results = (outs Python_PyObjectType:$module); } -def EmitPythonBytecode_ImportFrom : EmitPythonBytecode_Op<"IMPORT_FROM"> { +def EmitPythonBytecode_ImportFrom : EmitPythonBytecode_Op<"IMPORT_FROM", + [MemoryEffects<[MemWrite]>]> { let summary = "Import an object from a Python module using its name"; let arguments = (ins Python_PyObjectType:$module, @@ -438,13 +507,15 @@ def EmitPythonBytecode_ImportFrom : EmitPythonBytecode_Op<"IMPORT_FROM"> { let results = (outs Python_PyObjectType:$object); } -def EmitPythonBytecode_ImportAll : EmitPythonBytecode_Op<"IMPORT_ALL"> { +def EmitPythonBytecode_ImportAll : EmitPythonBytecode_Op<"IMPORT_ALL", + [MemoryEffects<[MemWrite]>]> { let summary = "Import all objects from a Python module"; let arguments = (ins Python_PyObjectType:$module); } -def EmitPythonBytecode_GetIter: EmitPythonBytecode_Op<"GET_ITER"> { +def EmitPythonBytecode_GetIter: EmitPythonBytecode_Op<"GET_ITER", + [MemoryEffects<[MemWrite]>]> { let summary = "Convenience op that calls Python's iter function on an iterable"; let arguments = (ins Python_PyObjectType:$iterable); @@ -461,7 +532,8 @@ def EmitPythonBytecode_ForIter: EmitPythonBytecode_Op<"FOR_ITER", [Terminator, let successors = (successor AnySuccessor:$body, AnySuccessor:$continuation); } -def EmitPythonBytecode_CastToBool: EmitPythonBytecode_Op<"TO_BOOL"> { +def EmitPythonBytecode_CastToBool: EmitPythonBytecode_Op<"TO_BOOL", + [MemoryEffects<[MemWrite]>]> { let summary = "Converts a value to a Python bool"; let arguments = (ins Python_PyObjectType:$value); @@ -469,13 +541,14 @@ def EmitPythonBytecode_CastToBool: EmitPythonBytecode_Op<"TO_BOOL"> { let results = (outs Python_PyObjectType:$output); } -def EmitPythonBytecode_LoadBuildClass: EmitPythonBytecode_Op<"LOAD_BUILD_CLASS"> { +def EmitPythonBytecode_LoadBuildClass: EmitPythonBytecode_Op<"LOAD_BUILD_CLASS", [Pure]> { let summary = "Loads the builtin class builder factory function"; let results = (outs Python_PyObjectType:$class_builder); } -def EmitPythonBytecode_Yield : EmitPythonBytecode_Op<"YIELD_VALUE"> { +def EmitPythonBytecode_Yield : EmitPythonBytecode_Op<"YIELD_VALUE", + [MemoryEffects<[MemWrite]>]> { let summary = "Yield value from generator"; let arguments = (ins Python_PyObjectType:$value); @@ -483,7 +556,8 @@ def EmitPythonBytecode_Yield : EmitPythonBytecode_Op<"YIELD_VALUE"> { let results = (outs Python_PyObjectType:$received); } -def EmitPythonBytecode_YieldFromIter : EmitPythonBytecode_Op<"YIELD_FROM_ITER"> { +def EmitPythonBytecode_YieldFromIter : EmitPythonBytecode_Op<"YIELD_FROM_ITER", + [MemoryEffects<[MemWrite]>]> { let summary = "Get the iterator from iterable with special cases for generators and coroutines"; let arguments = (ins Python_PyObjectType:$iterable); @@ -491,7 +565,8 @@ def EmitPythonBytecode_YieldFromIter : EmitPythonBytecode_Op<"YIELD_FROM_ITER"> let results = (outs Python_PyObjectType:$iterator); } -def EmitPythonBytecode_YieldFrom : EmitPythonBytecode_Op<"YIELD_FROM"> { +def EmitPythonBytecode_YieldFrom : EmitPythonBytecode_Op<"YIELD_FROM", + [MemoryEffects<[MemWrite]>]> { let summary = "Yield iterator from generator"; let arguments = (ins Python_PyObjectType:$iterator, Python_PyObjectType:$value); @@ -499,7 +574,8 @@ def EmitPythonBytecode_YieldFrom : EmitPythonBytecode_Op<"YIELD_FROM"> { let results = (outs Python_PyObjectType:$received); } -def UnpackSequenceOp: EmitPythonBytecode_Op<"UNPACK_SEQUENCE"> { +def UnpackSequenceOp: EmitPythonBytecode_Op<"UNPACK_SEQUENCE", + [MemoryEffects<[MemWrite]>]> { let summary = "Unpack sequence"; let arguments = (ins Python_PyObjectType:$iterable); @@ -507,7 +583,8 @@ def UnpackSequenceOp: EmitPythonBytecode_Op<"UNPACK_SEQUENCE"> { let results = (outs Variadic:$unpacked_values); } -def UnpackExpandOp: EmitPythonBytecode_Op<"UNPACK_EXPAND"> { +def UnpackExpandOp: EmitPythonBytecode_Op<"UNPACK_EXPAND", + [MemoryEffects<[MemWrite]>]> { let summary = "Unpack and expand rest"; let arguments = (ins Python_PyObjectType:$iterable); @@ -516,10 +593,11 @@ def UnpackExpandOp: EmitPythonBytecode_Op<"UNPACK_EXPAND"> { Python_PyObjectType:$rest); } -def GetAwaitableOp : EmitPythonBytecode_Op<"GET_AWAITABLE"> { +def GetAwaitableOp : EmitPythonBytecode_Op<"GET_AWAITABLE", + [MemoryEffects<[MemWrite]>]> { let summary = "Get awaitable"; let arguments = (ins Python_PyObjectType:$iterable); let results = (outs Python_PyObjectType:$iterator); -} \ No newline at end of file +} diff --git a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecodeBase.td b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecodeBase.td index 006d37eb..4baa9c58 100644 --- a/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecodeBase.td +++ b/src/executable/mlir/Dialect/EmitPythonBytecode/IR/EmitPythonBytecodeBase.td @@ -5,5 +5,30 @@ def EmitPythonBytecode_Dialect : Dialect { let cppNamespace = "::mlir::emitpybytecode"; let summary = "Dialect to generate Python bytecode for the PythonCpp VM from MLIR."; - let description = "TODO"; + let description = [{ +The EmitPythonBytecode dialect is the lower-level half of the +compilation pipeline. It models the wire format consumed by the +PythonCpp register-based VM as MLIR ops: each op corresponds 1:1 +(or close to it) to a bytecode instruction emitted by +`codegen::translateToPythonBytecode`. Ops mirror CPython opcode +names (LOAD_FAST, BINARY_OP, CALL, …) in uppercase to make the +correspondence obvious when reading IR dumps. + +Compared to the upstream `python` dialect, ops here are +deliberately denormalised: enum kinds are encoded as raw +`ui8`/`ui16` attributes matching the VM wire format, control flow +is expressed as CFG branches (no region-bearing structured ops), +and most "may run user code" semantic refinement has already +happened. As a result, this dialect is not a useful target for +high-level optimisation — its job is to provide a stable IR +checkpoint right before bytecode emission so that +canonicalize / CSE can dedupe LOAD_CONST and similar pure ops +introduced during conversion, and so that the bytecode emitter +itself is a structurally simple walker. + +New ops belong here when they directly model a VM instruction +that doesn't already exist; semantic refinement (operator +dispatch, scope-aware name resolution, etc.) belongs in the +`python` dialect instead. + }]; } \ No newline at end of file diff --git a/src/executable/mlir/Dialect/Python/CMakeLists.txt b/src/executable/mlir/Dialect/Python/CMakeLists.txt index f8b9b680..35a2f794 100644 --- a/src/executable/mlir/Dialect/Python/CMakeLists.txt +++ b/src/executable/mlir/Dialect/Python/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(PythonMLIRDialect MLIRPass MLIRSupport MLIRTransforms + MLIRTransformUtils ) target_include_directories(PythonMLIRDialect PUBLIC diff --git a/src/executable/mlir/Dialect/Python/IR/Dialect.hpp b/src/executable/mlir/Dialect/Python/IR/Dialect.hpp index 9eea94df..dae0de55 100644 --- a/src/executable/mlir/Dialect/Python/IR/Dialect.hpp +++ b/src/executable/mlir/Dialect/Python/IR/Dialect.hpp @@ -1,5 +1,45 @@ #pragma once #include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "Python/IR/Dialect.h.inc" \ No newline at end of file +#include "Python/IR/Dialect.h.inc" + +namespace mlir::py { + +// Resource representing the Python interpreter's active-exception state. +// Used as the write-resource for Load* ops: a load that targets a name not +// bound in the current scope raises NameError / UnboundLocalError, which +// is a visible side effect. Modelling that effect as a MemWrite on this +// resource prevents canonicalize/DCE from eliminating a load whose result +// happens to be unused. +struct PythonExceptionStateResource + : public ::mlir::SideEffects::Resource::Base +{ + ::mlir::StringRef getName() const override { return "PythonExceptionState"; } +}; + +namespace OpTrait { + // Trait asserting that an op carries a non-empty StringAttr "name". Used + // by the {Load,Store,Delete}{Name,Fast,Global,Deref} family in + // PythonOps.td. The check is defensive: code-gen always supplies a + // non-empty name, but a malformed parse or a future bug shouldn't + // silently produce an op with no binding target. + template + class NamedOp : public ::mlir::OpTrait::TraitBase + { + public: + static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) + { + auto name = op->getAttrOfType<::mlir::StringAttr>("name"); + if (!name) { return op->emitOpError("requires a 'name' StringAttr"); } + if (name.getValue().empty()) { + return op->emitOpError("requires a non-empty 'name' attribute"); + } + return ::mlir::success(); + } + }; +}// namespace OpTrait + +}// namespace mlir::py \ No newline at end of file diff --git a/src/executable/mlir/Dialect/Python/IR/Ops.cpp b/src/executable/mlir/Dialect/Python/IR/Ops.cpp index 867c30a8..4cb17522 100644 --- a/src/executable/mlir/Dialect/Python/IR/Ops.cpp +++ b/src/executable/mlir/Dialect/Python/IR/Ops.cpp @@ -19,6 +19,20 @@ namespace mlir { namespace py { + namespace { + // MLIR 23 removed RegionBranchPoint::getRegionOrNull(). The new API exposes + // the terminator op via getTerminatorPredecessorOrNull(); the region is the + // terminator's parent region. Returns nullptr when the branch point is the + // parent op. + mlir::Region *predecessor_region(mlir::RegionBranchPoint point) + { + if (point.isParent()) { return nullptr; } + auto term = point.getTerminatorPredecessorOrNull(); + if (!term) { return nullptr; } + return term.getOperation()->getParentRegion(); + } + }// namespace + void PythonDialect::initialize() { addOperations< @@ -85,12 +99,9 @@ namespace py { void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::py::PyEllipsisType) + mlir::py::EllipsisAttr value) { - ConstantOp::build(builder, - state, - PyObjectType::get(builder.getContext()), - EllipsisAttr::get(builder.getContext())); + ConstantOp::build(builder, state, PyObjectType::get(builder.getContext()), value); } void ConstantOp::build(mlir::OpBuilder &builder, @@ -108,6 +119,120 @@ namespace py { return mlir::detail::AttributeUniquer::get(context); } + mlir::LogicalResult ConstantOp::verify() + { + mlir::Attribute attr = getValue(); + // Accepted attribute kinds correspond to the constant kinds the + // ConstantOp builders can construct, plus EllipsisAttr (lowered to + // LoadEllipsisOp by the conversion pass). + if (mlir::isa(attr)) { + return mlir::success(); + } + return emitOpError() << "value attribute has unsupported kind: " << attr; + } + + mlir::LogicalResult FunctionCallOp::verify() + { + // Two distinct shapes share this op: + // + // 1. Plain kwargs: keywords[i] names kwargs[i], so the two lists + // are parallel and must agree in length. + // + // 2. Expansion: when requires_{args,kwargs}_expansion is set, + // the corresponding operand list holds at most one value (a + // tuple or dict to splat), and keywords is empty regardless + // of how many actual keyword names the call will produce at + // runtime. MLIRGenerator currently still emits a kwargs + // operand alongside an args-expansion-only call, so we can't + // use the parallel rule in that case. + if (getRequiresArgsExpansion()) { + if (getArgs().size() > 1) { + return emitOpError() + << "requires_args_expansion expects at most one args operand, got " + << getArgs().size(); + } + } + if (getRequiresKwargsExpansion()) { + if (getKwargs().size() > 1) { + return emitOpError() + << "requires_kwargs_expansion expects at most one kwargs operand, got " + << getKwargs().size(); + } + } + if (getRequiresArgsExpansion() || getRequiresKwargsExpansion()) { return mlir::success(); } + const auto keywords_size = getKeywords().size(); + const auto kwargs_size = getKwargs().size(); + if (keywords_size != kwargs_size) { + return emitOpError() << "has " << keywords_size << " keyword name(s) but " + << kwargs_size << " kwargs value(s)"; + } + return mlir::success(); + } + + mlir::LogicalResult UnpackSequenceOp::verify() + { + // Variadic results, but a zero-result unpack is meaningless: there + // would be no Python-level target receiving the unpacked values + // and the bytecode emitter would still pay for the iterator + // machinery. MLIRGenerator only emits this for sequence-assignment + // targets, which always carry at least one binding. + if (getUnpackedValues().empty()) { + return emitOpError() << "must produce at least one unpacked value"; + } + return mlir::success(); + } + + mlir::LogicalResult ClassDefinitionOp::verify() + { + // keywords[i] names kwargs[i] — the two lists must agree in length. + // Unlike py.call, ClassDefinitionOp has no expansion flag, so the + // parallel rule always applies. + const auto keywords_size = getKeywords().size(); + const auto kwargs_size = getKwargs().size(); + if (keywords_size != kwargs_size) { + return emitOpError() << "has " << keywords_size << " keyword name(s) but " + << kwargs_size << " kwargs value(s)"; + } + return mlir::success(); + } + + mlir::LogicalResult BuildDictOp::verify() + { + // SameVariadicOperandSize already enforces keys.size() == values.size(). + // requires_expansion is one bool per kv pair. + const auto expansion_size = getRequiresExpansion().size(); + const auto keys_size = getKeys().size(); + if (expansion_size != keys_size) { + return emitOpError() << "requires_expansion has " << expansion_size + << " entries but op has " << keys_size << " key/value pairs"; + } + return mlir::success(); + } + + namespace { + template mlir::LogicalResult verify_elementwise_expansion(Op op) + { + const auto expansion_size = op.getRequiresExpansion().size(); + const auto elements_size = op.getElements().size(); + if (expansion_size != elements_size) { + return op.emitOpError() << "requires_expansion has " << expansion_size + << " entries but op has " << elements_size << " elements"; + } + return mlir::success(); + } + }// namespace + + mlir::LogicalResult BuildListOp::verify() { return verify_elementwise_expansion(*this); } + mlir::LogicalResult BuildTupleOp::verify() { return verify_elementwise_expansion(*this); } + mlir::LogicalResult BuildSetOp::verify() { return verify_elementwise_expansion(*this); } + SuccessorOperands CondBranchSubclassOp::getSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); @@ -115,29 +240,52 @@ namespace py { index == 0 ? getTrueDestOperandsMutable() : getFalseDestOperandsMutable()); } + namespace { + // Shared getSuccessorInputs implementation for all four + // RegionBranchOpInterface ops in this dialect (WhileOp, ForLoopOp, + // TryOp, TryHandlerOp). MLIR 23 split the (region, block-args) + // pair that used to be carried by RegionSuccessor: the region is now + // in RegionSuccessor and the inputs come from this method. The + // inputs MLIR expects here are the operands the parent op forwards + // INTO the destination (which the verifier matches against the op's + // operand list along each control-flow edge). + // + // None of these ops actually forward operands into their regions or + // out of them: the for-loop iterator is consumed by FOR_ITER inside + // `step`; the while condition is computed inside `condition`; try's + // regions communicate via the exception state, not block args; and + // the ops' results are produced by the BranchYieldOp terminator, + // not threaded through the parent successor. Return empty in all + // cases. + mlir::ValueRange region_or_block_arguments(mlir::Operation *, mlir::RegionSuccessor) + { + return mlir::ValueRange{}; + } + }// namespace + // Based on CIR loop interface implementation void WhileOp::getSuccessorRegions(mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { // Branching to first region: go to condition. if (point.isParent()) { - regions.emplace_back(&getCondition(), getCondition().getArguments()); + regions.emplace_back(&getCondition()); } // Branching from condition: go to body or, exit or orelse if non-empty. - else if (point.getRegionOrNull() == &getCondition()) { + else if (predecessor_region(point) == &getCondition()) { if (getOrelse().empty()) { - regions.emplace_back(RegionSuccessor(getOperation()->getResults())); + regions.emplace_back(RegionSuccessor::parent()); } else { - regions.emplace_back(&getOrelse(), getOrelse().getArguments()); + regions.emplace_back(&getOrelse()); } - regions.emplace_back(&getBody(), getBody().getArguments()); + regions.emplace_back(&getBody()); } // Branching from body: go to condition. - else if (point.getRegionOrNull() == &getBody()) { - regions.emplace_back(&getCondition(), getCondition().getArguments()); + else if (predecessor_region(point) == &getBody()) { + regions.emplace_back(&getCondition()); } // Branching from orelse - can't go anywhere else. - else if (point.getRegionOrNull() == &getOrelse()) { + else if (predecessor_region(point) == &getOrelse()) { } else { llvm_unreachable("unexpected branch origin"); } @@ -148,139 +296,259 @@ namespace py { { // Branching to first region: go to step. if (point.isParent()) { - regions.emplace_back(&getStep(), getStep().getArguments()); + regions.emplace_back(&getStep()); } // Branching from condition: go to body or, exit or orelse if non-empty. - else if (point.getRegionOrNull() == &getStep()) { + else if (predecessor_region(point) == &getStep()) { if (getOrelse().empty()) { - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(RegionSuccessor::parent()); } else { - regions.emplace_back(&getOrelse(), getOrelse().getArguments()); + regions.emplace_back(&getOrelse()); } - regions.emplace_back(&getBody(), getBody().getArguments()); + regions.emplace_back(&getBody()); } // Branching from body: go to step. - else if (point.getRegionOrNull() == &getBody()) { - regions.emplace_back(&getStep(), getStep().getArguments()); + else if (predecessor_region(point) == &getBody()) { + regions.emplace_back(&getStep()); } // Branching from orelse - can't go anywhere else. - else if (point.getRegionOrNull() == &getOrelse()) { + else if (predecessor_region(point) == &getOrelse()) { } else { llvm_unreachable("unexpected branch origin"); } } + mlir::ValueRange WhileOp::getSuccessorInputs(mlir::RegionSuccessor successor) + { + return region_or_block_arguments(getOperation(), successor); + } + + mlir::ValueRange ForLoopOp::getSuccessorInputs(mlir::RegionSuccessor successor) + { + return region_or_block_arguments(getOperation(), successor); + } + void TryOp::getSuccessorRegions(mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { // Branching to first region: go to try body. if (point.isParent()) { - regions.emplace_back(&getBody(), getBody().getArguments()); + regions.emplace_back(&getBody()); } // Branching from try body: go to first handler and orelse block if non-empty, if there no // handlers go to finally - else if (point.getRegionOrNull() == &getBody()) { + else if (predecessor_region(point) == &getBody()) { if (!getHandlers().empty()) { - regions.emplace_back(&getHandlers().front(), getHandlers().front().getArguments()); - if (!getOrelse().empty()) { - regions.emplace_back(&getOrelse(), getOrelse().getArguments()); - } + regions.emplace_back(&getHandlers().front()); + if (!getOrelse().empty()) { regions.emplace_back(&getOrelse()); } } else { assert(getOrelse().empty()); - regions.emplace_back(&getFinally(), getFinally().getArguments()); + regions.emplace_back(&getFinally()); } } // Branching from handler: go to next handler if there is one, if not go to finally. else if (auto it = std::find_if(getHandlers().begin(), getHandlers().end(), [&point]( - mlir::Region &handler) { return point.getRegionOrNull() == &handler; }); + mlir::Region &handler) { return predecessor_region(point) == &handler; }); it != getHandlers().end()) { if (std::next(it) != getHandlers().end()) { it++; - regions.emplace_back(&*it, it->getArguments()); - } - if (!getFinally().empty()) { - regions.emplace_back(&getFinally(), getFinally().getArguments()); + regions.emplace_back(&*it); } + if (!getFinally().empty()) { regions.emplace_back(&getFinally()); } // regions.emplace_back(getOperation()->getParentRegion()); } // Branch from orelse: go to finally or parent - else if (point.getRegionOrNull() == &getOrelse()) { + else if (predecessor_region(point) == &getOrelse()) { if (!getFinally().empty()) { - regions.emplace_back(&getFinally(), getFinally().getArguments()); + regions.emplace_back(&getFinally()); } else { regions.emplace_back(getOperation()->getParentRegion()); } } // Branch from finally: go to parent - else if (point.getRegionOrNull() == &getFinally()) { + else if (predecessor_region(point) == &getFinally()) { regions.emplace_back(getOperation()->getParentRegion()); } } - void ControlFlowYield::getSuccessorRegions(llvm::ArrayRef operands, + void BranchYieldOp::getSuccessorRegions(llvm::ArrayRef operands, llvm::SmallVectorImpl ®ions) { - static_assert(ControlFlowYield::hasTrait< - mlir::OpTrait::HasParent::Impl>()); + static_assert(BranchYieldOp::hasTrait< + mlir::OpTrait::HasParent::Impl>()); if (getKind().has_value()) { return; } - auto result = - llvm::TypeSwitch(getOperation()->getParentOp()) - .Case([®ions](TryOp op) -> LogicalResult { - // regions.emplace_back(&op.getRegion()); - llvm_unreachable("TODO"); - return failure(); - }) - .Case([this, ®ions](ForLoopOp op) -> LogicalResult { - if (getOperation()->getParentRegion() == &op.getStep()) { - regions.emplace_back(&op.getBody()); - } else if (getOperation()->getParentRegion() == &op.getBody()) { - regions.emplace_back(&op.getStep()); - } else if (getOperation()->getParentRegion() == &op.getOrelse()) { - } else { - llvm_unreachable("unexpected branch origin"); - } - return success(); - }) - .Case([®ions](WithOp op) -> LogicalResult { - regions.emplace_back(op->getParentRegion()); - return success(); - }) - .Case([this, ®ions](WhileOp op) -> LogicalResult { - if (getOperation()->getParentRegion() == &op.getCondition()) { - regions.emplace_back(&op.getBody()); - } else if (getOperation()->getParentRegion() == &op.getBody()) { - regions.emplace_back(&op.getCondition()); - } else if (getOperation()->getParentRegion() == &op.getOrelse()) { - } else { - llvm_unreachable("unexpected branch origin"); - } - return success(); - }) - .Case([this, ®ions](TryHandlerScope op) -> LogicalResult { - llvm_unreachable("todo"); - return failure(); - }) - .Default([](Operation *) -> LogicalResult { - llvm_unreachable("TODO"); - std::abort(); - return failure(); - }); + auto result = llvm::TypeSwitch(getOperation()->getParentOp()) + .Case([this, ®ions](TryOp op) -> LogicalResult { + // Fallthrough (no exception) successors for a yield inside + // a TryOp. The exception → handler edges are not modeled + // here; they're induced by the exception state, not by a + // BranchYieldOp. Successor of yield from: + // body -> orelse if non-empty, else finally if + // non-empty, else parent. + // handler -> finally if non-empty, else parent. + // orelse -> finally if non-empty, else parent. + // finally -> parent. + // Matches the "exit-to-parent" convention used by + // TryOp::getSuccessorRegions (emplaces the containing + // region rather than RegionSuccessor::parent()). + auto *parent_region = getOperation()->getParentRegion(); + auto exit_to_finally_or_parent = [&] { + if (!op.getFinally().empty()) { + regions.emplace_back(&op.getFinally()); + } else { + regions.emplace_back(op->getParentRegion()); + } + }; + if (parent_region == &op.getBody()) { + if (!op.getOrelse().empty()) { + regions.emplace_back(&op.getOrelse()); + } else { + exit_to_finally_or_parent(); + } + } else if (parent_region == &op.getOrelse()) { + exit_to_finally_or_parent(); + } else if (parent_region == &op.getFinally()) { + regions.emplace_back(op->getParentRegion()); + } else { + bool in_handler = false; + for (mlir::Region &handler : op.getHandlers()) { + if (parent_region == &handler) { + in_handler = true; + break; + } + } + if (!in_handler) { llvm_unreachable("unexpected branch origin"); } + exit_to_finally_or_parent(); + } + return success(); + }) + .Case([this, ®ions](ForLoopOp op) -> LogicalResult { + if (getOperation()->getParentRegion() == &op.getStep()) { + regions.emplace_back(&op.getBody()); + } else if (getOperation()->getParentRegion() == &op.getBody()) { + regions.emplace_back(&op.getStep()); + } else if (getOperation()->getParentRegion() == &op.getOrelse()) { + } else { + llvm_unreachable("unexpected branch origin"); + } + return success(); + }) + .Case([®ions](WithOp op) -> LogicalResult { + regions.emplace_back(op->getParentRegion()); + return success(); + }) + .Case([this, ®ions](WhileOp op) -> LogicalResult { + if (getOperation()->getParentRegion() == &op.getCondition()) { + regions.emplace_back(&op.getBody()); + } else if (getOperation()->getParentRegion() == &op.getBody()) { + regions.emplace_back(&op.getCondition()); + } else if (getOperation()->getParentRegion() == &op.getOrelse()) { + } else { + llvm_unreachable("unexpected branch origin"); + } + return success(); + }) + .Case([this, ®ions](TryHandlerOp op) -> LogicalResult { + // TryHandlerOp models a single except-clause: cond is + // the type-match test, handler is the body run on match. + // Yield from cond can fall through to the handler (match) + // or out to the parent (no match - try the next clause or + // re-raise). Yield from handler exits the scope. + auto *parent_region = getOperation()->getParentRegion(); + if (parent_region == &op.getCond()) { + regions.emplace_back(&op.getHandler()); + regions.emplace_back(op->getParentRegion()); + } else if (parent_region == &op.getHandler()) { + regions.emplace_back(op->getParentRegion()); + } else { + llvm_unreachable("unexpected branch origin"); + } + return success(); + }) + .Default([](Operation *) -> LogicalResult { + // Unreachable in verified IR: the static_assert above + // constrains BranchYieldOp's parent op to one of the + // five Cases handled. A failure here means an + // unverified or malformed op slipped past verification. + llvm_unreachable("BranchYieldOp has unexpected parent op kind"); + }); assert(result.succeeded()); } - void TryHandlerScope::getSuccessorRegions(mlir::RegionBranchPoint point, + mlir::ValueRange TryOp::getSuccessorInputs(mlir::RegionSuccessor successor) + { + return region_or_block_arguments(getOperation(), successor); + } + + mlir::ValueRange TryHandlerOp::getSuccessorInputs(mlir::RegionSuccessor successor) + { + return region_or_block_arguments(getOperation(), successor); + } + + void TryHandlerOp::getSuccessorRegions(mlir::RegionBranchPoint point, llvm::SmallVectorImpl ®ions) { - if (point.getRegionOrNull() == &getCond()) { - regions.emplace_back(&getHandler(), getHandler().getArguments()); - } + if (predecessor_region(point) == &getCond()) { regions.emplace_back(&getHandler()); } regions.emplace_back(getOperation()->getParentRegion()); } + + namespace { + // Forward the value of a preceding py.store_fast of the same name in + // the same block when no intervening op kills the binding. Locals + // can only be touched by store_fast / delete_fast in the current + // FuncOp, so most ops (including FunctionCall) are safe to walk + // past - with one exception: ops with regions (py.for_loop / + // py.while / py.try / py.with / py.if-like ops) may contain a + // store_fast or delete_fast targeting this name in their bodies. + // We don't recurse into those regions yet, so conservatively bail + // if any region-bearing op sits between the candidate store and the + // load. + struct ForwardStoreFastToLoadFast : public mlir::OpRewritePattern + { + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(LoadFastOp load, + mlir::PatternRewriter &rewriter) const final + { + const auto name = load.getName(); + for (mlir::Operation *prev = load->getPrevNode(); prev != nullptr; + prev = prev->getPrevNode()) { + if (auto store = mlir::dyn_cast(prev)) { + if (store.getName() == name) { + rewriter.replaceOp(load, store.getValue()); + return mlir::success(); + } + } + if (auto del = mlir::dyn_cast(prev)) { + if (del.getName() == name) { + // The binding was deleted between the store + // and the load - can't forward. + return mlir::failure(); + } + } + if (prev->getNumRegions() > 0) { + // A nested region (loop body, try/with body, ...) + // might contain a store or delete of `name`. Be + // conservative and bail rather than try to prove it + // doesn't. + return mlir::failure(); + } + } + return mlir::failure(); + } + }; + }// namespace + + void LoadFastOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) + { + patterns.add(context); + } }// namespace py }// namespace mlir diff --git a/src/executable/mlir/Dialect/Python/IR/PythonAttributes.td b/src/executable/mlir/Dialect/Python/IR/PythonAttributes.td index 43367128..db687e23 100644 --- a/src/executable/mlir/Dialect/Python/IR/PythonAttributes.td +++ b/src/executable/mlir/Dialect/Python/IR/PythonAttributes.td @@ -38,8 +38,14 @@ def Python_FormatValueConversionAttr : I64EnumAttr< let cppNamespace = "::mlir::py"; } -def Python_InplaceOpKindAttr : I64EnumAttr< - "InplaceOpKind", "", +// Shared by py.binary and py.inplace_op. The set of operations is +// identical between binary and inplace forms in Python (the operator +// is binary; only the assignment differs), so a single enum avoids +// drift between the two and a redundant translation switch in the +// lowering pass. Covers arithmetic, bitwise, and shift operators, +// matching Python's numeric protocols. +def Python_ArithOpKindAttr : I64EnumAttr< + "ArithOpKind", "", [ I64EnumAttrCase<"add", 0>, I64EnumAttrCase<"sub", 1>, diff --git a/src/executable/mlir/Dialect/Python/IR/PythonOps.hpp b/src/executable/mlir/Dialect/Python/IR/PythonOps.hpp index a040df58..c694349b 100644 --- a/src/executable/mlir/Dialect/Python/IR/PythonOps.hpp +++ b/src/executable/mlir/Dialect/Python/IR/PythonOps.hpp @@ -1,6 +1,7 @@ #pragma once #include "Dialect.hpp" +#include "PythonAttributes.hpp" #include "PythonTypes.hpp" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/src/executable/mlir/Dialect/Python/IR/PythonOps.td b/src/executable/mlir/Dialect/Python/IR/PythonOps.td index c97e5231..2ef4f5d4 100644 --- a/src/executable/mlir/Dialect/Python/IR/PythonOps.td +++ b/src/executable/mlir/Dialect/Python/IR/PythonOps.td @@ -3,11 +3,28 @@ include "Python/IR/PythonTypes.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/SymbolInterfaces.td" class Python_Op traits = []> : Op; -def ConstantOp : Python_Op<"constant"> { +// Resource for the "may raise NameError / UnboundLocalError" effect that +// Load* ops have. Defined in C++ as mlir::py::PythonExceptionStateResource +// in Dialect.hpp. Marking Load* with a MemWrite on this resource prevents +// canonicalize/DCE from eliminating loads whose results are unused (the +// exception side effect is observable). +def PythonExceptionStateResource : + Resource<"::mlir::py::PythonExceptionStateResource">; +def PyExceptionStateWrite : MemWrite; + +// Verifies the op carries a non-empty StringAttr "name". Applied to the +// {Load,Store,Delete}{Name,Fast,Global,Deref} variable-access family. +// Implementation: mlir::py::OpTrait::NamedOp in Dialect.hpp. +def NamedOp : NativeOpTrait<"NamedOp"> { + let cppNamespace = "::mlir::py::OpTrait"; +} + +def ConstantOp : Python_Op<"constant", [Pure]> { let summary = "Build a PyObject from a constant"; let arguments = (ins AnyAttr:$value); @@ -21,12 +38,15 @@ def ConstantOp : Python_Op<"constant"> { OpBuilder<(ins "NoneType":$value)>, OpBuilder<(ins "StringAttr":$value)>, OpBuilder<(ins "std::vector":$value)>, - OpBuilder<(ins "::mlir::py::PyEllipsisType":$value)>, + OpBuilder<(ins "::mlir::py::EllipsisAttr":$value)>, OpBuilder<(ins "::mlir::ArrayRef":$elements)>, ]; + + let hasVerifier = 1; } -def LoadNameOp : Python_Op<"load_name"> { +def LoadNameOp : Python_Op<"load_name", + [NamedOp, MemoryEffects<[MemRead, PyExceptionStateWrite]>]> { let summary = "Load PyObject from the local environment"; let arguments = (ins StrAttr:$name); @@ -34,15 +54,23 @@ def LoadNameOp : Python_Op<"load_name"> { let results = (outs Python_PyObjectType:$output); } -def LoadFastOp : Python_Op<"load_fast"> { +def LoadFastOp : Python_Op<"load_fast", + [NamedOp, MemoryEffects<[MemRead, PyExceptionStateWrite]>]> { let summary = "Load PyObject bound variable"; let arguments = (ins StrAttr:$name); let results = (outs Python_PyObjectType:$output); + + // Forward the value of a preceding py.store_fast of the same name in + // the same block when no intervening op writes to that local. Locals + // can only be touched by store_fast / delete_fast in the same FuncOp, + // so most ops (including FunctionCall) don't kill the forwarding. + let hasCanonicalizer = 1; } -def LoadGlobalOp : Python_Op<"load_global"> { +def LoadGlobalOp : Python_Op<"load_global", + [NamedOp, MemoryEffects<[MemRead, PyExceptionStateWrite]>]> { let summary = "Load PyObject from the global environment"; let arguments = (ins StrAttr:$name); @@ -50,7 +78,8 @@ def LoadGlobalOp : Python_Op<"load_global"> { let results = (outs Python_PyObjectType:$output); } -def LoadDerefOp : Python_Op<"load_deref"> { +def LoadDerefOp : Python_Op<"load_deref", + [NamedOp, MemoryEffects<[MemRead, PyExceptionStateWrite]>]> { let summary = "Load PyObject from a cell"; let arguments = (ins StrAttr:$name); @@ -58,57 +87,49 @@ def LoadDerefOp : Python_Op<"load_deref"> { let results = (outs Python_PyObjectType:$output); } -def StoreNameOp : Python_Op<"store_name"> { +def StoreNameOp : Python_Op<"store_name", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Store value to the local environment"; let arguments = (ins StrAttr:$name, Python_PyObjectType:$value); - - let results = (outs Python_PyObjectType:$output); } -def StoreFastOp : Python_Op<"store_fast"> { +def StoreFastOp : Python_Op<"store_fast", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Store value to local bound variable"; let arguments = (ins StrAttr:$name, Python_PyObjectType:$value); - - let results = (outs Python_PyObjectType:$output); } -def StoreGlobalOp : Python_Op<"store_global"> { +def StoreGlobalOp : Python_Op<"store_global", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Store value in the global environment"; let arguments = (ins StrAttr:$name, Python_PyObjectType:$value); - - let results = (outs Python_PyObjectType:$output); } -def StoreDerefOp : Python_Op<"store_deref"> { +def StoreDerefOp : Python_Op<"store_deref", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Store PyObject to a cell"; let arguments = (ins StrAttr:$name, Python_PyObjectType:$value); - - let results = (outs Python_PyObjectType:$output); } -def DeleteNameOp : Python_Op<"delete_name"> { +def DeleteNameOp : Python_Op<"delete_name", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Delete value from the local environment"; let arguments = (ins StrAttr:$name); } -def DeleteFastOp : Python_Op<"delete_fast"> { +def DeleteFastOp : Python_Op<"delete_fast", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Delete value from the bound variable"; let arguments = (ins StrAttr:$name); } -def DeleteGlobalOp : Python_Op<"delete_global"> { +def DeleteGlobalOp : Python_Op<"delete_global", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Delete value from the global environment"; let arguments = (ins StrAttr:$name); } -def DeleteDerefOp : Python_Op<"delete_deref"> { +def DeleteDerefOp : Python_Op<"delete_deref", [NamedOp, MemoryEffects<[MemWrite]>]> { let summary = "Delete value from a cell"; let arguments = (ins StrAttr:$name); @@ -117,15 +138,25 @@ def DeleteDerefOp : Python_Op<"delete_deref"> { def MakeFunctionOp : Python_Op<"make_function", [AttrSizedOperandSegments]> { let summary = "Make a function using the local environment and code object"; + // captures is the list of free-variable names this function closes + // over. Encoded as a plain StrArrayAttr instead of a + // DenseStringElementsAttr — there's no tensor/vector semantics + // here, and StrArrayAttr prints as ["a", "b"] in textual IR + // instead of the noisy dense<["a", "b"]> : tensor<2x!llvm.ptr>. let arguments = (ins FlatSymbolRefAttr:$function_name, Variadic:$defaults, Variadic:$kw_defaults, - Builtin_DenseStringElementsAttr:$captures); + StrArrayAttr:$captures); let results = (outs Python_PyObjectType:$func_object); } -def FunctionCallOp : Python_Op<"call", [AttrSizedOperandSegments]> { +// MemWrite: the strongest case for the conservative effect — a +// generic call dispatches to arbitrary user code (or a builtin +// that can mutate any object reachable through args). Anything +// less than MemWrite would let a future pass illegally reorder or +// CSE calls. +def FunctionCallOp : Python_Op<"call", [AttrSizedOperandSegments, MemoryEffects<[MemWrite]>]> { let summary = "A generic function call"; let arguments = (ins Python_PyObjectType:$callee, @@ -137,123 +168,39 @@ def FunctionCallOp : Python_Op<"call", [AttrSizedOperandSegments]> { ); let results = (outs Python_PyObjectType:$output); -} - -def BinaryAddOp : Python_Op<"add"> { - let summary = "Generic binary addition operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinarySubtractOp : Python_Op<"sub"> { - let summary = "Generic binary subtraction operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinaryModuloOp : Python_Op<"mod"> { - let summary = "Generic binary modulo operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinaryMultiplyOp : Python_Op<"mul"> { - let summary = "Generic binary multiplication operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinaryExpOp : Python_Op<"exp"> { - let summary = "Generic binary exponential operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinaryDivOp : Python_Op<"div"> { - let summary = "Generic binary division operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def BinaryFloorDivOp : Python_Op<"floordiv"> { - let summary = "Generic binary floor division operation"; - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); + let hasVerifier = 1; } -def BinaryMatMulOp : Python_Op<"matmul"> { - let summary = "Generic binary matrix multiplication operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); +// MemWrite for the same reason as the unary ops above: __add__, +// __sub__, etc. are user-overridable and can run arbitrary code. +def BinaryOp : Python_Op<"binary", [MemoryEffects<[MemWrite]>]> { + let summary = "Generic binary operation dispatched on kind attribute"; - let results = (outs Python_PyObjectType:$output); -} - -def LeftShiftOp : Python_Op<"lshift"> { - let summary = "Generic binary left shift operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def RightShiftOp : Python_Op<"rshift"> { - let summary = "Generic binary right shift operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def LogicalAndOp : Python_Op<"logical_and"> { - let summary = "Generic binary logical and operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def LogicalOrOp : Python_Op<"logical_or"> { - let summary = "Generic binary logical or operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); - - let results = (outs Python_PyObjectType:$output); -} - -def LogicalXorOp : Python_Op<"logical_xor"> { - let summary = "Generic binary logical xor operation"; - - let arguments = (ins Python_PyObjectType:$lhs, Python_PyObjectType:$rhs); + let arguments = (ins Python_ArithOpKindAttr:$kind, + Python_PyObjectType:$lhs, + Python_PyObjectType:$rhs); let results = (outs Python_PyObjectType:$output); } -def InplaceOp : Python_Op<"inplace_op"> { +def InplaceOp : Python_Op<"inplace_op", [MemoryEffects<[MemWrite]>]> { let summary = "Generic inplace operation"; let arguments = (ins Python_PyObjectType:$src, Python_PyObjectType:$dst, - Python_InplaceOpKindAttr:$kind); + Python_ArithOpKindAttr:$kind); let results = (outs Python_PyObjectType:$result); } -def PositiveOp : Python_Op<"pos"> { +// MemWrite is conservative: each of these ops dispatches to a user- +// overridable dunder (__pos__, __neg__, __invert__, __bool__ + +// negation) that can have arbitrary side effects in Python. We can't +// use Load*'s MemRead+PyExceptionStateWrite pattern (which assumes +// the op only touches private slots) because Python's operator +// dispatch can run any user code. +def PositiveOp : Python_Op<"pos", [MemoryEffects<[MemWrite]>]> { let summary = "Unary positive operation"; let arguments = (ins Python_PyObjectType:$input); @@ -261,7 +208,7 @@ def PositiveOp : Python_Op<"pos"> { let results = (outs Python_PyObjectType:$output); } -def NegativeOp : Python_Op<"neg"> { +def NegativeOp : Python_Op<"neg", [MemoryEffects<[MemWrite]>]> { let summary = "Unary negative operation"; let arguments = (ins Python_PyObjectType:$input); @@ -269,7 +216,7 @@ def NegativeOp : Python_Op<"neg"> { let results = (outs Python_PyObjectType:$output); } -def InvertOp : Python_Op<"inv"> { +def InvertOp : Python_Op<"inv", [MemoryEffects<[MemWrite]>]> { let summary = "Unary invert operation"; let arguments = (ins Python_PyObjectType:$input); @@ -277,7 +224,7 @@ def InvertOp : Python_Op<"inv"> { let results = (outs Python_PyObjectType:$output); } -def NotOp : Python_Op<"not"> { +def NotOp : Python_Op<"not", [MemoryEffects<[MemWrite]>]> { let summary = "Unary negation operation"; let arguments = (ins Python_PyObjectType:$input); @@ -285,7 +232,9 @@ def NotOp : Python_Op<"not"> { let results = (outs Python_PyObjectType:$output); } -def CastToBoolOp : Python_Op<"as_bool"> { +// MemWrite: as_bool calls __bool__ (or falls back to __len__), +// both user-overridable. +def CastToBoolOp : Python_Op<"as_bool", [MemoryEffects<[MemWrite]>]> { let summary = "Convert Python object to bool"; let arguments = (ins Python_PyObjectType:$value); @@ -293,7 +242,8 @@ def CastToBoolOp : Python_Op<"as_bool"> { let results = (outs I1:$output); } -def Compare : Python_Op<"cmp"> { +// MemWrite: __eq__/__lt__/__contains__/etc. are user-overridable. +def CompareOp : Python_Op<"cmp", [MemoryEffects<[MemWrite]>]> { let summary = "Compare two operands"; let arguments = (ins Python_CmpPredicateAttr:$predicate, @@ -303,7 +253,7 @@ def Compare : Python_Op<"cmp"> { let results = (outs Python_PyObjectType:$output); } -def LoadAssertionError : Python_Op<"load_assertion_error"> { +def LoadAssertionError : Python_Op<"load_assertion_error", [Pure]> { let summary = [{ Loads AssertionError type. This is guaranteed to be the builtin AssertionError. @@ -320,9 +270,11 @@ def BuildDictOp : Python_Op<"build_dict", [SameVariadicOperandSize]> { DenseBoolArrayAttr:$requires_expansion); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; } -def DictAddOp : Python_Op<"dict_add"> { +def DictAddOp : Python_Op<"dict_add", [MemoryEffects<[MemWrite]>]> { let summary = "Adds a key/value pair to the provided dictionary"; let arguments = (ins Python_PyObjectType:$dict, @@ -337,9 +289,11 @@ def BuildListOp : Python_Op<"build_list"> { DenseBoolArrayAttr:$requires_expansion); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; } -def ListAppendOp : Python_Op<"list_append"> { +def ListAppendOp : Python_Op<"list_append", [MemoryEffects<[MemWrite]>]> { let summary = "Append to a Python list"; let arguments = (ins Python_PyObjectType:$list, Python_PyObjectType:$value); @@ -352,6 +306,8 @@ def BuildTupleOp : Python_Op<"build_tuple"> { DenseBoolArrayAttr:$requires_expansion); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; } def BuildSetOp : Python_Op<"build_set"> { @@ -361,9 +317,11 @@ def BuildSetOp : Python_Op<"build_set"> { DenseBoolArrayAttr:$requires_expansion); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; } -def SetAddOp : Python_Op<"set_add"> { +def SetAddOp : Python_Op<"set_add", [MemoryEffects<[MemWrite]>]> { let summary = "Add a value to a set"; let arguments = (ins Python_PyObjectType:$set, @@ -378,7 +336,8 @@ def BuildStringOp : Python_Op<"build_string"> { let results = (outs Python_PyObjectType:$output); } -def FormatValueOp : Python_Op<"format"> { +// MemWrite: __format__ / __str__ / __repr__ are user-overridable. +def FormatValueOp : Python_Op<"format", [MemoryEffects<[MemWrite]>]> { let summary = "Format an object as a string using the requested conversion"; let arguments = (ins Python_PyObjectType:$value, @@ -387,7 +346,11 @@ def FormatValueOp : Python_Op<"format"> { let results = (outs Python_PyObjectType:$output); } -def LoadAttributeOp: Python_Op<"load_attribute"> { +// MemWrite: attribute access goes through __getattribute__ / +// __getattr__, both user-overridable. Unlike load_fast / load_name +// (which read from private slots and are MemRead-safe), this can +// run any user code. +def LoadAttributeOp : Python_Op<"load_attribute", [MemoryEffects<[MemWrite]>]> { let summary = "Load an attribute"; let arguments = (ins Python_PyObjectType:$self, @@ -396,14 +359,15 @@ def LoadAttributeOp: Python_Op<"load_attribute"> { let results = (outs Python_PyObjectType:$output); } -def DeleteAttributeOp: Python_Op<"delete_attribute"> { +def DeleteAttributeOp : Python_Op<"delete_attribute", [MemoryEffects<[MemWrite]>]> { let summary = "Delete an attribute"; let arguments = (ins Python_PyObjectType:$self, StrAttr:$attr); } -def BinarySubscriptOp: Python_Op<"subscript"> { +// MemWrite: __getitem__ is user-overridable. +def BinarySubscriptOp : Python_Op<"subscript", [MemoryEffects<[MemWrite]>]> { let summary = "Subscript an object"; let arguments = (ins Python_PyObjectType:$self, @@ -412,7 +376,7 @@ def BinarySubscriptOp: Python_Op<"subscript"> { let results = (outs Python_PyObjectType:$output); } -def StoreSubscriptOp: Python_Op<"store_subscript"> { +def StoreSubscriptOp : Python_Op<"store_subscript", [MemoryEffects<[MemWrite]>]> { let summary = "Store value in object using a subscript"; let arguments = (ins Python_PyObjectType:$self, @@ -420,22 +384,24 @@ def StoreSubscriptOp: Python_Op<"store_subscript"> { Python_PyObjectType:$value); } -def DeleteSubscriptOp: Python_Op<"delete_subscript"> { +def DeleteSubscriptOp : Python_Op<"delete_subscript", [MemoryEffects<[MemWrite]>]> { let summary = "Delete value in object using a subscript"; let arguments = (ins Python_PyObjectType:$self, Python_PyObjectType:$subscript); } -def StoreAttributeOp: Python_Op<"store_attribute"> { +def StoreAttributeOp : Python_Op<"store_attribute", [MemoryEffects<[MemWrite]>]> { let summary = "Store value in object using an attribute"; let arguments = (ins Python_PyObjectType:$self, - StrAttr:$attribute, + StrAttr:$attr, Python_PyObjectType:$value); } -def LoadMethodOp: Python_Op<"load_method"> { +// MemWrite: like load_attribute, this triggers attribute lookup +// which is user-overridable via __getattribute__ / __getattr__. +def LoadMethodOp : Python_Op<"load_method", [MemoryEffects<[MemWrite]>]> { let summary = "Load a method"; let arguments = (ins Python_PyObjectType:$self, @@ -444,15 +410,17 @@ def LoadMethodOp: Python_Op<"load_method"> { let results = (outs Python_PyObjectType:$method); } -def UnpackSequenceOp: Python_Op<"unpack"> { +def UnpackSequenceOp : Python_Op<"unpack"> { let summary = "Unpack sequence"; let arguments = (ins Python_PyObjectType:$iterable); let results = (outs Variadic:$unpacked_values); + + let hasVerifier = 1; } -def UnpackExpandOp: Python_Op<"unpack_ex"> { +def UnpackExpandOp : Python_Op<"unpack_ex"> { let summary = "Unpack iterable with expansion"; let arguments = (ins Python_PyObjectType:$iterable); @@ -461,73 +429,60 @@ def UnpackExpandOp: Python_Op<"unpack_ex"> { Python_PyObjectType:$rest); } -def ForIterOp: Python_Op<"for_iter", [Terminator]> { - let summary = "Convenience op for for loops"; - - let description = [{ -Jumps to `end` if calling Python's `next` builtin raises StopIteration. -Otherwise, returns the result of `next`. -If an exception that is not a subclass of StopIteration is raised, that -exception is thrown. - }]; - - let arguments = (ins Python_PyObjectType:$iterator); - - let successors = (successor AnySuccessor:$start, AnySuccessor:$end); - - let results = (outs Python_PyObjectType:$value); -} - -def ForLoopOp: Python_Op<"for_loop", [DeclareOpInterfaceMethods]> { +def ForLoopOp : Python_Op<"for_loop", [DeclareOpInterfaceMethods]> { let summary = "For loop representation"; let arguments = (ins Python_PyObjectType:$iterable); + // orelse is AnyRegion (not MinSizedRegion<1>) so canonicalize can drop + // the empty entry block when there is no Python `else` clause. The + // getSuccessorRegions impl, the lowering pass, and the bytecode emitter + // all handle a 0-block orelse as "no else, branch to parent". let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$step, - MinSizedRegion<1>:$orelse); + AnyRegion:$orelse); } -def WhileOp: Python_Op<"while", [DeclareOpInterfaceMethods]> { +def WhileOp : Python_Op<"while", [DeclareOpInterfaceMethods]> { let summary = "While loop representation"; let regions = (region AnyRegion:$condition, AnyRegion:$body, AnyRegion:$orelse); } -def ConditionOp: Python_Op<"condition", [Terminator, - ParentOneOf<["WhileOp", "TryHandlerScope"]>]> { +def ConditionOp : Python_Op<"condition", [Terminator, + ParentOneOf<["WhileOp", "TryHandlerOp"]>]> { let summary = "Condition of a while loop or a catch statement"; let arguments = (ins Python_PyObjectType:$cond); } -def TryHandlerScope: Python_Op<"catch", [Terminator, +def TryHandlerOp : Python_Op<"catch", [Terminator, HasParent<"TryOp">, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let regions = (region AnyRegion:$cond, AnyRegion:$handler); } -def TryOp: Python_Op<"try", [DeclareOpInterfaceMethods]> { +def TryOp : Python_Op<"try", [DeclareOpInterfaceMethods]> { let regions = (region AnyRegion:$body, AnyRegion:$orelse, AnyRegion:$finally, VariadicRegion>:$handlers); } -def WithOp: Python_Op<"with"> { +def WithOp : Python_Op<"with"> { let arguments = (ins Variadic:$items); let regions = (region AnyRegion:$body); } -def WithExceptStartOp: Python_Op<"with_except_start"> { +def WithExceptStartOp : Python_Op<"with_except_start"> { let arguments = (ins Python_PyObjectType:$exit_method); let results = (outs Python_PyObjectType:$output); } -def RaiseOp: Python_Op<"raise", [Terminator, +def RaiseOp : Python_Op<"raise", [Terminator, AttrSizedOperandSegments]> { let summary = "Raises an exception"; @@ -544,18 +499,19 @@ def RaiseOp: Python_Op<"raise", [Terminator, ]; } -def ClearExceptionStateOp: Python_Op<"clear_exc_state"> { +def ClearExceptionStateOp : Python_Op<"clear_exc_state", + [MemoryEffects<[MemWrite]>]> { let summary = "Clears the interpreter exception state"; } -def ControlFlowYield: Python_Op<"cf_yield", [ReturnLike, +def BranchYieldOp : Python_Op<"br_yield", [ReturnLike, DeclareOpInterfaceMethods, Terminator, ParentOneOf<["TryOp", "ForLoopOp", "WithOp", "WhileOp", - "TryHandlerScope"]>]> { + "TryHandlerOp"]>]> { let summary = "Yield control to the parent operation"; let arguments = (ins OptionalAttr:$kind); @@ -571,7 +527,7 @@ def ControlFlowYield: Python_Op<"cf_yield", [ReturnLike, }]; } -def CondBranchSubclassOp: Python_Op<"cond_br_subclass", [Terminator, +def CondBranchSubclassOp : Python_Op<"cond_br_subclass", [Terminator, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { let summary = "Jump if the current active exception is not of `object_type`"; @@ -593,7 +549,10 @@ def CondBranchSubclassOp: Python_Op<"cond_br_subclass", [Terminator, def ImportOp : Python_Op<"import"> { let summary = "Import a Python module"; - let arguments = (ins StrAttr:$name, Builtin_DenseStringElementsAttr:$from_list, UI32Attr:$level); + // from_list is the list of names introduced by `from m import a, b`. + // Plain StrArrayAttr (prints as ["a", "b"]); see same rationale on + // py.make_function.captures. + let arguments = (ins StrAttr:$name, StrArrayAttr:$from_list, UI32Attr:$level); let results = (outs Python_PyObjectType:$module); } @@ -612,19 +571,56 @@ def ImportAllOp : Python_Op<"import_all"> { let arguments = (ins Python_PyObjectType:$module); } -def ClassDefinitionOp: Python_Op<"class", [AttrSizedOperandSegments]> { +def ClassDefinitionOp : Python_Op<"class", [AttrSizedOperandSegments]> { let summary = "Class definition"; + // captures is a list of free-variable names — no tensor semantics, + // so a plain StrArrayAttr is the natural fit (prints as ["a", "b"] + // instead of dense<...> : tensor<...>). let arguments = (ins StrAttr:$name, StrAttr:$mangled_name, Variadic:$bases, Builtin_DenseStringElementsAttr:$keywords, Variadic:$kwargs, - Builtin_DenseStringElementsAttr:$captures); + StrArrayAttr:$captures); let regions = (region AnyRegion:$body); let results = (outs Python_PyObjectType:$output); + + let hasVerifier = 1; +} + +// Terminator for the body of a py.class. Carries the class object (or +// __class__ cell) that the conversion pass forwards into the synthesised +// func.return when it lowers the class body into a func.func. Existed as +// func.return historically, but that op's verifier requires its parent to +// be func.func - which the class region is not - so canonicalize tripped +// the verifier whenever it ran pre-lowering. +def ClassReturnOp : Python_Op<"class_return", [Terminator, + HasParent<"ClassDefinitionOp">, + ReturnLike]> { + let summary = "Yield a value from a class definition body"; + + let arguments = (ins Python_PyObjectType:$value); +} + +// Function return that does not require its immediate parent to be a +// func.func. Used in place of func.return inside py.try / py.with +// regions where the return statement's enclosing FuncOp is several +// regions up. The conversion pass flattens those regions into the +// FuncOp's body and then rewrites py.return to func.return. +// +// NOTE: deliberately NOT marked ReturnLike. ReturnLike causes MLIR's +// RegionBranchOpInterface dataflow to treat the op as a branch from +// the enclosing region to its parent, and verify the operand count +// against getSuccessorInputs. py.return is meant to short-circuit +// straight through to the func.func boundary, not through each +// enclosing try/with's region-branch interface. +def ReturnOp : Python_Op<"return", [Terminator]> { + let summary = "Return from the enclosing func.func"; + + let arguments = (ins Variadic:$value); } def YieldOp : Python_Op<"yield"> { @@ -646,9 +642,16 @@ def YieldFromOp : Python_Op<"yield_from"> { def BuildSliceOp : Python_Op<"build_slice"> { let summary = "Build a slice object from object subscripting"; + // `step` is optional: the runtime BuildSlice instruction supports a + // step-less 2-operand form, and most Python slices in the wild + // (arr[:5], arr[1:], arr[1:5]) have no step. Omitting the None + // constant when step is missing saves a register per such slice. + // `lower` and `upper` stay required — the VM's BuildSlice::execute + // dereferences both unconditionally; rep-as-None is materialized + // by MLIRGenerator when the Python source lacks them. let arguments = (ins Python_PyObjectType:$lower, Python_PyObjectType:$upper, - Python_PyObjectType:$step); + Optional:$step); let results = (outs Python_PyObjectType:$slice); } diff --git a/src/executable/mlir/Dialect/Python/IR/PythonTypes.td b/src/executable/mlir/Dialect/Python/IR/PythonTypes.td index 87ee6e6a..09fe100a 100644 --- a/src/executable/mlir/Dialect/Python/IR/PythonTypes.td +++ b/src/executable/mlir/Dialect/Python/IR/PythonTypes.td @@ -12,10 +12,4 @@ def Python_PyObjectType: Python_Type<"PyObject", "object", []> { let summary = "Python object type"; let description = "Represents the Python object type"; -} - -def Python_PyEllipsisType: Python_Type<"PyEllipsis", "ellipsis", []> { - let summary = "Python ellipsis type (...)"; - - let description = "Represents the Python singleton ellipsis type"; } \ No newline at end of file diff --git a/src/executable/mlir/Dialect/Python/MLIRGenerator.cpp b/src/executable/mlir/Dialect/Python/MLIRGenerator.cpp index faecd838..ed12763f 100644 --- a/src/executable/mlir/Dialect/Python/MLIRGenerator.cpp +++ b/src/executable/mlir/Dialect/Python/MLIRGenerator.cpp @@ -48,18 +48,18 @@ void add_name(mlir::OpBuilder &builder, mlir::StringRef name, mlir::Operation *f if (fn->hasAttr("names")) { auto names = fn->getAttr("names"); std::vector names_vec; - auto arr = names.cast().getValue(); + auto arr = mlir::cast(names).getValue(); if (std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == name; + return mlir::cast(attr).getValue() == name; }) != arr.end()) { return; } std::transform( arr.begin(), arr.end(), std::back_inserter(names_vec), [](mlir::Attribute attr) { - return attr.cast().getValue(); + return mlir::cast(attr).getValue(); }); names_vec.emplace_back(name); fn->setAttr("names", builder.getStrArrayAttr(names_vec)); @@ -75,18 +75,18 @@ void add_cell_variable(mlir::OpBuilder &builder, mlir::StringRef name, mlir::Ope if (fn->hasAttr("cellvars")) { auto names = fn->getAttr("cellvars"); std::vector names_vec; - auto arr = names.cast().getValue(); + auto arr = mlir::cast(names).getValue(); if (std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == name; + return mlir::cast(attr).getValue() == name; }) != arr.end()) { return; } std::transform( arr.begin(), arr.end(), std::back_inserter(names_vec), [](mlir::Attribute attr) { - return attr.cast().getValue(); + return mlir::cast(attr).getValue(); }); names_vec.emplace_back(name); fn->setAttr("cellvars", builder.getStrArrayAttr(names_vec)); @@ -102,18 +102,18 @@ void add_free_variable(mlir::OpBuilder &builder, mlir::StringRef name, mlir::Ope if (fn->hasAttr("freevars")) { auto names = fn->getAttr("freevars"); std::vector names_vec; - auto arr = names.cast().getValue(); + auto arr = mlir::cast(names).getValue(); if (std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == name; + return mlir::cast(attr).getValue() == name; }) != arr.end()) { return; } std::transform( arr.begin(), arr.end(), std::back_inserter(names_vec), [](mlir::Attribute attr) { - return attr.cast().getValue(); + return mlir::cast(attr).getValue(); }); names_vec.emplace_back(name); fn->setAttr("freevars", builder.getStrArrayAttr(names_vec)); @@ -169,31 +169,6 @@ template mlir::Operation *getParentOfType(mlir::Region *re namespace codegen { -class SSABuilder -{ - std::unordered_map> m_current_def; - - public: - void write_variable(std::string varname, mlir::Block *block, mlir::Value value) - { - m_current_def[varname][block] = std::move(value); - } - - mlir::Value read_variable(std::string varname, mlir::Block *block) - { - if (auto var_block_it = m_current_def.find(varname); var_block_it != m_current_def.end()) { - if (auto it = var_block_it->second.find(block); it != var_block_it->second.end()) { - // local value numbering - return it->second; - } - } - return read_variable_recursive(std::move(varname), block); - } - - private: - mlir::Value read_variable_recursive(std::string varname, mlir::Block *block) { TODO(); } -}; - struct Context::ContextImpl { mlir::MLIRContext m_ctx; @@ -208,8 +183,6 @@ struct Context::ContextImpl } mlir::Type pyobject_type() { return mlir::py::PyObjectType::get(&m_ctx); } - - mlir::py::PyEllipsisType pyellipsis_type() { return mlir::py::PyEllipsisType::get(&m_ctx); } }; mlir::MLIRContext &Context::ctx() { return m_impl->m_ctx; } @@ -234,7 +207,6 @@ MLIRGenerator::MLIRGenerator(Context &ctx) : m_context(ctx) m_context.ctx().loadDialect(); m_context.ctx().loadDialect(); m_context.ctx().loadDialect(); - // m_builder = std::make_unique(); } bool MLIRGenerator::compile(std::shared_ptr m, @@ -356,69 +328,48 @@ void MLIRGenerator::store_name(std::string_view name, switch (visibility) { case VariablesResolver::Visibility::NAME: { m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } break; case VariablesResolver::Visibility::LOCAL: { m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } break; case VariablesResolver::Visibility::EXPLICIT_GLOBAL: case VariablesResolver::Visibility::IMPLICIT_GLOBAL: { if (&m_scope.front() == &scope()) { m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } else { auto current_fn = getParentOfType( m_context.builder().getInsertionBlock()->getParent()); add_name(m_context.builder(), name, current_fn); m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } } break; case VariablesResolver::Visibility::CELL: { auto parent = getParentOfType( m_context.builder().getInsertionBlock()->getParent()); - auto arr = parent->getAttr("cellvars").cast().getValue(); + auto arr = mlir::cast(parent->getAttr("cellvars")).getValue(); ASSERT(std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == mlir::StringRef{ name }; + return mlir::cast(attr).getValue() == mlir::StringRef{ name }; }) != arr.end()); m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } break; case VariablesResolver::Visibility::FREE: { auto parent = getParentOfType( m_context.builder().getInsertionBlock()->getParent()); - auto arr = parent->getAttr("freevars").cast().getValue(); + auto arr = mlir::cast(parent->getAttr("freevars")).getValue(); ASSERT(std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == mlir::StringRef{ name }; + return mlir::cast(attr).getValue() == mlir::StringRef{ name }; }) != arr.end()); m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } break; case VariablesResolver::Visibility::HIDDEN: { m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context->pyobject_type(), - name, - value->value); + loc(m_context.builder(), m_context.filename(), location), name, value->value); } break; } } @@ -473,9 +424,9 @@ MLIRGenerator::MLIRValue *MLIRGenerator::load_name(std::string_view name, case VariablesResolver::Visibility::CELL: { auto parent = getParentOfType( m_context.builder().getInsertionBlock()->getParent()); - auto arr = parent->getAttr("cellvars").cast().getValue(); + auto arr = mlir::cast(parent->getAttr("cellvars")).getValue(); ASSERT(std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == mlir::StringRef{ name }; + return mlir::cast(attr).getValue() == mlir::StringRef{ name }; }) != arr.end()); return new_value(m_context.builder().create( loc(m_context.builder(), m_context.filename(), location), @@ -485,9 +436,9 @@ MLIRGenerator::MLIRValue *MLIRGenerator::load_name(std::string_view name, case VariablesResolver::Visibility::FREE: { auto parent = getParentOfType( m_context.builder().getInsertionBlock()->getParent()); - auto arr = parent->getAttr("freevars").cast().getValue(); + auto arr = mlir::cast(parent->getAttr("freevars")).getValue(); ASSERT(std::find_if(arr.begin(), arr.end(), [name](mlir::Attribute attr) { - return attr.cast().getValue() == mlir::StringRef{ name }; + return mlir::cast(attr).getValue() == mlir::StringRef{ name }; }) != arr.end()); return new_value(m_context.builder().create( loc(m_context.builder(), m_context.filename(), location), @@ -761,53 +712,53 @@ ast::Value *MLIRGenerator::visit(const ast::AugAssign *node) auto result = [&]() { auto make_binop = [this, &node]( - ast::Value *value, ast::Value *target, mlir::py::InplaceOpKind kind) { + ast::Value *value, ast::Value *target, mlir::py::ArithOpKind kind) { return new_value(m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), static_cast(value)->value, static_cast(target)->value, - mlir::py::InplaceOpKindAttr::get(&m_context.ctx(), kind))); + mlir::py::ArithOpKindAttr::get(&m_context.ctx(), kind))); }; switch (node->op()) { case ast::BinaryOpType::PLUS: { - return make_binop(value, target, mlir::py::InplaceOpKind::add); + return make_binop(value, target, mlir::py::ArithOpKind::add); } break; case ast::BinaryOpType::MINUS: { - return make_binop(value, target, mlir::py::InplaceOpKind::sub); + return make_binop(value, target, mlir::py::ArithOpKind::sub); } break; case ast::BinaryOpType::MODULO: { - return make_binop(value, target, mlir::py::InplaceOpKind::mod); + return make_binop(value, target, mlir::py::ArithOpKind::mod); } break; case ast::BinaryOpType::MULTIPLY: { - return make_binop(value, target, mlir::py::InplaceOpKind::mul); + return make_binop(value, target, mlir::py::ArithOpKind::mul); } break; case ast::BinaryOpType::EXP: { - return make_binop(value, target, mlir::py::InplaceOpKind::exp); + return make_binop(value, target, mlir::py::ArithOpKind::exp); } break; case ast::BinaryOpType::SLASH: { - return make_binop(value, target, mlir::py::InplaceOpKind::div); + return make_binop(value, target, mlir::py::ArithOpKind::div); } break; case ast::BinaryOpType::FLOORDIV: { - return make_binop(value, target, mlir::py::InplaceOpKind::fldiv); + return make_binop(value, target, mlir::py::ArithOpKind::fldiv); } break; case ast::BinaryOpType::MATMUL: { - return make_binop(value, target, mlir::py::InplaceOpKind::mmul); + return make_binop(value, target, mlir::py::ArithOpKind::mmul); } break; case ast::BinaryOpType::LEFTSHIFT: { - return make_binop(value, target, mlir::py::InplaceOpKind::lshift); + return make_binop(value, target, mlir::py::ArithOpKind::lshift); } break; case ast::BinaryOpType::RIGHTSHIFT: { - return make_binop(value, target, mlir::py::InplaceOpKind::rshift); + return make_binop(value, target, mlir::py::ArithOpKind::rshift); } break; case ast::BinaryOpType::AND: { - return make_binop(value, target, mlir::py::InplaceOpKind::and_); + return make_binop(value, target, mlir::py::ArithOpKind::and_); } break; case ast::BinaryOpType::OR: { - return make_binop(value, target, mlir::py::InplaceOpKind::or_); + return make_binop(value, target, mlir::py::ArithOpKind::or_); } break; case ast::BinaryOpType::XOR: { - return make_binop(value, target, mlir::py::InplaceOpKind::xor_); + return make_binop(value, target, mlir::py::ArithOpKind::xor_); } break; } ASSERT_NOT_REACHED(); @@ -848,7 +799,7 @@ ast::Value *MLIRGenerator::visit(const ast::Break *node) m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), b); m_context.builder().setInsertionPointToStart(b); - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), mlir::py::LoopOpKindAttr::get(&m_context.ctx(), mlir::py::LoopOpKind::break_)); return nullptr; @@ -858,112 +809,43 @@ ast::Value *MLIRGenerator::visit(const ast::BinaryExpr *node) { auto lhs = static_cast(node->lhs()->codegen(this))->value; auto rhs = static_cast(node->rhs()->codegen(this))->value; + auto location = loc(m_context.builder(), m_context.filename(), node->source_location()); - switch (node->op_type()) { - case ast::BinaryOpType::PLUS: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), + auto build_binary = [&](mlir::py::ArithOpKind kind) { + return new_value(m_context.builder().create(location, m_context->pyobject_type(), + mlir::py::ArithOpKindAttr::get(&m_context.ctx(), kind), lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::MINUS: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::MODULO: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::MULTIPLY: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::EXP: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::SLASH: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::FLOORDIV: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::MATMUL: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::LEFTSHIFT: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::RIGHTSHIFT: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::AND: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::OR: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; - case ast::BinaryOpType::XOR: { - auto result = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyobject_type(), - lhs, - rhs); - return new_value(result); - } break; + rhs)); + }; + + switch (node->op_type()) { + case ast::BinaryOpType::PLUS: + return build_binary(mlir::py::ArithOpKind::add); + case ast::BinaryOpType::MINUS: + return build_binary(mlir::py::ArithOpKind::sub); + case ast::BinaryOpType::MODULO: + return build_binary(mlir::py::ArithOpKind::mod); + case ast::BinaryOpType::MULTIPLY: + return build_binary(mlir::py::ArithOpKind::mul); + case ast::BinaryOpType::EXP: + return build_binary(mlir::py::ArithOpKind::exp); + case ast::BinaryOpType::SLASH: + return build_binary(mlir::py::ArithOpKind::div); + case ast::BinaryOpType::FLOORDIV: + return build_binary(mlir::py::ArithOpKind::fldiv); + case ast::BinaryOpType::MATMUL: + return build_binary(mlir::py::ArithOpKind::mmul); + case ast::BinaryOpType::LEFTSHIFT: + return build_binary(mlir::py::ArithOpKind::lshift); + case ast::BinaryOpType::RIGHTSHIFT: + return build_binary(mlir::py::ArithOpKind::rshift); + case ast::BinaryOpType::AND: + return build_binary(mlir::py::ArithOpKind::and_); + case ast::BinaryOpType::OR: + return build_binary(mlir::py::ArithOpKind::or_); + case ast::BinaryOpType::XOR: + return build_binary(mlir::py::ArithOpKind::xor_); } ASSERT_NOT_REACHED(); @@ -1268,13 +1150,13 @@ ast::Value *MLIRGenerator::visit(const ast::ClassDefinition *node) if (class_scope->requires_class_ref) { auto *__class__ = load_name("__class__", node->source_location()); store_name("__classcell__", __class__, node->source_location()); - m_context.builder().create( - m_context.builder().getUnknownLoc(), mlir::ValueRange{ __class__->value }); + m_context.builder().create( + m_context.builder().getUnknownLoc(), __class__->value); } else { auto result = m_context.builder().create( m_context.builder().getUnknownLoc(), m_context.builder().getNoneType()); - m_context.builder().create( - m_context.builder().getUnknownLoc(), mlir::ValueRange{ result }); + m_context.builder().create( + m_context.builder().getUnknownLoc(), result); } } @@ -1283,10 +1165,7 @@ ast::Value *MLIRGenerator::visit(const ast::ClassDefinition *node) std::vector captures_ref; captures_ref.reserve(captures.size()); for (const auto &el : captures) { captures_ref.push_back(el); } - output.setCapturesAttr(mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ static_cast(captures.size()) }, - mlir::StringAttr::get(&m_context.ctx()).getType()), - captures_ref)); + output.setCapturesAttr(m_context.builder().getStrArrayAttr(captures_ref)); store_name(node->name(), new_value(output), node->source_location()); @@ -1318,7 +1197,7 @@ ast::Value *MLIRGenerator::visit(const ast::Continue *node) m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), b); m_context.builder().setInsertionPointToStart(b); - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), mlir::py::LoopOpKindAttr::get(&m_context.ctx(), mlir::py::LoopOpKind::continue_)); return nullptr; @@ -1326,7 +1205,7 @@ ast::Value *MLIRGenerator::visit(const ast::Continue *node) ast::Value *MLIRGenerator::visit(const ast::Compare *node) { - std::optional result; + std::optional result; auto lhs = static_cast(node->lhs()->codegen(this))->value; const auto &comparators = node->comparators(); const auto &ops = node->ops(); @@ -1337,7 +1216,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) switch (op) { case ast::Compare::OpType::Eq: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::eq), @@ -1345,7 +1224,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::NotEq: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::ne), @@ -1353,7 +1232,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::Lt: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::lt), @@ -1361,7 +1240,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::LtE: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::le), @@ -1369,7 +1248,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::Gt: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::gt), @@ -1377,7 +1256,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::GtE: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::ge), @@ -1385,7 +1264,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::Is: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::is), @@ -1393,7 +1272,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::IsNot: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::isnot), @@ -1401,7 +1280,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::In: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::in), @@ -1409,7 +1288,7 @@ ast::Value *MLIRGenerator::visit(const ast::Compare *node) rhs); } break; case ast::Compare::OpType::NotIn: { - result = m_context.builder().create( + result = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), m_context->pyobject_type(), mlir::py::CmpPredicateAttr::get(&m_context.ctx(), mlir::py::CmpPredicate::notin), @@ -1489,7 +1368,7 @@ ast::Value *MLIRGenerator::visit(const ast::Constant *node) [this, node](py::Ellipsis) -> ast::Value * { mlir::py::ConstantOp op = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), - m_context->pyellipsis_type()); + mlir::py::EllipsisAttr::get(&m_context.ctx())); return new_value(op); }, [](auto) -> ast::Value * { @@ -1578,7 +1457,7 @@ ast::Value *MLIRGenerator::visit(const ast::For *node) m_context->pyobject_type(), m_context.builder().getUnknownLoc())); assign(node->target(), iterator, node->target()->source_location()); - m_context.builder().create(m_context.builder().getUnknownLoc()); + m_context.builder().create(m_context.builder().getUnknownLoc()); m_context.builder().setInsertionPointToStart(&body_start); for (const auto &el : node->body()) { el->codegen(this); } @@ -1587,7 +1466,7 @@ ast::Value *MLIRGenerator::visit(const ast::For *node) .getInsertionBlock() ->back() .hasTrait()) { - m_context.builder().create(m_context.builder().getUnknownLoc()); + m_context.builder().create(m_context.builder().getUnknownLoc()); } if (!node->orelse().empty()) { @@ -1595,7 +1474,7 @@ ast::Value *MLIRGenerator::visit(const ast::For *node) for (const auto &el : node->orelse()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create(loc( + m_context.builder().create(loc( m_context.builder(), m_context.filename(), node->body().back()->source_location())); } } @@ -1749,8 +1628,7 @@ ast::Value *MLIRGenerator::visit(const ast::Import *node) { for (const auto &n : node->names()) { // empty from_list - auto from_list = mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ 0 }, mlir::StringAttr::get(&m_context.ctx()).getType()), {}); + auto from_list = m_context.builder().getStrArrayAttr({}); const uint32_t level = 0; auto module = new_value(m_context.builder().create( @@ -1780,10 +1658,7 @@ ast::Value *MLIRGenerator::visit(const ast::ImportFrom *node) names.reserve(node->names().size()); for (const auto &n : node->names()) { names.emplace_back(n.name); } - auto from_list = mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ static_cast(names.size()) }, - mlir::StringAttr::get(&m_context.ctx()).getType()), - names); + auto from_list = m_context.builder().getStrArrayAttr(names); auto module = m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location()), @@ -1906,7 +1781,20 @@ ast::Value *MLIRGenerator::visit(const ast::Module *m) m_context.builder().create(m_context.builder().getUnknownLoc(), "__hidden_init__", m_context.builder().getFunctionType({}, { m_context->pyobject_type() })); - module_fn.setPrivate(); + // Public visibility: the module entry's side effects (storing names + // into the module dict, etc.) escape this MLIR module because the + // bytecode runtime invokes it and importers observe the resulting + // bindings. Public also tells RemoveDeadValuesPass to skip its + // processFuncOp (which would otherwise strip the return value and + // propagate "dead" backward through every side-effecting op in the + // body — see compile.cpp). + module_fn.setPublic(); + // is_module_entry is the stable signal the bytecode emitter and any + // pipeline-internal pass uses to detect the module-entry FuncOp, + // without relying on the symbol name (a user-defined Python function + // could accidentally collide) or on the symbol visibility (which + // other passes may rewrite). + module_fn->setAttr("is_module_entry", m_context.builder().getBoolAttr(true)); auto *entry_block = module_fn.addEntryBlock(); auto *exit_block = module_fn.addBlock(); m_context.builder().setInsertionPointToEnd(entry_block); @@ -2041,7 +1929,13 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( func_type, mlir::ArrayRef{}, args_attrs); - f.setPrivate(); + // Public visibility: Python functions are invoked at runtime via + // MAKE_FUNCTION + CALL ops (a value-based runtime dispatch), not + // via func.call. MLIR's CallOpInterface-based analyses — including + // RemoveDeadValuesPass — don't see the runtime dispatch as a caller + // and would otherwise strip "uncalled" private functions' arguments + // and return values, which breaks the runtime dispatch. + f.setPublic(); std::vector captures; { @@ -2068,7 +1962,7 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( auto for_loop = m_context.builder().create( loc(m_context.builder(), m_context.filename(), generator->source_location()), iterable); - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), generator->source_location())); // iterator { @@ -2076,7 +1970,7 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( auto iterator = new_value(for_loop.getStep().addArgument( m_context->pyobject_type(), m_context.builder().getUnknownLoc())); assign(generator->target(), iterator, generator->target()->source_location()); - m_context.builder().create( + m_context.builder().create( m_context.builder().getUnknownLoc()); } // loop body @@ -2123,7 +2017,7 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( &body_continue); } m_context.builder().setInsertionPointToStart(&body_continue); - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), generator->source_location()), @@ -2151,7 +2045,7 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( next_generator(iter, generator); } container_update(container); - m_context.builder().create(m_context.builder().getUnknownLoc()); + m_context.builder().create(m_context.builder().getUnknownLoc()); m_context.builder().setInsertionPointToEnd(entry_block); m_context.builder().getBlock()->back().erase(); @@ -2179,10 +2073,7 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_comprehension( mangled_name, mlir::ValueRange{}, mlir::ValueRange{}, - mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ static_cast(captures.size()) }, - mlir::StringAttr::get(&m_context.ctx()).getType()), - captures_ref)); + m_context.builder().getStrArrayAttr(captures_ref)); auto iterable = static_cast(generators.front()->iter()->codegen(this))->value; @@ -2230,6 +2121,11 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_slice( return new_value(static_cast(*idx.value->codegen(this)).value); }, [this, location](ast::Subscript::Slice slice) -> MLIRValue * { + // lower / upper are materialized as None constants when + // the Python source omits them (a[:5] etc.) because the + // runtime BuildSlice::execute dereferences both. step is + // genuinely optional at the dialect + runtime level: + // pass nullptr when missing and save a register. auto lower = slice.lower ? static_cast(*slice.lower->codegen(this)).value : m_context.builder().create( @@ -2240,11 +2136,8 @@ codegen::MLIRGenerator::MLIRValue *MLIRGenerator::build_slice( : m_context.builder().create( loc(m_context.builder(), m_context.filename(), location), m_context.builder().getNoneType()); - auto step = slice.step - ? static_cast(*slice.step->codegen(this)).value - : m_context.builder().create( - loc(m_context.builder(), m_context.filename(), location), - m_context.builder().getNoneType()); + auto step = slice.step ? static_cast(*slice.step->codegen(this)).value + : mlir::Value{}; return new_value(m_context.builder().create( loc(m_context.builder(), m_context.filename(), location), m_context->pyobject_type(), @@ -2431,8 +2324,9 @@ void MLIRGenerator::return_value(MLIRValue *value, const SourceLocation &source_ scope().finally_blocks = std::move(finally_blocks); } - m_context.builder().create( - loc(m_context.builder(), m_context.filename(), source_location), value->value); + m_context.builder().create( + loc(m_context.builder(), m_context.filename(), source_location), + mlir::ValueRange{ value->value }); } MLIRGenerator::RAIIScope MLIRGenerator::setup_function(mlir::func::FuncOp &f, @@ -2473,6 +2367,88 @@ MLIRGenerator::RAIIScope MLIRGenerator::setup_function(mlir::func::FuncOp &f, return function_scope; } +// Codegen each AST node and collect the resulting SSA values (as the +// existing MLIRValue* indirection used elsewhere on MLIRGenerator's private +// API). +std::vector evaluate_expressions_for_make_function(MLIRGenerator &gen, + const std::vector> &expressions) +{ + std::vector values; + values.reserve(expressions.size()); + for (const auto &expr : expressions) { + auto *v = expr->codegen(&gen); + ASSERT(v); + values.push_back(static_cast(v)); + } + return values; +} + +std::vector evaluate_default_arguments_for_make_function( + MLIRGenerator &gen, + const std::shared_ptr &args) +{ + return evaluate_expressions_for_make_function(gen, args->defaults()); +} + +std::vector evaluate_keyword_default_arguments_for_make_function( + MLIRGenerator &gen, + const std::shared_ptr &args) +{ + std::vector kw_defaults; + kw_defaults.reserve(args->kw_defaults().size()); + for (const auto &default_ : args->kw_defaults()) { + if (default_) { + kw_defaults.push_back(static_cast(default_->codegen(&gen))); + } + } + return kw_defaults; +} + +// Apply a decorator chain to the value currently stored under +// function_name. Decorators are applied in reverse order (innermost first) +// per Python's @ semantics. +void apply_decorators_for_make_function(MLIRGenerator &gen, + const std::vector &decorator_functions, + const std::string &function_name, + const SourceLocation &source_location) +{ + if (decorator_functions.empty()) { return; } + + auto &builder = gen.m_context.builder(); + auto empty_keywords_attr = mlir::DenseStringElementsAttr::get( + mlir::VectorType::get({ 0 }, mlir::StringAttr::get(&gen.m_context.ctx()).getType()), {}); + + mlir::Value arg = gen.load_name(function_name, source_location)->value; + for (auto *decorator : decorator_functions | std::ranges::views::reverse) { + mlir::Value decorator_function = decorator->value; + arg = builder.create(decorator_function.getLoc(), + gen.m_context->pyobject_type(), + decorator_function, + mlir::ValueRange{ arg }, + empty_keywords_attr, + mlir::ValueRange{}, + false, + false); + } + gen.store_name(function_name, gen.new_value(arg), source_location); +} + +std::vector MLIRGenerator::collect_function_captures(const std::string &mangled_name) +{ + std::vector captures; + const auto &scope = *m_variable_visibility.at(mangled_name); + // Cell captures first, then free captures. Order matters: the bytecode + // emitter consumes this list and assigns indices in this order. + for (auto kind : { VariablesResolver::Visibility::CELL, VariablesResolver::Visibility::FREE }) { + for (const auto &el : scope.symbol_map.symbols) { + if (el.visibility == kind && scope.captures.contains(el.name)) { + captures.push_back(el.name); + } + } + } + return captures; +} + MLIRGenerator::MLIRValue *MLIRGenerator::make_function(const std::string &function_name, const std::string &mangled_name, const std::shared_ptr &args, @@ -2484,148 +2460,94 @@ MLIRGenerator::MLIRValue *MLIRGenerator::make_function(const std::string &functi { auto *last_block = m_context.builder().getBlock(); - std::vector decorator_functions; - decorator_functions.reserve(decorator_list.size()); - for (const auto &decorator_function : decorator_list) { - auto *f = decorator_function->codegen(this); - ASSERT(f); - decorator_functions.push_back(static_cast(f)->value); - } + // Evaluate decorator expressions, then default-value expressions, in the + // caller's scope (before we create the new FuncOp). + auto decorator_functions = evaluate_expressions_for_make_function(*this, decorator_list); + auto defaults = evaluate_default_arguments_for_make_function(*this, args); + auto kw_defaults = evaluate_keyword_default_arguments_for_make_function(*this, args); const size_t args_size = args->args().size() + args->posonlyargs().size() + args->kwonlyargs().size() + (args->vararg() != nullptr) + (args->kwarg() != nullptr); - - std::vector defaults; - for (const auto &default_ : args->defaults()) { - defaults.push_back(static_cast(default_->codegen(this))->value); - } - - std::vector kw_defaults; - kw_defaults.reserve(args->kw_defaults().size()); - for (const auto &default_ : args->kw_defaults()) { - if (default_) { - kw_defaults.push_back(static_cast(default_->codegen(this))->value); - } - } - std::vector param_types(args_size, m_context->pyobject_type()); - auto func_type = mlir::FunctionType::get(&m_context.ctx(), param_types, { m_context->pyobject_type() }); - std::vector args_attrs; - for (const auto &arg : args->argument_names()) { - std::vector arg_attrs; - arg_attrs.push_back( - m_context.builder().getNamedAttr("llvm.name", m_context.builder().getStringAttr(arg))); - args_attrs.push_back(m_context.builder().getDictionaryAttr(arg_attrs)); - } - for (const auto &arg : args->kw_only_argument_names()) { + // Build the per-argument llvm.name / llvm.kwonlyarg / llvm.vararg / + // llvm.kwarg attribute dicts. Order: positional, kwonly, vararg, kwarg + // - the downstream bytecode emitter relies on this order to assign + // argument indices. + auto &builder = m_context.builder(); + const auto make_arg_attr = [&](mlir::StringRef name, + std::initializer_list extras = {}) { std::vector arg_attrs; - arg_attrs.push_back( - m_context.builder().getNamedAttr("llvm.name", m_context.builder().getStringAttr(arg))); - arg_attrs.push_back(m_context.builder().getNamedAttr( - "llvm.kwonlyarg", m_context.builder().getBoolAttr(true))); - args_attrs.push_back(m_context.builder().getDictionaryAttr(arg_attrs)); + arg_attrs.push_back(builder.getNamedAttr("llvm.name", builder.getStringAttr(name))); + arg_attrs.insert(arg_attrs.end(), extras); + return builder.getDictionaryAttr(arg_attrs); + }; + std::vector args_attrs; + for (const auto &arg : args->argument_names()) { args_attrs.push_back(make_arg_attr(arg)); } + for (const auto &arg : args->kw_only_argument_names()) { + args_attrs.push_back(make_arg_attr( + arg, { builder.getNamedAttr("llvm.kwonlyarg", builder.getBoolAttr(true)) })); } - if (args->vararg()) { - std::vector arg_attrs; - arg_attrs.push_back(m_context.builder().getNamedAttr( - "llvm.name", m_context.builder().getStringAttr(args->vararg()->name()))); - arg_attrs.push_back( - m_context.builder().getNamedAttr("llvm.vararg", m_context.builder().getBoolAttr(true))); - args_attrs.push_back(m_context.builder().getDictionaryAttr(arg_attrs)); + args_attrs.push_back(make_arg_attr(args->vararg()->name(), + { builder.getNamedAttr("llvm.vararg", builder.getBoolAttr(true)) })); } - if (args->kwarg()) { - std::vector arg_attrs; - arg_attrs.push_back(m_context.builder().getNamedAttr( - "llvm.name", m_context.builder().getStringAttr(args->kwarg()->name()))); - arg_attrs.push_back( - m_context.builder().getNamedAttr("llvm.kwarg", m_context.builder().getBoolAttr(true))); - args_attrs.push_back(m_context.builder().getDictionaryAttr(arg_attrs)); + args_attrs.push_back(make_arg_attr(args->kwarg()->name(), + { builder.getNamedAttr("llvm.kwarg", builder.getBoolAttr(true)) })); } - m_context.builder().setInsertionPointToEnd( - &m_context.module().getBodyRegion().getBlocks().back()); - auto f = m_context.builder().create( - loc(m_context.builder(), m_context.filename(), source_location), + builder.setInsertionPointToEnd(&m_context.module().getBodyRegion().getBlocks().back()); + auto f = builder.create(loc(builder, m_context.filename(), source_location), mangled_name, func_type, mlir::ArrayRef{}, args_attrs); - m_context.builder().setInsertionPointToStart(f.addEntryBlock()); + builder.setInsertionPointToStart(f.addEntryBlock()); std::vector captures; - { [[maybe_unused]] auto function_scope = setup_function(f, function_name, mangled_name); - if (is_async) { f->setAttr("async", m_context.builder().getBoolAttr(true)); } + if (is_async) { f->setAttr("async", builder.getBoolAttr(true)); } - // captures.reserve(m_variable_visibility.at(mangled_name)->captures.size()); - for (const auto &el : m_variable_visibility.at(mangled_name)->symbol_map.symbols) { - if (el.visibility == VariablesResolver::Visibility::CELL - && m_variable_visibility.at(mangled_name)->captures.contains(el.name)) { - captures.push_back(el.name); - } - } - - for (const auto &el : m_variable_visibility.at(mangled_name)->symbol_map.symbols) { - if (el.visibility == VariablesResolver::Visibility::FREE - && m_variable_visibility.at(mangled_name)->captures.contains(el.name)) { - captures.push_back(el.name); - } - } + captures = collect_function_captures(mangled_name); - m_context.builder().setInsertionPointToStart(&f.front()); + builder.setInsertionPointToStart(&f.front()); for (const auto &el : body) { el->codegen(this); } - if (m_context.builder().getBlock()->empty() - || !m_context.builder().getBlock()->back().hasTrait()) { - auto none = m_context.builder().create( - m_context.builder().getUnknownLoc(), m_context.builder().getNoneType()); + if (builder.getBlock()->empty() + || !builder.getBlock()->back().hasTrait()) { + auto none = builder.create( + builder.getUnknownLoc(), builder.getNoneType()); return_value(new_value(none), source_location); } } - m_context.builder().setInsertionPointToEnd(last_block); + builder.setInsertionPointToEnd(last_block); std::vector captures_ref; captures_ref.reserve(captures.size()); for (const auto &el : captures) { captures_ref.push_back(el); } - auto fn_obj = new_value(m_context.builder().create( - loc(m_context.builder(), m_context.filename(), source_location), + std::vector defaults_values; + defaults_values.reserve(defaults.size()); + for (auto *v : defaults) { defaults_values.push_back(v->value); } + std::vector kw_defaults_values; + kw_defaults_values.reserve(kw_defaults.size()); + for (auto *v : kw_defaults) { kw_defaults_values.push_back(v->value); } + auto fn_obj = new_value(builder.create( + loc(builder, m_context.filename(), source_location), m_context->pyobject_type(), mangled_name, - defaults, - kw_defaults, - mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ static_cast(captures.size()) }, - mlir::StringAttr::get(&m_context.ctx()).getType()), - captures_ref))); + defaults_values, + kw_defaults_values, + builder.getStrArrayAttr(captures_ref))); if (is_anon) { return fn_obj; } store_name(function_name, fn_obj, source_location); - if (!decorator_functions.empty()) { - ASSERT(!is_anon); - mlir::Value arg = load_name(function_name, source_location)->value; - for (const auto &decorator_function : decorator_functions | std::ranges::views::reverse) { - arg = m_context.builder().create(decorator_function.getLoc(), - m_context->pyobject_type(), - decorator_function, - mlir::ValueRange{ arg }, - mlir::DenseStringElementsAttr::get( - mlir::VectorType::get({ 0 }, mlir::StringAttr::get(&m_context.ctx()).getType()), - {}), - mlir::ValueRange{}, - false, - false); - } - store_name(function_name, new_value(arg), source_location); - } - + apply_decorators_for_make_function(*this, decorator_functions, function_name, source_location); return nullptr; } @@ -2698,7 +2620,7 @@ ast::Value *MLIRGenerator::visit(const ast::Try *node) for (const auto &el : node->body()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location())); } @@ -2716,7 +2638,7 @@ ast::Value *MLIRGenerator::visit(const ast::Try *node) ASSERT(!handler_region.getBlocks().empty()); m_context.builder().setInsertionPointToStart(&handler_region.front()); - auto handler_op = m_context.builder().create( + auto handler_op = m_context.builder().create( loc(m_context.builder(), m_context.filename(), handler->source_location())); if (handler->type()) { @@ -2741,7 +2663,7 @@ ast::Value *MLIRGenerator::visit(const ast::Try *node) .getBlock() ->back() .hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location())); } } @@ -2752,7 +2674,7 @@ ast::Value *MLIRGenerator::visit(const ast::Try *node) for (auto el : node->orelse()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location())); } } @@ -2765,7 +2687,7 @@ ast::Value *MLIRGenerator::visit(const ast::Try *node) for (auto el : node->finalbody()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location())); } } @@ -2860,7 +2782,7 @@ ast::Value *MLIRGenerator::visit(const ast::While *node) if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->body().back()->source_location())); } @@ -2869,7 +2791,7 @@ ast::Value *MLIRGenerator::visit(const ast::While *node) for (const auto &el : node->orelse()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create(loc( + m_context.builder().create(loc( m_context.builder(), m_context.filename(), node->body().back()->source_location())); } } @@ -2933,7 +2855,7 @@ ast::Value *MLIRGenerator::visit(const ast::With *node) for (const auto &el : node->body()) { el->codegen(this); } if (m_context.builder().getBlock()->empty() || !m_context.builder().getBlock()->back().hasTrait()) { - m_context.builder().create( + m_context.builder().create( loc(m_context.builder(), m_context.filename(), node->source_location())); } scope().finally_blocks.pop_back(); diff --git a/src/executable/mlir/Dialect/Python/MLIRGenerator.hpp b/src/executable/mlir/Dialect/Python/MLIRGenerator.hpp index 89c00b4c..9c0e2653 100644 --- a/src/executable/mlir/Dialect/Python/MLIRGenerator.hpp +++ b/src/executable/mlir/Dialect/Python/MLIRGenerator.hpp @@ -38,13 +38,26 @@ class Context ContextImpl *operator->() { return m_impl.get(); } }; -class SSABuilder; - class MLIRGenerator : ast::CodeGenerator { struct MLIRValue; - // std::unique_ptr m_builder; + // Free-function helpers for make_function() defined in MLIRGenerator.cpp. + // Declared friends here so they can take MLIRValue* (a private inner + // struct) and reach the private store_name / load_name / new_value / + // m_context / m_variable_visibility members without exposing mlir:: types + // in this header. + friend std::vector evaluate_expressions_for_make_function(MLIRGenerator &, + const std::vector> &); + friend std::vector evaluate_default_arguments_for_make_function(MLIRGenerator &, + const std::shared_ptr &); + friend std::vector evaluate_keyword_default_arguments_for_make_function( + MLIRGenerator &, + const std::shared_ptr &); + friend void apply_decorators_for_make_function(MLIRGenerator &, + const std::vector &, + const std::string &, + const SourceLocation &); private: struct Scope @@ -146,6 +159,8 @@ class MLIRGenerator : ast::CodeGenerator bool is_async, const SourceLocation &source_location); + std::vector collect_function_captures(const std::string &mangled_name); + MLIRValue *build_comprehension(std::string_view function_name, std::function container_factory, std::function container_update, diff --git a/src/executable/mlir/Target/PythonBytecode/LinearScanRegisterAllocation.hpp b/src/executable/mlir/Target/PythonBytecode/LinearScanRegisterAllocation.hpp index 348ae5d5..a3f3cc01 100644 --- a/src/executable/mlir/Target/PythonBytecode/LinearScanRegisterAllocation.hpp +++ b/src/executable/mlir/Target/PythonBytecode/LinearScanRegisterAllocation.hpp @@ -8,10 +8,12 @@ #include "mlir/IR/Builders.h" #include "utilities.hpp" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Format.h" #include #include +#include #include #include #include @@ -38,7 +40,9 @@ class LinearScanRegisterAllocation size_t idx; }; - // Stack spill location (currently not used - we abort if we run out of registers) + // Stack spill location (reserved for future use; iterative spilling via STORE_FAST/LOAD_FAST + // ensures all final assignments are Reg, so StackLocation never appears in value2mem_map + // after a completed allocation pass). struct StackLocation { size_t idx; @@ -52,22 +56,45 @@ class LinearScanRegisterAllocation using LiveIntervalSet = std::multiset; - // Track registers that are reserved for FOR_ITER iterators - // Maps loop variable value -> iterator register index - ValueMapping foriter_reserved_regs; + static constexpr size_t kRegCount = kNumRegisters; // Live interval analysis results (stored for visualization) std::optional live_interval_analysis; + // Monotonically increasing across passes to generate unique spill slot names + size_t m_spill_slot_count{ 0 }; + + // Set to true by spill_value(); causes analyse() to restart the pass + bool m_spills_emitted{ false }; + /** - * Run register allocation on the function + * Run register allocation on the function. + * + * Uses an iterative strategy: if register pressure forces a spill, spill code + * (STORE_FAST / LOAD_FAST) is inserted into the IR and the entire analysis is + * restarted so that the new ops receive proper live intervals and register assignments. + * This repeats until a complete allocation with no spills is achieved. */ void analyse(mlir::func::FuncOp &func, mlir::OpBuilder builder) + { + m_spill_slot_count = 0; + do { + m_spills_emitted = false; + value2mem_map.clear(); + live_interval_analysis.reset(); + run_single_pass(func, builder); + } while (m_spills_emitted); + } + + private: + /** + * Single pass of liveness analysis + linear scan. Invoked by analyse(). + * Returns without completing allocation if a spill is needed (m_spills_emitted is set). + */ + void run_single_pass(mlir::func::FuncOp &func, mlir::OpBuilder &builder) { auto logger = get_regalloc_logger(); - // Enable debug logging temporarily to diagnose ForIter bug - logger->set_level(spdlog::level::debug); - logger->info("Starting linear scan register allocation"); + logger->info("Starting linear scan register allocation pass"); // Run live interval analysis live_interval_analysis = LiveIntervalAnalysis{}; @@ -88,8 +115,8 @@ class LinearScanRegisterAllocation LiveIntervalSet inactive; LiveIntervalSet handled; - // 32 available registers - std::bitset<32> free; + // kNumRegisters available registers + std::bitset free; free.set(); // Pre-allocate r0 for operations that clobber it @@ -100,21 +127,45 @@ class LinearScanRegisterAllocation const auto &cur = *unhandled.begin(); unhandled = unhandled.subspan(1, unhandled.size() - 1); - logger->trace("Processing interval: {}", to_string(cur.value)); + logger->trace( + "Processing interval: {} [start={}, end={}]", cur.value, cur.start(), cur.end()); + + // Log active intervals before expiring + if (logger->should_log(spdlog::level::trace)) { + logger->trace("Active intervals before expire:"); + for (const auto &interval : active) { + if (auto it = value2mem_map.find(interval.value); it != value2mem_map.end()) { + if (std::holds_alternative(it->second)) { + logger->trace(" {} in r{} [start={}, end={}]", + interval.value, + std::get(it->second).idx, + interval.start(), + interval.end()); + } + } + } + } // Expire old intervals and free their registers - expire_old_intervals(cur, active, inactive, handled, free, logger); + bool was_inactive = expire_old_intervals(cur, active, inactive, handled, free, logger); // Collect available registers for this interval auto available_regs = collect_available_registers( cur, free, inactive, unhandled, *live_interval_analysis); if (available_regs.none()) { - logger->error("No available registers for {}", to_string(cur.value)); - TODO();// Should implement spilling + logger->info("Register pressure: spilling for {}", cur.value); + spill_value(cur, active, func, builder); + return;// Restart from scratch with updated IR } else { - allocate_register( - cur, free, available_regs, builder, *live_interval_analysis, active, logger); + allocate_register(cur, + free, + available_regs, + builder, + *live_interval_analysis, + active, + was_inactive, + logger); } } @@ -131,10 +182,198 @@ class LinearScanRegisterAllocation } } - private: /** - * Pre-allocate register 0 for operations that must use it - * (function calls, yield, etc.) + * Returns true if the interval's value is a GET_ITER result. + * GET_ITER intervals are deliberately collapsed to contiguous spans and must never be spilled. + */ + static bool is_get_iter_result(const LiveIntervalAnalysis::LiveInterval &interval) + { + if (!std::holds_alternative(interval.value)) { return false; } + auto val = std::get(interval.value); + return val.getDefiningOp() && mlir::isa(val.getDefiningOp()); + } + + /** + * Returns true if the interval's value is a regular (non-block-argument) op result that + * can be spilled using the standard STORE_FAST-after-def / LOAD_FAST-before-use pattern. + */ + static bool is_regular_op_result(const LiveIntervalAnalysis::LiveInterval &interval) + { + if (!std::holds_alternative(interval.value)) { return false; } + auto val = std::get(interval.value); + return !mlir::isa(val); + } + + /** + * Spill a regular op-result value: insert STORE_FAST after its defining op and + * LOAD_FAST before each existing use, replacing those uses with the reload. + */ + void do_spill_op_result(mlir::Value victim_value, + mlir::StringAttr name_attr, + mlir::OpBuilder &builder) + { + // Collect existing uses before inserting the STORE_FAST (which adds a new use) + llvm::SmallVector uses; + for (auto &use : victim_value.getUses()) { uses.push_back(&use); } + + // After the defining op: STORE_FAST to save the spilled value + auto *def_op = victim_value.getDefiningOp(); + ASSERT(def_op); + builder.setInsertionPointAfter(def_op); + builder.create( + def_op->getLoc(), name_attr, victim_value); + + // Before each original use: LOAD_FAST to reload the spilled value + for (auto *use : uses) { + auto *user_op = use->getOwner(); + builder.setInsertionPoint(user_op); + auto reload = builder.create( + user_op->getLoc(), victim_value.getType(), name_attr); + use->set(reload.getOutput()); + } + } + + /** + * Spill a block argument: insert STORE_FAST as the first op of its block and + * LOAD_FAST before each use, replacing those uses with the reload. + * + * Used for both regular block arguments and ForwardedOutput loop variables + * (whose corresponding value is the body block's first argument). + */ + void do_spill_block_argument(mlir::BlockArgument arg, + mlir::StringAttr name_attr, + mlir::OpBuilder &builder) + { + auto *bb = arg.getOwner(); + ASSERT(!bb->empty()); + auto loc = bb->front().getLoc(); + + // Collect existing uses before inserting STORE_FAST + llvm::SmallVector uses; + for (auto &use : arg.getUses()) { uses.push_back(&use); } + + // Insert STORE_FAST at the very start of the block + builder.setInsertionPoint(bb, bb->begin()); + builder.create(loc, name_attr, arg); + + // Before each original use: LOAD_FAST to reload the value + for (auto *use : uses) { + auto *user_op = use->getOwner(); + builder.setInsertionPoint(user_op); + auto reload = builder.create( + user_op->getLoc(), arg.getType(), name_attr); + use->set(reload.getOutput()); + } + } + + /** + * Spill a value to a named local variable slot to free up a register. + * + * Victim selection (in priority order): + * 1. The active non-block-arg op-result with the latest end point (cheapest to spill). + * 2. Active block argument or ForwardedOutput (spilled at block entry). + * 3. cur itself if the above don't apply. + * + * GET_ITER results are never eligible: their intervals are collapsed to contiguous spans + * and must stay alive throughout the loop. + * + * After inserting spill code, sets m_spills_emitted = true so analyse() restarts. + */ + void spill_value(const LiveIntervalAnalysis::LiveInterval &cur, + LiveIntervalSet &active, + mlir::func::FuncOp &func, + mlir::OpBuilder &builder) + { + auto logger = get_regalloc_logger(); + + // --- Victim selection --- + // Pass 1: prefer regular op-results (cheapest path) + const LiveIntervalAnalysis::LiveInterval *victim_interval = nullptr; + for (auto it = active.rbegin(); it != active.rend(); ++it) { + if (!is_regular_op_result(*it)) { continue; } + if (is_get_iter_result(*it)) { continue; } + victim_interval = &(*it); + break; + } + + // Pass 2: fall back to block arguments / ForwardedOutputs + if (!victim_interval) { + for (auto it = active.rbegin(); it != active.rend(); ++it) { + if (is_get_iter_result(*it)) { continue; } + victim_interval = &(*it); + break; + } + } + + // --- Choose: spill active victim or cur --- + const LiveIntervalAnalysis::LiveInterval *to_spill = nullptr; + if (victim_interval && victim_interval->end() > cur.end()) { + to_spill = victim_interval; + logger->info("Spilling active {} (end={}) to free register for {} (end={})", + victim_interval->value, + victim_interval->end(), + cur.value, + cur.end()); + } else if (!is_get_iter_result(cur)) { + to_spill = &cur; + logger->info("Spilling current {} (end={})", cur.value, cur.end()); + } else { + logger->error( + "Cannot spill: all live intervals are GET_ITER results — function requires " + "more than {} registers.", + kRegCount); + TODO(); + return; + } + + // --- Allocate spill slot --- + const std::string spill_name = "__spill_" + std::to_string(m_spill_slot_count++); + logger->info("Spilling {} to slot '{}'", to_spill->value, spill_name); + + auto *ctx = func->getContext(); + auto name_attr = mlir::StringAttr::get(ctx, spill_name); + { + auto existing = func->getAttr("locals"); + llvm::SmallVector locals; + if (existing) { + auto arr = mlir::cast(existing); + locals.assign(arr.begin(), arr.end()); + } + locals.push_back(name_attr); + func->setAttr("locals", mlir::ArrayAttr::get(ctx, locals)); + } + + // --- Dispatch spill by value type --- + if (std::holds_alternative(to_spill->value)) { + auto val = std::get(to_spill->value); + if (mlir::isa(val)) { + // Block argument: spill at block entry + do_spill_block_argument(mlir::cast(val), name_attr, builder); + } else { + // Regular op result: spill after defining op + do_spill_op_result(val, name_attr, builder); + } + } else { + // ForwardedOutput: the loop variable from FOR_ITER lives as the body block's arg + ASSERT(std::holds_alternative(to_spill->value)); + auto [op_ptr, idx] = std::get(to_spill->value); + auto for_iter = mlir::cast(op_ptr); + auto body_arg = for_iter.getBody()->getArgument(idx); + do_spill_block_argument(body_arg, name_attr, builder); + } + + m_spills_emitted = true; + } + + /** + * Pre-allocate r0 for operations that directly clobber it (CALL, YIELD, etc.) + * + * This ensures that values produced by these operations are assigned to r0, + * matching the VM's behavior where these operations place results directly in r0. + * + * Note: Block arguments are NOT pre-allocated here. If a block argument receives + * values from r0-clobbering operations in different registers, MOVE instructions + * will be inserted at block boundaries during bytecode emission. */ void preallocate_r0_clobbering_operations( std::span unhandled, @@ -144,50 +383,36 @@ class LinearScanRegisterAllocation auto logger = get_regalloc_logger(); for (const auto &interval : unhandled) { - bool needs_r0 = false; - - // Check if this value directly clobbers r0 - if (std::holds_alternative(interval.value)) { - auto value = std::get(interval.value); - if (clobbers_r0(value)) { - needs_r0 = true; - logger->debug("Value {} clobbers r0", to_string(value)); - } - } + // Only pre-allocate values that directly clobber r0 + if (!std::holds_alternative(interval.value)) { continue; } - // Check if this value flows from something that clobbers r0 - if (!needs_r0) { - if (auto it = live_interval_analysis.block_input_mappings.find(interval.value); - it != live_interval_analysis.block_input_mappings.end()) { - for (auto mapped_value : it->second) { - if (std::holds_alternative(mapped_value)) { continue; } - if (clobbers_r0(std::get(mapped_value))) { - needs_r0 = true; - logger->debug("Value {} flows from r0-clobbering value", - to_string(interval.value)); - break; - } - } - } - } + auto value = std::get(interval.value); + + // Skip block arguments - they don't clobber r0 themselves + if (mlir::isa(value)) { continue; } - if (needs_r0) { + if (clobbers_r0(value)) { + // Pre-allocate r0 but DON'T add to inactive - let it be processed normally + // in the main loop to handle conflicts and spilling if needed value2mem_map.insert_or_assign(interval.value, Reg{ .idx = 0 }); - inactive.insert(interval); + logger->debug("Pre-allocated r0 for: {}", interval.value); } } } /** * Expire intervals that are no longer alive and free their registers + * Returns true if cur was found in inactive and moved to active */ - void expire_old_intervals(const LiveIntervalAnalysis::LiveInterval &cur, + bool expire_old_intervals(const LiveIntervalAnalysis::LiveInterval &cur, LiveIntervalSet &active, LiveIntervalSet &inactive, LiveIntervalSet &handled, - std::bitset<32> &free, + std::bitset &free, std::shared_ptr &logger) { + bool cur_was_inactive = false; + // Expire active intervals for (auto it = active.begin(); it != active.end();) { const auto &interval = *it; @@ -199,7 +424,9 @@ class LinearScanRegisterAllocation it = active.erase(it); free_register(interval, free, logger); } else if (!interval.alive_at(cur.start())) { - // Interval temporarily not alive (goes inactive) + // Interval temporarily not alive (goes inactive). + // GET_ITER intervals are guaranteed contiguous by extend_iterator_liveness(), + // so they will never take this branch during their loop span. inactive.insert(interval); it = active.erase(it); free_register(interval, free, logger); @@ -216,6 +443,8 @@ class LinearScanRegisterAllocation // Current interval was previously allocated (e.g., r0 clobbering) active.insert(interval); it = inactive.erase(it); + cur_was_inactive = true; + logger->debug("Moved cur from inactive to active: {}", cur.value); } else if (interval.end() < cur.start()) { // Interval completely expired handled.insert(interval); @@ -225,21 +454,26 @@ class LinearScanRegisterAllocation active.insert(interval); it = inactive.erase(it); mark_register_used(interval, free, logger); + logger->debug("Reactivated interval: {}", interval.value); } else { ++it; } } + + return cur_was_inactive; } /** * Collect available registers, accounting for special constraints */ - std::bitset<32> collect_available_registers(const LiveIntervalAnalysis::LiveInterval &cur, - const std::bitset<32> &free, + std::bitset collect_available_registers( + const LiveIntervalAnalysis::LiveInterval &cur, + const std::bitset &free, LiveIntervalSet &inactive, std::span unhandled, const LiveIntervalAnalysis &live_interval_analysis) { + auto logger = get_regalloc_logger(); auto available = free; // Exclude registers used by overlapping inactive intervals @@ -271,73 +505,31 @@ class LinearScanRegisterAllocation /** * Apply special constraints for specific operations: - * - GetIter cannot use r0 (reserved for function call results) - * - ForIter loop variable cannot use the same register as its iterator + * - GetIter cannot use r0 (reserved for function call results by VM convention) + * - BuildList cannot use r0 (ListExtend internally calls the iterator protocol + * which executes Python bytecode and triggers pop_frame(true), propagating + * the callee's r0 into the caller's r0, overwriting the list) */ void apply_special_constraints(const LiveIntervalAnalysis::LiveInterval &cur, - std::bitset<32> &available, - const LiveIntervalAnalysis &live_interval_analysis) + std::bitset &available, + const LiveIntervalAnalysis & /*live_interval_analysis*/) { auto logger = get_regalloc_logger(); - // GetIter: cannot use r0 if (std::holds_alternative(cur.value)) { auto value = std::get(cur.value); + // GetIter: cannot use r0 (r0 is reserved for function call results) if (value.getDefiningOp() && mlir::isa(value.getDefiningOp())) { available.set(0, false); logger->debug("GetIter result cannot use r0"); } - } - - // FIX FOR FORITER BUG: - // ForIter loop variable (ForwardedOutput) cannot use same register as iterator - if (std::holds_alternative(cur.value)) { - auto forwarded = std::get(cur.value); - if (auto for_iter = mlir::dyn_cast(forwarded.first)) { - // Get the iterator value - auto iterator = for_iter.getIterator(); - - logger->debug("Processing ForIter loop variable, looking for iterator register"); - - // Find what register the iterator is assigned to - // Need to check both as mlir::Value and potentially through block argument mappings - std::optional iterator_reg; - - if (auto it = value2mem_map.find(iterator); it != value2mem_map.end()) { - if (std::holds_alternative(it->second)) { - iterator_reg = std::get(it->second).idx; - logger->debug("Found iterator in r{}", *iterator_reg); - } - } - - // Also check block argument mappings - if (!iterator_reg.has_value()) { - if (auto it = live_interval_analysis.block_input_mappings.find(iterator); - it != live_interval_analysis.block_input_mappings.end()) { - for (const auto &mapped : it->second) { - if (auto reg_it = value2mem_map.find(mapped); - reg_it != value2mem_map.end()) { - if (std::holds_alternative(reg_it->second)) { - iterator_reg = std::get(reg_it->second).idx; - logger->debug( - "Found iterator via block mapping in r{}", *iterator_reg); - break; - } - } - } - } - } - - if (iterator_reg.has_value()) { - available.set(*iterator_reg, false); - logger->info( - "ForIter loop variable CANNOT use r{} (iterator register)", *iterator_reg); - } else { - logger->error("ForIter iterator register not found - BUG NOT FIXED!"); - // This is a critical error - the iterator must be allocated before the loop - // variable - } + // BuildList: cannot use r0 (ListExtend iterates Python iterators which + // trigger pop_frame(true) and overwrite r0 with the iterator's return value) + if (value.getDefiningOp() + && mlir::isa(value.getDefiningOp())) { + available.set(0, false); + logger->debug("BuildList result cannot use r0"); } } } @@ -346,11 +538,12 @@ class LinearScanRegisterAllocation * Allocate a register for the current interval */ void allocate_register(const LiveIntervalAnalysis::LiveInterval &cur, - std::bitset<32> &free, - const std::bitset<32> &available, + std::bitset &free, + const std::bitset &available, mlir::OpBuilder &builder, const LiveIntervalAnalysis &live_interval_analysis, LiveIntervalSet &active, + bool was_inactive, std::shared_ptr &logger) { std::optional cur_reg; @@ -366,7 +559,7 @@ class LinearScanRegisterAllocation if (available.test(i)) { cur_reg = i; value2mem_map.insert_or_assign(cur.value, Reg{ .idx = i }); - logger->debug("Allocated r{} to {}", i, to_string(cur.value)); + logger->debug("Allocated r{} to {}", i, cur.value); break; } } @@ -376,68 +569,31 @@ class LinearScanRegisterAllocation // Handle case where the chosen register is not free (need to save/restore) if (!free.test(*cur_reg)) { + logger->info("Register conflict: r{} is not free, handling conflict for {}", + *cur_reg, + cur.value); handle_register_conflict( - cur, *cur_reg, available, builder, live_interval_analysis, logger); + cur, *cur_reg, available, free, builder, live_interval_analysis, logger); } else { + logger->debug("Marking r{} as not free for {} [{}..{})", + *cur_reg, + cur.value, + cur.start(), + cur.end()); free.set(*cur_reg, false); } - active.insert(cur); - - // CRITICAL FIX FOR FORITER BUG: - // When allocating a block argument that is a FOR_ITER loop variable, reserve the - // iterator register for the duration of the loop to prevent it from being reused - if (std::holds_alternative(cur.value)) { - auto value = std::get(cur.value); - - // Check if this is a block argument - if (mlir::isa(value)) { - logger->debug("Allocated block argument: {}", to_string(cur.value)); - - // Check if this block argument comes from a FOR_ITER - if (auto it = live_interval_analysis.block_input_mappings.find(cur.value); - it != live_interval_analysis.block_input_mappings.end()) { - - for (const auto &input : it->second) { - if (std::holds_alternative(input)) { - auto forwarded = std::get(input); - - if (auto for_iter = mlir::dyn_cast( - forwarded.first)) { - logger->debug("Block argument is FOR_ITER loop variable"); - auto iterator = for_iter.getIterator(); - logger->debug("Iterator value: {}", to_string(iterator)); - - // Find the iterator's register - if (auto iter_it = value2mem_map.find(iterator); - iter_it != value2mem_map.end()) { - if (std::holds_alternative(iter_it->second)) { - auto iterator_reg = std::get(iter_it->second).idx; - logger->debug("Found iterator register: r{}", iterator_reg); - - // Reserve this register for the duration of the loop - // variable's lifetime - foriter_reserved_regs[cur.value] = iterator_reg; - - // Mark the iterator register as busy - free.set(iterator_reg, false); - - logger->info( - "FOR_ITER FIX: Reserved r{} (iterator) for loop " - "variable {}", - iterator_reg, - to_string(cur.value)); - } else { - logger->warn("Iterator register is not a Reg!"); - } - } else { - logger->error("Iterator not found in value2mem_map!"); - } - } - } - } - } - } + // Only insert into active if it wasn't already moved from inactive + if (!was_inactive) { + active.insert(cur); + logger->debug( + "Added to active: {} in r{} [{}..{})", cur.value, *cur_reg, cur.start(), cur.end()); + } else { + logger->debug("Already in active (was inactive): {} in r{} [{}..{})", + cur.value, + *cur_reg, + cur.start(), + cur.end()); } } @@ -447,7 +603,8 @@ class LinearScanRegisterAllocation */ void handle_register_conflict(const LiveIntervalAnalysis::LiveInterval &cur, size_t cur_reg, - const std::bitset<32> &available, + const std::bitset &available, + std::bitset &free, mlir::OpBuilder &builder, const LiveIntervalAnalysis &live_interval_analysis, std::shared_ptr &logger) @@ -492,8 +649,40 @@ class LinearScanRegisterAllocation value2mem_map.insert_or_assign(cur.value, Reg{ .idx = *scratch_reg }); - logger->debug("Register conflict: moved {} from r{} to r{} (scratch)", - to_string(current_value), + // BUG FIX: Mark the scratch register as not free + free.set(*scratch_reg, false); + + logger->info( + "Register conflict: moved {} from r{} to r{} (scratch), marked as not free", + current_value, + cur_reg, + *scratch_reg); + } else { + // ForwardedOutput: the defining op is the FOR_ITER terminator. + // The loop variable becomes available in the body block; insert PUSH before + // FOR_ITER and MOVE/POP at the start of the body block. + ASSERT(std::holds_alternative(cur.value)); + auto [op_ptr, idx] = std::get(cur.value); + auto *for_iter_op = op_ptr; + auto loc = for_iter_op->getLoc(); + + auto for_iter = mlir::cast(for_iter_op); + auto *body_block = for_iter.getBody(); + ASSERT(!body_block->empty()); + + // Save cur_reg before FOR_ITER, move loop variable to scratch, restore cur_reg + builder.setInsertionPoint(for_iter_op); + builder.create(loc, cur_reg); + builder.setInsertionPoint(body_block, body_block->begin()); + builder.create(loc, *scratch_reg, cur_reg); + builder.create(loc, cur_reg); + + value2mem_map.insert_or_assign(cur.value, Reg{ .idx = *scratch_reg }); + free.set(*scratch_reg, false); + + logger->info( + "Register conflict (ForwardedOutput): moved FOR_ITER loop var from r{} to r{} " + "(scratch)", cur_reg, *scratch_reg); } @@ -503,30 +692,25 @@ class LinearScanRegisterAllocation * Free a register when an interval expires */ void free_register(const LiveIntervalAnalysis::LiveInterval &interval, - std::bitset<32> &free, + std::bitset &free, std::shared_ptr &logger) { const auto reg = value2mem_map.at(interval.value); ASSERT(std::holds_alternative(reg)); size_t reg_idx = std::get(reg).idx; free.set(reg_idx, true); - logger->trace("Freed r{} from {}", reg_idx, to_string(interval.value)); - - // If this was a FOR_ITER loop variable with a reserved iterator register, free it too - if (auto it = foriter_reserved_regs.find(interval.value); - it != foriter_reserved_regs.end()) { - auto reserved_reg = it->second; - free.set(reserved_reg, true); - logger->info("FOR_ITER FIX: Freed reserved iterator register r{}", reserved_reg); - foriter_reserved_regs.erase(it); - } + logger->debug("Freed r{} from {} [{}..{})", + reg_idx, + interval.value, + interval.start(), + interval.end()); } /** * Mark a register as used when an interval becomes active */ void mark_register_used(const LiveIntervalAnalysis::LiveInterval &interval, - std::bitset<32> &free, + std::bitset &free, std::shared_ptr &logger) { const auto reg = value2mem_map.at(interval.value); @@ -534,16 +718,7 @@ class LinearScanRegisterAllocation size_t reg_idx = std::get(reg).idx; ASSERT(free.test(reg_idx)); free.set(reg_idx, false); - logger->trace("Marked r{} as used for {}", reg_idx, to_string(interval.value)); - - // If this is a FOR_ITER loop variable, also mark the iterator register as used - if (auto it = foriter_reserved_regs.find(interval.value); - it != foriter_reserved_regs.end()) { - auto reserved_reg = it->second; - free.set(reserved_reg, false); - logger->trace( - "FOR_ITER FIX: Marked reserved iterator register r{} as used", reserved_reg); - } + logger->trace("Marked r{} as used for {}", reg_idx, interval.value); } /** @@ -572,7 +747,7 @@ class LinearScanRegisterAllocation logger->debug("Final register assignments:"); for (const auto &[value, location] : value2mem_map) { if (std::holds_alternative(location)) { - logger->debug(" {} -> r{}", to_string(value), std::get(location).idx); + logger->debug(" {} -> r{}", value, std::get(location).idx); } } } @@ -616,7 +791,7 @@ class LinearScanRegisterAllocation // Print each value's liveness for (const auto &interval : live_interval_analysis->sorted_live_intervals) { // Print value name (truncate to 55 chars) - std::string value_str = to_string(interval.value); + std::string value_str = fmt::format("{}", interval.value); if (value_str.length() > 55) { value_str = value_str.substr(0, 52) + "..."; } llvm::outs() << llvm::format("%-55s", value_str.c_str()) << " | "; @@ -666,7 +841,7 @@ class LinearScanRegisterAllocation // Print each value's register assignment for (const auto &interval : live_interval_analysis->sorted_live_intervals) { // Print value name (truncate to 55 chars) - std::string value_str = to_string(interval.value); + std::string value_str = fmt::format("{}", interval.value); if (value_str.length() > 55) { value_str = value_str.substr(0, 52) + "..."; } llvm::outs() << llvm::format("%-55s", value_str.c_str()) << " | "; diff --git a/src/executable/mlir/Target/PythonBytecode/LiveAnalysis.hpp b/src/executable/mlir/Target/PythonBytecode/LiveAnalysis.hpp index 6a639fb9..715911e3 100644 --- a/src/executable/mlir/Target/PythonBytecode/LiveAnalysis.hpp +++ b/src/executable/mlir/Target/PythonBytecode/LiveAnalysis.hpp @@ -3,9 +3,11 @@ #include "Dialect/EmitPythonBytecode/IR/EmitPythonBytecode.hpp" #include "RegisterAllocationLogger.hpp" #include "RegisterAllocationTypes.hpp" +#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AsmState.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" #include "utilities.hpp" @@ -20,15 +22,25 @@ namespace codegen { /** - * LiveAnalysis performs backward dataflow analysis to determine which values are alive - * at each point in the program. This is the first step in register allocation. + * LiveAnalysis determines which values are alive at each point in the program. + * This is the first step in register allocation. * - * Uses the standard liveness algorithm: - * LiveOut[B] = union of LiveIn[S] for all successors S of B - * LiveIn[B] = Use[B] ∪ (LiveOut[B] - Def[B]) - * Iterate until fixed point + * The block-level fixed-point dataflow is delegated to mlir::Liveness — it + * already implements the standard LiveOut[B] = ∪ LiveIn[S] for S in succ(B), + * LiveIn[B] = Use[B] ∪ (LiveOut[B] - Def[B]) algorithm and handles loops / + * back-edges. What's project-specific stays here: * - * This correctly handles loops and back edges. + * 1. ForwardedOutput for FOR_ITER. The body block's first argument is the + * loop value produced by FOR_ITER's terminator, which MLIR can't model + * natively. We track it as a synthetic Value-like entity. + * + * 2. alive_at_timestep — a linearised per-operation list of "what's alive + * here". mlir::Liveness only exposes per-block / per-value queries; the + * register allocator wants the linearised view, so we materialise it on + * top of mlir::Liveness's block-level results. + * + * 3. block_input_mappings — for each value, the set of block arguments it + * could flow into via a CFG edge. Computed from the terminator operands. */ class LiveAnalysis { @@ -52,18 +64,20 @@ class LiveAnalysis block_input_mappings; /** - * Analyze the function to determine liveness information using backward dataflow + * Analyze the function to determine liveness information. */ void analyse(mlir::func::FuncOp &fn) { auto logger = get_regalloc_logger(); - logger->info( - "Starting backward dataflow live analysis for function: {}", fn.getName().str()); + logger->info("Starting live analysis for function: {}", fn.getName().str()); auto ®ion = fn.getRegion(); auto sorted_blocks = sortBlocks(region); - // Build block information and Use/Def sets + // Collect block operations and the project-specific bits MLIR doesn't + // model natively (ForwardedOutputs from FOR_ITER, value→block-arg + // edge mapping). Use/def computation is no longer needed here — + // mlir::Liveness owns the dataflow. std::map block_info; std::vector, mlir::BlockArgument>> block_parameters_to_args; @@ -72,12 +86,13 @@ class LiveAnalysis build_block_info(block, block_info[block], block_parameters_to_args, logger); } - // Run backward dataflow to compute LiveIn/LiveOut - compute_liveness(sorted_blocks, block_info, logger); + // Block-level live-in / live-out via the upstream analysis. + mlir::Liveness liveness(fn); - // Build alive_at_timestep from LiveIn/LiveOut + // Materialise alive_at_timestep on top of mlir::Liveness's block + // results, injecting ForwardedOutputs into the live sets as needed. std::map> blocks_span; - build_timesteps(sorted_blocks, block_info, blocks_span); + build_timesteps(sorted_blocks, block_info, liveness, blocks_span); // Propagate block argument inputs through the liveness information propagate_block_arguments(block_parameters_to_args, blocks_span); @@ -90,22 +105,13 @@ class LiveAnalysis private: /** - * Information about a single block for dataflow analysis + * Per-block ancillary state: the operation list (in topological order, + * needed because the linearised timesteps must match the order ops are + * walked by the bytecode emitter) plus the ForwardedOutputs created by + * this block's terminator (FOR_ITER's synthetic loop-var value). */ struct BlockInfo { - // Values used in this block (before being defined) - ValueSet use; - - // Values defined in this block - ValueSet def; - - // Values live at entry to this block (computed by dataflow) - ValueSet live_in; - - // Values live at exit from this block (computed by dataflow) - ValueSet live_out; - // Operations in this block (in order) std::vector operations; @@ -114,7 +120,8 @@ class LiveAnalysis }; /** - * Build Use/Def sets for a block + * Collect operations and project-specific terminator metadata for a block. + * Block-level live-in/out is computed separately by mlir::Liveness. */ void build_block_info(mlir::Block *block, BlockInfo &info, @@ -128,26 +135,12 @@ class LiveAnalysis std::abort(); } - // Build Use/Def sets - // For each operation, add operands to Use (if not already in Def), and add results to Def - for (auto &op : block->getOperations()) { - info.operations.push_back(&op); - - // Add operands to Use (if not already defined) - for (const auto &operand : op.getOperands()) { - if (!info.def.contains(operand)) { info.use.insert(operand); } - } - - // Add results to Def - for (const auto &result : op.getResults()) { info.def.insert(result); } - } + for (auto &op : block->getOperations()) { info.operations.push_back(&op); } // Handle terminators specially if (auto *terminator = block->getTerminator()) { handle_terminator(terminator, info, block_parameters_to_args, logger); } - - logger->debug("Block has {} uses, {} defs", info.use.size(), info.def.size()); } /** @@ -216,70 +209,7 @@ class LiveAnalysis } /** - * Compute LiveIn/LiveOut using backward dataflow iteration - */ - void compute_liveness(const std::vector &sorted_blocks, - std::map &block_info, - std::shared_ptr &logger) - { - logger->info("Running backward dataflow iteration to compute liveness"); - - // Initialize all LiveIn and LiveOut to empty (already done by default) - - // Iterate until fixed point - bool changed = true; - int iteration = 0; - - while (changed) { - changed = false; - iteration++; - - logger->debug("Dataflow iteration {}", iteration); - - // Process blocks in reverse post-order for better convergence - for (auto it = sorted_blocks.rbegin(); it != sorted_blocks.rend(); ++it) { - auto *block = *it; - auto &info = block_info[block]; - - // Save old LiveIn for convergence check - auto old_live_in = info.live_in; - - // LiveOut[B] = union of LiveIn[S] for all successors S - info.live_out.clear(); - for (auto *successor : block->getSuccessors()) { - const auto &succ_info = block_info[successor]; - info.live_out.insert(succ_info.live_in.begin(), succ_info.live_in.end()); - } - - // LiveIn[B] = Use[B] ∪ (LiveOut[B] - Def[B]) - info.live_in = info.use; - for (const auto &val : info.live_out) { - if (!info.def.contains(val)) { info.live_in.insert(val); } - } - - // Add ForwardedOutputs to LiveIn (they're "defined" by the terminator but need - // to be live for the successor) - for (const auto &fwd : info.forwarded_outputs) { info.live_in.insert(fwd); } - - // Check if LiveIn changed - if (info.live_in != old_live_in) { changed = true; } - } - } - - logger->info("Dataflow converged after {} iterations", iteration); - - // Debug: print LiveIn/LiveOut for each block - for (auto *block : sorted_blocks) { - const auto &info = block_info[block]; - logger->debug("Block {} LiveIn: {} values, LiveOut: {} values", - static_cast(block), - info.live_in.size(), - info.live_out.size()); - } - } - - /** - * Build alive_at_timestep from LiveIn/LiveOut + * Build alive_at_timestep from mlir::Liveness's block live-out info. * * Computes precise per-operation liveness by propagating backward within each block. * This ensures values are only marked alive when actually needed, not conservatively @@ -287,6 +217,7 @@ class LiveAnalysis */ void build_timesteps(const std::vector &sorted_blocks, const std::map &block_info, + const mlir::Liveness &liveness, std::map> &blocks_span) { auto logger = get_regalloc_logger(); @@ -300,8 +231,13 @@ class LiveAnalysis std::vector alive_before_op; alive_before_op.resize(info.operations.size()); - // Start from LiveOut (values alive at block exit) and work backward - ValueSet alive_after = info.live_out; + // Start from LiveOut (values alive at block exit) and work backward. + // mlir::Liveness returns a SmallPtrSet; the project's + // ValueSet is a variant set, so we + // convert. ForwardedOutputs aren't part of mlir::Liveness's view + // and are added separately below. + ValueSet alive_after; + for (mlir::Value v : liveness.getLiveOut(block)) { alive_after.insert(v); } for (size_t i = info.operations.size(); i-- > 0;) { auto *op = info.operations[i]; @@ -319,51 +255,53 @@ class LiveAnalysis alive_after = alive_before_op[i]; } - // For operations with side effects, ensure their results are kept alive - // even if they're not used, so they get proper register assignments. - // This is needed for operations like LoadAttribute that may raise exceptions - // or trigger descriptors, and must be emitted even if results are unused. + // For operations with side effects or that produce values in specific registers, + // ensure their results are kept alive so they get proper register assignments. + // This is needed for: + // - Operations that raise exceptions (LoadAttribute, etc.) + // - Operations that clobber r0 (CALL, YIELD, etc.) - their results MUST be tracked for (size_t i = 0; i < info.operations.size(); i++) { auto *op = info.operations[i]; - // Check if operation is pure (has no side effects) - // Non-pure operations must be emitted even if results are unused - if (!mlir::isPure(op)) { - // Re-add any results to ensure they get register assignments + bool needs_tracking = !mlir::isPure(op); + + // CRITICAL: Always track CALL operations - their results go to r0 and MUST have + // a live interval even if the result is unused + if (!needs_tracking) { + if (llvm::isa(op) + || llvm::isa(op) + || llvm::isa(op)) { + needs_tracking = true; + } + } + + if (needs_tracking) { + // Add results to the current operation's alive_before to ensure they appear + // in the timestep where the operation executes. This is semantically odd + // (the value doesn't exist before the operation), but it ensures the result + // gets a register allocation at the point where it's produced. for (auto result : op->getResults()) { alive_before_op[i].insert(result); } } } - // Add ForwardedOutputs to the first operation if they're in LiveIn - // (they're "defined" by the terminator but need to be live for the successor) + // Add ForwardedOutputs to the first operation's alive set. + // These are "defined" by the terminator but need to be live + // throughout the block so the successor (the FOR_ITER body) + // can pick the loop variable up via the same register. The + // previous custom dataflow unconditionally injected these + // into LiveIn before this point; mlir::Liveness doesn't know + // about them, so they're injected here directly. if (!info.operations.empty()) { - for (const auto &fwd : info.forwarded_outputs) { - if (info.live_in.contains(fwd)) { alive_before_op[0].insert(fwd); } - } + for (const auto &fwd : info.forwarded_outputs) { alive_before_op[0].insert(fwd); } } - // Sanity check: alive_before[0] should equal LiveIn - // (We computed it backward from LiveOut, should match forward computation) - if (!alive_before_op.empty() && alive_before_op[0] != info.live_in) { - logger->warn( - "Block {} liveness mismatch: alive_before[0] has {} values, LiveIn has {} " - "values", - static_cast(block), - alive_before_op[0].size(), - info.live_in.size()); - - // Debug: show the difference - logger->debug(" Values in LiveIn but not alive_before[0]:"); - for (const auto &val : info.live_in) { - if (!alive_before_op[0].contains(val)) { - logger->debug(" {}", to_string(val)); - } - } - logger->debug(" Values in alive_before[0] but not LiveIn:"); - for (const auto &val : alive_before_op[0]) { - if (!info.live_in.contains(val)) { logger->debug(" {}", to_string(val)); } - } - } + // Note: alive_before_op[0] may differ from mlir::Liveness's + // LiveIn(block) because the needs_tracking pass above adds + // impure operation results to the timestep of their defining + // op. These results are defined within this block so they + // cannot be in LiveIn. The discrepancy is intentional — it + // ensures impure ops get a register assignment at their + // definition site. // Now build timesteps in forward order using the computed liveness for (size_t i = 0; i < info.operations.size(); i++) { @@ -384,7 +322,7 @@ class LiveAnalysis start, end, info.operations.size(), - info.live_in.size()); + liveness.getLiveIn(block).size()); if (block->getTerminator() && mlir::isa(block->getTerminator())) { logger->debug(" ^ FOR_ITER block"); @@ -393,28 +331,51 @@ class LiveAnalysis } /** - * Propagate block argument values through liveness information + * Propagate block argument values through liveness information. + * + * This function transforms block arguments in the alive_at_timestep data into + * BlockArgumentInputs structures that track all the source values flowing into + * each block argument (PHI nodes in SSA form). + * + * For example, if: + * bb1: br ^bb3(%val1) + * bb2: br ^bb3(%val2) + * bb3(%arg): + * + * Then %arg will be transformed into BlockArgumentInputs{%arg, [%val1, %val2]} + * indicating that %arg can receive values from either %val1 or %val2 depending + * on which predecessor block was executed. + * + * This information is crucial for register allocation to ensure that: + * 1. All source values are allocated to compatible registers + * 2. The block argument is allocated to the same register as its sources + * 3. MOVE instructions are inserted if sources end up in different registers */ void propagate_block_arguments( const std::vector, mlir::BlockArgument>> &block_parameters_to_args, const std::map> &blocks_span) { + // For each (source_value, block_argument) pair collected during block analysis for (const auto &[param, arg] : block_parameters_to_args) { auto *bb = arg.getOwner(); const auto [start, end] = blocks_span.at(bb); auto block_timesteps = std::span{ alive_at_timestep.begin() + start, alive_at_timestep.begin() + end }; + // Replace all occurrences of the block argument with BlockArgumentInputs for (auto &ts : block_timesteps) { for (auto &val : ts) { + // Check if this is the block argument we're looking for if (std::holds_alternative(val) && mlir::isa(std::get(val)) && mlir::cast(std::get(val)) == arg) { + // First occurrence: create BlockArgumentInputs with this source val = BlockArgumentInputs{ arg, { param } }; block_input_mappings[param].insert(arg); } else if (std::holds_alternative(val) && std::get<0>(std::get(val)) == arg) { + // Subsequent occurrence: append this source to existing BlockArgumentInputs std::get<1>(std::get(val)).push_back(param); block_input_mappings[param].insert(arg); } @@ -424,26 +385,47 @@ class LiveAnalysis } /** - * Resolve transitive chains of block arguments + * Resolve transitive chains of block arguments. + * + * This function handles cases where a block argument receives values from other + * block arguments (transitive PHI nodes). For example: + * + * bb1: br ^bb2(%val1) + * bb2(%arg2): br ^bb3(%arg2) + * bb3(%arg3): + * + * Here %arg3 receives %arg2, and %arg2 receives %val1. We need to resolve this + * chain so that %arg3 is understood to ultimately receive %val1. + * + * The function works backwards through timesteps, following chains of block + * arguments until reaching concrete values, and updates the block_input_mappings + * accordingly. */ void resolve_block_argument_chains() { + // Process timesteps in reverse order for (auto &values : alive_at_timestep | std::views::reverse) { for (auto &value : values | std::views::reverse) { + // Convert BlockArgumentInputs back to plain block arguments for processing if (std::holds_alternative(value)) { value = std::get<0>(std::get(value)); } + // Find if this value maps to any block arguments auto start = std::visit(overloaded{ [this](const auto &v) { return block_input_mappings.find(v); }, [this](const BlockArgumentInputs &) { - TODO(); + // BlockArgumentInputs are converted to plain mlir::Value + // (block arguments) at lines 485-487 before this visitor + // runs, so this branch is unreachable. + ASSERT(false); return block_input_mappings.end(); }, }, value); + // Follow the chain of block arguments auto it = start; while (it != block_input_mappings.end()) { ASSERT(it->second.size() == 1); diff --git a/src/executable/mlir/Target/PythonBytecode/LiveIntervalAnalysis.hpp b/src/executable/mlir/Target/PythonBytecode/LiveIntervalAnalysis.hpp index d1b1a50e..921e683c 100644 --- a/src/executable/mlir/Target/PythonBytecode/LiveIntervalAnalysis.hpp +++ b/src/executable/mlir/Target/PythonBytecode/LiveIntervalAnalysis.hpp @@ -5,6 +5,7 @@ #include "RegisterAllocationTypes.hpp" #include +#include #include #include @@ -41,35 +42,31 @@ class LiveIntervalAnalysis size_t end() const { return std::get<1>(intervals.back()); } /** - * Check if this interval is alive at the given position + * Check if this interval is alive at the given position. + * Uses precise sub-interval membership rather than a conservative span check. + * GET_ITER intervals are collapsed to a single contiguous span by + * extend_iterator_liveness() so they remain alive throughout the loop. */ bool alive_at(size_t pos) const { - // FIXME: the commented code is correct, but currently there is no logic - // to populate a register when an interval goes from inactive to active - // (i.e., the register is potentially clobbered) - // return std::find_if(intervals.begin(), - // intervals.end(), - // [pos](const Interval &interval) { - // auto [start, end] = interval; - // return pos >= start && pos < end; - // }) - // != intervals.end(); - - // Conservative approximation: check only the full span - return pos >= start() && pos < end(); + return std::find_if(intervals.begin(), + intervals.end(), + [pos](const Interval &interval) { + auto [start, end] = interval; + return pos >= start && pos < end; + }) + != intervals.end(); } /** - * Check if this interval overlaps with another + * Check if this interval overlaps with another. + * Intervals are half-open [start, end), so [a,b) and [c,d) overlap iff a < d && c < b. */ bool overlaps(const LiveInterval &other) const { - // Naive quadratic search - could be optimized with interval tree for (const auto &[start, end] : intervals) { for (const auto &[other_start, other_end] : other.intervals) { - if (other_start >= start && other_start <= end) { return true; } - if (other_end >= start && other_end <= end) { return true; } + if (other_start < end && start < other_end) { return true; } } } return false; @@ -99,12 +96,15 @@ class LiveIntervalAnalysis for (const auto &el : value) { block_input_mappings[el].push_back(key); } } - // Build live intervals from liveness information + // Build live intervals from liveness information. + // An index map provides O(log n) lookup instead of O(n) linear scan per value. std::vector unsorted_live_intervals; + std::map, size_t, ValueMappingComparator> + interval_index; for (size_t timestep = 0; const auto &alive_values : live_analysis.alive_at_timestep) { for (const auto &alive_value : alive_values) { - update_interval(alive_value, timestep, unsorted_live_intervals); + update_interval(alive_value, timestep, unsorted_live_intervals, interval_index); } timestep++; } @@ -118,28 +118,17 @@ class LiveIntervalAnalysis sorted_live_intervals = std::move(unsorted_live_intervals); + // Collapse GET_ITER live intervals to contiguous spans. + // This ensures that iterator values stay permanently active throughout the loop, + // eliminating the need for inactive→active reload logic and simplifying the allocator. + extend_iterator_liveness(); + logger->info( "Live interval analysis complete. Found {} intervals", sorted_live_intervals.size()); - // Log intervals at debug level (temporarily for debugging) for (const auto &interval : sorted_live_intervals) { - // Log all intervals, especially GET_ITER - if (std::holds_alternative(interval.value)) { - auto val = std::get(interval.value); - if (val.getDefiningOp() - && mlir::isa(val.getDefiningOp())) { - logger->info("GET_ITER LiveInterval: start={}, end={}, {} sub-intervals", - interval.start(), - interval.end(), - interval.intervals.size()); - for (size_t i = 0; i < interval.intervals.size(); ++i) { - auto [s, e] = interval.intervals[i]; - logger->info(" Interval {}: [{}, {})", i, s, e); - } - } - } logger->trace("LiveInterval for {}: start={}, end={}", - to_string(interval.value), + interval.value, interval.start(), interval.end()); } @@ -147,13 +136,45 @@ class LiveIntervalAnalysis private: /** - * Update or create a live interval for the given value at the current timestep + * Collapse GET_ITER live intervals to a single contiguous span. + * + * Backward dataflow may produce gaps in a GET_ITER value's live interval when loop + * back-edges cause the iterator to appear as a separate sub-interval on each iteration. + * With the now-precise alive_at() check, such gaps would allow the allocator to move + * the iterator to inactive and potentially clobber its register. + * + * By collapsing [start, gap, end] → [start, end), the iterator remains permanently + * active throughout the loop. This is semantically correct because the iterator MUST + * survive for the entire duration of the FOR loop. + */ + void extend_iterator_liveness() + { + for (auto &interval : sorted_live_intervals) { + if (!std::holds_alternative(interval.value)) { continue; } + auto val = std::get(interval.value); + if (!val.getDefiningOp()) { continue; } + if (!mlir::isa(val.getDefiningOp())) { continue; } + + // Collapse all sub-intervals into one contiguous [start, end) span + const size_t first = interval.start(); + const size_t last = interval.end(); + interval.intervals.clear(); + interval.intervals.emplace_back(first, last); + } + } + + /** + * Update or create a live interval for the given value at the current timestep. + * + * Uses an index map for O(log n) lookup rather than a linear scan over all intervals. */ void update_interval( const std::variant &alive_value, size_t timestep, - std::vector &intervals) + std::vector &intervals, + std::map, size_t, ValueMappingComparator> + &interval_index) { // Extract the actual values to track from alive_value std::vector> values_to_track; @@ -169,26 +190,26 @@ class LiveIntervalAnalysis values_to_track.insert(values_to_track.end(), inputs.begin(), inputs.end()); } - // Update interval for each value + // Update interval for each value using the O(log n) index map for (auto value : values_to_track) { - auto it = std::find_if(intervals.begin(), - intervals.end(), - [&value](const auto &interval) { return interval.value == value; }); + auto idx_it = interval_index.find(value); - if (it == intervals.end()) { - // Create new interval + if (idx_it == interval_index.end()) { + // Create new interval and record its index + const size_t idx = intervals.size(); intervals.emplace_back( std::vector{ std::make_tuple(timestep, timestep + 1) }, value); + interval_index.emplace(value, idx); } else { // Extend existing interval - auto &value_intervals = it->intervals; + auto &value_intervals = intervals[idx_it->second].intervals; const size_t end = std::get<1>(value_intervals.back()); if (timestep == end) { // Consecutive timestep: extend the current interval std::get<1>(value_intervals.back())++; } else { - // Gap detected: start a new interval + // Gap detected: start a new sub-interval value_intervals.emplace_back(timestep, timestep + 1); } } diff --git a/src/executable/mlir/Target/PythonBytecode/RegisterAllocationLogger.hpp b/src/executable/mlir/Target/PythonBytecode/RegisterAllocationLogger.hpp index 5ee9d036..7a14fac9 100644 --- a/src/executable/mlir/Target/PythonBytecode/RegisterAllocationLogger.hpp +++ b/src/executable/mlir/Target/PythonBytecode/RegisterAllocationLogger.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -15,34 +16,97 @@ namespace codegen { using ForwardedOutput = std::pair; -// Helper functions to convert MLIR values to strings for logging -inline std::string to_string(const mlir::Value &value) -{ - std::string result; - llvm::raw_string_ostream os(result); - os << "Value@" << value.getImpl() << "["; - value.print(os); - os << "]"; - return result; } -inline std::string to_string(const ForwardedOutput &output) +template<> struct fmt::formatter { - std::string result; - llvm::raw_string_ostream os(result); - os << "ForwardedOutput@" << static_cast(output.first) << "["; - output.first->print(os); - os << ", result_idx=" << output.second << "]"; - return result; -} + template constexpr ParseContext::iterator parse(ParseContext &ctx) + { + return ctx.begin(); + } -inline std::string to_string(const std::variant &value) + template FmtContext::iterator format(mlir::Value value, FmtContext &ctx) const + { + std::string result; + llvm::raw_string_ostream os(result); + os << "Value@" << value.getImpl() << "["; + value.print(os); + os << "]"; + return fmt::format_to(ctx.out(), "{}", result); + } +}; + +template<> struct fmt::formatter { - if (std::holds_alternative(value)) { - return to_string(std::get(value)); - } else { - return to_string(std::get(value)); + template constexpr ParseContext::iterator parse(ParseContext &ctx) + { + return ctx.begin(); + } + + template + FmtContext::iterator format(codegen::ForwardedOutput output, FmtContext &ctx) const + { + std::string result; + llvm::raw_string_ostream os(result); + os << "ForwardedOutput@" << static_cast(output.first) << "["; + output.first->print(os); + os << ", result_idx=" << output.second << "]"; + return fmt::format_to(ctx.out(), "{}", result); } +}; + + +template<> struct fmt::formatter, char> +{ + + template constexpr ParseContext::iterator parse(ParseContext &ctx) + { + return ctx.begin(); + } + + template + FmtContext::iterator format(std::variant value, + FmtContext &ctx) const + { + if (std::holds_alternative(value)) { + return fmt::format_to(ctx.out(), "{}", std::get(value)); + } else { + return fmt::format_to(ctx.out(), "{}", std::get(value)); + } + } +}; + + +namespace codegen { + +// Parse log level string to spdlog level enum +inline spdlog::level::level_enum parse_log_level(const char *level_str, + spdlog::level::level_enum default_level) +{ + if (!level_str) { return default_level; } + + std::string level(level_str); + // Convert to lowercase for case-insensitive comparison + std::transform( + level.begin(), level.end(), level.begin(), [](unsigned char c) { return std::tolower(c); }); + + if (level == "trace") { + return spdlog::level::trace; + } else if (level == "debug") { + return spdlog::level::debug; + } else if (level == "info") { + return spdlog::level::info; + } else if (level == "warn" || level == "warning") { + return spdlog::level::warn; + } else if (level == "error" || level == "err") { + return spdlog::level::err; + } else if (level == "critical") { + return spdlog::level::critical; + } else if (level == "off") { + return spdlog::level::off; + } + + return default_level; } // Get the logger for register allocation @@ -50,11 +114,9 @@ inline std::shared_ptr get_regalloc_logger() { static auto logger = []() { auto l = spdlog::get("regalloc"); - if (!l) { - l = spdlog::stdout_color_mt("regalloc"); - // Set default level to warning (can be changed via spdlog::set_level) - l->set_level(spdlog::level::warn); - } + if (!l) { l = spdlog::stdout_color_mt("regalloc"); } + auto log_level = parse_log_level(std::getenv("LOG_REGALLOC"), spdlog::level::err); + l->set_level(log_level); return l; }(); return logger; diff --git a/src/executable/mlir/Target/PythonBytecode/RegisterAllocationTypes.hpp b/src/executable/mlir/Target/PythonBytecode/RegisterAllocationTypes.hpp index bbb88f42..bc4e0e4c 100644 --- a/src/executable/mlir/Target/PythonBytecode/RegisterAllocationTypes.hpp +++ b/src/executable/mlir/Target/PythonBytecode/RegisterAllocationTypes.hpp @@ -78,4 +78,10 @@ inline std::vector sortBlocks(mlir::Region ®ion) return { result.begin(), result.end() }; } +// Total number of registers available for allocation +constexpr size_t kNumRegisters = 32; + +// Register index where CALL/YIELD/WITH_EXCEPT_START results land at runtime +constexpr size_t kCallResultReg = 0; + }// namespace codegen diff --git a/src/executable/mlir/Target/PythonBytecode/TranslateToPythonBytecode.cpp b/src/executable/mlir/Target/PythonBytecode/TranslateToPythonBytecode.cpp index d14c0dfc..93be4222 100644 --- a/src/executable/mlir/Target/PythonBytecode/TranslateToPythonBytecode.cpp +++ b/src/executable/mlir/Target/PythonBytecode/TranslateToPythonBytecode.cpp @@ -78,6 +78,7 @@ #include "executable/bytecode/instructions/YieldFrom.hpp" #include "executable/bytecode/instructions/YieldLoad.hpp" #include "executable/bytecode/instructions/YieldValue.hpp" +#include "executable/mlir/Target/PythonBytecode/LinearScanRegisterAllocation.hpp" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" @@ -94,689 +95,17 @@ #include "runtime/Value.hpp" #include "utilities.hpp" #include "llvm/ADT/APSInt.h" -#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/TypeSwitch.h" #include #include #include #include -#include using namespace mlir; namespace codegen { -namespace { - bool is_function_call(mlir::Value value) - { - return mlir::isa(value.getDefiningOp()) - || mlir::isa(value.getDefiningOp()) - || mlir::isa( - value.getDefiningOp()); - } - - bool clobbers_r0(mlir::Value value) - { - return is_function_call(value) - || mlir::isa(value.getDefiningOp()) - || mlir::isa(value.getDefiningOp()) - || mlir::isa(value.getDefiningOp()); - } - - std::vector sortBlocks(mlir::Region ®ion) - { - auto result = mlir::getBlocksSortedByDominance(region); - return { result.begin(), result.end() }; - } -}// namespace - -using ForwardedOutput = std::pair; - -template -using ValueMapping = std::map, - ValueT, - decltype([](const std::variant &lhs, - const std::variant &rhs) { - if (rhs.valueless_by_exception()) { - return false; - } else if (lhs.valueless_by_exception()) { - return true; - } else if (lhs.index() < rhs.index()) { - return true; - } else if (lhs.index() > rhs.index()) { - return false; - } - if (std::holds_alternative(lhs)) { - return std::get(lhs).getImpl() < std::get(rhs).getImpl(); - } - return std::get(lhs) < std::get(rhs); - })>; - -struct LiveAnalysis -{ - using BlockArgumentInputs = - std::tuple>>; - using AliveAtTimestepT = - std::vector>; - std::vector alive_at_timestep; - - ValueMapping(lhs).getImpl() - < static_cast(rhs).getImpl(); - })>> - block_input_mappings; - - void analyse(mlir::func::FuncOp &fn) - { - auto ®ion = fn.getRegion(); - - auto sorted_blocks = sortBlocks(region); - - auto add_value = [](AliveAtTimestepT::value_type value, AliveAtTimestepT &alive) { - auto it = std::find_if(alive.begin(), alive.end(), [&value](const auto &el) { - ASSERT(std::holds_alternative(el) - || std::holds_alternative(el)); - return el == value; - }); - if (it == alive.end()) { alive.push_back(std::move(value)); } - }; - - AsmState state{ fn.getOperation() }; - std::vector, mlir::BlockArgument>> - block_parameters_to_args; - std::map> blocks_span; - for (auto *block : sorted_blocks) { - const auto start = alive_at_timestep.size(); - if (!sortTopologically(block)) { std::abort(); } - // block->print(llvm::outs(), state); - // llvm::outs() << '\n'; - for (auto &op : block->getOperations()) { - auto &alive = alive_at_timestep.emplace_back(); - // ASSERT(op.getOpResults().size() <= 1); - for (const auto &result : op.getResults()) { add_value(result, alive); } - for (const auto &operand : op.getOperands()) { add_value(operand, alive); } - // if (!op.getResults().empty()) { - // llvm::outs() << "@" << (void *)Value{ op.getResults().back() }.getImpl() - // << ": "; - // op.print(llvm::outs(), state); - // llvm::outs() << '\n'; - // } - } - if (block->getTerminator()) { - if (auto branch = - dyn_cast(block->getTerminator())) { - auto *true_block = branch.getTrueDest(); - ASSERT( - branch.getTrueDestOperands().size() == true_block->getArguments().size()); - for (const auto &[p, arg] : - llvm::zip(branch.getTrueDestOperands(), true_block->getArguments())) { - block_parameters_to_args.emplace_back(p, arg); - } - auto *false_block = branch.getFalseDest(); - ASSERT( - branch.getFalseDestOperands().size() == false_block->getArguments().size()); - for (const auto &[p, arg] : - llvm::zip(branch.getFalseDestOperands(), false_block->getArguments())) { - block_parameters_to_args.emplace_back(p, arg); - } - } else if (auto branch = dyn_cast(block->getTerminator())) { - auto *jmp_block = branch.getDest(); - ASSERT(branch.getDestOperands().size() == jmp_block->getArguments().size()); - for (const auto &[p, arg] : - llvm::zip(branch.getDestOperands(), jmp_block->getArguments())) { - block_parameters_to_args.emplace_back(p, arg); - } - } else if (auto for_iter = - dyn_cast(block->getTerminator())) { - ASSERT(for_iter.getBody()->getArguments().size() == 1); - ASSERT(for_iter.getSuccessorOperands(0).getProducedOperandCount() == 1); - block_parameters_to_args.emplace_back( - ForwardedOutput{ for_iter, 0 }, for_iter.getBody()->getArgument(0)); - // start the lifetime of the for_iter returned value - auto &alive = alive_at_timestep.back(); - add_value(ForwardedOutput{ for_iter, 0 }, alive); - } - } - blocks_span.emplace(block, std::pair{ start, alive_at_timestep.size() }); - } - - for (const auto &[param, arg] : block_parameters_to_args) { - auto *bb = arg.getOwner(); - const auto [start, end] = blocks_span.at(bb); - auto block_timesteps = - std::span{ alive_at_timestep.begin() + start, alive_at_timestep.begin() + end }; - for (auto &ts : block_timesteps) { - for (auto &val : ts) { - if (std::holds_alternative(val) - && std::get(val).isa() - && std::get(val).cast() == arg) { - val = BlockArgumentInputs{ arg, { param } }; - block_input_mappings[param].insert(arg); - } else if (std::holds_alternative(val) - && std::get<0>(std::get(val)) == arg) { - std::get<1>(std::get(val)).push_back(param); - block_input_mappings[param].insert(arg); - } - } - } - } - - // auto printv = [](Value v) { - // llvm::outs() << v.getImpl() << '['; - // v.print(llvm::outs()); - // llvm::outs() << "]"; - // }; - // auto printo = [](ForwardedOutput o) { - // llvm::outs() << static_cast(o.first) << '['; - // o.first->print(llvm::outs()); - // llvm::outs() << ", " << o.second << "]"; - // }; - - // for (size_t idx = 0; const auto &values : alive_at_timestep) { - // llvm::outs() << idx++ << ": "; - // for (auto value : values) { - // std::visit( - // overloaded{ - // printv, - // printo, - // [printv, printo](const BlockArgumentInputs &b) { - // llvm::outs() << static_cast(std::get<0>(b)).getImpl() << '{'; - // for (const auto &v : std::get<1>(b)) { - // std::visit(overloaded{ printv, printo }, v); - // llvm::outs() << ", "; - // } - // llvm::outs() << "}"; - // }, - // }, - // value); - // llvm::outs() << ", "; - // } - // llvm::outs() << '\n'; - // } - - // for (const auto &[k, v] : block_input_mappings) { - // std::visit(overloaded{ printv, printo }, k); - // llvm::outs() << ": "; - // for (const auto &el : v) { - // printv(el); - // llvm::outs() << ", "; - // } - // llvm::outs() << '\n'; - // } - - for (auto &values : alive_at_timestep | std::ranges::views::reverse) { - for (auto &value : values | std::ranges::views::reverse) { - auto original_value = value; - if (std::holds_alternative(value)) { - value = std::get<0>(std::get(value)); - } - auto start = - std::visit(overloaded{ - [this](const auto &v) { return block_input_mappings.find(v); }, - [this](const BlockArgumentInputs &) { - TODO(); - return block_input_mappings.end(); - }, - }, - value); - // std::stack< - auto it = start; - while (it != block_input_mappings.end()) { - ASSERT(it->second.size() == 1); - value = *it->second.begin(); - start->second.erase(start->second.begin()); - start->second.insert(mlir::cast(std::get(value))); - it = block_input_mappings.find(std::get(value)); - } - } - } - - // for (size_t idx = 0; const auto &values : alive_at_timestep) { - // llvm::outs() << idx++ << ": "; - // for (auto value : values) { - // std::visit( - // overloaded{ - // printv, - // printo, - // [printv, printo](const BlockArgumentInputs &b) { - // llvm::outs() << static_cast(std::get<0>(b)).getImpl() << '{'; - // for (const auto &v : std::get<1>(b)) { - // std::visit(overloaded{ printv, printo }, v); - // llvm::outs() << ", "; - // } - // llvm::outs() << "}"; - // }, - // }, - // value); - // llvm::outs() << ", "; - // } - // llvm::outs() << '\n'; - // } - // llvm::outs().flush(); - } -}; - -struct LiveIntervalAnalysis -{ - struct LiveInterval - { - // start, end - using Interval = std::tuple; - std::vector intervals; - std::variant value; - - size_t start() const { return std::get<0>(intervals.front()); } - - size_t end() const { return std::get<1>(intervals.back()); } - - bool alive_at(size_t pos) const - { - // FIXME: the commented code is correct, but currently there is no logic - // to populate a register when an interval goes from inactive to active - // ie. the register is potentially clobbered - // return std::find_if(intervals.begin(), - // intervals.end(), - // [pos](const Interval &interval) { - // auto [start, end] = interval; - // return pos >= start && pos < end; - // }) - // != intervals.end(); - return pos >= start() && pos < end(); - } - - bool overlaps(const LiveInterval &other) const - { - // naive quadratic search - for (const auto &[a, b] : intervals) { - for (const auto &[c, d] : other.intervals) { - if (a < d && c < b) { return true; } - } - } - - return false; - } - }; - - std::vector sorted_live_intervals; - ValueMapping>> block_input_mappings; - - void analyse(mlir::func::FuncOp &func) - { - LiveAnalysis live_analysis{}; - live_analysis.analyse(func); - for (auto [key, value] : live_analysis.block_input_mappings) { - for (const auto &el : value) { block_input_mappings[el].push_back(key); } - } - - // auto printv = [](Value v) { - // llvm::outs() << v.getImpl() << '['; - // v.print(llvm::outs()); - // llvm::outs() << "]"; - // }; - // auto printo = [](ForwardedOutput o) { - // llvm::outs() << static_cast(o.first) << '['; - // o.first->print(llvm::outs()); - // llvm::outs() << ", " << o.second << "]"; - // }; - - // for (const auto &[k, v] : block_input_mappings) { - // std::visit(overloaded{ printv, printo }, k); - // llvm::outs() << ": {"; - // for (const auto &el : v) { - // std::visit(overloaded{ printv, printo }, el); - // llvm::outs() << ", "; - // } - // llvm::outs() << "}\n"; - // } - - std::vector unsorted_live_intervals; - auto update_interval = - [this, &unsorted_live_intervals]( - const std::variant - &inputs, - size_t current) { - std::vector> - compute_live_interval_values; - if (std::holds_alternative(inputs)) { - compute_live_interval_values.push_back(std::get(inputs)); - } else if (std::holds_alternative(inputs)) { - compute_live_interval_values.push_back(std::get(inputs)); - } else { - compute_live_interval_values.insert(compute_live_interval_values.end(), - std::get<1>(std::get(inputs)).begin(), - std::get<1>(std::get(inputs)).end()); - } - - for (auto value : compute_live_interval_values) { - if (auto it = std::find_if(unsorted_live_intervals.begin(), - unsorted_live_intervals.end(), - [&value](const auto &el) { return el.value == value; }); - it == unsorted_live_intervals.end()) { - unsorted_live_intervals.emplace_back( - std::vector{ std::make_tuple(current, current + 1) }, value); - } else { - auto &intervals = it->intervals; - const size_t end = std::get<1>(intervals.back()); - if (current == end) { - std::get<1>(intervals.back())++; - } else { - intervals.emplace_back(current, current + 1); - } - } - } - }; - - for (size_t i = 0; const auto &vals : live_analysis.alive_at_timestep) { - for (const auto &val : vals) { update_interval(val, i); } - i++; - } - - std::sort(unsorted_live_intervals.begin(), - unsorted_live_intervals.end(), - [](const LiveIntervalAnalysis::LiveInterval &lhs, - const LiveIntervalAnalysis::LiveInterval &rhs) { - return lhs.start() < rhs.start(); - }); - sorted_live_intervals = std::move(unsorted_live_intervals); - - // for (const auto &live_interval : sorted_live_intervals) { - // auto [intervals, value] = live_interval; - // if (std::holds_alternative(value)) { - // llvm::outs() << "@" - // << static_cast(std::get(value).getImpl()) - // << " "; - // } else { - // llvm::outs() << "[@" - // << static_cast(std::get(value).first) - // << ", " << std::get(value).second << "] "; - // } - // for (const auto &interval : intervals) { - // auto [start, end] = interval; - // llvm::outs() << fmt::format("[{}, {}[ ", start, end); - // } - // llvm::outs() << '\n'; - // } - } -}; - -struct LinearScanRegisterAllocation -{ - struct Reg - { - size_t idx; - }; - struct StackLocation - { - size_t idx; - }; - using ValueLocation = std::variant; - ValueMapping value2mem_map; - - void analyse(mlir::func::FuncOp &func, mlir::OpBuilder builder) - { - LiveIntervalAnalysis live_interval_analysis; - live_interval_analysis.analyse(func); - - auto unhandled = std::span(live_interval_analysis.sorted_live_intervals.begin(), - live_interval_analysis.sorted_live_intervals.end()); - ASSERT(std::is_sorted(unhandled.begin(), - unhandled.end(), - [](const LiveIntervalAnalysis::LiveInterval &lhs, - const LiveIntervalAnalysis::LiveInterval &rhs) { - return lhs.start() < rhs.start(); - })); - - auto increasing_endpoint_cmp = [](const LiveIntervalAnalysis::LiveInterval &lhs, - const LiveIntervalAnalysis::LiveInterval &rhs) { - return lhs.end() < rhs.end(); - }; - - std::multiset active; - std::multiset - inactive; - std::multiset - handled; - - std::bitset<32> free; - free.set(); - - for (const auto &interval : unhandled) { - // the result of a function call is always in Reg{0}, so we start by claiming Reg{0} for - // the result of all call operations - if (std::holds_alternative(interval.value) - && (std::get(interval.value).getDefiningOp() - && clobbers_r0(std::get(interval.value)))) { - value2mem_map.insert_or_assign( - std::get(interval.value), Reg{ .idx = 0 }); - inactive.insert(interval); - } - - // account for block arguments that could be the result of a function call - if (auto it = live_interval_analysis.block_input_mappings.find(interval.value); - it != live_interval_analysis.block_input_mappings.end()) { - for (auto mapped_value : it->second) { - if (std::holds_alternative(mapped_value)) { continue; } - if ((std::get(mapped_value).getDefiningOp() - && clobbers_r0(std::get(mapped_value)))) { - value2mem_map.insert_or_assign(interval.value, Reg{ .idx = 0 }); - inactive.insert(interval); - break; - } - } - } - } - - while (!unhandled.empty()) { - // llvm::outs() << "free: " << free << '\n'; - const auto &cur = *unhandled.begin(); - unhandled = unhandled.subspan(1, unhandled.size() - 1); - - // const_cast(cur.value).print(llvm::outs()); - // llvm::outs() << '\n'; - - // check for active intervals that expired - for (auto it = active.begin(); it != active.end();) { - const auto &interval = *it; - ASSERT(interval.value != cur.value); - if (interval.end() < cur.start()) { - handled.insert(interval); - it = active.erase(it); - const auto reg = value2mem_map.at(interval.value); - ASSERT(std::holds_alternative(reg)); - free.set(std::get(reg).idx, true); - } else if (!interval.alive_at(cur.start())) { - inactive.insert(interval); - it = active.erase(it); - const auto reg = value2mem_map.at(interval.value); - ASSERT(std::holds_alternative(reg)); - free.set(std::get(reg).idx, true); - } else { - ++it; - } - } - // check for inactive intervals that expired or become reactivated - for (auto it = inactive.begin(); it != inactive.end();) { - const auto &interval = *it; - if (interval.value == cur.value) { - ASSERT( - (std::holds_alternative(interval.value) - && std::get(interval.value).getDefiningOp() - && clobbers_r0(std::get(interval.value))) - || (live_interval_analysis.block_input_mappings.contains(interval.value) - && std::ranges::any_of( - live_interval_analysis.block_input_mappings.find(interval.value) - ->second, - [](auto mapped_value) { - if (std::holds_alternative(mapped_value)) { - return std::get(mapped_value).getDefiningOp() - && clobbers_r0(std::get(mapped_value)); - } - return false; - }))); - active.insert(interval); - it = inactive.erase(it); - } else if (interval.end() < cur.start()) { - handled.insert(interval); - it = inactive.erase(it); - } else if (interval.alive_at(cur.start())) { - active.insert(interval); - const auto reg = value2mem_map.at(interval.value); - ASSERT(std::holds_alternative(reg)); - ASSERT(free.test(std::get(reg).idx)); - free.set(std::get(reg).idx, false); - } else { - ++it; - } - } - - auto f = free; - // collect available registers - auto overlaps = - std::views::filter([&cur](const auto &interval) { return interval.overlaps(cur); }); - for (const auto &interval : inactive | overlaps) { - if (auto it = value2mem_map.find(interval.value); it != value2mem_map.end()) { - const auto reg = it->second; - ASSERT(std::holds_alternative(reg)); - - // if it is still inactive it should be ok if this register is still being used - // we just don't want it to be used when the interval becomes active - f.set(std::get(reg).idx, false); - } - } - - for (const auto &interval : unhandled | overlaps) { - if (auto it = value2mem_map.find(interval.value); it != value2mem_map.end()) { - const auto reg = it->second; - ASSERT(std::holds_alternative(reg)); - - // if it is unhandled it should be ok if this register is still being used - // we just don't want it to be used when the interval becomes active - f.set(std::get(reg).idx, false); - } - } - - if (f.none()) { - TODO(); - } else { - std::optional cur_reg; - if (auto it = value2mem_map.find(cur.value); it == value2mem_map.end()) { - for (size_t i = 0; i < f.size(); ++i) { - if (i == 0 && std::get(cur.value).getDefiningOp() - && mlir::isa( - std::get(cur.value).getDefiningOp())) { - continue; - } - if (f.test(i)) { - value2mem_map.insert_or_assign(cur.value, Reg{ .idx = i }); - cur_reg = i; - break; - } - } - } else { - ASSERT(std::holds_alternative(it->second)); - cur_reg = std::get(it->second).idx; - } - - ASSERT(cur_reg.has_value()); - // const_cast(cur.value).print(llvm::outs()); - // llvm::outs() << '\n'; - if (!free.test(*cur_reg)) { - std::optional scratch_reg; - for (size_t i = 1; i < f.size(); ++i) { - if (f.test(i)) { - scratch_reg = i; - break; - } - } - ASSERT(scratch_reg.has_value()); - if (std::holds_alternative(cur.value)) { - auto current_value = std::get(cur.value); - if (current_value.isa()) { - if (auto it = - live_interval_analysis.block_input_mappings.find(cur.value); - it != live_interval_analysis.block_input_mappings.end()) { - for (auto mapped_value : it->second) { - ASSERT(!std::holds_alternative(mapped_value)); - if ((std::get(mapped_value).getDefiningOp() - && clobbers_r0(std::get(mapped_value)))) { - ASSERT(current_value.isa()); - current_value = std::get(mapped_value); - break; - } - } - } - } - ASSERT(!current_value.isa()); - auto loc = current_value.getLoc(); - builder.setInsertionPoint(current_value.getDefiningOp()); - builder.create(loc, *cur_reg); - builder.setInsertionPointAfter(current_value.getDefiningOp()); - builder.create(loc, *scratch_reg, *cur_reg); - builder.create(loc, *cur_reg); - value2mem_map.insert_or_assign( - std::get(cur.value), Reg{ .idx = *scratch_reg }); - free.set(*scratch_reg, false); - } - } else { - free.set(*cur_reg, false); - } - active.insert(cur); - } - } - - decltype(value2mem_map) value2mem_map_additional; - for (auto [value, reg] : value2mem_map) { - if (auto it = live_interval_analysis.block_input_mappings.find(value); - it != live_interval_analysis.block_input_mappings.end()) { - for (auto mapped_value : it->second) { - value2mem_map_additional[mapped_value] = reg; - } - } - } - value2mem_map.merge(std::move(value2mem_map_additional)); - - { - // const auto end = - // std::max_element(live_interval_analysis.sorted_live_intervals.begin(), - // live_interval_analysis.sorted_live_intervals.end(), - // [](const auto &lhs, const auto &rhs) { - // return lhs.end() < rhs.end(); - // })->end(); - - // std::vector> values_by_timestep; - // for (size_t i = 0; i < end; ++i) { - // values_by_timestep.emplace_back(free.size(), nullptr); - // } - - // for (const auto &interval : live_interval_analysis.sorted_live_intervals) { - // for (auto [start, end] : interval.intervals) { - // for (; start < end; ++start) { - // auto reg = value2mem_map[interval.value]; - // if (std::holds_alternative(reg)) { - // values_by_timestep[start][std::get(reg).idx] = - // std::get(interval.value).getImpl(); - // } else { - // TODO(); - // } - // } - // } - // } - - // for (size_t i = 0; const auto &ts : values_by_timestep) { - // llvm::outs() << i++ << ' '; - // for (auto *val : ts) { llvm::outs() << val << ' '; } - // llvm::outs() << '\n'; - // } - } - - // llvm::outs() << "Register allocator modified function:\n"; - // func.print(llvm::outs()); - // llvm::outs() << '\n'; - } -}; - struct PythonBytecodeEmitter { struct FunctionInfo @@ -971,13 +300,12 @@ struct PythonBytecodeEmitter Register get_name_idx(StringRef name) const { - // llvm::outs() << const_cast(m_parent_fn).getName() << '\n'; auto names = const_cast(m_parent_fn).getOperation()->getAttr("names"); ASSERT(names); - auto names_array = names.cast(); + auto names_array = mlir::cast(names); auto it = std::find_if(names_array.begin(), names_array.end(), [&name](mlir::Attribute attr) { - return attr.cast().getValue() == name; + return mlir::cast(attr).getValue() == name; }); ASSERT(it != names_array.end()); const auto idx = std::distance(names_array.begin(), it); @@ -1008,13 +336,13 @@ struct PythonBytecodeEmitter std::vector &func_array) { auto attr = op.getOperation()->getAttr(array_name); if (attr) { - auto array = attr.cast(); + auto array = mlir::cast(attr); func_array.reserve(array.size()); std::transform(array.begin(), array.end(), std::back_inserter(func_array), [](mlir::Attribute attr) { - return attr.cast().getValue().str(); + return mlir::cast(attr).getValue().str(); }); } }; @@ -1029,8 +357,8 @@ struct PythonBytecodeEmitter && std::any_of(op.getAllArgAttrs().begin(), op.getAllArgAttrs().end(), [](mlir::Attribute arg_attr) { - auto vararg = - arg_attr.cast().getAs("llvm.vararg"); + auto vararg = mlir::cast(arg_attr).getAs( + "llvm.vararg"); return vararg && vararg.getValue(); })) { current_function().set_varargs(); @@ -1040,8 +368,8 @@ struct PythonBytecodeEmitter && std::any_of(op.getAllArgAttrs().begin(), op.getAllArgAttrs().end(), [](mlir::Attribute arg_attr) { - auto vararg = - arg_attr.cast().getAs("llvm.kwarg"); + auto vararg = mlir::cast(arg_attr).getAs( + "llvm.kwarg"); return vararg && vararg.getValue(); })) { current_function().set_kwargs(); @@ -1051,7 +379,7 @@ struct PythonBytecodeEmitter auto kwonlyarg_count = std::count_if(op.getAllArgAttrs().begin(), op.getAllArgAttrs().end(), [](mlir::Attribute arg_attr) { - auto vararg = arg_attr.cast().getAs( + auto vararg = mlir::cast(arg_attr).getAs( "llvm.kwonlyarg"); return vararg && vararg.getValue(); }); @@ -1092,14 +420,15 @@ struct PythonBytecodeEmitter } if (current_function().m_flags.is_set(CodeFlags::Flag::VARARGS)) { - auto arg_name = std::find_if(op.getAllArgAttrs().begin(), + auto vararg_attr_it = std::find_if(op.getAllArgAttrs().begin(), op.getAllArgAttrs().end(), [](mlir::Attribute arg_attr) { - auto vararg = arg_attr.cast().getAs( - "llvm.vararg"); + auto vararg = + mlir::cast(arg_attr).getAs( + "llvm.vararg"); return vararg && vararg.getValue(); - }) - ->cast() + }); + auto arg_name = mlir::cast(*vararg_attr_it) .getAs("llvm.name") .getValue(); if (auto it = std::ranges::find(current_function().m_cellvars, arg_name); @@ -1110,14 +439,15 @@ struct PythonBytecodeEmitter } if (current_function().m_flags.is_set(CodeFlags::Flag::VARKEYWORDS)) { - auto arg_name = std::find_if(op.getAllArgAttrs().begin(), + auto kwarg_attr_it = std::find_if(op.getAllArgAttrs().begin(), op.getAllArgAttrs().end(), [](mlir::Attribute arg_attr) { - auto kwarg = arg_attr.cast().getAs( - "llvm.kwarg"); + auto kwarg = + mlir::cast(arg_attr).getAs( + "llvm.kwarg"); return kwarg && kwarg.getValue(); - }) - ->cast() + }); + auto arg_name = mlir::cast(*kwarg_attr_it) .getAs("llvm.name") .getValue(); if (auto it = std::ranges::find(current_function().m_cellvars, arg_name); @@ -1346,10 +676,6 @@ template<> LogicalResult PythonBytecodeEmitter::emitOperation(mlir::emitpybyteco template<> LogicalResult PythonBytecodeEmitter::emitOperation(mlir::emitpybytecode::ForIter &op) { - // auto this_block = std::find( - // m_sorted_blocks.top().begin(), m_sorted_blocks.top().end(), op.getOperation()->getBlock()); - // ASSERT(*(this_block + 1) == op.body()); - auto exit_label = m_block_labels.emplace_back(op.getContinuation(), std::make_shared