From f717e123a323c5675f9ba3245eafb69982c554cf Mon Sep 17 00:00:00 2001 From: Curtis Maddalozzo Date: Tue, 30 Apr 2024 18:32:45 -0400 Subject: [PATCH] Fix Huggingface server stopping criteria (#3659) * Encoder-decoder models do not include input tokens in their output Signed-off-by: Curtis Maddalozzo * Pass stopping criteria into streamer Signed-off-by: Curtis Maddalozzo --------- Signed-off-by: Curtis Maddalozzo --- .../huggingfaceserver/generative_model.py | 6 +++++- .../huggingfaceserver/test_model.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/huggingfaceserver/huggingfaceserver/generative_model.py b/python/huggingfaceserver/huggingfaceserver/generative_model.py index 4919bfeecc..d1265da21d 100644 --- a/python/huggingfaceserver/huggingfaceserver/generative_model.py +++ b/python/huggingfaceserver/huggingfaceserver/generative_model.py @@ -414,7 +414,10 @@ async def create_completion( for seq in stop ] stop_sequence_stopping_criteria = StopSequenceStoppingCriteria( - input_length=inputs["input_ids"].shape[-1], + # Encoder-decoder models do not include input tokens in output + input_length=( + 0 if self.is_encoder_decoder else inputs["input_ids"].shape[-1] + ), stop_sequences=stop_sequences, ) stopping_criteria = StoppingCriteriaList([stop_sequence_stopping_criteria]) @@ -432,6 +435,7 @@ async def create_completion( request=request, generate_queue=response_queue, system_fingerprint=self.system_fingerprint, + stop_sequence_stopping_criteria=stop_sequence_stopping_criteria, ) else: outputs = await response_queue.get() diff --git a/python/huggingfaceserver/huggingfaceserver/test_model.py b/python/huggingfaceserver/huggingfaceserver/test_model.py index 45be906269..4287102fca 100644 --- a/python/huggingfaceserver/huggingfaceserver/test_model.py +++ b/python/huggingfaceserver/huggingfaceserver/test_model.py @@ -108,6 +108,19 @@ async def test_t5(t5_model: HuggingfaceGenerativeModel): assert response.choices[0].text == "wir setzen Worte" +@pytest.mark.asyncio +async def test_t5_stopping_criteria(t5_model: HuggingfaceGenerativeModel): + params = CreateCompletionRequest( + model="t5-small", + prompt="translate from English to German: we are making words", + stop=["setzen "], + stream=False, + ) + request = CompletionRequest(params=params) + response = await t5_model.create_completion(request) + assert response.choices[0].text == "wir setzen" + + @pytest.mark.asyncio async def test_t5_bad_params(t5_model: HuggingfaceGenerativeModel): params = CreateCompletionRequest(