diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index d8d4027500f99c..c411010603ac61 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -35,8 +35,8 @@ std::unique_ptr createTosaToLinalgNamed(); void addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, // Note: Default to 'none' level unless otherwise specified. - tosa::ValidationOptions const &validationOptions = - tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None)); + tosa::TosaValidationOptions const &validationOptions = { + tosa::TosaProfileEnum::Undefined, false, tosa::TosaLevelEnum::None}); /// Populates conversion passes from TOSA dialect to Linalg dialect. void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 555d9bea18ba4d..a9bc3351f4cff0 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,6 +34,11 @@ class PatternRewriter; namespace tosa { +ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &attr); +void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, + Attribute attr); + #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" } // namespace tosa diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index d75f5dffa8716c..f9f25da1b649de 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -79,4 +79,71 @@ def Tosa_YieldOp : Tosa_Op<"yield", [ let assemblyFormat = "$inputs attr-dict `:` type($inputs)"; } +//===----------------------------------------------------------------------===// +// Operator: variable +//===----------------------------------------------------------------------===// +def Tosa_VariableOp : Tosa_Op<"variable", []> { + let summary = "Defines a variable"; + + let description = [{ + Defines a new TOSA variable. This is a mutable value. + Modifications are expressed using read/write semantics. + }]; + + let arguments = (ins + SymbolNameAttr:$name, + TypeAttr:$type, + OptionalAttr:$initial_value + ); + + let assemblyFormat = [{ + $name + attr-dict + custom($type, $initial_value) + }]; +} + +//===----------------------------------------------------------------------===// +// Operator: variable.write +//===----------------------------------------------------------------------===// +def Tosa_VariableWriteOp : Tosa_Op<"variable.write", []> { + let summary = "write_buffer operator"; + + let description = [{ + Assigns a value to pseudo-buffer resource holding a mutable tensor. + }]; + + let arguments = (ins + SymbolNameAttr:$name, + AnyType:$value + ); + + let assemblyFormat = [{ + $name attr-dict `,` $value `:` type($value) + }]; +} + +//===----------------------------------------------------------------------===// +// Operator: variable.read +//===----------------------------------------------------------------------===// +def Tosa_VariableReadOp : Tosa_Op<"variable.read", []> { + let summary = "read_buffer operator"; + + let description = [{ + Reads the value from a pseudo-buffer resource holding a mutable tensor. + }]; + + let arguments = (ins + SymbolNameAttr:$name + ); + + let results = (outs + AnyType:$value + ); + + let assemblyFormat = [{ + $name attr-dict `:` type($value) + }]; +} + #endif // TOSA_UTIL_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h index 940aed107e2f91..fbfc56dfe2cf4f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -68,9 +68,6 @@ struct ValidationOptions { } }; -std::unique_ptr createTosaValidationPass( - ValidationOptions const &options = ValidationOptions()); - #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index ac100a6d75c7c0..a0f670de20150f 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -89,13 +89,12 @@ def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", let cppNamespace = "mlir::tosa"; } -def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { +def TosaValidation : Pass<"tosa-validate", "mlir::ModuleOp"> { let summary = "Validates TOSA dialect"; let description = [{ This pass validates if input TOSA operations match the specification for given criteria, e.g. TOSA profile. }]; - let constructor = "createTosaValidationPass()"; let options = [ Option<"profile", "profile", "mlir::tosa::TosaProfileEnum", diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 718e34ced8d7e7..3c54f85b033b0b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -76,7 +76,7 @@ std::unique_ptr mlir::tosa::createTosaToLinalg() { void mlir::tosa::addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, - tosa::ValidationOptions const &validationOptions) { + tosa::TosaValidationOptions const &validationOptions) { // Optional decompositions are designed to benefit linalg. if (!options.disableTosaDecompositions) pm.addNestedPass(tosa::createTosaOptionalDecompositions()); @@ -90,7 +90,6 @@ void mlir::tosa::addTosaToLinalgPasses( pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass( {options.aggressiveReduceConstant})); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); - pm.addNestedPass( - tosa::createTosaValidationPass(validationOptions)); + pm.addNestedPass(tosa::createTosaValidation(validationOptions)); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 6db04fe38bcd35..ff34183f9a030a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -146,6 +146,49 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, return nullptr; } +//===----------------------------------------------------------------------===// +// Parsers and printers +//===----------------------------------------------------------------------===// + +ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &attr) { + if (succeeded(parser.parseOptionalEqual())) { + if (failed(parser.parseAttribute(attr))) { + return parser.emitError(parser.getCurrentLocation()) + << "expected attribute"; + } + if (auto typedAttr = attr.dyn_cast()) { + typeAttr = TypeAttr::get(typedAttr.getType()); + } + return success(); + } + + Type type; + if (failed(parser.parseColonType(type))) { + return parser.emitError(parser.getCurrentLocation()) << "expected type"; + } + typeAttr = TypeAttr::get(type); + + return success(); +} + +void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, + Attribute attr) { + bool needsSpace = false; + auto typedAttr = attr.dyn_cast_or_null(); + if (!typedAttr || typedAttr.getType() != type.getValue()) { + p << ": "; + p.printAttribute(type); + needsSpace = true; // subsequent attr value needs a space separator + } + if (attr) { + if (needsSpace) + p << ' '; + p << "= "; + p.printAttribute(attr); + } +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 52885e69c3924f..d686ce125c1351 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -14,6 +14,9 @@ #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" +#include +#include + #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Builders.h" @@ -96,12 +99,13 @@ static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0}; struct TosaValidation : public tosa::impl::TosaValidationBase { public: explicit TosaValidation() { populateConstantOperandChecks(); } - explicit TosaValidation(const ValidationOptions &options) : TosaValidation() { + explicit TosaValidation(const TosaValidationOptions &options) + : TosaValidation() { this->profile = options.profile; - this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment; + this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment; this->level = options.level; } - void runOnOperation() override; + void runOnOperation() final; LogicalResult applyConstantOperandCheck(Operation *op) { for (auto &checker : const_checkers) { @@ -113,6 +117,9 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { LogicalResult applyLevelCheck(Operation *op); + // check variable read/write data types against variable declarations + LogicalResult applyVariableCheck(Operation *op); + private: void populateConstantOperandChecks() { const_checkers.emplace_back(checkConstantOperandPad); @@ -398,8 +405,12 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } } + bool CheckVariable(Operation *op); + bool CheckVariableReadOrWrite(Operation *op); + SmallVector> const_checkers; tosa_level_t tosa_level; + DenseMap variables_map; }; LogicalResult TosaValidation::applyLevelCheck(Operation *op) { @@ -427,6 +438,69 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) { return success(); } +inline bool CompatibleTypes(const mlir::Type &type, + const mlir::Type &declared_type) { + // for now, simply use type equality comparison + return type == declared_type; +} + +bool TosaValidation::CheckVariable(Operation *op) { + if (isa(op)) { + auto name_attr = cast(op->getAttr("name")); + + if (variables_map.count(&name_attr)) { + op->emitOpError() << "name has already been declared"; + return false; + } + + auto type_attr = cast(op->getAttr("type")); + mlir::Type type = type_attr.getValue(); + + variables_map[&name_attr] = type; + } + + return true; +} + +bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { + if (isa(op) || + isa(op)) { + auto name_attr = cast(op->getAttr("name")); + + if (!variables_map.count(&name_attr)) { + op->emitOpError() << "name has not been declared"; + return false; + } + + auto var_type = variables_map[&name_attr]; + + for (auto v : op->getOperands()) { + auto type = v.getType(); + if (!CompatibleTypes(type, var_type)) { + op->emitOpError() << "operand type does not equal variable type"; + return false; + } + } + + for (auto v : op->getResults()) { + auto type = v.getType(); + if (!CompatibleTypes(type, var_type)) { + op->emitOpError() << "result type does not equal variable type"; + return false; + } + } + } + + return true; +} + +LogicalResult TosaValidation::applyVariableCheck(Operation *op) { + if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { + return failure(); + } + return success(); +} + void TosaValidation::runOnOperation() { configLevelAndProfile(); getOperation().walk([&](Operation *op) { @@ -440,18 +514,18 @@ void TosaValidation::runOnOperation() { } } - // Some uses of TOSA rely on the constant operands of particular operations. + // Some uses of TOSA rely on the constant operands of particular + // operations. if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) signalPassFailure(); // do level checks if (failed(applyLevelCheck(op))) signalPassFailure(); + + // do variable type checks + if (failed(applyVariableCheck(op))) + signalPassFailure(); }); } } // namespace - -std::unique_ptr -mlir::tosa::createTosaValidationPass(ValidationOptions const &options) { - return std::make_unique(options); -} diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 7c58bb10b9c5ed..9233662e88db90 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -203,3 +203,48 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor< : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } + +// ----- + +func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi32>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable' op name has already been declared}} + tosa.variable @stored_var : tensor<1x4x8xi32> + return +} + +// ----- + +func.func @test_variable_read_type(%arg0: tensor<2x4x8xi32>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} + %0 = tosa.variable.read @stored_var : tensor<2x4x8xi16> + return +} + +// ----- + +func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi32>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.read' op result type does not equal variable type}} + %0 = tosa.variable.read @stored_var : tensor<1x4x8xi32> + return +} + +// ----- + +func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} + tosa.variable.write @stored_var, %arg0 : tensor<2x4x8xi16> + return +} + +// ----- + +func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi32>) -> () { + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // expected-error@+1 {{'tosa.variable.write' op operand type does not equal variable type}} + tosa.variable.write @stored_var, %arg0 : tensor<1x4x8xi32> + return +} diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir new file mode 100644 index 00000000000000..9a26aa0bc8bf4d --- /dev/null +++ b/mlir/test/Dialect/Tosa/variables.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: @test_variable_scalar( +// CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { +func.func @test_variable_scalar(%arg0: tensor) -> () { + // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor + tosa.variable @stored_var = dense<3.14> : tensor + // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor + %0 = tosa.variable.read @stored_var : tensor + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor + %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor + tosa.variable.write @stored_var, %1 : tensor + return +} + +// ----- +// CHECK-LABEL: @test_variable_tensor( +// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { +func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + // CHECK: %[[STORED_VAL:.*]] = tosa.variable.read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable.read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable.write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable.write @stored_var, %1 : tensor<2x4x8xi32> + return +}