Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Adapt repetition_penalty and no_repeat_ngram_size #5708

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions colossalai/inference/batch_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def use_spec_dec(self) -> bool:
def num_tokens_to_verify(self) -> int:
return self._num_tokens_to_verify

@property
def batch_token_ids(self):
return self.get_batch_token_ids()

def set_use_spec_dec(self, num_tokens_to_verify: int = 5) -> None:
"""Set batch bucket to use speculatvie decoding.
This will notify the adjust the lengths of inputs during modeling,
Expand Down Expand Up @@ -328,6 +332,7 @@ def pop_n_seqs(
seqs.append(seq)
if not self.is_compact:
self._make_compact()

return seqs, block_tables

def pop_finished(
Expand Down Expand Up @@ -432,6 +437,7 @@ def merge(self, other: "BatchBucket") -> List[int]:
block_tables = torch.stack(block_tables_li)
self.add_seqs(seqs, alloc_block_tables=block_tables)
unmerged_ids = other.seqs_ids

return unmerged_ids

########## The following methods are expected to be used in modeling ###########
Expand Down Expand Up @@ -504,6 +510,14 @@ def get_sequence_lengths(self) -> torch.Tensor:
sequence_lengths = self.seq_lengths[: self.current_batch_size]
return sequence_lengths.to(device=self.device)

def get_batch_token_ids(self) -> List[torch.LongTensor]:
assert self.is_compact # Debug usage
out = []
for seq_id, _ in self._sequences_indexes.items():
seq: Sequence = self._sequences_dict[seq_id]
out.append(torch.tensor(seq.input_token_id + seq.output_token_id, device=self.device))
return out

yuehuayingxueluo marked this conversation as resolved.
Show resolved Hide resolved
# For compatibility
@property
def fd_inter_tensor(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions colossalai/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class InferenceConfig:
early_stopping (Optional[bool]): Whether to stop the generation when all beam hypotheses have finished or not, defaults to False.
top_k (Optional[int]): The number of highest probability vocabulary tokens to keep for top-k-filtering, defaults to None.
top_p (Optional[float]): The cumulative probability threshold for retaining tokens with a total probability above it, defaults to None.
min_p (Optional[float]): The minimum probability to keep for top-p filtering, defaults to None.
temperature (Optional[float]): Randomness used to control randomization, defaults to 1.0.
repetition_penalty (Optional[float]): The parameter that influences the model's treatment of new tokens in relation to their appearance in the prompt and the generated text. Values greater than 1 incentivize the model to introduce new tokens, whereas values less than 1 incentivize token repetition., defaults to 1.0.
no_repeat_ngram_size (Optional[int]): If no_repeat_ngram_size > 0, the consecutive tokens of ngram size can only appear once in inference sentences.
n_spec_tokens (int): The maximum number of speculating tokens, defaults to None.
glimpse_large_kv (bool): Whether to use large KV in drafter model, defaults to False.
block_size (int): The number of blocks in a logical block, defaults to 16.
Expand Down Expand Up @@ -136,7 +138,9 @@ class InferenceConfig:
early_stopping: Optional[bool] = False
top_k: Optional[int] = None
top_p: Optional[float] = None
min_p: Optional[float] = None
temperature: Optional[float] = 1.0
no_repeat_ngram_size: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0

# speculative decoding configs
max_n_spec_tokens: int = 5
Expand Down Expand Up @@ -213,7 +217,7 @@ def to_generation_config(self, model_config) -> GenerationConfig:
"do_sample": self.do_sample,
"num_beams": self.beam_width,
}
for type in ["top_k", "top_p", "min_p"]:
for type in ["repetition_penalty", "no_repeat_ngram_size", "temperature", "top_k", "top_p"]:
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self, type):
meta_config[type] = getattr(self, type)
for type in ["pad_token_id", "bos_token_id", "eos_token_id"]:
Expand Down
14 changes: 12 additions & 2 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,20 @@ def search_tokens(self, generation_config: GenerationConfig, logits):
Sample tokens for finished requests.
"""

# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process repetition_penalty
for type in ["repetition_penalty", "no_repeat_ngram_size"]:
if type in config_dict and config_dict[type] is not None:
if not self.prefill_bb.is_empty:
batch = self.prefill_bb
else:
batch = self.running_bb
yuanheng-zhao marked this conversation as resolved.
Show resolved Hide resolved
logits = logit_processor(type, logits, config_dict[type], batch)
yuehuayingxueluo marked this conversation as resolved.
Show resolved Hide resolved

