diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 0aef4653b74ff..e048f8af7cc33 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -197,6 +197,16 @@ def Tosa_PadOpQuantInfoBuilder : OpBuilder< input, paddings); }]>; +// This builder is called on the TOSA variable operator with a variable type +// and optional initial value. The builder will extract var_shape and element type +// attributes from variable type. +def Tosa_VariableOpBuilder : OpBuilder< + (ins "StringRef":$name, "Type":$variable_type, "Attribute":$initial_value), + [{ + buildVariableOp($_builder, $_state, name, variable_type, initial_value); + }]>; + + // Wrapper over base I32EnumAttr to set common fields. class Tosa_I32Enum cases> : I32EnumAttr { diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 6fa4aedc1f0b0..a15f073bc5fcb 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -44,10 +44,14 @@ class PatternRewriter; namespace tosa { -ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, - Attribute &attr); -void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, - Attribute attr); +ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, + DenseElementsAttr &varShapeAttr, + TypeAttr &typeAttr, + Attribute &initialValueAttr); +void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, + DenseElementsAttr varShapeAttr, + TypeAttr typeAttr, + Attribute initialValueAttr); #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" @@ -172,6 +176,9 @@ std::optional createZeroPointTensor(OpBuilder &builder, Location loc, Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val = 0); +// returns type of variable op +RankedTensorType getVariableType(VariableOp variableOp); + } // namespace tosa } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td index 5f99162907949..c8f2907f8dd1b 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -92,6 +92,7 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { let arguments = (ins SymbolNameAttr:$name, + IndexElementsAttr:$var_shape, TypeAttr:$type, OptionalAttr:$initial_value ); @@ -101,12 +102,16 @@ def Tosa_VariableOp : Tosa_Op<"variable", []> { Extension<[Tosa_EXT_VARIABLE]>, ]; + let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ $name attr-dict - custom($type, $initial_value) + custom($var_shape, $type, $initial_value) }]; + let builders = [Tosa_VariableOpBuilder]; + let hasVerifier = 1; } diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index 310566e692202..7dbccd19a0518 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -26,8 +26,9 @@ class VariableOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::VariableOp op, PatternRewriter &rewriter) const final { + auto variableType = tosa::getVariableType(op); auto newVariable = rewriter.create( - op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true, + op.getLoc(), op.getName(), variableType, /*is_mutable=*/true, op.getInitialValueAttr(), /*sym_visibility=*/nullptr); newVariable.setPrivate(); rewriter.replaceOp(op, newVariable); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 93a6a8be48df7..a22e6b7aa9791 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -131,6 +131,24 @@ SmallVector tosa::WhileOp::getLoopRegions() { return {&getBodyGraph()}; } +//===----------------------------------------------------------------------===// +// TOSA variable operator support. +//===----------------------------------------------------------------------===// + +static SmallVector convertToMlirShape(ArrayRef shape) { + return to_vector(llvm::map_range(shape, [](int64_t dim) { + return dim == -1 ? ShapedType::kDynamic : dim; + })); +} + +// returns type of variable op +RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) { + Type elementType = variableOp.getType(); + DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); + auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues())); + return RankedTensorType::get(shape, elementType); +} + //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// @@ -177,42 +195,80 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, // Parsers and printers //===----------------------------------------------------------------------===// -ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, - Attribute &attr) { +namespace { + +ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType, + DenseElementsAttr &varShapeAttr, + TypeAttr &typeAttr) { + if (auto shapedType = dyn_cast(parsedType)) { + if (!shapedType.hasRank()) + return parser.emitError(parser.getCurrentLocation()) + << "expected ranked type"; + + auto elementType = shapedType.getElementType(); + typeAttr = TypeAttr::get(elementType); + ArrayRef shape = shapedType.getShape(); + Builder builder(parser.getContext()); + varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); + return success(); + } + return parser.emitError(parser.getCurrentLocation()) + << "expected shaped type"; +} + +} // namespace + +// parses the optional initial value or type for a tosa variable +// with initial value: +// tosa.variable @name = dense<0.0> : tensor<1x8xf32> +// +// without initial value: +// tosa.variable @name : tensor<1x8xf32> +ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue( + OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, + Attribute &initialValueAttr) { if (succeeded(parser.parseOptionalEqual())) { - if (failed(parser.parseAttribute(attr))) { + if (failed(parser.parseAttribute(initialValueAttr))) { return parser.emitError(parser.getCurrentLocation()) << "expected attribute"; } - if (auto typedAttr = dyn_cast(attr)) { - typeAttr = TypeAttr::get(typedAttr.getType()); + if (auto typedAttr = dyn_cast(initialValueAttr)) { + return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr, + typeAttr); } - return success(); + return parser.emitError(parser.getCurrentLocation()) + << "expected Typed attr"; } - Type type; - if (failed(parser.parseColonType(type))) { - return parser.emitError(parser.getCurrentLocation()) << "expected type"; + initialValueAttr = nullptr; + Type parsedType; + if (failed(parser.parseColonType(parsedType))) { + return parser.emitError(parser.getCurrentLocation()) + << "expected type after colon"; } - typeAttr = TypeAttr::get(type); - - return success(); + return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr); } -void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, - Attribute attr) { +void mlir::tosa::printVariableOpTypeOrInitialValue( + OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, + TypeAttr typeAttr, Attribute initialValueAttr) { bool needsSpace = false; - auto typedAttr = dyn_cast_or_null(attr); - if (!typedAttr || typedAttr.getType() != type.getValue()) { + if (!dyn_cast_or_null(initialValueAttr)) { + auto shape = + convertToMlirShape(to_vector(varShapeAttr.getValues())); + Type elementType = typeAttr.getValue(); + RankedTensorType tensorType = + RankedTensorType::get(ArrayRef(shape), elementType); + auto tensorTypeAttr = TypeAttr::get(tensorType); p << ": "; - p.printAttribute(type); + p.printAttribute(tensorTypeAttr); needsSpace = true; // subsequent attr value needs a space separator } - if (attr) { + if (initialValueAttr) { if (needsSpace) p << ' '; p << "= "; - p.printAttribute(attr); + p.printAttribute(initialValueAttr); } } @@ -657,8 +713,9 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { << symName << "' has not been declared by 'tosa.variable'"; // Verify type and shape - Type varType = cast(varOp.value()).getType(); - if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor") + auto variableType = getVariableType(varOp.value()); + if (errorIfTypeOrShapeMismatch(op, type, name, variableType, + "the input tensor") .failed()) return failure(); @@ -1103,6 +1160,33 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, result.types.push_back(outputType); } +static void buildVariableOp(OpBuilder &builder, OperationState &result, + StringRef name, Type variableType, + Attribute initialValue) { + const Location loc{result.location}; + auto nameAttr = builder.getStringAttr(name); + + auto shapedType = dyn_cast(variableType); + if (!shapedType) { + (void)emitError(loc, "variable type must be a shaped type"); + return; + } + if (!shapedType.hasRank()) { + (void)emitError(loc, "variable type must be a ranked type"); + return; + } + + auto elementType = shapedType.getElementType(); + auto elementTypeAttr = TypeAttr::get(elementType); + ArrayRef shape = shapedType.getShape(); + auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); + + result.addAttribute("name", nameAttr); + result.addAttribute("var_shape", varShapeAttr); + result.addAttribute("type", elementTypeAttr); + result.addAttribute("initial_value", initialValue); +} + //===----------------------------------------------------------------------===// // TOSA Operator Return Type Inference. //===----------------------------------------------------------------------===// @@ -1676,12 +1760,6 @@ LogicalResult tosa::PadOp::verify() { return success(); } -static SmallVector convertToMlirShape(ArrayRef shape) { - return to_vector(llvm::map_range(shape, [](int64_t dim) { - return dim == -1 ? ShapedType::kDynamic : dim; - })); -} - LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 1a896c1464e1c..de08e7e9a4394 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -215,15 +215,8 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) { template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::VariableOp op) { - ::mlir::Attribute attr = op.getInitialValueAttr(); - if (attr == nullptr) - return failure(); - - if (auto typedAttr = dyn_cast(attr)) { - addType(getElementTypeOrSelf(typedAttr)); - return success(); - } - return failure(); + addType(op.getType()); + return success(); } template <> diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index a3ee76bf7026c..d33fc902de3a1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -238,10 +238,10 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return true; } - template - bool levelCheckRank(Operation *op, const T &v, + // Perform the Level Rank check on the tensor type. + bool levelCheckRank(Operation *op, const Type typeToCheck, const StringRef operandOrResult, int32_t highest_rank) { - if (ShapedType type = dyn_cast(v.getType())) { + if (ShapedType type = dyn_cast(typeToCheck)) { if (!type.hasRank()) { op->emitOpError() << "failed level check: unranked tensor"; return false; @@ -255,10 +255,22 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return true; } - // Perform the Level tensor size check on the input tensor. - bool levelCheckSize(Operation *op, const Value &v, + // Perform the Level Rank check on the tensor value. + bool levelCheckRank(Operation *op, const Value &v, + const StringRef operandOrResult, int32_t highest_rank) { + return levelCheckRank(op, v.getType(), operandOrResult, highest_rank); + } + + // Perform the Level tensor size check on the tensor type. + bool levelCheckSize(Operation *op, const Type &typeToCheck, const StringRef operandOrResult); + // Perform the Level tensor size check on the tensor value. + bool levelCheckSize(Operation *op, const Value &v, + const StringRef operandOrResult) { + return levelCheckSize(op, v.getType(), operandOrResult); + } + // Level check sizes of all operands and results of the operation. template bool levelCheckSizes(T tosaOp) { @@ -284,15 +296,6 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { return false; } - if (!op->getAttrs().empty()) { - for (NamedAttribute attr : op->getAttrs()) { - if (auto elemAttr = dyn_cast(attr.getValue())) { - if (!levelCheckRank(op, elemAttr, "attribute", tosaLevel.MAX_RANK)) - return false; - } - } - } - for (auto v : op->getResults()) { if (!levelCheckRank(op, v, "result", tosaLevel.MAX_RANK)) return false; @@ -596,6 +599,26 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) { return true; } +template <> +bool TosaValidation::levelCheckRanks(tosa::VariableOp tosaOp) { + auto op = tosaOp.getOperation(); + auto variableType = getVariableType(tosaOp); + if (!levelCheckRank(op, variableType, "variable type", tosaLevel.MAX_RANK)) + return false; + + return true; +} + +template <> +bool TosaValidation::levelCheckSizes(tosa::VariableOp tosaOp) { + auto op = tosaOp.getOperation(); + auto variableType = getVariableType(tosaOp); + if (!levelCheckSize(op, variableType, "variable type")) + return false; + + return true; +} + bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { #define CHECK_RANKS_AND_SIZES(tosaOp) \ if (isa(op)) { \ @@ -714,10 +737,10 @@ bool TosaValidation::levelCheckRanksAndSizes(Operation *op) { return true; } -// Perform the Level tensor size check -bool TosaValidation::levelCheckSize(Operation *op, const Value &v, +// Perform the Level tensor size check on the tensor type. +bool TosaValidation::levelCheckSize(Operation *op, const Type &typeToCheck, const StringRef operandOrResult) { - if (ShapedType type = dyn_cast(v.getType())) { + if (ShapedType type = dyn_cast(typeToCheck)) { if (!type.hasRank()) { op->emitOpError() << "failed level check: unranked tensor"; return false; @@ -800,18 +823,21 @@ inline bool CompatibleTypes(const mlir::Type &type, } bool TosaValidation::CheckVariable(Operation *op) { - if (isa(op)) { - mlir::StringAttr nameAttr = cast(op->getAttr("name")); + if (auto variableOp = dyn_cast(op)) { + mlir::StringAttr nameAttr = variableOp.getNameAttr(); if (variablesMap.count(nameAttr)) { op->emitOpError() << "name has already been declared"; return false; } - auto typeAttr = cast(op->getAttr("type")); - mlir::Type type = typeAttr.getValue(); + auto elementType = variableOp.getType(); + DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); + SmallVector shape = to_vector(varShapeAttr.getValues()); + RankedTensorType variableType = + RankedTensorType::get(ArrayRef(shape), elementType); - variablesMap[nameAttr] = type; + variablesMap[nameAttr] = variableType; } return true; diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir index 365b05ff084da..d2092753f1f58 100644 --- a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir +++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s +// RUN: mlir-opt --tosa-to-mlprogram %s -split-input-file -o -| FileCheck %s module { // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32> @@ -10,4 +10,18 @@ module { %0 = tosa.variable_read @var_x : tensor<1xf32> return %0 : tensor<1xf32> } +} + +// ----- + +module { + // CHECK: ml_program.global private mutable @var_x : tensor + tosa.variable @var_x : tensor + func.func @test_stateful_ops(%arg0: tensor) -> (tensor) { + // CHECK: ml_program.global_store @var_x = %arg0 : tensor + tosa.variable_write @var_x, %arg0 : tensor + // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor + %0 = tosa.variable_read @var_x : tensor + return %0 : tensor + } } \ No newline at end of file diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c41f079ec526c..05505c3671674 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -564,6 +564,23 @@ func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: ten // ----- +func.func @test_variable_unranked(%arg0: tensor<2x4x8xi8>) -> () { + tosa.variable @stored_var : tensor<*xi8> + // expected-error@+1 {{custom op 'tosa.variable' expected ranked type}} + return +} + +// ----- + +func.func @test_variable_unranked_initial_value(%arg0: tensor<2x4x8xi8>) -> () { + // expected-error@+1 {{elements literal type must have static shape}} + tosa.variable @stored_var = dense<0> : tensor<*xi8> + // expected-error@+1 {{custom op 'tosa.variable' expected attribute}} + return +} + +// ----- + func.func @test_variable_duplicates(%arg0: tensor<2x4x8xi8>) -> () { tosa.variable @stored_var = dense<-1> : tensor<2x4x8xi8> // expected-error@+1 {{'tosa.variable' op illegal to have multiple declaration of 'stored_var'}} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index e7d0a0e1fa4ea..223bf3b635e18 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -443,7 +443,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso // ----- func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> { - // expected-error@+1 {{'tosa.const' op failed level check: attribute rank(shape) <= MAX_RANK}} + // expected-error@+1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RANK}} %0 = "tosa.const"() {values = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32> return %0: tensor<1x1x1x1x1x1x1xi32> } @@ -1089,7 +1089,8 @@ func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, % // ----- func.func @test_variable_read_write_tensor_size_invalid() -> () { - tosa.variable @stored_var = dense<3.14> : tensor<536870912xf32> + // expected-error@+1 {{'tosa.variable' op failed level check: variable type tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} + tosa.variable @stored_var : tensor<536870912xf32> // expected-error@+1 {{'tosa.variable_read' op failed level check: result tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} %0 = tosa.variable_read @stored_var : tensor<536870912xf32> // expected-error@+1 {{'tosa.variable_write' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} @@ -1156,8 +1157,8 @@ func.func @test_cond_if_rank_invalid(%arg0: tensor<1x1x1x1x1x1x1x1xf32>, %arg1: // ----- func.func @test_variable_read_write_rank_invalid() -> () { - // expected-error@+1 {{'tosa.variable' op failed level check: attribute rank(shape) <= MAX_RANK}} - tosa.variable @stored_var = dense<3.14> : tensor<1x1x1x1x1x1x1x1xf32> + // expected-error@+1 {{'tosa.variable' op failed level check: variable type rank(shape) <= MAX_RANK}} + tosa.variable @stored_var : tensor<1x1x1x1x1x1x1x1xf32> // expected-error@+1 {{'tosa.variable_read' op failed level check: result rank(shape) <= MAX_RANK}} %0 = tosa.variable_read @stored_var : tensor<1x1x1x1x1x1x1x1xf32> // expected-error@+1 {{'tosa.variable_write' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/variables.mlir b/mlir/test/Dialect/Tosa/variables.mlir index 25f63331f39df..9953eb375d3ac 100644 --- a/mlir/test/Dialect/Tosa/variables.mlir +++ b/mlir/test/Dialect/Tosa/variables.mlir @@ -31,3 +31,48 @@ func.func @test_variable_tensor(%arg0: tensor<2x4x8xi32>) -> () { tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> return } + +// ----- +// CHECK-LABEL: @test_variable_scalar_no_initial_value( +// CHECK-SAME: %[[ADD_VAL:.*]]: tensor) { +func.func @test_variable_scalar_no_initial_value(%arg0: tensor) -> () { + // CHECK: tosa.variable @stored_var : tensor + tosa.variable @stored_var : tensor + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor + %0 = tosa.variable_read @stored_var : tensor + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor, tensor) -> tensor + %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor + tosa.variable_write @stored_var, %1 : tensor + return +} + +// ----- +// CHECK-LABEL: @test_variable_tensor_no_initial_value( +// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { +func.func @test_variable_tensor_no_initial_value(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: tosa.variable @stored_var : tensor<2x4x8xi32> + tosa.variable @stored_var : tensor<2x4x8xi32> + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> + return +} + +// ----- +// CHECK-LABEL: @test_variable_tensor_with_unknowns( +// CHECK-SAME: %[[ADD_VAL:.*]]: tensor<2x4x8xi32>) { +func.func @test_variable_tensor_with_unknowns(%arg0: tensor<2x4x8xi32>) -> () { + // CHECK: tosa.variable @stored_var : tensor<2x?x8xi32> + tosa.variable @stored_var : tensor<2x?x8xi32> + // CHECK: %[[STORED_VAL:.*]] = tosa.variable_read @stored_var : tensor<2x4x8xi32> + %0 = tosa.variable_read @stored_var : tensor<2x4x8xi32> + // CHECK: %[[RESULT_ADD:.*]] = tosa.add %[[ADD_VAL]], %[[STORED_VAL]] : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + %1 = "tosa.add"(%arg0, %0) : (tensor<2x4x8xi32>, tensor<2x4x8xi32>) -> tensor<2x4x8xi32> + // CHECK: tosa.variable_write @stored_var, %[[RESULT_ADD]] : tensor<2x4x8xi32> + tosa.variable_write @stored_var, %1 : tensor<2x4x8xi32> + return +}