-
Notifications
You must be signed in to change notification settings - Fork 13.5k
bugfix: Respect n_predict=-2 in server (#12264) #12323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b85e149
f3fdca7
7199eb9
8511ec5
ff41929
f94e105
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1321,17 +1321,24 @@ struct server_slot { | |
| && are_lora_equal(lora, other_slot.lora); | ||
| } | ||
|
|
||
| // There are two caps on the budge of a single request: | ||
| // * [params.n_predict] | ||
| // * [global_params.n_predict] | ||
| // This function returns true if the request is not limited by either of them. | ||
| bool has_budget(const common_params & global_params) { | ||
| if (params.n_predict == -1 && global_params.n_predict == -1) { | ||
| return true; // limitless | ||
| } | ||
| n_remaining = INT32_MAX; | ||
|
|
||
| n_remaining = -1; | ||
| // The request or server have specified limits on the number of tokens to generate. | ||
| if ((params.n_predict >= 0) || (global_params.n_predict >= 0)) { | ||
| n_remaining = std::min(n_remaining, params.n_predict - n_decoded); | ||
| } | ||
|
|
||
| if (params.n_predict != -1) { | ||
| n_remaining = params.n_predict - n_decoded; | ||
| } else if (global_params.n_predict != -1) { | ||
| n_remaining = global_params.n_predict - n_decoded; | ||
| // The request or server have limits based on the context window. | ||
| if (params.n_predict == -2 || global_params.n_predict == -2) { | ||
| n_remaining = std::min(n_remaining, n_ctx - n_decoded); | ||
|
Comment on lines
+1334
to
+1341
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tbh this code become a bit hard to understand now (and probably that's why the CI fails) What I recommend here is instead of having this double Indeed, my idea is much simple, why risking modifying all of this while we can just set |
||
| } | ||
|
|
||
| return n_remaining > 0; // no budget | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,13 +143,15 @@ def test_consistent_result_same_seed(n_slots: int): | |
| def test_different_result_different_seed(n_slots: int): | ||
| global server | ||
| server.n_slots = n_slots | ||
| server.n_predict = -1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get why it is needed here. The default value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a bit of a readability and maintainability issue in this file in that the server object is global, and as a test fixture is resused across tests, even though every test changes its parameters. |
||
| server.start() | ||
| last_res = None | ||
| for seed in range(4): | ||
| res = server.make_request("POST", "/completion", data={ | ||
| "prompt": "I believe the meaning of life is", | ||
| "seed": seed, | ||
| "temperature": 1.0, | ||
| "n_predict": -1, | ||
| "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed | ||
| }) | ||
| if last_res is not None: | ||
|
|
@@ -426,3 +428,18 @@ def test_cancel_request(): | |
| time.sleep(1) # wait for HTTP_POLLING_SECONDS | ||
| res = server.make_request("GET", "/slots") | ||
| assert res.body[0]["is_processing"] == False | ||
|
|
||
|
|
||
| def test_context_window_sized_completion(): | ||
| server = ServerPreset.tinyllama2() | ||
| server.n_ctx = 16 | ||
| server.n_predict = -1 | ||
| server.start() | ||
| res = server.make_request("POST", "/completion", data={ | ||
| "n_predict": -2, | ||
| "prompt": "The 50 states in the US are ", | ||
| }) | ||
| assert res.status_code == 200 | ||
| assert res.body["timings"]["predicted_n"] == server.n_ctx | ||
| assert res.body["stop_type"] == "limit" | ||
| assert type(res.body["has_new_line"]) == bool | ||
Uh oh!
There was an error while loading. Please reload this page.