diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 9ecdb74f4d82e..91ee89919e58e 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1593,4 +1593,90 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects, let hasVerifier = 1; } +def EmitC_ClassOp + : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove, + OpAsmOpInterface, SymbolTable, + Symbol]#GraphRegionNoTerminator.traits> { + let summary = + "Represents a C++ class definition, encapsulating fields and methods."; + + let description = [{ + The `emitc.class` operation defines a C++ class, acting as a container + for its data fields (`emitc.field`) and methods (`emitc.func`). + It creates a distinct scope, isolating its contents from the surrounding + MLIR region, similar to how C++ classes encapsulate their internals. + + Example: + + ```mlir + emitc.class @modelClass { + emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"} + emitc.func @execute() { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = get_field @fieldName0 : !emitc.array<1xf32> + %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + return + } + } + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + // Returns the body block containing class members and methods. + Block &getBlock(); + }]; + + let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }]; +} + +def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { + let summary = "A field within a class"; + let description = [{ + The `emitc.field` operation declares a named field within an `emitc.class` + operation. The field's type must be an EmitC type. + + Example: + + ```mlir + // Example with an attribute: + emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} + // Example with no attribute: + emitc.field @fieldName0 : !emitc.array<1xf32> + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, + OptionalAttr:$attrs); + + let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; + + let hasVerifier = 1; +} + +def EmitC_GetFieldOp + : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods< + SymbolUserOpInterface>]> { + let summary = "Obtain access to a field within a class instance"; + let description = [{ + The `emitc.get_field` operation retrieves the lvalue of a + named field from a given class instance. + + Example: + + ```mlir + %0 = get_field @fieldName0 : !emitc.array<1xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$field_name); + let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); + let assemblyFormat = "$field_name `:` type($result) attr-dict"; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITC diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h index 5a103f181c76b..1af4aa06fa811 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h @@ -15,6 +15,7 @@ namespace mlir { namespace emitc { #define GEN_PASS_DECL_FORMEXPRESSIONSPASS +#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td index f46b705ca2dfe..74c49132b61f6 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td @@ -20,4 +20,42 @@ def FormExpressionsPass : Pass<"form-expressions"> { let dependentDialects = ["emitc::EmitCDialect"]; } +def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { + let summary = "Wrap functions in classes, using arguments as fields."; + let description = [{ + This pass transforms `emitc.func` operations into `emitc.class` operations. + Function arguments become fields of the class, and the function body is moved + to a new `execute` method within the class. + If the corresponding function argument has attributes (accessed via `argAttrs`), + these attributes are attached to the field operation. + Otherwise, the field is created without additional attributes. + + Example: + + ```mlir + emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + return + } + // becomes + emitc.class @modelClass { + emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"} + emitc.func @execute() { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = get_field @input_tensor : !emitc.array<1xf32> + %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + return + } + } + ``` + }]; + let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "namedAttribute", "named-attribute", "std::string", + /*default=*/"", + "Attribute key used to extract field names from function argument's " + "dictionary attributes">]; +} + #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h index 2574acd7d48e0..a4e8fe10ff853 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h @@ -28,6 +28,10 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder); /// Populates `patterns` with expression-related patterns. void populateExpressionPatterns(RewritePatternSet &patterns); +/// Populates 'patterns' with func-related patterns. +void populateFuncPatterns(RewritePatternSet &patterns, + StringRef namedAttribute); + } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e602210c2dc6c..d17c4afab016c 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1400,6 +1400,49 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { builder.getNamedAttr("id", builder.getStringAttr(id))); } +//===----------------------------------------------------------------------===// +// FieldOp +//===----------------------------------------------------------------------===// +LogicalResult FieldOp::verify() { + if (!isSupportedEmitCType(getType())) + return emitOpError("expected valid emitc type"); + + Operation *parentOp = getOperation()->getParentOp(); + if (!parentOp || !isa(parentOp)) + return emitOpError("field must be nested within an emitc.class operation"); + + StringAttr symName = getSymNameAttr(); + if (!symName || symName.getValue().empty()) + return emitOpError("field must have a non-empty symbol name"); + + if (!getAttrs()) + return success(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GetFieldOp +//===----------------------------------------------------------------------===// +LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); + FieldOp fieldOp = + symbolTable.lookupNearestSymbolFrom(*this, fieldNameAttr); + if (!fieldOp) + return emitOpError("field '") + << fieldNameAttr << "' not found in the class"; + + Type getFieldResultType = getResult().getType(); + Type fieldType = fieldOp.getType(); + + if (fieldType != getFieldResultType) + return emitOpError("result type ") + << getFieldResultType << " does not match field '" << fieldNameAttr + << "' type " << fieldType; + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt index 19b80b22bd84b..baf67afc30072 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms Transforms.cpp FormExpressions.cpp TypeConversions.cpp + WrapFuncInClass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp new file mode 100644 index 0000000000000..17d436f6df028 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -0,0 +1,112 @@ +//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +using namespace mlir; +using namespace emitc; + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +namespace { +struct WrapFuncInClassPass + : public impl::WrapFuncInClassPassBase { + using WrapFuncInClassPassBase::WrapFuncInClassPassBase; + void runOnOperation() override { + Operation *rootOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateFuncPatterns(patterns, namedAttribute); + + walkAndApplyPatterns(rootOp, std::move(patterns)); + } +}; + +} // namespace +} // namespace emitc +} // namespace mlir + +class WrapFuncInClass : public OpRewritePattern { +public: + WrapFuncInClass(MLIRContext *context, StringRef attrName) + : OpRewritePattern(context), attributeName(attrName) {} + + LogicalResult matchAndRewrite(emitc::FuncOp funcOp, + PatternRewriter &rewriter) const override { + + auto className = funcOp.getSymNameAttr().str() + "Class"; + ClassOp newClassOp = rewriter.create(funcOp.getLoc(), className); + + SmallVector> fields; + rewriter.createBlock(&newClassOp.getBody()); + rewriter.setInsertionPointToStart(&newClassOp.getBody().front()); + + auto argAttrs = funcOp.getArgAttrs(); + for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { + StringAttr fieldName; + Attribute argAttr = nullptr; + + fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); + if (argAttrs && idx < argAttrs->size()) + argAttr = (*argAttrs)[idx]; + + TypeAttr typeAttr = TypeAttr::get(val.getType()); + fields.push_back({fieldName, typeAttr}); + rewriter.create(funcOp.getLoc(), fieldName, typeAttr, + argAttr); + } + + rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); + FunctionType funcType = funcOp.getFunctionType(); + Location loc = funcOp.getLoc(); + FuncOp newFuncOp = + rewriter.create(loc, ("execute"), funcType); + + rewriter.createBlock(&newFuncOp.getBody()); + newFuncOp.getBody().takeBody(funcOp.getBody()); + + rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); + std::vector newArguments; + newArguments.reserve(fields.size()); + for (auto &[fieldName, attr] : fields) { + GetFieldOp arg = + rewriter.create(loc, attr.getValue(), fieldName); + newArguments.push_back(arg); + } + + for (auto [oldArg, newArg] : + llvm::zip(newFuncOp.getArguments(), newArguments)) { + rewriter.replaceAllUsesWith(oldArg, newArg); + } + + llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true); + if (failed(newFuncOp.eraseArguments(argsToErase))) + newFuncOp->emitOpError("failed to erase all arguments using BitVector"); + + rewriter.replaceOp(funcOp, newClassOp); + return success(); + } + +private: + StringRef attributeName; +}; + +void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns, + StringRef namedAttribute) { + patterns.add(patterns.getContext(), namedAttribute); +} diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir new file mode 100644 index 0000000000000..c67a0c197fcd9 --- /dev/null +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s + +module attributes { } { + emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, + %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"}, + %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + %2 = load %1 : + %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + %4 = load %3 : + %5 = add %2, %4 : (f32, f32) -> f32 + %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue + assign %5 : f32 to %6 : + return + } +} + + +// CHECK: module { +// CHECK-NEXT: emitc.class @modelClass { +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} +// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} +// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} +// CHECK-NEXT: emitc.func @execute() { +// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32> +// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32> +// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: load {{.*}} : +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: load {{.*}} : +// CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32 +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue +// CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir new file mode 100644 index 0000000000000..92ed20c4b14e3 --- /dev/null +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s + +emitc.func @foo(%arg0 : !emitc.array<1xf32>) { + emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> () + emitc.return +} + +// CHECK: module { +// CHECK-NEXT: emitc.class @fooClass { +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.func @execute() { +// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: }