Skip to content

Conversation

@ssusie
Copy link
Contributor

@ssusie ssusie commented Oct 3, 2023

What does this PR do?

Enables users run sdxl inference in PyTorch XLA. README_sdxl.md guide is also updated.

@patrickvonplaten @sayakpaul

Thanks.

@ssusie ssusie changed the title Adding PyTorch XLA support for sdxl Adding PyTorch XLA support for sdxl inference Oct 3, 2023
@patrickvonplaten
Copy link
Contributor

Hey @ssusie,

Nice addition, generally I'm ok with having this addition. However, could you maybe add torch_xla is a soft dependency here:

_torch_version = "N/A"

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@ssusie
Copy link
Contributor Author

ssusie commented Oct 4, 2023

Thanks for the feedback Patrick. Added the dependency in diffusers/src/diffusers/utils/import_utils.py

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

@sayakpaul
Copy link
Member

@ssusie could you run make style && make quality on your end?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@ssusie
Copy link
Contributor Author

ssusie commented Oct 5, 2023

Thanks everyone for the review and comments. I ran the make style and added changes to the pr.

Comment on lines +123 to +124
speedup, we need to call the pipe again on the input with the same length
as the original prompt to reuse the optimized graph and get the performance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the prompt length really matters here, as embeddings always have the same shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great question. With the current setup we need the same length for the input for XLA, otherwise it will recompile the whole graph. We can potentially cut the graph with xm.mark_step after embeddings are calculated, but it is not address in this PR.

@ssusie
Copy link
Contributor Author

ssusie commented Oct 10, 2023

Thanks for the reviews and comments. Is there anything else to change or address or can this be merged?

@patrickvonplaten
Copy link
Contributor

Let's merge it - great job @ssusie!

We don't have tests here yet, but I think this is fine to begin with :-)

@patrickvonplaten patrickvonplaten merged commit aea7383 into huggingface:main Oct 11, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Added  mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla

* adding soft dependency on torch_xla

* fix some styling

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Added  mark_step for sdxl to run with pytorch xla. Also updated README with instructions for xla

* adding soft dependency on torch_xla

* fix some styling

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants