New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support returning single element tuple #1471
Comments
Hi @ZihengJiang, in torch-mlir there is no support for returning a tuple with a single element in it because at the backend contract level, the IR produced would be impossible to differentiate from the IR produced by a function that returns a single tensor. For example, if you have a function def some_func(...):
return (a, b, c) this becomes the following in MLIR: func @some_func(...) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) { ... } Therefore, one would expect a function def some_func(...):
return (a,) to become: func @some_func(...) -> (!torch.vtensor) { ... } However, this is the same MLIR that we would expect from a function that returns a single tensor. Because the signature is ambiguous, we can only support one of the cases, and torch-mlir chose to support returning a single tensor rather than a tuple with a single element in it. |
Hi @ramiro050 , thanks for the clarification! In this case, can we convert the single element tuple to a single tensor automatically instead of raising the error? |
We cannot just automatically do this. It would have to be done by the user before their code reaches Torch-MLIR. We have some code to do this on FX graphs that we have been copying from project to project: https://github.com/iree-org/iree-torch/blob/94911c8fc068d135de2a810699d399e10b1cc43f/torchdynamo_poc/utils.py#L55 |
um, ok. In my case, I only have some torch scripted models so modifying them might be hard. Thank you! |
@ZihengJiang In the past the way I have dealt with this is to define a wrapper module around the original model that unpacks the output, rather than modifying the model itself. |
I was running into a similar issue when trying to return a nested tuple of tensors. Is there a way we could improve the error messages? My error was:
|
@dellis23 I think right here in the code we could check the current function signature and raise an error if all args and returns were not tensors or scalars. -- is this something you could take on? Happy to help. |
Yup, I can do that. |
For example:
When such module get compiled with torch-mlir, it will raise such error:
The text was updated successfully, but these errors were encountered: