Skip to content

Commit

Permalink
Fix Huggingface server stopping criteria (#3659)
Browse files Browse the repository at this point in the history
* Encoder-decoder models do not include input tokens in their output

Signed-off-by: Curtis Maddalozzo <cmaddalozzo@bloomberg.net>

* Pass stopping criteria into streamer

Signed-off-by: Curtis Maddalozzo <cmaddalozzo@bloomberg.net>

---------

Signed-off-by: Curtis Maddalozzo <cmaddalozzo@bloomberg.net>
  • Loading branch information
cmaddalozzo committed Apr 30, 2024
1 parent 8cfb3e0 commit f717e12
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ async def create_completion(
for seq in stop
]
stop_sequence_stopping_criteria = StopSequenceStoppingCriteria(
input_length=inputs["input_ids"].shape[-1],
# Encoder-decoder models do not include input tokens in output
input_length=(
0 if self.is_encoder_decoder else inputs["input_ids"].shape[-1]
),
stop_sequences=stop_sequences,
)
stopping_criteria = StoppingCriteriaList([stop_sequence_stopping_criteria])
Expand All @@ -432,6 +435,7 @@ async def create_completion(
request=request,
generate_queue=response_queue,
system_fingerprint=self.system_fingerprint,
stop_sequence_stopping_criteria=stop_sequence_stopping_criteria,
)
else:
outputs = await response_queue.get()
Expand Down
13 changes: 13 additions & 0 deletions python/huggingfaceserver/huggingfaceserver/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ async def test_t5(t5_model: HuggingfaceGenerativeModel):
assert response.choices[0].text == "wir setzen Worte"


@pytest.mark.asyncio
async def test_t5_stopping_criteria(t5_model: HuggingfaceGenerativeModel):
params = CreateCompletionRequest(
model="t5-small",
prompt="translate from English to German: we are making words",
stop=["setzen "],
stream=False,
)
request = CompletionRequest(params=params)
response = await t5_model.create_completion(request)
assert response.choices[0].text == "wir setzen"


@pytest.mark.asyncio
async def test_t5_bad_params(t5_model: HuggingfaceGenerativeModel):
params = CreateCompletionRequest(
Expand Down

0 comments on commit f717e12

Please sign in to comment.