Skip to content

Commit

Permalink
Gemma GPTQ checks: skip logprob checks
Browse files Browse the repository at this point in the history
This test fails somewhat regularly due to non-determinism and this
test is primarily to verify that we are loading a model which doesn't
have `float16` as the default dtype correctly.
  • Loading branch information
danieldk committed May 30, 2024
1 parent 36dd160 commit 967ced2
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions integration-tests/models/test_flash_gemma_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):

@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, response_snapshot):
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
response = await flash_gemma_gptq.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)

assert response.details.generated_tokens == 10
assert response == response_snapshot
assert response == ignore_logprob_response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
async def test_flash_gemma_gptq_all_params(
flash_gemma_gptq, ignore_logprob_response_snapshot
):
response = await flash_gemma_gptq.generate(
"Test request",
max_new_tokens=10,
Expand All @@ -44,13 +46,13 @@ async def test_flash_gemma_gptq_all_params(flash_gemma_gptq, response_snapshot):
)

assert response.details.generated_tokens == 10
assert response == response_snapshot
assert response == ignore_logprob_response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_load(
flash_gemma_gptq, generate_load, response_snapshot
flash_gemma_gptq, generate_load, ignore_logprob_response_snapshot
):
responses = await generate_load(
flash_gemma_gptq, "Test request", max_new_tokens=10, n=4
Expand All @@ -59,4 +61,4 @@ async def test_flash_gemma_gptq_load(
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])

assert responses == response_snapshot
assert responses == ignore_logprob_response_snapshot

0 comments on commit 967ced2

Please sign in to comment.