Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Possible Bug]: Device management issue when using Torch Modules inside Keras Models #18411

Open
soumik12345 opened this issue Sep 4, 2023 · 3 comments

Comments

@soumik12345
Copy link
Contributor

I attempted to use torch.nn.Modules inside a keras_core.Model without wrapping them with TorchModuleWrapper (assuming it to be applied behind the scene). However, when I pass a torch.cuda.FloatTensor to the Model, it shows the following error:

RuntimeError: Exception encountered when calling Classifier.call().

Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Arguments received by Classifier.call():
  • inputs=torch.Tensor(shape=torch.Size([1, 1, 28, 28]), dtype=float32)

This error is not encountered when trainable modules are explicitly wrapped with TorchModuleWrapper.

Notebook to reproduce the issue: https://colab.research.google.com/drive/1UO8uY86Ff-lNq5_7vKeLAb5sK99UFDWJ?usp=sharing

@fchollet
Copy link
Member

fchollet commented Sep 5, 2023

Thanks for the report. It is surprising that there would be any difference between the two, because the autowrapping doesn't do anything besides creating the TorchModuleWrapper.

What happens if you place the module on device before setting it as a model attribute?

@soumik12345
Copy link
Contributor Author

What happens if you place the module on device before setting it as a model attribute?

Hi @fchollet
If the modules are placed on device while setting it as a model attribute, it doesn't seem like the trainable variables are being tracked.

Notebook to reproduce issue: https://colab.research.google.com/drive/1ETFBY9duT_snn-HY92ptUWVO-4iBesQZ?usp=sharing

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@github-actions
Copy link

github-actions bot commented Oct 7, 2023

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants