Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static KV cache and test on Gemma-2B #4

Merged
merged 13 commits into from
Mar 15, 2024
Merged

Conversation

tengomucho
Copy link
Collaborator

@tengomucho tengomucho commented Mar 13, 2024

What does this PR do?

This test adapts TGI server to better take advantage of Pytorch/XLA graphs. Relevant changes:

  • Model compilation is disabled by default, because XLA by default. This is because it does not always work, and it is sometimes slower than using compilation.
  • Create tensors on device directly, to avoid copying them.
  • Added support for static KV cache whenever possible.
  • Added Gemma-2b example (it uses static KV cache).

All this leads to performance general enhancements, so even if I added a test with a new model test run in 4m40s whereas before they where running in 5m21s

@tengomucho tengomucho force-pushed the static-compilation branch 2 times, most recently from c342e4f to 5542841 Compare March 13, 2024 16:46
if DBG_DEVICE env var is set, it will used to set the device for the
model.
This will avoid loading the model twice.
Make compilation optional, it can be enabled with the environment
variable DBG_COMPILE. This is because:

1. There are some models that produce bugs when the model is compiled.
   (notably gemma).
2. Models inference input params shapes change, triggering
   recompilation, leading to slow performance.
3. With the added xm.mark_step, performance is actually better when the
   model is not compiled. XLA builds a graph anyway, so performance is
   going to be good.
This is to reduce useless gradient calculations.
This will allow to handle passing different params in different model
configurations later.
Some models, like Gemma and Llama, support static KV cache in
transformers. For these, it is possible to use this feature, leading to
much higher performance.
Also manually install accelerate to avoid memory issues when loading
gemma.
The test produces different results after some operations are being done
in a slightly different order.
@tengomucho tengomucho marked this pull request as ready for review March 14, 2024 17:12
Copy link
Member

@mfuntowicz mfuntowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - a few comments more for further reflexion moving forward - Congratz!

self._id = id
self._tokenizer = tokenizer
self.clear()
self._device = device
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's do the conversion from str to torch.device() right away here to ensure we can fail fast if this device doesn't exist and avoid overhead later down the road?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion does not make the check that the device is available. The only ways I found to check if the device is available is to invoke the torch_xla api directly. I can add a check before mapping the model if you wish.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed offline, adding such check is probably useless, given that the check will be done implicitly while mapping the model.

)
# Update mask only if it was set previously
if self._mask is not None:
self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)])
Copy link
Member

@mfuntowicz mfuntowicz Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for later: Does this concatenate can be replaced by an inplace set from 0 to 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll take a note.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having said that: this is handled in a transparent way by models that use static cache, I guess they already do that inside the model.

@tengomucho tengomucho merged commit fdcd7ea into main Mar 15, 2024
1 check passed
@mfuntowicz mfuntowicz deleted the static-compilation branch March 28, 2024 13:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants