Skip to content

Conversation

@sahas3
Copy link
Member

@sahas3 sahas3 commented Nov 18, 2025

Legalizing following IR to tosa using tf-tosa-opt from tensorflow repo:

func.func @main(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32> {
    %0 = "tfl.pseudo_const"() <{value = dense<0.000000e+00> : tensor<5xf32>}> : () -> tensor<5xf32>
    %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
    return %1 : tensor<?x?x?x5xf32>
  }

fails with

error: 'tosa.add' op operands don't have matching ranks
    %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
         ^
tfl.mlir:3:10: note: see current operation: %1 = "tosa.add"(%arg0, %0) : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
// -----// IR Dump After TosaLegalizeTFLPass Failed (tosa-legalize-tfl) //----- //
"func.func"() <{function_type = (tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32>, sym_name = "main"}> ({
^bb0(%arg0: tensor<?x?x?x?xf32>):
  %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<5xf32>}> : () -> tensor<5xf32>
  %1 = "tosa.add"(%arg0, %0) : (tensor<?x?x?x?xf32>, tensor<5xf32>) -> tensor<?x?x?x5xf32>
  "func.return"(%1) : (tensor<?x?x?x5xf32>) -> ()
}) : () -> ()

This is because of the following check in computeReshapeOutput called from EqualizeRanks function:

if (lowerRankDim != 1 && higherRankDim != 1 &&
        lowerRankDim != higherRankDim)
      return failure();

Based on the broadcast semantics defined in https://mlir.llvm.org/docs/Traits/Broadcastable/#dimension-inference I think it's legal to allow lowerRankDim != higherRankDim if one of them is dynamic. At runtime verifier should enforce that

  1. if lowerRankDim is dynamic and higherRankDim is static then the dynamic dim matches the static dim and vice-versa
  2. if both are dynamic, they should match
    It's not necessary to error out during the op construction time.

@llvmbot
Copy link
Member

llvmbot commented Nov 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Sayan Saha (sahas3)

Changes

Legalizing following IR to tosa using tf-tosa-opt from tensorflow repo:

func.func @<!-- -->main(%arg0: tensor&lt;?x?x?x?xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt; {
    %0 = "tfl.pseudo_const"() &lt;{value = dense&lt;0.000000e+00&gt; : tensor&lt;5xf32&gt;}&gt; : () -&gt; tensor&lt;5xf32&gt;
    %1 = tfl.add(%arg0, %0) &lt;{fused_activation_function = "NONE"}&gt; : (tensor&lt;?x?x?x?xf32&gt;, tensor&lt;5xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt;
    return %1 : tensor&lt;?x?x?x5xf32&gt;
  }

fails with

error: 'tosa.add' op operands don't have matching ranks
    %1 = tfl.add(%arg0, %0) &lt;{fused_activation_function = "NONE"}&gt; : (tensor&lt;?x?x?x?xf32&gt;, tensor&lt;5xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt;
         ^
tfl.mlir:3:10: note: see current operation: %1 = "tosa.add"(%arg0, %0) : (tensor&lt;?x?x?x?xf32&gt;, tensor&lt;5xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt;
// -----// IR Dump After TosaLegalizeTFLPass Failed (tosa-legalize-tfl) //----- //
"func.func"() &lt;{function_type = (tensor&lt;?x?x?x?xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt;, sym_name = "main"}&gt; ({
^bb0(%arg0: tensor&lt;?x?x?x?xf32&gt;):
  %0 = "tosa.const"() &lt;{values = dense&lt;0.000000e+00&gt; : tensor&lt;5xf32&gt;}&gt; : () -&gt; tensor&lt;5xf32&gt;
  %1 = "tosa.add"(%arg0, %0) : (tensor&lt;?x?x?x?xf32&gt;, tensor&lt;5xf32&gt;) -&gt; tensor&lt;?x?x?x5xf32&gt;
  "func.return"(%1) : (tensor&lt;?x?x?x5xf32&gt;) -&gt; ()
}) : () -&gt; ()

This is because of the following check in computeReshapeOutput called from EqualizeRanks function:

if (lowerRankDim != 1 &amp;&amp; higherRankDim != 1 &amp;&amp;
        lowerRankDim != higherRankDim)
      return failure();

Based on the broadcast semantics defined in https://mlir.llvm.org/docs/Traits/Broadcastable/#dimension-inference I think it's legal to allow lowerRankDim != higherRankDim if one of them is dynamic. At runtime verifier should enforce that

  1. if lowerRankDim is dynamic and higherRankDim is static then the dynamic dim matches the static dim and vice-versa
  2. if both are dynamic, they should match
    It's not necessary to error out during the op construction time.

Full diff: https://github.com/llvm/llvm-project/pull/168564.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+8-1)
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62c015a85ee36..bb52d15026367 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -70,6 +70,8 @@ namespace {
 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c].
+// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?].
 LogicalResult
 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
                      ArrayRef<int64_t> lowerRankShape,
@@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
     higherRankDim = higherRankShape[i + rankDiff];
     lowerRankDim = lowerRankShape[i];
 
-    if (lowerRankDim != 1 && higherRankDim != 1 &&
+    auto isKnownStaticShapeNotEqualToOne = [](int64_t dim) {
+      return dim != 1 && dim != ShapedType::kDynamic;
+    };
+
+    if (isKnownStaticShapeNotEqualToOne(lowerRankDim) &&
+        isKnownStaticShapeNotEqualToOne(higherRankDim) &&
         lowerRankDim != higherRankDim)
       return failure();
 

@sahas3
Copy link
Member Author

sahas3 commented Nov 18, 2025

I've verified that with the fix the IR is legalized correctly to

func.func @main(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1x5xf32>}> : () -> tensor<1x1x1x5xf32>
    %1 = tosa.add %arg0, %0 : (tensor<?x?x?x?xf32>, tensor<1x1x1x5xf32>) -> tensor<?x?x?x5xf32>
    return %1 : tensor<?x?x?x5xf32>
  }

However, I am not sure what test I can add here to lock this down. If you have any suggestions on adding unit-tests please let me know. Thanks!

@sahas3 sahas3 requested review from Jerry-Ge and sjarus November 18, 2025 16:40
@github-actions
Copy link

github-actions bot commented Nov 18, 2025

🐧 Linux x64 Test Results

  • 7099 tests passed
  • 594 tests skipped

@Jerry-Ge
Copy link
Member

I've verified that with the fix the IR is legalized correctly to

func.func @main(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1x5xf32>}> : () -> tensor<1x1x1x5xf32>
    %1 = tosa.add %arg0, %0 : (tensor<?x?x?x?xf32>, tensor<1x1x1x5xf32>) -> tensor<?x?x?x5xf32>
    return %1 : tensor<?x?x?x5xf32>
  }

However, I am not sure what test I can add here to lock this down. If you have any suggestions on adding unit-tests please let me know. Thanks!

Thanks for fixing this! How about adding the test case you constructed somewhere here? https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

@sahas3
Copy link
Member Author

sahas3 commented Nov 18, 2025

How about adding the test case you constructed somewhere here? https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Thanks for the suggestion, but can you please elaborate more on this?

The EqualizeRank function will operate on tensor<?x?x?x?xf32> and tensor<5xf32> to expand the latter to tensor<1x1x1x5xf32> when the tosa.add op is being constructed in the legalization example I have. However, it's not possible to create a tosa.add or other such tosa ops with broadcasting semantics with operands of types tensor<?x?x?x?xf32> and tensor<5xf32> that will trigger the code-path.

The only use of EqualizeRanks in llvm-project is in TosaDecomposeDepthWise, TosaDecomposeTransposeConv and TosaMakeBroadcastable. The first two passes don't handle dynamic dimensions so the fix in this change cannot be triggered -- I think it's possible to enhance these passes to handle dynamic batch and then maybe the change in this code can be triggered. I can look into that separately.

For TosaMakeBroadcastable, it seems the pass may even be no-op now since it's expected that tosa ops will already have it's operands broadcasted during construction time, otherwise the verify method will error.

@sahas3
Copy link
Member Author

sahas3 commented Nov 18, 2025

I've verified that with the fix the IR is legalized correctly to

func.func @main(%arg0: tensor<?x?x?x?xf32>) -> tensor<?x?x?x5xf32> {
    %0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1x5xf32>}> : () -> tensor<1x1x1x5xf32>
    %1 = tosa.add %arg0, %0 : (tensor<?x?x?x?xf32>, tensor<1x1x1x5xf32>) -> tensor<?x?x?x5xf32>
    return %1 : tensor<?x?x?x5xf32>
  }

For clarification, I verified this by mimicking the change in the llvm directory synced in tensorflow repo.

@Jerry-Ge
Copy link
Member

How about adding the test case you constructed somewhere here? https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Thanks for the suggestion, but can you please elaborate more on this?

The EqualizeRank function will operate on tensor<?x?x?x?xf32> and tensor<5xf32> to expand the latter to tensor<1x1x1x5xf32> when the tosa.add op is being constructed in the legalization example I have. However, it's not possible to create a tosa.add or other such tosa ops with broadcasting semantics with operands of types tensor<?x?x?x?xf32> and tensor<5xf32> that will trigger the code-path.

The only use of EqualizeRanks in llvm-project is in TosaDecomposeDepthWise, TosaDecomposeTransposeConv and TosaMakeBroadcastable. The first two passes don't handle dynamic dimensions so the fix in this change cannot be triggered -- I think it's possible to enhance these passes to handle dynamic batch and then maybe the change in this code can be triggered. I can look into that separately.

For TosaMakeBroadcastable, it seems the pass may even be no-op now since it's expected that tosa ops will already have it's operands broadcasted during construction time, otherwise the verify method will error.

Ah, i see your problem. Thanks for the clarification. If i understood it correctly, we simply don't have the mechanism in the llvm-project to trigger this specific test case.

To me this test case seems should be put in the TF repo: somewhere here (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tosa/tests)

If we really need to put it in the llvm-project, could we create a dummy pass to verify it? I have never done something like this before. I think your proposal may also work.

I haven't touched on this for a time long so nothing creative came to my mind immediately. cc my colleagues @Tai78641 and @lhutton1 if they have better ideas 🤔

@Tai78641
Copy link
Contributor

LGTM
It is hard to test EqualizeRanks in tosa because it is a utility function.
I would just test it in tfl to tosa legalization lit tests

@sahas3
Copy link
Member Author

sahas3 commented Nov 19, 2025

Thanks for the suggestions.

I would just test it in tfl to tosa legalization lit tests

I'll make a note of adding this LIT test when this change lands on the TF side.

If no further changes are required, can one of you please approve the PR. Thanks!

Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sahas3 sahas3 merged commit 6ad1623 into llvm:main Nov 19, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants