-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[tosa] : Enhance EqualizeRanks to handle dynamic dimensions. #168564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Sayan Saha (sahas3) ChangesLegalizing following IR to fails with This is because of the following check in Based on the broadcast semantics defined in https://mlir.llvm.org/docs/Traits/Broadcastable/#dimension-inference I think it's legal to allow
Full diff: https://github.com/llvm/llvm-project/pull/168564.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index 62c015a85ee36..bb52d15026367 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -70,6 +70,8 @@ namespace {
// If lower=[a], higher=[a, a], [a] reshaped into [1, a].
// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
+// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c].
+// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?].
LogicalResult
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
ArrayRef<int64_t> lowerRankShape,
@@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
higherRankDim = higherRankShape[i + rankDiff];
lowerRankDim = lowerRankShape[i];
- if (lowerRankDim != 1 && higherRankDim != 1 &&
+ auto isKnownStaticShapeNotEqualToOne = [](int64_t dim) {
+ return dim != 1 && dim != ShapedType::kDynamic;
+ };
+
+ if (isKnownStaticShapeNotEqualToOne(lowerRankDim) &&
+ isKnownStaticShapeNotEqualToOne(higherRankDim) &&
lowerRankDim != higherRankDim)
return failure();
|
|
I've verified that with the fix the IR is legalized correctly to However, I am not sure what test I can add here to lock this down. If you have any suggestions on adding unit-tests please let me know. Thanks! |
🐧 Linux x64 Test Results
|
Thanks for fixing this! How about adding the test case you constructed somewhere here? https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |
Thanks for the suggestion, but can you please elaborate more on this? The The only use of For |
For clarification, I verified this by mimicking the change in the llvm directory synced in |
Ah, i see your problem. Thanks for the clarification. If i understood it correctly, we simply don't have the mechanism in the llvm-project to trigger this specific test case. To me this test case seems should be put in the TF repo: somewhere here (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/tosa/tests) If we really need to put it in the llvm-project, could we create a dummy pass to verify it? I have never done something like this before. I think your proposal may also work. I haven't touched on this for a time long so nothing creative came to my mind immediately. cc my colleagues @Tai78641 and @lhutton1 if they have better ideas 🤔 |
|
LGTM |
|
Thanks for the suggestions.
I'll make a note of adding this LIT test when this change lands on the TF side. If no further changes are required, can one of you please approve the PR. Thanks! |
Jerry-Ge
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Legalizing following IR to
tosausingtf-tosa-optfromtensorflowrepo:fails with
This is because of the following check in
computeReshapeOutputcalled fromEqualizeRanksfunction:Based on the broadcast semantics defined in https://mlir.llvm.org/docs/Traits/Broadcastable/#dimension-inference I think it's legal to allow
lowerRankDim != higherRankDimif one of them is dynamic. At runtime verifier should enforce thatIt's not necessary to error out during the op construction time.