Skip to content

Commit

Permalink
[Bugfix] fix crash if max_tokens=None (vllm-project#2570)
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolaBorisov committed Jan 31, 2024
1 parent 7ce229f commit 33ef7a6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def test_duplicated_ignored_sequence_group():
assert len(prompts) == len(outputs)


def test_max_tokens_none():
sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=None)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["Just say hello!"]
outputs = llm.generate(prompts, sampling_params=sampling_params)

assert len(prompts) == len(outputs)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
13 changes: 13 additions & 0 deletions tests/test_sampling_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Tests for the SamplingParams class.
"""
from vllm import SamplingParams


def test_max_tokens_none():
"""max_tokens=None should be allowed"""
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)


if __name__ == "__main__":
import pytest
pytest.main([__file__])

0 comments on commit 33ef7a6

Please sign in to comment.