From 6d0db3e935ad65535e69c51d19e58d99a05fc728 Mon Sep 17 00:00:00 2001 From: tdrussell <6509934+tdrussell@users.noreply.github.com> Date: Fri, 18 Aug 2023 20:53:17 -0500 Subject: [PATCH 1/5] Add additive_repetition_penalty sampler setting. --- extensions/api/util.py | 1 + modules/loaders.py | 5 +++++ modules/presets.py | 1 + modules/sampler_hijack.py | 23 +++++++++++++++++------ modules/text_generation.py | 2 +- modules/ui.py | 1 + modules/ui_parameters.py | 5 ++++- 7 files changed, 30 insertions(+), 8 deletions(-) diff --git a/extensions/api/util.py b/extensions/api/util.py index 032a9e5c93..30d8819ea0 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -31,6 +31,7 @@ def build_parameters(body, chat=False): 'tfs': float(body.get('tfs', 1)), 'top_a': float(body.get('top_a', 0)), 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), + 'additive_repetition_penalty': float(body.get('additive_repetition_penalty', body.get('additive_rep_pen', 0))), 'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)), 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 'top_k': int(body.get('top_k', 0)), diff --git a/modules/loaders.py b/modules/loaders.py index 7444555f96..cc888d5e2d 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -116,6 +116,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -146,6 +147,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -183,6 +185,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -213,6 +216,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -254,6 +258,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', diff --git a/modules/presets.py b/modules/presets.py index 32b7f71c52..289ee02c73 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -16,6 +16,7 @@ def default_preset(): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index d5ebbb7690..6890d90de6 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -127,11 +127,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): Copied from the transformers library ''' - def __init__(self, penalty: float, _range: int): - if not isinstance(penalty, float) or not (penalty > 0): - raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + def __init__(self, penalty: float, additive_penalty: float, _range: int): + if not (penalty > 0): + raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}") self.penalty = penalty + self.additive_penalty = additive_penalty self._range = _range def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -141,6 +142,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.penalty, score / self.penalty) + score -= self.additive_penalty scores.scatter_(1, input_ids, score) return scores @@ -172,14 +174,22 @@ def get_logits_warper_patch(self, generation_config): def get_logits_processor_patch(self, **kwargs): - result = self._get_logits_processor_old(**kwargs) repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range repetition_penalty = kwargs['generation_config'].repetition_penalty + additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty + need_rep_pen_hijack = (repetition_penalty_range > 0) or (additive_repetition_penalty > 0) + if need_rep_pen_hijack: + # Make sure it always creates a RepetitionPenaltyLogitsProcessor + kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 + result = self._get_logits_processor_old(**kwargs) + if need_rep_pen_hijack: + # Now set the rep_pen back to the actual value (just in case) + kwargs['generation_config'].repetition_penalty = repetition_penalty - if repetition_penalty_range > 0: + if need_rep_pen_hijack: for i in range(len(result)): if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': - result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range) + result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range) return result @@ -192,6 +202,7 @@ def generation_config_init_patch(self, **kwargs): self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_tau = kwargs.pop("mirostat_tau", 5) self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) + self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0) def hijack_samplers(): diff --git a/modules/text_generation.py b/modules/text_generation.py index 30e81355e9..9d6958a6b2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -236,7 +236,7 @@ def apply_stopping_strings(reply, all_stop_strings): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']: generate_params[k] = state[k] if state['negative_prompt'] != '': diff --git a/modules/ui.py b/modules/ui.py index 15f24d859c..821e390707 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -99,6 +99,7 @@ def list_interface_input_elements(): 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index a0f9515898..92fdcac050 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -36,6 +36,7 @@ def create_ui(default_preset): with gr.Column(): shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') + shared.gradio['additive_repetition_penalty'] = gr.Slider(0, 4, value=generate_params['additive_repetition_penalty'], step=0.05, label='additive_repetition_penalty') shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') @@ -79,7 +80,9 @@ def create_ui(default_preset): ### eta_cutoff In units of 1e-4; a reasonable value is 3. Should be used with top_p, top_k, and epsilon_cutoff set to 0. ### repetition_penalty - Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition. + Exponential penalty factor for repeating prior tokens. This is a multiplicative factor on the raw token scores. 1 means no penalty, higher value = less repetition, lower value = more repetition. + ### additive_repetition_penalty + Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. 0 means no penalty, higher value = less repetition, lower value = more repetition. ### repetition_penalty_range The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. ### encoder_repetition_penalty From 4878994ae11aeeb81ea22e19c920030d70938e74 Mon Sep 17 00:00:00 2001 From: tdrussell <6509934+tdrussell@users.noreply.github.com> Date: Sun, 20 Aug 2023 11:08:14 -0500 Subject: [PATCH 2/5] Log what token probs changed the most from rep pen. --- modules/sampler_hijack.py | 27 ++++++++++++++++++++++----- modules/shared.py | 1 + modules/text_generation.py | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index 6890d90de6..9f8ef41edb 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -10,6 +10,7 @@ TemperatureLogitsWarper ) +import modules.shared as shared class TailFreeLogitsWarper(LogitsWarper): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -134,9 +135,10 @@ def __init__(self, penalty: float, additive_penalty: float, _range: int): self.penalty = penalty self.additive_penalty = additive_penalty self._range = _range + shared.rep_pen_diffs = {} def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - + original_scores = scores.clone() input_ids = input_ids[:, -self._range:] score = torch.gather(scores, 1, input_ids) @@ -145,6 +147,21 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to score -= self.additive_penalty scores.scatter_(1, input_ids, score) + + # Find out what probabilities changed the most + old_probs = torch.nn.functional.softmax(original_scores, dim=-1) + new_probs = torch.nn.functional.softmax(scores, dim=-1) + prob_diff = (new_probs - old_probs).squeeze() + increase_only = torch.maximum(prob_diff, torch.zeros_like(prob_diff)) + decrease_only = torch.maximum(-prob_diff, torch.zeros_like(prob_diff)) + if 'increase' in shared.rep_pen_diffs: + shared.rep_pen_diffs['increase'] += increase_only + else: + shared.rep_pen_diffs['increase'] = increase_only + if 'decrease' in shared.rep_pen_diffs: + shared.rep_pen_diffs['decrease'] += decrease_only + else: + shared.rep_pen_diffs['decrease'] = decrease_only return scores @@ -177,16 +194,16 @@ def get_logits_processor_patch(self, **kwargs): repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range repetition_penalty = kwargs['generation_config'].repetition_penalty additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty - need_rep_pen_hijack = (repetition_penalty_range > 0) or (additive_repetition_penalty > 0) - if need_rep_pen_hijack: + do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0) + if do_rep_pen_hijack: # Make sure it always creates a RepetitionPenaltyLogitsProcessor kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 result = self._get_logits_processor_old(**kwargs) - if need_rep_pen_hijack: + if do_rep_pen_hijack: # Now set the rep_pen back to the actual value (just in case) kwargs['generation_config'].repetition_penalty = repetition_penalty - if need_rep_pen_hijack: + if do_rep_pen_hijack: for i in range(len(result)): if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range) diff --git a/modules/shared.py b/modules/shared.py index 385b99da1d..a65f79e9ab 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,6 +14,7 @@ is_seq2seq = False model_dirty_from_training = False lora_names = [] +rep_pen_diffs = {} # Generation variables stop_everything = False diff --git a/modules/text_generation.py b/modules/text_generation.py index 9d6958a6b2..687ad15e3c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -322,6 +322,21 @@ def generate_with_streaming(**kwargs): original_tokens = len(original_input_ids[0]) new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0) print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + + def print_most_changed_tokens(positive_prob_diffs): + diffs, sorted_indices = torch.sort(positive_prob_diffs, descending=True) + top_n = 30 + diffs = diffs[:top_n].tolist() + tokens = shared.tokenizer.batch_decode([[idx] for idx in sorted_indices[:top_n]]) + for token, diff in zip(tokens, diffs): + token = token.encode('unicode_escape') + print(f'{token}: {diff}') + if shared.args.verbose and shared.rep_pen_diffs: + print('Most penalized tokens due to repetition penalty:') + print_most_changed_tokens(shared.rep_pen_diffs['decrease']) + print() + print('Most boosted tokens due to repetition penalty:') + print_most_changed_tokens(shared.rep_pen_diffs['increase']) return From 28c933778f99a5948576c56de7b9b7417bf7015c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 22 Oct 2023 21:50:16 -0700 Subject: [PATCH 3/5] Minor changes --- modules/loaders.py | 2 ++ modules/sampler_hijack.py | 14 +++++--------- modules/shared.py | 1 - 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/modules/loaders.py b/modules/loaders.py index cf51079e43..c7e5d80031 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -246,6 +246,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', @@ -394,6 +395,7 @@ 'tfs', 'top_a', 'repetition_penalty', + 'additive_repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index e6bc8e7ee4..ddc8b021e8 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -148,7 +148,7 @@ def __init__(self, penalty: float, additive_penalty: float, _range: int): self._range = _range def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - original_scores = scores.clone() + input_ids = input_ids[:, -self._range:] score = torch.gather(scores, 1, input_ids) @@ -158,13 +158,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores.scatter_(1, input_ids, score) - # Find out what probabilities changed the most - old_probs = torch.nn.functional.softmax(original_scores, dim=-1) - new_probs = torch.nn.functional.softmax(scores, dim=-1) - prob_diff = (new_probs - old_probs).squeeze() - increase_only = torch.maximum(prob_diff, torch.zeros_like(prob_diff)) - decrease_only = torch.maximum(-prob_diff, torch.zeros_like(prob_diff)) - return scores @@ -195,13 +188,16 @@ def get_logits_warper_patch(self, generation_config): def get_logits_processor_patch(self, **kwargs): - repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range repetition_penalty = kwargs['generation_config'].repetition_penalty additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty + repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range + do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0) + if do_rep_pen_hijack: # Make sure it always creates a RepetitionPenaltyLogitsProcessor kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 + result = self._get_logits_processor_old(**kwargs) if do_rep_pen_hijack: # Now set the rep_pen back to the actual value (just in case) diff --git a/modules/shared.py b/modules/shared.py index b67832dfa4..3744d551f3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -15,7 +15,6 @@ is_seq2seq = False model_dirty_from_training = False lora_names = [] -rep_pen_diffs = {} # Generation variables stop_everything = False From dc6335cf0a3616867a92773efc7b49b20b243c79 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 22 Oct 2023 22:04:52 -0700 Subject: [PATCH 4/5] Minor changes --- modules/sampler_hijack.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index ddc8b021e8..c0c85c2dec 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -157,7 +157,6 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to score -= self.additive_penalty scores.scatter_(1, input_ids, score) - return scores @@ -191,17 +190,12 @@ def get_logits_processor_patch(self, **kwargs): repetition_penalty = kwargs['generation_config'].repetition_penalty additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range - do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0) - if do_rep_pen_hijack: - # Make sure it always creates a RepetitionPenaltyLogitsProcessor + # Make sure that a RepetitionPenaltyLogitsProcessor will be created kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1 result = self._get_logits_processor_old(**kwargs) - if do_rep_pen_hijack: - # Now set the rep_pen back to the actual value (just in case) - kwargs['generation_config'].repetition_penalty = repetition_penalty if do_rep_pen_hijack: for i in range(len(result)): From 2c857e0dbb7936fb92421501e848b93e2a7ed250 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 22 Oct 2023 22:07:22 -0700 Subject: [PATCH 5/5] Add to API examples --- api-examples/api-example-chat-stream.py | 1 + api-examples/api-example-chat.py | 1 + api-examples/api-example-stream.py | 1 + api-examples/api-example.py | 1 + extensions/openai/defaults.py | 1 + 5 files changed, 5 insertions(+) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index bfa5d4f580..31bd120cea 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -52,6 +52,7 @@ async def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index b2a1e1e42b..e7c0ae7d78 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -46,6 +46,7 @@ def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 966ca6f62d..ad907196de 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -35,6 +35,7 @@ async def run(context): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/api-examples/api-example.py b/api-examples/api-example.py index d9fd60d05c..2f0267f294 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -27,6 +27,7 @@ def run(prompt): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 2ebade8272..1115ba97ff 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -10,6 +10,7 @@ 'top_p': 1.0, 'top_k': 1, # choose 20 for chat in absence of another default 'repetition_penalty': 1.18, + 'additive_repetition_penalty': 0, 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1.0, 'suffix': None,