diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 41ecb279feb89..5c477c8897b47 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3727,7 +3727,7 @@ struct server_context { } } else { if (slot.n_prompt_tokens() >= slot.n_ctx) { - send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); continue; } @@ -4955,9 +4955,17 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - + const size_t n_ctx_slot = ctx_server.n_ctx / ctx_server.params_base.n_parallel; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { + auto n_prompt_tokens = inputs[i].size(); + if (n_prompt_tokens >= n_ctx_slot) { + json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + error_data["n_prompt_tokens"] = n_prompt_tokens; + error_data["n_ctx"] = n_ctx_slot; + res_error(res, error_data); + return; + } server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 6e5a3488e789b..d56d3d5f178b8 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -408,6 +408,28 @@ def test_context_size_exceeded(): assert res.body["error"]["n_ctx"] == server.n_ctx // server.n_slots +def test_context_size_exceeded_stream(): + global server + server.start() + try: + for _ in server.make_stream_request("POST", "/chat/completions", data={ + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ] * 100, # make the prompt too long + "stream": True}): + pass + assert False, "Should have failed" + except ServerError as e: + assert e.code == 400 + assert "error" in e.body + assert e.body["error"]["type"] == "exceed_context_size_error" + assert e.body["error"]["n_prompt_tokens"] > 0 + assert server.n_ctx is not None + assert server.n_slots is not None + assert e.body["error"]["n_ctx"] == server.n_ctx // server.n_slots + + @pytest.mark.parametrize( "n_batch,batch_count,reuse_cache", [ diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index abd6fff10d0d1..4ba3d43c33044 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -35,6 +35,12 @@ class ServerResponse: body: dict | Any +class ServerError(Exception): + def __init__(self, code, body): + self.code = code + self.body = body + + class ServerProcess: # default options debug: bool = False @@ -297,6 +303,8 @@ def make_stream_request( response = requests.post(url, headers=headers, json=data, stream=True) else: raise ValueError(f"Unimplemented method: {method}") + if response.status_code != 200: + raise ServerError(response.status_code, response.json()) for line_bytes in response.iter_lines(): line = line_bytes.decode("utf-8") if '[DONE]' in line: