From 62d75f687f81d44258a516ceb4760b3a100b96d1 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 18 Nov 2025 14:01:22 +0000 Subject: [PATCH] Revert "[mlir][tosa] Add a pass to narrow i64 to i32 (#165581)" This reverts commit c61c5d29334c7ff044ba46bff17e1f3d57e230a3. --- .../mlir/Dialect/Tosa/Transforms/Passes.td | 23 -- .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 - .../Tosa/Transforms/TosaNarrowI64ToI32.cpp | 310 ------------------ .../tosa-narrow-i64-to-i32-aggressive.mlir | 81 ----- .../Dialect/Tosa/tosa-narrow-i64-to-i32.mlir | 162 --------- 5 files changed, 577 deletions(-) delete mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp delete mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir delete mode 100644 mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 420e58192b8fd..14b00b04ccc18 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -166,27 +166,4 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> { ]; } -def TosaNarrowI64ToI32Pass : Pass<"tosa-narrow-i64-to-i32", "func::FuncOp"> { - let summary = "Narrow I64 TOSA operations to I32"; - let description = [{ - This pass narrows TOSA operations with 64-bit integer tensor types to - 32-bit integer tensor types. This can be useful for backends that do not - support the EXT-INT64 extension of TOSA. - }]; - - let options = [ - Option<"aggressiveRewrite", "aggressive-rewrite", "bool", "false", - "If enabled, all TOSA operations are rewritten, regardless or whether the narrowing" - "is safe. This option may lead to data loss if not used carefully.">, - Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool", "false", - "If enabled, the pass will convert function I/O types as well. Otherwise casts will" - "be inserted at the I/O boundaries."> - ]; - - let dependentDialects = [ - "func::FuncDialect", - "tosa::TosaDialect", - ]; -} - #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 987ce4ed870c9..41b338d6e7189 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -12,7 +12,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaTypeConverters.cpp TosaProfileCompliance.cpp TosaValidation.cpp - TosaNarrowI64ToI32.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp deleted file mode 100644 index ddaf7d8a5e033..0000000000000 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaNarrowI64ToI32.cpp +++ /dev/null @@ -1,310 +0,0 @@ -//===- TosaNarrowI64ToI32.cpp ---------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass narrows TOSA operations with 64-bit integer tensor types to -// 32-bit integer tensor types. This can be useful for backends that do not -// support the EXT-INT64 extension of TOSA. The pass has two options: -// -// - aggressive-rewrite - If enabled, all TOSA operations are rewritten, -// regardless or whether the narrowing is safe. This option may lead to -// data loss if not used carefully. -// - convert-function-boundaries - If enabled, the pass will convert function -// I/O types as well. Otherwise casts will be inserted at the I/O -// boundaries. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Tosa/Transforms/Passes.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace tosa { -#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS -#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" -} // namespace tosa -} // namespace mlir - -using namespace mlir; -using namespace mlir::tosa; - -namespace { - -LogicalResult convertGenericOp(Operation *op, ValueRange operands, - ConversionPatternRewriter &rewriter, - const TypeConverter *typeConverter) { - // Convert types of results - SmallVector newResults; - if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults))) - return failure(); - - // Create a new operation state - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - newResults, {}, op->getSuccessors()); - - for (const NamedAttribute &namedAttribute : op->getAttrs()) { - const Attribute attribute = namedAttribute.getValue(); - - // Convert integer attribute type - if (const auto intAttr = dyn_cast(attribute)) { - const std::optional convertedAttribute = - typeConverter->convertTypeAttribute(intAttr.getType(), attribute); - state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); - continue; - } - - if (const auto typeAttr = dyn_cast(attribute)) { - Type type = typeAttr.getValue(); - const std::optional convertedAttribute = - typeConverter->convertTypeAttribute(type, attribute); - if (!convertedAttribute) - return rewriter.notifyMatchFailure(op, - "Failed to convert type attribute."); - state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); - continue; - } - - if (const auto denseElementsAttr = dyn_cast(attribute)) { - const Type type = denseElementsAttr.getType(); - const std::optional convertedAttribute = - typeConverter->convertTypeAttribute(type, denseElementsAttr); - if (!convertedAttribute) - return rewriter.notifyMatchFailure( - op, "Failed to convert dense elements attribute."); - state.addAttribute(namedAttribute.getName(), convertedAttribute.value()); - continue; - } - - state.addAttribute(namedAttribute.getName(), attribute); - } - - for (Region ®ion : op->getRegions()) { - Region *newRegion = state.addRegion(); - rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); - if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter))) - return failure(); - } - - Operation *newOp = rewriter.create(state); - rewriter.replaceOp(op, newOp->getResults()); - return success(); -} - -// =========================== -// Aggressive rewrite patterns -// =========================== - -class ConvertGenericOp : public ConversionPattern { -public: - ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - if (!isa(op)) - return rewriter.notifyMatchFailure( - op, - "Support for operations other than TOSA has not been implemented."); - - return convertGenericOp(op, operands, rewriter, typeConverter); - } -}; - -// =============================== -// Bounds checked rewrite patterns -// =============================== - -class ConvertArgMaxOpWithBoundsChecking - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::ArgMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // Output type can be narrowed based on the size of the axis dimension - const int32_t axis = op.getAxis(); - const auto inputType = dyn_cast(adaptor.getInput().getType()); - if (!inputType || !inputType.isStaticDim(axis)) - return rewriter.notifyMatchFailure( - op, "Requires a static axis dimension for bounds checking."); - const int64_t axisDim = inputType.getDimSize(axis); - if (axisDim >= std::numeric_limits::max()) - return rewriter.notifyMatchFailure( - op, "Axis dimension is too large to narrow safely."); - - const Type resultType = op.getOutput().getType(); - const Type newResultType = typeConverter->convertType(resultType); - rewriter.replaceOpWithNewOp(op, newResultType, - adaptor.getInput(), axis); - return success(); - } -}; - -class ConvertCastOpWithBoundsChecking - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(tosa::CastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - const auto inputType = dyn_cast(adaptor.getInput().getType()); - const auto resultType = dyn_cast(op.getResult().getType()); - if (!inputType || !resultType) - return failure(); - - const auto elementInputIntType = - dyn_cast(inputType.getElementType()); - const auto elementResultIntType = - dyn_cast(resultType.getElementType()); - if (elementInputIntType && elementResultIntType && - elementInputIntType.getWidth() > elementResultIntType.getWidth()) - return rewriter.notifyMatchFailure( - op, "Narrowing cast may lead to data loss."); - - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(resultType), adaptor.getInput()); - return success(); - } -}; - -template -class ConvertTypedOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - return convertGenericOp(op, adaptor.getOperands(), rewriter, - this->getTypeConverter()); - } -}; - -struct TosaNarrowI64ToI32 - : public tosa::impl::TosaNarrowI64ToI32PassBase { -public: - explicit TosaNarrowI64ToI32() = default; - explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) - : TosaNarrowI64ToI32() { - this->aggressiveRewrite = options.aggressiveRewrite; - this->convertFunctionBoundaries = options.convertFunctionBoundaries; - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) -> Type { return type; }); - typeConverter.addConversion([](IntegerType type) -> Type { - if (!type.isInteger(64)) - return type; - return IntegerType::get(type.getContext(), 32); - }); - typeConverter.addConversion( - [&typeConverter](RankedTensorType type) -> Type { - const Type elementType = type.getElementType(); - if (!elementType.isInteger(64)) - return type; - return RankedTensorType::get(type.getShape(), - typeConverter.convertType(elementType)); - }); - - const auto materializeCast = [](OpBuilder &builder, Type resultType, - ValueRange inputs, Location loc) -> Value { - if (inputs.size() != 1) - return Value(); - return tosa::CastOp::create(builder, loc, resultType, inputs.front()); - }; - typeConverter.addSourceMaterialization(materializeCast); - typeConverter.addTargetMaterialization(materializeCast); - - typeConverter.addTypeAttributeConversion( - [](IntegerType type, IntegerAttr attribute) -> Attribute { - const APInt value = attribute.getValue().truncSSat(32); - return IntegerAttr::get(IntegerType::get(type.getContext(), 32), - value); - }); - typeConverter.addTypeAttributeConversion( - [&typeConverter](ShapedType type, - DenseIntElementsAttr attr) -> Attribute { - const ShapedType newType = - cast(typeConverter.convertType(type)); - const auto oldElementType = cast(type.getElementType()); - const auto newElementType = - cast(newType.getElementType()); - if (oldElementType.getWidth() == newElementType.getWidth()) - return attr; - - DenseElementsAttr mapped = - attr.mapValues(newElementType, [&](const APInt &v) { - return v.truncSSat(newElementType.getWidth()); - }); - return mapped; - }); - - ConversionTarget target(*context); - target.addDynamicallyLegalDialect( - [&typeConverter](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()) && - typeConverter.isLegal(op->getOperandTypes()); - }); - if (convertFunctionBoundaries) { - target.addDynamicallyLegalOp( - [&typeConverter](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp([](func::ReturnOp op) { - const FunctionType funcType = - op->getParentOfType().getFunctionType(); - return llvm::equal(op.getOperandTypes(), funcType.getResults()); - }); - } else { - target.addDynamicallyLegalOp( - [](func::FuncOp op) { return true; }); - target.addDynamicallyLegalOp( - [](func::ReturnOp op) { return true; }); - } - - RewritePatternSet patterns(context); - if (convertFunctionBoundaries) { - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - } - if (aggressiveRewrite) { - patterns.add(typeConverter, context); - } else { - // Tensor - patterns.add(typeConverter, context); - // Data layout - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - // Type conversion - patterns.add(typeConverter, context); - // Controlflow - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - } - - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); - } -}; - -} // namespace diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir deleted file mode 100644 index 1a36177a37033..0000000000000 --- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32-aggressive.mlir +++ /dev/null @@ -1,81 +0,0 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT -// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="aggressive-rewrite=1 convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND - -// CHECK-LABEL: test_i64_argmax_large_axis_dim -func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { - // DEFAULT: tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi32> - %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> - return %0 : tensor<1x513x513xi64> -} - -// ----- - -// CHECK-LABEL: test_convert_input_parameters -// DEFAULT: %[[IN:.*]]: tensor<1x513x513x3xi64> -// FUNCBOUND: %[[IN:.*]]: tensor<1x513x513x3xi32> -func.func @test_convert_input_parameters(%arg0: tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xf32> { - // DEFAULT: %[[FUNC_BOUND_CAST:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> - // DEFAULT: %[[CAST1:.*]] = tosa.cast %[[FUNC_BOUND_CAST]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> - // FUNCBOUND: %[[CAST1:.*]] = tosa.cast %[[IN]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xi32> - %0 = tosa.cast %arg0 : (tensor<1x513x513x3xi64>) -> tensor<1x513x513x3xi32> - - // COMMON: %[[CAST2:.*]] = tosa.cast %[[CAST1]] : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> - %1 = tosa.cast %0 : (tensor<1x513x513x3xi32>) -> tensor<1x513x513x3xf32> - return %1 : tensor<1x513x513x3xf32> -} - -// ----- - -// CHECK-LABEL: test_add -// DEFAULT: %[[IN0:.*]]: tensor<13x21x1xi64>, %[[IN1:.*]]: tensor<13x21x3xi64> -// FUNCBOUND: %[[IN0:.*]]: tensor<13x21x1xi32>, %[[IN1:.*]]: tensor<13x21x3xi32> -func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { - // DEFAULT-DAG: %[[FUNC_BOUND_CAST0:.*]] = tosa.cast %[[IN0]] : (tensor<13x21x1xi64>) -> tensor<13x21x1xi32> - // DEFAULT-DAG: %[[FUNC_BOUND_CAST1:.*]] = tosa.cast %[[IN1]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi32> - // DEFAULT: %[[ADD:.*]] = tosa.add %[[FUNC_BOUND_CAST0]], %[[FUNC_BOUND_CAST1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> - // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ADD]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xi64> - // DEFAULT: return %[[CAST]] : tensor<13x21x3xi64> - // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> - // FUNCBOUND: return %[[ADD]] : tensor<13x21x3xi32> - %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> - return %0 : tensor<13x21x3xi64> -} - -// ----- - -// CHECK-LABEL: test_regions -// DEFAULT: %[[IN0:.*]]: tensor, %[[IN1:.*]]: tensor -func.func @test_regions(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // DEFAULT-DAG: %[[CAST0:.*]] = tosa.cast %[[IN0]] : (tensor) -> tensor - // DEFAULT-DAG: %[[CAST1:.*]] = tosa.cast %[[IN1]] : (tensor) -> tensor - // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if - %0 = tosa.cond_if %arg2 : tensor -> (tensor) { - // DEFAULT: %[[ADD:.*]] = tosa.add %[[CAST0]], %[[CAST1]] : (tensor, tensor) -> tensor - // FUNCBOUND: %[[ADD:.*]] = tosa.add %[[IN0]], %[[IN1]] : (tensor, tensor) -> tensor - %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor - // COMMON: tosa.yield %[[ADD]] : tensor - tosa.yield %1 : tensor - } else { - // DEFAULT: %[[SUB:.*]] = tosa.sub %[[CAST0]], %[[CAST1]] : (tensor, tensor) -> tensor - // FUNCBOUND: %[[SUB:.*]] = tosa.sub %[[IN0]], %[[IN1]] : (tensor, tensor) -> tensor - %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor - // COMMON: tosa.yield %[[SUB]] : tensor - tosa.yield %1 : tensor - } - // DEFAULT: %[[OUT:.*]] = tosa.cast %[[IF_RESULT]] : (tensor) -> tensor - // DEFAULT: return %[[OUT]] : tensor - // FUNCBOUND: return %[[IF_RESULT]] : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: test_const -func.func @test_const() -> tensor<2xi64> { - // COMMON: %[[CONST:.*]] = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> - %0 = "tosa.const"() <{values = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> - // DEFAULT: %[[OUT:.*]] = tosa.cast %[[CONST]] : (tensor<2xi32>) -> tensor<2xi64> - // DEFAULT: return %[[OUT]] : tensor<2xi64> - // FUNCBOUND: return %[[CONST]] : tensor<2xi32> - return %0 : tensor<2xi64> -} diff --git a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir b/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir deleted file mode 100644 index a14483fcdd7b0..0000000000000 --- a/mlir/test/Dialect/Tosa/tosa-narrow-i64-to-i32.mlir +++ /dev/null @@ -1,162 +0,0 @@ -// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=0" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,DEFAULT -// RUN: mlir-opt -split-input-file -verify-diagnostics -tosa-narrow-i64-to-i32="convert-function-boundaries=1" %s | FileCheck %s --allow-unused-prefixes --check-prefixes=COMMON,FUNCBOUND - -// ----- - -// CHECK-LABEL: test_i64_argmax -func.func @test_i64_argmax(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> { - // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> - %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> - - // DEFAULT: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xi64> - // FUNCBOUND: return %[[ARGMAX]] : tensor<1x513x513xi32> - return %0 : tensor<1x513x513xi64> -} - -// ----- - -// CHECK-LABEL: test_i64_argmax_cast -func.func @test_i64_argmax_cast(%arg0: tensor<1x513x513x19xi8>) -> tensor<1x513x513xf32> { - // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi32> - %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x19xi8>) -> tensor<1x513x513xi64> - // COMMON: tosa.cast %[[ARGMAX]] : (tensor<1x513x513xi32>) -> tensor<1x513x513xf32> - %1 = tosa.cast %0 : (tensor<1x513x513xi64>) -> tensor<1x513x513xf32> - return %1 : tensor<1x513x513xf32> -} - -// ----- - -// CHECK-LABEL: test_i64_argmax_large_axis_dim -func.func @test_i64_argmax_large_axis_dim(%arg0: tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> { - // expected-error @+1 {{failed to legalize operation 'tosa.argmax'}} - %0 = tosa.argmax %arg0 {axis = 3 : i32} : (tensor<1x513x513x2147483650xi8>) -> tensor<1x513x513xi64> - return %0 : tensor<1x513x513xi64> -} - -// ----- - -// CHECK-LABEL: test_add -func.func @test_add(%arg0: tensor<13x21x1xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { - // expected-error @+1 {{failed to legalize operation 'tosa.add'}} - %0 = tosa.add %arg0, %arg1 : (tensor<13x21x1xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64> - return %0 : tensor<13x21x3xi64> -} - -// ----- - -// CHECK-LABEL: test_regions -func.func @test_regions(%arg0: tensor<1x2xi32>, %arg1: tensor<1xi32>, %arg2: tensor) -> tensor<1xi32> { - // COMMON: %[[IF_RESULT:.*]] = tosa.cond_if %arg2 : tensor -> tensor<1xi32> - %0 = tosa.cond_if %arg2 : tensor -> tensor<1xi32> { - // COMMON: %[[ARGMAX:.*]] = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi32> - %1 = tosa.argmax %arg0 {axis = 1 : i32} : (tensor<1x2xi32>) -> tensor<1xi64> - // COMMON: %[[CAST:.*]] = tosa.cast %[[ARGMAX]] : (tensor<1xi32>) -> tensor<1xi32> - %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> - // COMMON: tosa.yield %[[CAST]] : tensor<1xi32> - tosa.yield %2 : tensor<1xi32> - } else { - tosa.yield %arg1 : tensor<1xi32> - } - // COMMON: return %[[IF_RESULT]] : tensor<1xi32> - return %0 : tensor<1xi32> -} - -// ----- - -// CHECK-LABEL: test_concat -func.func @test_concat(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi64>) -> tensor<26x21x3xi64> { - // COMMON: tosa.concat %{{.*}}, %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<26x21x3xi32> - %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<26x21x3xi64> - return %0 : tensor<26x21x3xi64> -} - -// ----- - -// CHECK-LABEL: test_pad -func.func @test_pad(%arg0: tensor<13x21x3xi64>, %arg1: tensor<1xi64>) -> tensor<15x23x5xi64> { - %padding = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> - // COMMON: tosa.pad %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<6>, tensor<1xi32>) -> tensor<15x23x5xi32> - %1 = tosa.pad %arg0, %padding, %arg1 : (tensor<13x21x3xi64>, !tosa.shape<6>, tensor<1xi64>) -> tensor<15x23x5xi64> - return %1 : tensor<15x23x5xi64> -} - -// ----- - -// CHECK-LABEL: test_reshape -func.func @test_reshape(%arg0: tensor<13x21x3xi64>) -> tensor<1x819xi64> { - %1 = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} : () -> !tosa.shape<2> - // COMMON: tosa.reshape %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<2>) -> tensor<1x819xi32> - %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xi64>, !tosa.shape<2>) -> tensor<1x819xi64> - return %0 : tensor<1x819xi64> -} - -// ----- - -// CHECK-LABEL: test_reverse -func.func @test_reverse(%arg0: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { - // COMMON: tosa.reverse %{{.*}} {axis = 0 : i32} : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> - %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> - return %0 : tensor<13x21x3xi64> -} - -// ----- - -// CHECK-LABEL: test_slice -func.func @test_slice(%arg0: tensor<13x21x3xi64>) -> tensor<4x11x1xi64> { - %0 = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> - %1 = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} : () -> !tosa.shape<3> - // COMMON: tosa.slice %{{.*}}, %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi32> - %2 = tosa.slice %arg0, %0, %1 : (tensor<13x21x3xi64>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xi64> - return %2 : tensor<4x11x1xi64> -} - -// ----- - -// CHECK-LABEL: test_tile -func.func @test_tile(%arg0: tensor<13x21x3xi64>) -> tensor<39x21x6xi64> { - %cst = tosa.const_shape { values = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> - // COMMON: tosa.tile %{{.*}}, %{{.*}} : (tensor<13x21x3xi32>, !tosa.shape<3>) -> tensor<39x21x6xi32> - %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xi64>, !tosa.shape<3>) -> tensor<39x21x6xi64> - return %0 : tensor<39x21x6xi64> -} - -// ----- - -// CHECK-LABEL: transpose -func.func @test_transpose(%arg0: tensor<13x21x3xi64>) -> tensor<3x13x21xi64> { - // COMMON: tosa.transpose %{{.*}} {perms = array} : (tensor<13x21x3xi32>) -> tensor<3x13x21xi32> - %1 = tosa.transpose %arg0 {perms = array} : (tensor<13x21x3xi64>) -> tensor<3x13x21xi64> - return %1 : tensor<3x13x21xi64> -} - -// ----- - -// CHECK-LABEL: test_transition_to_i64 -func.func @test_transition_to_i64(%arg0: tensor<1xi32>) -> tensor<1xi64> { - // COMMON: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi32> - %0 = tosa.cast %arg0 : (tensor<1xi32>) -> tensor<1xi64> - // COMMON: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> - %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> - // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> - %2 = tosa.identity %1 : (tensor<1xi64>) -> tensor<1xi64> - // DEFAULT: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi64> - // DEFAULT: return %[[OUT_CAST]] : tensor<1xi64> - // FUNCBOUND: return %[[IDENTITY2]] : tensor<1xi32> - return %2 : tensor<1xi64> -} - -// ----- - -// CHECK-LABEL: test_transition_from_i64 -func.func @test_transition_from_i64(%arg0: tensor<1xi64>) -> tensor<1xi32> { - // DEFAULT: %[[CAST:.*]] = tosa.cast %arg0 : (tensor<1xi64>) -> tensor<1xi32> - // DEFAULT: %[[IDENTITY1:.*]] = tosa.identity %[[CAST]] : (tensor<1xi32>) -> tensor<1xi32> - // FUNCBOUND: %[[IDENTITY1:.*]] = tosa.identity %arg0 : (tensor<1xi32>) -> tensor<1xi32> - %0 = tosa.identity %arg0 : (tensor<1xi64>) -> tensor<1xi64> - // COMMON: %[[IDENTITY2:.*]] = tosa.identity %[[IDENTITY1]] : (tensor<1xi32>) -> tensor<1xi32> - %1 = tosa.identity %0 : (tensor<1xi64>) -> tensor<1xi64> - // COMMON: %[[OUT_CAST:.*]] = tosa.cast %[[IDENTITY2]] : (tensor<1xi32>) -> tensor<1xi32> - %2 = tosa.cast %1 : (tensor<1xi64>) -> tensor<1xi32> - // COMMON: return %[[OUT_CAST]] : tensor<1xi32> - return %2 : tensor<1xi32> -}