Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additive_repetition_penalty sampler setting. #3627

Merged
merged 6 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api-examples/api-example-chat-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def run(user_input, history):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def run(user_input, history):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def run(context):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def run(prompt):
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'top_k': 40,
'min_length': 0,
Expand Down
1 change: 1 addition & 0 deletions docs/03 ‐ Parameters Tab.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ For more information about the parameters, the [transformers documentation](http
* **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.
* **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.
* **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.
* **additive_repetition_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition.
* **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.
* **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.
* **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens.
Expand Down
1 change: 1 addition & 0 deletions extensions/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def build_parameters(body, chat=False):
'tfs': float(body.get('tfs', 1)),
'top_a': float(body.get('top_a', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))),
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
Expand Down
1 change: 1 addition & 0 deletions extensions/openai/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'top_p': 1.0,
'top_k': 1, # choose 20 for chat in absence of another default
'repetition_penalty': 1.18,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1.0,
'suffix': None,
Expand Down
7 changes: 7 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -186,6 +187,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -244,6 +246,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -273,6 +276,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -306,6 +310,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -353,6 +358,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down Expand Up @@ -389,6 +395,7 @@
'tfs',
'top_a',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down
1 change: 1 addition & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def default_preset():
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'additive_repetition_penalty': 0,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
Expand Down
23 changes: 16 additions & 7 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
Copied from the transformers library
'''

def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
def __init__(self, penalty: float, additive_penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")

self.penalty = penalty
self.additive_penalty = additive_penalty
self._range = _range

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand All @@ -153,6 +154,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to

# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score -= self.additive_penalty

scores.scatter_(1, input_ids, score)
return scores
Expand Down Expand Up @@ -185,14 +187,20 @@ def get_logits_warper_patch(self, generation_config):


def get_logits_processor_patch(self, **kwargs):
result = self._get_logits_processor_old(**kwargs)
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
repetition_penalty = kwargs['generation_config'].repetition_penalty
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1

result = self._get_logits_processor_old(**kwargs)

if repetition_penalty_range > 0:
if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)

return result

Expand All @@ -205,6 +213,7 @@ def generation_config_init_patch(self, **kwargs):
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)


def hijack_samplers():
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def apply_stopping_strings(reply, all_stop_strings):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def list_interface_input_elements():
'epsilon_cutoff',
'eta_cutoff',
'repetition_penalty',
'additive_repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
Expand Down
1 change: 1 addition & 0 deletions modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def create_ui(default_preset):
shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty')
shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range')
shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=generate_params['tfs'], step=0.01, label='tfs')
Expand Down