diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 115a11b346780..80337fc30bc66 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -201,9 +201,9 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder< // and optional initial value. The builder will extract var_shape and element type // attributes from variable type. def Tosa_VariableOpBuilder : OpBuilder< - (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value), + (ins "StringRef":$sym_name, "Type":$variable_type, "Attribute":$initial_value), [{ - buildVariableOp($_builder, $_state, name, variable_type, initial_value); + buildVariableOp($_builder, $_state, sym_name, variable_type, initial_value); }]>; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index d819cc198e3f2..f1a618e75061b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -18,6 +18,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" @@ -82,7 +83,7 @@ def Tosa_YieldOp : Tosa_Op<"yield", [ //===----------------------------------------------------------------------===// // Operator: variable //===----------------------------------------------------------------------===// -def Tosa_VariableOp : Tosa_Op<"variable", []> { +def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> { let summary = "Defines a variable"; let description = [{ @@ -91,7 +92,10 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { }]; let arguments = (ins - SymbolNameAttr:$name, + // Note: "sym_name" is used as opposed to "name" in the specification, + // since a Symbol must be named "sym_name" for it to be recognised by + // the containing SymbolTable. + SymbolNameAttr:$sym_name, IndexElementsAttr:$var_shape, TypeAttr:$type, OptionalAttr:$initial_value @@ -105,14 +109,18 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ - $name + $sym_name attr-dict custom($var_shape, $type, $initial_value) }]; let builders = [Tosa_VariableOpBuilder]; - let hasVerifier = 1; + let extraClassDeclaration = [{ + ::llvm::StringRef getName() { + return getSymName(); + } + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 332f1a0e5506f..c51b5e9cbfc78 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -905,56 +905,29 @@ static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { return shapeAdaptor.getNumElements() == 1 ? success() : failure(); } -// Returns the first declaration point prior to this operation or failure if -// not found. -static FailureOr findVariableDecl(Operation *op, - StringRef symName) { - ModuleOp module = op->getParentOfType(); - tosa::VariableOp varOp = nullptr; - - // TODO: Adopt SymbolTable trait to Varible ops. - // Currently, the variable's definition point is searched via walk(), - // starting from the top-level ModuleOp and stopping at the point of use. Once - // TOSA control flow and variable extensions reach the complete state, may - // leverage MLIR's Symbol Table functionality to look up symbol and enhance - // the search to a TOSA specific graph traversal over the IR structure. - module.walk([&](Operation *tempOp) { - // Reach this op itself. - if (tempOp == op) { - return WalkResult::interrupt(); - } - - if (auto tosaOp = dyn_cast(tempOp)) { - if (symName == tosaOp.getName()) { - varOp = tosaOp; - return WalkResult::interrupt(); - } - } - - return WalkResult::advance(); - }); - - if (varOp) - return varOp; - - return failure(); -} - template static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { - StringRef symName = op.getName(); - FailureOr varOp = findVariableDecl(op, symName); - if (failed(varOp)) + Operation *symTableOp = + op->template getParentWithTrait(); + if (!symTableOp) + // If the operation is not the scope of a symbol table, we cannot + // verify it against it's declaration. + return success(); + + SymbolTable symTable(symTableOp); + const auto varOp = symTable.lookup(op.getName()); + + // Verify prior declaration + if (!varOp) return op->emitOpError("'") - << symName << "' has not been declared by 'tosa.variable'"; + << op.getName() << "' has not been declared by 'tosa.variable'"; // Verify type and shape - auto variableType = getVariableType(varOp.value()); + auto variableType = getVariableType(varOp); if (errorIfTypeOrShapeMismatch(op, type, name, variableType, "the input tensor") .failed()) return failure(); - return success(); } @@ -1418,7 +1391,7 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result, ArrayRef shape = shapedType.getShape(); auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); - result.addAttribute("name", nameAttr); + result.addAttribute("sym_name", nameAttr); result.addAttribute("var_shape", varShapeAttr); result.addAttribute("type", elementTypeAttr); result.addAttribute("initial_value", initialValue); @@ -4160,16 +4133,6 @@ LogicalResult tosa::SelectOp::verify() { return success(); } -LogicalResult tosa::VariableOp::verify() { - StringRef symName = getName(); - FailureOr varOp = findVariableDecl(*this, symName); - if (succeeded(varOp)) - return emitOpError("illegal to have multiple declaration of '") - << symName << "'"; - - return success(); -} - LogicalResult tosa::VariableReadOp::verify() { if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'") .failed()) diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 41c3243792259..e60f1c9b4a01a 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -573,64 +573,61 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten // ----- -func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () { +module { tosa.variable @stored_var : tensor<*xi8> // expected-error@+1 {{custom op 'tosa.variable' expected ranked type}} - return } // ----- -func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () { +module { // expected-error@+1 {{elements literal type must have static shape}} tosa.variable @stored_var = dense<0> : tensor<*xi8> // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} - return -} - -// ----- - -func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () { - tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}} - tosa.variable @stored_var = dense<3> : tensor<1x4x8xi8> - return } // ----- -func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { +module { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}} - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16> - return + func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i16') and the input tensor ('i8')}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi16> + return + } } // ----- -func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () { +module { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}} - %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32> - return + func.func @test_variable_read_shape(%arg0: tensor<2x4x8xi8>) -> () { + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('i8'}} + %0 = tosa.variable_read @stored_var : tensor<1x4x8xi32> + return + } } // ----- -func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { +module { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}} - tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16> - return + func.func @test_variable_write_type(%arg0: tensor<2x4x8xi16>) -> () { + // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i16') and the input tensor ('i8')}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi16> + return + } } // ----- -func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { +module { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}} - tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8> - return + func.func @test_variable_write_shape(%arg0: tensor<1x4x8xi8>) -> () { + // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<1x4x8xi8>') and the input tensor ('tensor<2x4x8xi8>')}} + tosa.variable_write @stored_var, %arg0 : tensor<1x4x8xi8> + return + } } // ----- diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 3138ce2621a3a..1daabe9222a9b 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -310,21 +310,27 @@ func.func @test_identity(%arg0: tensor<13x21x3xi4>) -> tensor<13x21x3xi4> { } // ----- -func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { +module { // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}} tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}} - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8> - return + + func.func @test_variable_read_type(%arg0: tensor<2x4x8xi8>) -> () { + // expected-error@+1 {{'tosa.variable_read' op illegal: requires [variable]}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi8> + return + } } // ----- -func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () { +module { // expected-error@+1 {{'tosa.variable' op illegal: requires [variable] but not enabled in target}} tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> - // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}} - tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8> - return + + func.func @test_variable_write_type(%arg0: tensor<2x4x8xi8>) -> () { + // expected-error@+1 {{'tosa.variable_write' op illegal: requires [variable]}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi8> + return + } } // ----- diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 3742adf650408..5bf2dbb8d02b1 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1097,14 +1097,17 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, % // ----- -func.func @test_variable_read_write_tensor_size_invalid() -> () { +module { // expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} tosa.variable @stored_var : tensor<536870912xf32> - // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - %0 = tosa.variable_read @stored_var : tensor<536870912xf32> - // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - tosa.variable_write @stored_var, %0 : tensor<536870912xf32> - return + + func.func @test_variable_read_write_tensor_size_invalid() -> () { + // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + %0 = tosa.variable_read @stored_var : tensor<536870912xf32> + // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + tosa.variable_write @stored_var, %0 : tensor<536870912xf32> + return + } } // ----- @@ -1165,14 +1168,17 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1: // ----- -func.func @test_variable_read_write_rank_invalid() -> () { +module { // expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}} tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32> - // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}} - %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32> - // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}} - tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32> - return + + func.func @test_variable_read_write_rank_invalid() -> () { + // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}} + %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32> + // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}} + tosa.variable_write @stored_var, %0 : tensor<1x1x1x1x1x1x1x1xf32> + return + } } // ----- diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir index 9953eb375d3ac..0c104e8e8d7ea 100644 --- a/mlir/test/Dialect/Tosa/variables.mlir +++ b/mlir/test/Dialect/Tosa/variables.mlir @@ -3,76 +3,98 @@ // ----- -// CHECK-LABEL: @test_variable_scalar( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { -func.func @test_variable_scalar(%arg0: tensor) -> () { - // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor + +module { + // CHECK: tosa.variable @stored_var = dense<3.140000e+00> : tensor tosa.variable @stored_var = dense<3.14> : tensor - // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor - %0 = tosa.variable_read @stored_var : tensor - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor - %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor - // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor - tosa.variable_write @stored_var, %1 : tensor - return + + // CHECK-LABEL: @test_variable_scalar( + // CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { + func.func @test_variable_scalar(%arg0: tensor) -> () { + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor + %0 = tosa.variable_read @stored_var : tensor + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor + %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor + tosa.variable_write @stored_var, %1 : tensor + return + } } + // ----- -// CHECK-LABEL: @test_variable_tensor( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { -func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () { - // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> + +module { + // CHECK: tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> - tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> - return + + // CHECK-LABEL: @test_variable_tensor( + // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { + func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> + return + } } // ----- -// CHECK-LABEL: @test_variable_scalar_no_initial_value( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { -func.func @test_variable_scalar_no_initial_value(%arg0: tensor) -> () { - // CHECK: tosa.variable @stored_var : tensor + +module { + // CHECK: tosa.variable @stored_var : tensor tosa.variable @stored_var : tensor - // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor - %0 = tosa.variable_read @stored_var : tensor - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor - %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor - // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor - tosa.variable_write @stored_var, %1 : tensor - return + + // CHECK-LABEL: @test_variable_scalar_no_initial_value( + // CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { + func.func @test_variable_scalar_no_initial_value(%arg0: tensor) -> () { + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor + %0 = tosa.variable_read @stored_var : tensor + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor + %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor + tosa.variable_write @stored_var, %1 : tensor + return + } } // ----- -// CHECK-LABEL: @test_variable_tensor_no_initial_value( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { -func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () { - // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32> + +module { + // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32> tosa.variable @stored_var : tensor<2x4x8xi32> - // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> - tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> - return + + // CHECK-LABEL: @test_variable_tensor_no_initial_value( + // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { + func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> + return + } } + // ----- -// CHECK-LABEL: @test_variable_tensor_with_unknowns( -// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { -func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () { - // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32> + +module { + // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32> tosa.variable @stored_var : tensor<2x?x8xi32> - // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> - // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> - // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> - tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> - return + + // CHECK-LABEL: @test_variable_tensor_with_unknowns( + // CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { + func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> + return + } } diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 0128da729136e..430b06ad16c39 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -944,29 +944,27 @@ func.func @test_while_loop_cond_output_not_bool(%arg0: tensor<10xi32>, %arg1: te // ----- -func.func @test_variable_multiple_declaration() -> () { +module { + // expected-note@below {{see existing symbol definition here}} tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi32> - // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}} + // expected-error@+1 {{redefinition of symbol named 'stored_var'}} tosa.variable @stored_var = dense<-3> : tensor<2x4x8xi32> - return } // ----- -func.func @test_variable_shape_mismatch() -> () { +module { // expected-error@+1 {{inferred shape of elements literal ([2]) does not match type ([3])}} tosa.variable @stored_var = dense<[3.14, 2.14]> : tensor<3xf32> // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} - return } // ----- -func.func @test_variable_type_mismatch() -> () { +module { // expected-error@+1 {{expected integer elements, but parsed floating-point}} tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xi32> // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} - return } // ----- @@ -979,20 +977,26 @@ func.func @test_variable_read_no_declaration() -> () { // ----- -func.func @test_variable_read_type_mismatch() -> () { +module { tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32> - // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}} - %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> - return + + func.func @test_variable_read_type_mismatch() -> () { + // expected-error@+1 {{'tosa.variable_read' op require same element type for 'output1' ('i32') and the input tensor ('f32')}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + return + } } // ----- -func.func @test_variable_read_shape_mismatch() -> () { +module { tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32> - // expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} - %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32> - return + + func.func @test_variable_read_shape_mismatch() -> () { + // expected-error@+1 {{'tosa.variable_read' op require same shapes for 'output1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} + %0 = tosa.variable_read @stored_var : tensor<2x4x8xf32> + return + } } // ----- @@ -1005,20 +1009,26 @@ func.func @test_variable_write_no_declaration(%arg0: tensor) -> () { // ----- -func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () { +module { tosa.variable @stored_var = dense<-1.2> : tensor<2x4x8xf32> - // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}} - tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32> - return + + func.func @test_variable_write_type_mismatch(%arg0: tensor<2x4x8xi32>) -> () { + // expected-error@+1 {{'tosa.variable_write' op require same element type for 'input1' ('i32') and the input tensor ('f32')}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xi32> + return + } } // ----- -func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () { +module { tosa.variable @stored_var = dense<-1.2> : tensor<8x4x2xf32> - // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} - tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32> - return + + func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () { + // expected-error@+1 {{'tosa.variable_write' op require same shapes for 'input1' ('tensor<2x4x8xf32>') and the input tensor ('tensor<8x4x2xf32>')}} + tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32> + return + } } // -----