Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow gradient calculation in DiffusionPipeline.__call__ #529

Closed
JoaoLages opened this issue Sep 16, 2022 · 3 comments
Closed

Allow gradient calculation in DiffusionPipeline.__call__ #529

JoaoLages opened this issue Sep 16, 2022 · 3 comments
Labels
stale Issues that haven't received updates

Comments

@JoaoLages
Copy link

Nowadays the gradient calculation is disabled when calling the DiffusionPipeline object (example in here).
IMO, this makes the framework less flexible to work with some use cases. For example, this framework had to copy the whole __call__ method just to be able to get the gradients attached to the output image tensor.

@anton-l
Copy link
Member

anton-l commented Sep 21, 2022

Hi @JoaoLages! I feel like torch.no_grad is a good default for pipelines, as they're intended only for inference by design. Additional functionality would require implementing a custom pipeline, indeed.

For your use case, however, there's a neat python trick to unwrap the nograd-decorated method and keep its contents: pipeline.__call__ = pipeline.__call__.__wrapped__. Could you check if it works for you?

@JoaoLages
Copy link
Author

Hi @JoaoLages! I feel like torch.no_grad is a good default for pipelines, as they're intended only for inference by design. Additional functionality would require implementing a custom pipeline, indeed.

For your use case, however, there's a neat python trick to unwrap the nograd-decorated method and keep its contents: pipeline.__call__ = pipeline.__call__.__wrapped__. Could you check if it works for you?

I tried that in an early stage but it is not possible to override __call__ afaik

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests

2 participants