From 1c65d0e283ce3b8c271fb33b1154336a0e448ff0 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Tue, 11 Nov 2025 17:10:32 +0000 Subject: [PATCH] [TOSA] Introduce arith.constant -> tosa.const normalization pass Add a standalone pass that rewrites tensor-valued `arith.constant` ops into `tosa.const`, normalize the TOSA backend contract. Co-authored-by: Shubham Signed-off-by: Vitalii Shutov Change-Id: I4e71926107633007a71bd1fcc3311a5da6d38849 --- .../mlir/Dialect/Tosa/Transforms/Passes.td | 9 ++ .../Dialect/Tosa/Transforms/CMakeLists.txt | 1 + .../Transforms/TosaArithConstantToConst.cpp | 117 ++++++++++++++++++ .../Tosa/arith-const-to-tosa-const.mlir | 84 +++++++++++++ 4 files changed, 211 insertions(+) create mode 100644 mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp create mode 100644 mlir/test/Dialect/Tosa/arith-const-to-tosa-const.mlir diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 14b00b04ccc18..34572c5c4d131 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def TosaArithConstantToTosaConstPass + : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> { + let summary = "Convert tensor arith.constant operations into tosa.const"; + let description = [{ + Normalizes tensor-valued arith.constant operations into tosa.const so that + subsequent TOSA passes operate on a consistent representation of constants. + }]; +} + def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> { let summary = "Convert integer types to signless"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 41b338d6e7189..46c299834e2df 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaAttachTarget.cpp + TosaArithConstantToConst.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp new file mode 100644 index 0000000000000..72680489b92a2 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp @@ -0,0 +1,117 @@ +//===- TosaArithConstantToConst.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that converts tensor-valued arith.constant ops +// into tosa.const so that TOSA pipelines operate on a uniform constant form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +/// Returns true when `elementType` is natively representable by tosa.const. +static bool isSupportedElementType(Type elementType) { + if (isa(elementType)) + return true; + + if (auto intType = dyn_cast(elementType)) + return intType.isSignless() || intType.isUnsigned(); + + if (isa(elementType)) + return true; + + return false; +} + +class ArithConstantToTosaConst : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constOp, + PatternRewriter &rewriter) const override { + // TOSA constant verification requires a ranked, statically shaped tensor. + auto resultType = dyn_cast(constOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + if (!isSupportedElementType(resultType.getElementType())) + return failure(); + + Attribute attr = constOp.getValueAttr(); + auto elementsAttr = dyn_cast(attr); + if (!elementsAttr) + return failure(); + + auto attrType = dyn_cast(elementsAttr.getType()); + if (!attrType || !attrType.hasStaticShape()) + return failure(); + + if (attrType != resultType) { + // Allow reshape when the payload can be reinterpreted without altering + // the number of elements or element type. Dense resource attributes + // cannot be reshaped losslessly, so bail out in that case. + if (!isa(elementsAttr)) + return failure(); + + if (attrType.getElementType() != resultType.getElementType()) + return failure(); + + auto denseAttr = cast(elementsAttr); + if (denseAttr.getNumElements() != resultType.getNumElements()) + return failure(); + + elementsAttr = denseAttr.reshape(resultType); + } + + auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(), + resultType, elementsAttr); + rewriter.replaceOp(constOp, newConst.getResult()); + return success(); + } +}; + +struct TosaArithConstantToTosaConstPass + : public tosa::impl::TosaArithConstantToTosaConstPassBase< + TosaArithConstantToTosaConstPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/Tosa/arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/arith-const-to-tosa-const.mlir new file mode 100644 index 0000000000000..38dbcf7e45b3c --- /dev/null +++ b/mlir/test/Dialect/Tosa/arith-const-to-tosa-const.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @rewrite_f32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: return %[[CST]] +func.func @rewrite_f32_tensor() -> tensor<2xf32> { + %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + return %c : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: return %[[CST]] +func.func @rewrite_i32_tensor() -> tensor<3xi32> { + %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32> + return %c : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i1_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1> +func.func @rewrite_i1_tensor() -> tensor<2xi1> { + %c = arith.constant dense<[true, false]> : tensor<2xi1> + return %c : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_rank0_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor}> : () -> tensor +func.func @rewrite_rank0_tensor() -> tensor { + %c = arith.constant dense<1.234500e+00> : tensor + return %c : tensor +} + +// ----- + +// CHECK-LABEL: func.func @preserve_scalar_i32 +// CHECK: %[[CST:.*]] = arith.constant 42 : i32 +func.func @preserve_scalar_i32() -> i32 { + %c = arith.constant 42 : i32 + return %c : i32 +} + +// ----- + +// CHECK-LABEL: func.func @preserve_index_tensor +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex> +func.func @preserve_index_tensor() -> tensor<2xindex> { + %c = arith.constant dense<[0, 1]> : tensor<2xindex> + return %c : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_resource_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource : tensor<4xf32>}> : () -> tensor<4xf32> +func.func @rewrite_resource_tensor() -> tensor<4xf32> { + %c = arith.constant dense_resource<"blob1"> : tensor<4xf32> + return %c : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8> +func.func @rewrite_quant_tensor() -> tensor<2xui8> { + %c = arith.constant dense<[10, 20]> : tensor<2xui8> + return %c : tensor<2xui8> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform>}> : () -> tensor<2x!quant.uniform> +func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform> { + %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform> + return %c : tensor<2x!quant.uniform> +} + +// -----