diff --git a/python/huggingfaceserver/huggingfaceserver/generative_model.py b/python/huggingfaceserver/huggingfaceserver/generative_model.py index 4919bfeecc..d1265da21d 100644 --- a/python/huggingfaceserver/huggingfaceserver/generative_model.py +++ b/python/huggingfaceserver/huggingfaceserver/generative_model.py @@ -414,7 +414,10 @@ async def create_completion( for seq in stop ] stop_sequence_stopping_criteria = StopSequenceStoppingCriteria( - input_length=inputs["input_ids"].shape[-1], + # Encoder-decoder models do not include input tokens in output + input_length=( + 0 if self.is_encoder_decoder else inputs["input_ids"].shape[-1] + ), stop_sequences=stop_sequences, ) stopping_criteria = StoppingCriteriaList([stop_sequence_stopping_criteria]) @@ -432,6 +435,7 @@ async def create_completion( request=request, generate_queue=response_queue, system_fingerprint=self.system_fingerprint, + stop_sequence_stopping_criteria=stop_sequence_stopping_criteria, ) else: outputs = await response_queue.get() diff --git a/python/huggingfaceserver/huggingfaceserver/test_model.py b/python/huggingfaceserver/huggingfaceserver/test_model.py index 45be906269..4287102fca 100644 --- a/python/huggingfaceserver/huggingfaceserver/test_model.py +++ b/python/huggingfaceserver/huggingfaceserver/test_model.py @@ -108,6 +108,19 @@ async def test_t5(t5_model: HuggingfaceGenerativeModel): assert response.choices[0].text == "wir setzen Worte" +@pytest.mark.asyncio +async def test_t5_stopping_criteria(t5_model: HuggingfaceGenerativeModel): + params = CreateCompletionRequest( + model="t5-small", + prompt="translate from English to German: we are making words", + stop=["setzen "], + stream=False, + ) + request = CompletionRequest(params=params) + response = await t5_model.create_completion(request) + assert response.choices[0].text == "wir setzen" + + @pytest.mark.asyncio async def test_t5_bad_params(t5_model: HuggingfaceGenerativeModel): params = CreateCompletionRequest(