-
Notifications
You must be signed in to change notification settings - Fork 8
[Param] Recheck and update repetition penalty parameter #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e53ba4f
abe3d18
3197fa7
65ae93a
f439b97
fe8b5de
fa161b8
77180aa
52155ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,9 @@ class SamplingTensors: | |
| mask_top_logprob: torch.Tensor | ||
| Mask for requests with top_logprob. | ||
| shape: (LOGPROB_TOP_K_MAX) + 1, batch_size,) | ||
| mask_prompt: torch.Tensor | ||
| Mask for request with repetition penalty (prompt part) | ||
| shape: (batch_size, vocab_size) | ||
| temperatures: torch.Tensor | ||
| Tensor for temperature values | ||
| shape: (batch_size, ) | ||
|
|
@@ -85,6 +88,7 @@ class SamplingTensors: | |
| mask_random: torch.Tensor | ||
| mask_greedy: torch.Tensor | ||
| mask_top_logprob: torch.Tensor | ||
| mask_prompt: torch.Tensor | ||
| temperatures: torch.Tensor | ||
| top_ps: torch.Tensor | ||
| top_ks: torch.Tensor | ||
|
|
@@ -102,6 +106,7 @@ def from_lists( | |
| dev, | ||
| list_mask_random: List[bool], | ||
| list_mask_top_logprob: List[List[bool]], | ||
| list_mask_prompt: List[torch.Tensor], | ||
| list_temperatures: List[float], | ||
| list_top_ps: List[float], | ||
| list_top_ks: List[int], | ||
|
|
@@ -124,6 +129,7 @@ def from_lists( | |
| ) | ||
| # `mask_top_logprob` will be on cpu | ||
| mask_top_logprob = torch.from_numpy(list_mask_top_logprob) | ||
| mask_prompt = torch.stack(list_mask_prompt) | ||
| temp = torch.tensor( | ||
| list_temperatures, | ||
| dtype=dtype, | ||
|
|
@@ -185,6 +191,7 @@ def from_lists( | |
| mask_random, | ||
| mask_greedy, | ||
| mask_top_logprob, | ||
| mask_prompt, | ||
| temp.to(device=dev, non_blocking=True), | ||
| top_ps.to(device=dev, non_blocking=True), | ||
| top_ks.to(device=dev, non_blocking=True), | ||
|
|
@@ -250,6 +257,7 @@ def from_sampling_params( | |
| vocab_size: int, | ||
| ): | ||
| list_mask_random = [] | ||
| list_mask_prompt = [] | ||
| list_temperatures = [] | ||
| list_top_ps = [] | ||
| list_top_ks = [] | ||
|
|
@@ -307,6 +315,7 @@ def from_sampling_params( | |
| list_frequency_penalties.append(param.frequency_penalty) | ||
| list_presence_penalties.append(param.presence_penalty) | ||
| list_repetition_penalties.append(param.repetition_penalty) | ||
| list_mask_prompt.append(param.mask_prompt) | ||
|
|
||
| if param.logit_bias_index: | ||
| assert param.logit_bias_value | ||
|
|
@@ -348,6 +357,7 @@ def from_sampling_params( | |
| dev, | ||
| list_mask_random, | ||
| list_mask_top_logprob, | ||
| list_mask_prompt, | ||
| list_temperatures, | ||
| list_top_ps, | ||
| list_top_ks, | ||
|
|
@@ -372,20 +382,39 @@ def from_sampling_params( | |
| ) | ||
|
|
||
|
|
||
| def adjust_logits(logits, sampling_metadata, vocab_size): | ||
| def get_bin_counts_and_mask( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment with above. Would it make sense to move this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Valery talked about this in the issue, repasting some of it here:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, thanks @binarybana and this is great point. @vvchernov, can we follow-up about this? We can separate what needs/does not need to be updated each iteration like you raised.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, thanks @binarybana to share my ideas! I repeat it again in other words above.
First of all it looks like a separate task, if we need to support correct repetition penalty in high priority it is better to do it first. I see the task is redesign of API, it should be done carefully and we need more discussions in details. I can prepare separate draft of PR with some dirty code which implements my idea and we will discuss there how it should be. And I'm still aware about logits_processor. Now it looks like it was done for json mode, but in general view it does the same as sampler. It is better also rethink this moment in our redesign |
||
| tokens: torch.Tensor, | ||
| vocab_size: int, | ||
| num_seqs: int, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| bin_counts = torch.zeros((num_seqs, vocab_size + 1), | ||
| dtype=torch.long, | ||
| device=tokens.device) | ||
| bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) | ||
| bin_counts = bin_counts[:, :vocab_size] | ||
| mask = bin_counts > 0 | ||
|
|
||
| return bin_counts, mask | ||
|
|
||
|
|
||
| def adjust_logits( | ||
| logits: torch.Tensor, | ||
| sampling_state: SamplingState, | ||
| vocab_size: int): | ||
| batch_size = logits.shape[0] | ||
| ( | ||
| apply_top_p_top_k, | ||
| apply_penalty, | ||
| apply_bias, | ||
| sampling_tensors, | ||
| ) = ( | ||
| sampling_metadata.apply_top_p_top_k, | ||
| sampling_metadata.apply_penalty, | ||
| sampling_metadata.apply_bias, | ||
| sampling_metadata.sampling_tensors, | ||
| sampling_state.apply_top_p_top_k, | ||
| sampling_state.apply_penalty, | ||
| sampling_state.apply_bias, | ||
| sampling_state.sampling_tensors, | ||
| ) | ||
| ( | ||
| prompt_mask, | ||
| temp_t, | ||
| top_ps_t, | ||
| top_ks_t, | ||
|
|
@@ -396,6 +425,7 @@ def adjust_logits(logits, sampling_metadata, vocab_size): | |
| logit_bias_indices_t, | ||
| logit_bias_values_t, | ||
| ) = ( | ||
| sampling_tensors.mask_prompt, | ||
| sampling_tensors.temperatures, | ||
| sampling_tensors.top_ps, | ||
| sampling_tensors.top_ks, | ||
|
|
@@ -411,20 +441,30 @@ def adjust_logits(logits, sampling_metadata, vocab_size): | |
| # (e.g., repetition penalty, frequency/presence penalty, logit bias, temperature...) | ||
| # in the right order. | ||
| if apply_penalty: | ||
| bin_counts, output_mask = get_bin_counts_and_mask( | ||
| past_output_tokens_t, | ||
| vocab_size, | ||
| batch_size, | ||
| ) | ||
|
|
||
| # It was checked that vLLM and HF approaches for repetition penalty are the same | ||
| # For calculation of it their combination is used (see references below) | ||
| # Calculate repetition penalty use vLLM approach | ||
| # https://github.com/vllm-project/vllm/blob/0580aab02ffe60fee50bddc80b787828eb233c44/vllm/model_executor/layers/sampler.py#L177 | ||
| # and RepetitionPenaltyLogitsProcessor approach from HF TGI API | ||
vvchernov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L332C1-L339C22 | ||
| # where score is logits | ||
| # https://github.com/huggingface/transformers/blob/de11e654c962d5b23eb53a4387cd637b01987491/src/transformers/generation/logits_process.py#L76C1-L78C92 | ||
| repetition_penalties_t = repetition_penalties_t[:, None].repeat(1, vocab_size) | ||
| prompt_mask = prompt_mask.to(repetition_penalties_t.device) | ||
| repetition_penalties_t[~(prompt_mask | output_mask)] = 1.0 | ||
| logits = torch.where( | ||
| logits > 0, logits / repetition_penalties_t, logits * repetition_penalties_t | ||
| ) | ||
| bin_counts = torch.zeros( | ||
| (batch_size, vocab_size + 1), dtype=torch.long, device=logits.device | ||
| ) | ||
| bin_counts.scatter_add_( | ||
| 1, past_output_tokens_t, torch.ones_like(past_output_tokens_t) | ||
| ) | ||
| bin_counts = bin_counts[:, :vocab_size] | ||
| mask = bin_counts > 0 | ||
|
|
||
| # Calculate frequency and presence penalties | ||
| logits -= frequency_penalties_t.unsqueeze_(dim=1) * bin_counts | ||
| logits -= presence_penalties_t.unsqueeze_(dim=1) * mask | ||
| logits -= presence_penalties_t.unsqueeze_(dim=1) * output_mask | ||
|
|
||
| # Adjust temperature | ||
| logits.div_(temp_t.unsqueeze(dim=1)) | ||
|
|
@@ -447,7 +487,7 @@ class SamplingOutput: | |
|
|
||
| def sample( | ||
| logits: torch.Tensor, | ||
| sampling_metadata: SamplingState, | ||
| sampling_state: SamplingState, | ||
| check_safety: bool = False, | ||
| ) -> SamplingOutput: | ||
| def _is_safe_to_sample(prob_like): | ||
|
|
@@ -457,7 +497,7 @@ def _is_safe_to_sample(prob_like): | |
| ) | ||
|
|
||
| res_greedy, res_random = None, None | ||
| sampling_tensors = sampling_metadata.sampling_tensors | ||
| sampling_tensors = sampling_state.sampling_tensors | ||
|
|
||
| batch_size = logits.shape[0] | ||
| mask_greedy_t, mask_random_t = ( | ||
|
|
@@ -466,13 +506,13 @@ def _is_safe_to_sample(prob_like): | |
| ) | ||
|
|
||
| next_tokens = np.empty((batch_size,), dtype=np.int64) | ||
| if sampling_metadata.has_greedy: | ||
| if sampling_state.has_greedy: | ||
| res_greedy = torch.argmax(logits[mask_greedy_t], -1) | ||
| np_mask_greedy = mask_greedy_t.cpu().numpy() | ||
| next_tokens[np_mask_greedy] = res_greedy.cpu().numpy() | ||
|
|
||
| probs_random = None | ||
| if sampling_metadata.has_random: | ||
| if sampling_state.has_random: | ||
| probs_random = torch.softmax(logits[mask_random_t], dim=-1) | ||
| if check_safety and not _is_safe_to_sample(probs_random): | ||
| return None | ||
|
|
@@ -481,9 +521,9 @@ def _is_safe_to_sample(prob_like): | |
| next_tokens[np_mask_random] = res_random.cpu().numpy() | ||
|
|
||
| logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size | ||
| if sampling_metadata.has_logprob: | ||
| if sampling_state.has_logprob: | ||
| # If everything is random sampling, save one extra softmax | ||
| if not sampling_metadata.has_greedy: | ||
| if not sampling_state.has_greedy: | ||
| assert probs_random is not None | ||
| logprobs = torch.log(probs_random) | ||
| else: | ||
|
|
@@ -494,13 +534,13 @@ def _is_safe_to_sample(prob_like): | |
| all_top_logprobs, all_top_tokens = torch.topk( | ||
| extended_logprobs, k=LOGPROB_TOP_K_MAX, dim=-1, largest=True, sorted=True | ||
| ) | ||
| mask = sampling_metadata.sampling_tensors.mask_top_logprob | ||
| mask = sampling_state.sampling_tensors.mask_top_logprob | ||
| top_tokens = all_top_tokens[mask] | ||
| top_logprobs = all_top_logprobs[mask] | ||
| for idx, batch_idx in enumerate(sampling_metadata.logprob_batch_indices): | ||
| for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices): | ||
| next_token = next_tokens[batch_idx] | ||
| assert sampling_metadata.sampling_params[batch_idx].logprobs | ||
| top_k = sampling_metadata.sampling_params[batch_idx].top_logprobs | ||
| assert sampling_state.sampling_params[batch_idx].logprobs | ||
| top_k = sampling_state.sampling_params[batch_idx].top_logprobs | ||
| logprob_infos[batch_idx] = RawLogprobsInfo( | ||
| current_token_id=next_token, | ||
| current_logprob=logprobs[batch_idx][next_token], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.