Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small optimizations #5

Merged
merged 5 commits into from
Mar 22, 2024
Merged

Small optimizations #5

merged 5 commits into from
Mar 22, 2024

Conversation

tengomucho
Copy link
Collaborator

@tengomucho tengomucho commented Mar 20, 2024

What does this PR do?

This PR adds few optimizations that have been added in preparation of model compilation for decoding. Note that compilation is still not enabled by default due to a bug I am currently investigating.
On the other hand, I spent some time profiling the code with the xla profiling API and I was able to understand that adding few xm.mark_step improved the performance, and that the token processing code eventually runs faster if executed in the CPU, because it will avoid recompilation.

This will only enable compilation for decoding. Note that there is not a
big speedup for now, probably due to slot increasing buffer size over
time, triggering recompilation.
Logits post-processing is not very heavyweight, and doing it on CPU
actually accelerates decoding, because compilation is not re-triggered.
@tengomucho tengomucho marked this pull request as ready for review March 20, 2024 15:57
Copy link
Member

@mfuntowicz mfuntowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🤗

@@ -512,8 +523,11 @@ def _generate_token(
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)
xm.mark_step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my knowledge: We were not waiting before, why this has changed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is wait=False, and I did not want to give the false impression I am changing the default behaviour, so I just removed the default parameter.

@@ -44,13 +45,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if do_sample == False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not do_sample

@@ -35,13 +36,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if do_sample == False:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not do_sample

@tengomucho tengomucho merged commit a8452e7 into main Mar 22, 2024
1 check passed
@tengomucho tengomucho deleted the small-optimizations branch March 22, 2024 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants