diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td index 14b00b04ccc18..34572c5c4d131 100644 --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -105,6 +105,15 @@ def TosaReduceTransposes : Pass<"tosa-reduce-transposes", "func::FuncOp"> { }]; } +def TosaArithConstantToTosaConstPass + : Pass<"tosa-arith-const-to-tosa-const", "func::FuncOp"> { + let summary = "Convert tensor arith.constant operations into tosa.const"; + let description = [{ + Normalizes tensor-valued arith.constant operations into tosa.const so that + subsequent TOSA passes operate on a consistent representation of constants. + }]; +} + def TosaConvertIntegerTypeToSignless : Pass<"tosa-convert-integer-type-to-signless", "func::FuncOp"> { let summary = "Convert integer types to signless"; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt index 41b338d6e7189..46c299834e2df 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaAttachTarget.cpp + TosaArithConstantToConst.cpp TosaConvertIntegerTypeToSignless.cpp TosaDecomposeTransposeConv.cpp TosaDecomposeDepthwise.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp new file mode 100644 index 0000000000000..8ddde9c05724e --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaArithConstantToConst.cpp @@ -0,0 +1,126 @@ +//===- TosaArithConstantToConst.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that converts tensor-valued arith.constant ops +// into tosa.const so that TOSA pipelines operate on a uniform constant form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAARITHCONSTANTTOTOSACONSTPASS +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +// NOTE: TOSA pipelines already lower their constants through shared Arith +// folding passes, so tensor literals often come back as `arith.constant` even +// after the IR is otherwise TOSA-only. Keep this normalization with the rest of +// the TOSA transforms so any client can re-establish a canonical `tosa.const` +// representation without needing a full Arith->TOSA conversion library. + +/// Returns true when `elementType` is natively representable by tosa.const. +static bool isSupportedElementType(Type elementType) { + if (isa(elementType)) + return true; + + if (auto intType = dyn_cast(elementType)) + return intType.isSignless() || intType.isUnsigned(); + + if (isa(elementType)) + return true; + + if (isa(elementType)) + return true; + + return false; +} + +class ArithConstantToTosaConst : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constOp, + PatternRewriter &rewriter) const override { + // TOSA constant verification requires a ranked, statically shaped tensor. + auto resultType = dyn_cast(constOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + + if (!isSupportedElementType(resultType.getElementType())) + return failure(); + + Attribute attr = constOp.getValueAttr(); + auto elementsAttr = dyn_cast(attr); + if (!elementsAttr) + return failure(); + + auto attrType = dyn_cast(elementsAttr.getType()); + if (!attrType || !attrType.hasStaticShape()) + return failure(); + + if (attrType != resultType) { + // Allow reshape when the payload can be reinterpreted without altering + // the number of elements or element type. Dense resource attributes + // cannot be reshaped losslessly, so bail out in that case. + if (!isa(elementsAttr)) + return failure(); + + if (attrType.getElementType() != resultType.getElementType()) + return failure(); + + auto denseAttr = cast(elementsAttr); + if (denseAttr.getNumElements() != resultType.getNumElements()) + return failure(); + + elementsAttr = denseAttr.reshape(resultType); + } + + auto newConst = tosa::ConstOp::create(rewriter, constOp.getLoc(), + resultType, elementsAttr); + rewriter.replaceOp(constOp, newConst.getResult()); + return success(); + } +}; + +struct TosaArithConstantToTosaConstPass + : public tosa::impl::TosaArithConstantToTosaConstPassBase< + TosaArithConstantToTosaConstPass> { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir new file mode 100644 index 0000000000000..3f54a68ed3c00 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-arith-const-to-tosa-const.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt %s --tosa-arith-const-to-tosa-const --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @rewrite_f32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: return %[[CST]] +func.func @rewrite_f32_tensor() -> tensor<2xf32> { + %c = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + return %c : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i32_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: return %[[CST]] +func.func @rewrite_i32_tensor() -> tensor<3xi32> { + %c = arith.constant dense<[1, 0, -1]> : tensor<3xi32> + return %c : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_i1_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[true, false]> : tensor<2xi1>}> : () -> tensor<2xi1> +func.func @rewrite_i1_tensor() -> tensor<2xi1> { + %c = arith.constant dense<[true, false]> : tensor<2xi1> + return %c : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_rank0_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<1.234500e+00> : tensor}> : () -> tensor +func.func @rewrite_rank0_tensor() -> tensor { + %c = arith.constant dense<1.234500e+00> : tensor + return %c : tensor +} + +// ----- + +// CHECK-LABEL: func.func @preserve_scalar_i32 +// CHECK: %[[CST:.*]] = arith.constant 42 : i32 +func.func @preserve_scalar_i32() -> i32 { + %c = arith.constant 42 : i32 + return %c : i32 +} + +// ----- + +// CHECK-LABEL: func.func @preserve_index_tensor +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1]> : tensor<2xindex> +func.func @preserve_index_tensor() -> tensor<2xindex> { + %c = arith.constant dense<[0, 1]> : tensor<2xindex> + return %c : tensor<2xindex> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_resource_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource : tensor<4xf32>}> : () -> tensor<4xf32> +func.func @rewrite_resource_tensor() -> tensor<4xf32> { + %c = arith.constant dense_resource<"blob1"> : tensor<4xf32> + return %c : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[10, 20]> : tensor<2xui8>}> : () -> tensor<2xui8> +func.func @rewrite_quant_tensor() -> tensor<2xui8> { + %c = arith.constant dense<[10, 20]> : tensor<2xui8> + return %c : tensor<2xui8> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_quant_uniform_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["10", "20"]> : tensor<2x!quant.uniform>}> : () -> tensor<2x!quant.uniform> +func.func @rewrite_quant_uniform_tensor() -> tensor<2x!quant.uniform> { + %c = arith.constant dense<["10", "20"]> : tensor<2x!quant.uniform> + return %c : tensor<2x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_reshape_collapse_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1, 2, 3, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> +func.func @rewrite_reshape_collapse_tensor() -> tensor<4xi32> { + %c = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> + %d = tensor.collapse_shape %c [[0, 1]] : tensor<2x2xi32> into tensor<4xi32> + return %d : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_fp8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<[1.000000e+00, -5.000000e-01]> : tensor<2xf8E4M3FN>}> : () -> tensor<2xf8E4M3FN> +func.func @rewrite_fp8_tensor() -> tensor<2xf8E4M3FN> { + %c = arith.constant dense<[1.0, -0.5]> : tensor<2xf8E4M3FN> + return %c : tensor<2xf8E4M3FN> +} + +// ----- + +// CHECK-LABEL: func.func @rewrite_mxint8_tensor +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8>}> : () -> tensor<2x!tosa.mxint8> +func.func @rewrite_mxint8_tensor() -> tensor<2x!tosa.mxint8> { + %c = arith.constant dense<["0x00", "0x7F"]> : tensor<2x!tosa.mxint8> + return %c : tensor<2x!tosa.mxint8> +}