-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Small optimizations #5
Conversation
This will only enable compilation for decoding. Note that there is not a big speedup for now, probably due to slot increasing buffer size over time, triggering recompilation.
Logits post-processing is not very heavyweight, and doing it on CPU actually accelerates decoding, because compilation is not re-triggered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! 🤗
@@ -512,8 +523,11 @@ def _generate_token( | |||
# Save KV cache | |||
self.past_key_values = outputs.past_key_values | |||
# Barrier for XLA model | |||
xm.mark_step(wait=False) | |||
xm.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my knowledge: We were not waiting before, why this has changed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default is wait=False
, and I did not want to give the false impression I am changing the default behaviour, so I just removed the default parameter.
@@ -44,13 +45,19 @@ def create_request( | |||
seed: int = 0, | |||
repetition_penalty: float = 1.0, | |||
): | |||
# For these tests we can safely set typical_p to 1.0 (default) | |||
typical_p = 1.0 | |||
if do_sample == False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not do_sample
@@ -35,13 +36,19 @@ def create_request( | |||
seed: int = 0, | |||
repetition_penalty: float = 1.0, | |||
): | |||
# For these tests we can safely set typical_p to 1.0 (default) | |||
typical_p = 1.0 | |||
if do_sample == False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not do_sample
4dda7e9
to
401eea6
Compare
What does this PR do?
This PR adds few optimizations that have been added in preparation of model compilation for decoding. Note that compilation is still not enabled by default due to a bug I am currently investigating.
On the other hand, I spent some time profiling the code with the xla profiling API and I was able to understand that adding few
xm.mark_step
improved the performance, and that the token processing code eventually runs faster if executed in the CPU, because it will avoid recompilation.