# do logit processor
if generation_config.do_sample:
# NOTE: need to decide the granularity to process logits (sequence or batch)
config_dict = generation_config.to_dict()
# process temperature, top_k, top_p
for type in ["temperature", "top_k", "top_p"]:
if type in config_dict and config_dict[type] is not None:
logits = logit_processor(type, logits, config_dict[type])
Expand Down
80 changes: 74 additions & 6 deletions colossalai/inference/logit_processors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/generation/logits_process.py

import torch
import torch.nn.functional as F

from colossalai.inference.batch_bucket import BatchBucket
from colossalai.logging import get_dist_logger

_LOGIT_PROCESSOR_MAP = {}
logger = get_dist_logger(__name__)
yuehuayingxueluo marked this conversation as resolved.
Show resolved Hide resolved


def register_logit_processor(process_type):
Expand All @@ -17,6 +23,68 @@ def register(func):
return register


@register_logit_processor("no_repeat_ngram_size")
def no_repeat_ngram_size_logit_process(logits, ngram_size: int, batch: BatchBucket):
"""
enforces no repetition of n-grams to avoid repetitions of word sequences.
"""

if not isinstance(ngram_size, int) or ngram_size < 0:
raise ValueError(f"'temperature={ngram_size}' should be a strictly positive integer.")

if ngram_size != 0:
batch_token_ids = batch.batch_token_ids
batch_size = len(batch_token_ids)

for batch_id in range(batch_size):
current_token_ids = batch_token_ids[batch_id]
current_len = current_token_ids.size(0)
if current_len + 1 < ngram_size:
continue

token_ids_list = current_token_ids.tolist()

ngrams_dict = {}

for ngram in zip(*[token_ids_list[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
ngrams_dict[prev_ngram_tuple] = ngrams_dict.get(prev_ngram_tuple, []) + [ngram[-1]]

prev_ngrams = tuple(token_ids_list[current_len + 1 - ngram_size : current_len])
banned_token = ngrams_dict.get(prev_ngrams, [])

logits[batch_id, banned_token] = -float("inf")

return logits


@register_logit_processor("repetition_penalty")
def repetition_penalty_logit_process(logits, penalty: float, batch: BatchBucket):
"""
apply the penalty to the tokens present in the prompt.
"""

if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"'penalty={penalty}' has to be a strictly positive float and greater than 0.")

logit_list = []

# TODO(yuehuayingxueluo) This is only a temporary implementation. Later, we will implement presence_penalties, frequency_penalties, and repetition_penalties using CUDA kernels.
if penalty != 1.0:
batch_token_ids = batch.batch_token_ids
for batch_id in range(len(batch_token_ids)):
current_logit = logits[batch_id]
current_token = batch_token_ids[batch_id]

curretn_socre = torch.gather(current_logit, 0, current_token)
curretn_socre = torch.where(curretn_socre < 0, curretn_socre * penalty, curretn_socre / penalty)
logit_list.append(current_logit.scatter(0, current_token, curretn_socre))

logits = torch.stack(logit_list)

return logits


@register_logit_processor("temperature")
def temperature_logit_process(logits, temperature: float):
"""
Expand Down Expand Up @@ -68,14 +136,13 @@ def top_p_logit_processor(logits, top_p: float):
return logits


def logit_processor(processor: str, logits, attrs):
def logit_processor(processor: str, logits, *args, **kwargs):
"""
do logit process for given logits.

Args:
processor(str): the type of logit processor
logits(torch.Tensor): input logits
attrs(dict): attrs of the logit processor

Returns:
logits after process
Expand All @@ -84,8 +151,9 @@ def logit_processor(processor: str, logits, attrs):
return logits
else:
func = _LOGIT_PROCESSOR_MAP[processor]
try:
logits = func(logits, attrs)
except Exception:
return logits
# try:
logits = func(logits, *args, **kwargs)
# except Exception as e:
# logger.warning(f"An exception ({e}) occurred during the logit processing ({processor}), skip this logit processing step.")
# return logits
yuehuayingxueluo marked this conversation as resolved.
Show resolved Hide resolved
return logits
Loading