Skip to content

Conversation

@vvchernov
Copy link

@vvchernov vvchernov commented Feb 12, 2024

  • Recheck that HF TGI API and vLLM approach for repetition penalty is the same
  • Correct calculation of repetition penalty, add prompt tokens for it
  • Clean code: rename sampling_metadata -> sampling_state anywhere

@vvchernov vvchernov marked this pull request as draft February 12, 2024 07:43
@vvchernov vvchernov marked this pull request as ready for review February 12, 2024 17:31
@vvchernov
Copy link
Author

cc @sunggg

Copy link

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the quick action, @vvchernov! A couple questions.



def adjust_logits(logits, sampling_metadata, vocab_size):
def get_bin_counts_and_mask(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment with above. Would it make sense to move this to SamplingState prep and perform this on CPU? If the perf impact is not bad, it seems better to prepare there as we are looking for an option to make the prep process more async.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valery talked about this in the issue, repasting some of it here:

In my point of view the good place is SamplingState, but it calculates each time for new request. It looks like it is not good thing due to in general it can be done once, but redo this requires much time. May be SamplingParams in Request should be replaced by SamplingState with some sharing for parameter n > 1.
What do you think about it?
Just now I plan to save it in SamplingParams for quick fix, but we need a better place

Copy link

@sunggg sunggg Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks @binarybana and this is great point. @vvchernov, can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks @binarybana to share my ideas! I repeat it again in other words above.

can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.

First of all it looks like a separate task, if we need to support correct repetition penalty in high priority it is better to do it first. I see the task is redesign of API, it should be done carefully and we need more discussions in details. I can prepare separate draft of PR with some dirty code which implements my idea and we will discuss there how it should be. And I'm still aware about logits_processor. Now it looks like it was done for json mode, but in general view it does the same as sampler. It is better also rethink this moment in our redesign

@binarybana
Copy link

@vvchernov , thanks for the quick action on this one. Can you also see what the performance impact is with and without this enabled? Ideally for sequences of ~2k input, 50 tokens under fairly loaded (5-15 VUs) output to match customers.

assert torch.allclose(expected, new_logits)


def _test_penalties_checker():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add the testcase?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My colleague extends tests for sampling parameters, particularly for repetition_penalty. See #200

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is okay to be very basic, so please add one. We've not been doing this so far but I'd like to recommend every PR to have unittests that validates the very basic functionality at least. With SLM migration, we will install the CI as well.

@vvchernov vvchernov force-pushed the vc/repetition_penalty branch from 4c554a0 to f439b97 Compare February 14, 2024 09:21
return sampling_state


def _test_temperature(temp=0, batch_size=1):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunggg why do we have _ symbol before test? I added such symbol to functions which cannot be run using pytests, but functions from test_sampler.py does not require any compiled model and can be called under pytest

If there is no special reason, I propose to remove underscore symbol

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can ask Iliya to do it in his PR (#200) due to he extends these tests now.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, I just followed the other testcase. I'm okay to remove this for pytest.

@binarybana binarybana merged commit 2ebbacf into octoml:batch-serving Feb 14, 2024
@sunggg
Copy link

sunggg commented Feb 14, 2024

Based on the offline discussion, we decided to merge this for the sense of urgency. @vvchernov, let's add the test case in #200.

@vvchernov vvchernov deleted the vc/repetition_penalty branch February 16, 2024 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants