Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning single element tuple #1471

Closed
ZihengJiang opened this issue Oct 10, 2022 · 8 comments
Closed

Support returning single element tuple #1471

ZihengJiang opened this issue Oct 10, 2022 · 8 comments
Labels
bug Something isn't working

Comments

@ZihengJiang
Copy link
Collaborator

For example:

class Mod(torch.nn.Module):

    def forward(self, a):
       return (torch.tanh(a), )

When such module get compiled with torch-mlir, it will raise such error:

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the followingdiagnostics:
error: failed to legalize unresolved materialization from '!torch.tuple<tensor>' to '!torch.tensor' that remained live after conversion
note: see current operation: %3 = "builtin.unrealized_conversion_cast"(%2) : (!torch.tuple<tensor>) -> !torch.tensor
note: see existing live user here: "func.return"(%3) : (!torch.tensor) -> ()
@ZihengJiang ZihengJiang added the bug Something isn't working label Oct 10, 2022
@ZihengJiang ZihengJiang changed the title Don't support returning single element tuple Support returning single element tuple Oct 10, 2022
@ramiro050
Copy link
Collaborator

Hi @ZihengJiang, in torch-mlir there is no support for returning a tuple with a single element in it because at the backend contract level, the IR produced would be impossible to differentiate from the IR produced by a function that returns a single tensor.

For example, if you have a function

def some_func(...):
    return (a, b, c)

this becomes the following in MLIR:

func @some_func(...) -> (!torch.vtensor, !torch.vtensor, !torch.vtensor) { ... }

Therefore, one would expect a function

def some_func(...):
    return (a,)

to become:

func @some_func(...) -> (!torch.vtensor) { ... }

However, this is the same MLIR that we would expect from a function that returns a single tensor. Because the signature is ambiguous, we can only support one of the cases, and torch-mlir chose to support returning a single tensor rather than a tuple with a single element in it.

@ZihengJiang
Copy link
Collaborator Author

Hi @ramiro050 , thanks for the clarification! In this case, can we convert the single element tuple to a single tensor automatically instead of raising the error?

@silvasean
Copy link
Contributor

Hi @ramiro050 , thanks for the clarification! In this case, can we convert the single element tuple to a single tensor automatically instead of raising the error?

We cannot just automatically do this. It would have to be done by the user before their code reaches Torch-MLIR. We have some code to do this on FX graphs that we have been copying from project to project: https://github.com/iree-org/iree-torch/blob/94911c8fc068d135de2a810699d399e10b1cc43f/torchdynamo_poc/utils.py#L55

@ZihengJiang
Copy link
Collaborator Author

um, ok. In my case, I only have some torch scripted models so modifying them might be hard. Thank you!

@silvasean
Copy link
Contributor

@ZihengJiang In the past the way I have dealt with this is to define a wrapper module around the original model that unpacks the output, rather than modifying the model itself.

@dellis23
Copy link
Collaborator

I was running into a similar issue when trying to return a nested tuple of tensors. Is there a way we could improve the error messages? My error was:

<unknown>:0: error: unsupported by backend contract: tensor with unknown rank
<unknown>:0: note: see current operation: %10 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<*,f32>

@silvasean
Copy link
Contributor

silvasean commented Oct 13, 2022

@dellis23 I think right here in the code we could check the current function signature and raise an error if all args and returns were not tensors or scalars. -- is this something you could take on? Happy to help.

@dellis23
Copy link
Collaborator

@dellis23 I think right here in the code we could check the current function signature and raise an error if all args and returns were not tensors or scalars. -- is this something you could take on? Happy to help.

Yup, I can do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants