Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Pytorch model.generate method to work on TPU #18661

Open
mikcnt opened this issue Aug 17, 2022 · 20 comments
Open

Refactor Pytorch model.generate method to work on TPU #18661

mikcnt opened this issue Aug 17, 2022 · 20 comments
Assignees
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@mikcnt
Copy link
Contributor

mikcnt commented Aug 17, 2022

Feature request

Refactor PT version of the method model.generate for text generating models to make it compatible with XLA and speed up inference on TPU.

Motivation

Right now, model.generate on PT is extremely slow on TPU compared to CPU and GPU. This is probably due to the fact that some operations done in the PT version of model.generate are not XLA compatible, and thus the generation process falls back on CPU. This makes inference on TPU infeasible. A major refactoring work has already been done on its TF counterpart, so it would be nice to have the PT version working as well.

A more in-depth discussion with @gante took place in #12322 and on this huggingface discussion.

Your contribution

If there is some interest from the HF team, I can definitely assist during the work.

@gante
Copy link
Member

gante commented Aug 17, 2022

cc @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

Hey @mikcnt,

This sounds like a very cool project and I think we should sooner or later focus on it. Currently I won't have the time to take a closer look here, but my advice would be:

  • I think you're totally right in that PyTorch/XLA often falls back on CPU which is why it is very slow. We're luckier here with Jax and TF because if things fall back on CPU the code fails
  • It'll take some time to get this fully working so we should start with the easiest example -> see what code changes are necessary to make PyTorch/XLA work with greedy(...)
  • To set expectations: PyTorch's generate method is one of Transformers most used functions - it's extremely important and we're trying very hard to keep the code readable, easy to understand. If making PyTorch XLA-compatible requires too many changes or makes the code too unreadable we might come to the conclusion that it's just not worth it and maybe just add it as a "experimental" additional function but not in "main" generate. Also @michaelbenayoun @mfuntowicz is that maybe something we want to have only in optimum maybe but not in Transformers?

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@divyanshuaggarwal
Copy link

divyanshuaggarwal commented Sep 24, 2022

Hi,

Any updates on this? When can we expect to generate a function to work on TPUs? Also, will it be part of transformers or optimum? as mentioned by @patrickvonplaten above?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Sep 27, 2022

I won't have time to look into this sadly anytime soon. @gante maybe?

@gante
Copy link
Member

gante commented Sep 28, 2022

Added to my generate task queue 👍

@divyanshuaggarwal it would be part of transformers!

@divyanshuaggarwal
Copy link

Thanks @gante!

@huggingface huggingface deleted a comment from github-actions bot Oct 24, 2022
@huggingface huggingface deleted a comment from github-actions bot Nov 17, 2022
@gante gante self-assigned this Nov 17, 2022
@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Nov 17, 2022
@divyanshuaggarwal
Copy link

Hi, @gante just noticed it had been marked WIP, any ETAs on when can we expect this feature?

@sgugger
Copy link
Collaborator

sgugger commented Dec 12, 2022

This is not a prioritized feature as you can already use TPUs for generation in Flax and TensorFlow. Since you can easily convert a model from one framework to the other, there is an easy workaround :-)

@deveworld
Copy link

Is there any update on this PR?

@gante
Copy link
Member

gante commented Jun 12, 2023

@deveworld we are atm exploring PT-level optimizations, which include the static shapes needed for XLA (TPU). A significant upgrade in this direction is likely in the next releases (keep an eye there :) )

@divyanshuaggarwal
Copy link

@gante folks from Meta were able to do llama inference on TPU using pytorch XLA. Might be helpful for this issue.

https://pytorch.org/blog/path-achieve-low-inference-latency/?utm_content=254892693&utm_medium=social&utm_source=linkedin&hss_channel=lcp-78618366

@verityw
Copy link

verityw commented Aug 11, 2023

Has there been any update on this? When is the next release likely to be released?

@gante
Copy link
Member

gante commented Aug 16, 2023

We have some code ready, which makes the generation loop friendly with compiled forward passes (e.g. with torch.compile). Pretty much the same algorithm we use with TF/FLAX + XLA.

However, there are performance regressions on some devices, and the PyTorch team is having a look. We will include these changes when the performance bump is consistent across devices.

Meanwhile, feel free to adapt code from this repo/PR.

@verityw
Copy link

verityw commented Aug 16, 2023

I see. Will this work on TPU then / are TPUs one of the device that are experiencing performance regressions?

I also looked into the Optimum Neuron greedy decode implementation. While it no longer requires moving computations to CPU, running inference on TPU with it seems significantly slower than on GPU.

@gante
Copy link
Member

gante commented Aug 17, 2023

@verityw I can't confirm. We are aiming at having models that are fully compatible and efficient to use with torch.compile(), there may be additional issues when selecting the XLA backend :)

@paulbricman
Copy link

Any update on this? I'm trying to work with trl and peft on a TPU slice (to run tests on yet another HF-aspiring lib), but these newer parts of the ecosystem seem to currently only support torch, which is not supported in an XLA-friendly way in the underlying transformers.

I looked into it a bit and it seems that both mostly wrap the transformers generate(), so maybe an XLA-friendly version of that would help throughout? I also expect to encounter other issues of XLA-awkwardness in the backward step of trl, but I don't have a good intuition of that. Would love any pointers to learn about what it takes to make them XLA-friendly and how far the stack is from that.

@gante
Copy link
Member

gante commented Nov 7, 2023

Not far from seeing the light, actually!

Our current major endeavor in generate is possibility of using different types of caches. By default, caches grow with the input length, but XLA needs a fixed-size cache -- we will be adding it as part of this task. In turn, this should make the forward pass of most models XLA-compatible (or close to it).

@mmcclean-aws
Copy link

Any updates on this @gante ?

@gante
Copy link
Member

gante commented Jan 12, 2024

Yes: #27931 (it is a pre requisite :) )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

9 participants