diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45bc07a542..c40250d5df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,3 +16,40 @@ repos: - id: fmt - id: cargo-check - id: clippy +- repo: local + hooks: + - id: check-openapi-update + name: check openapi spec update + entry: python + language: system + pass_filenames: false + always_run: true + args: + - -c + - | + import os + import sys + import subprocess + + def get_changed_files(): + result = subprocess.run(['git', 'diff', '--name-only'], capture_output=True, text=True) + return result.stdout.splitlines() + + changed_files = get_changed_files() + router_files = [f for f in changed_files if f.startswith('router/')] + + if not router_files: + print("No router files changed. Skipping OpenAPI spec check.") + sys.exit(0) + + openapi_file = 'docs/openapi.json' + if not os.path.exists(openapi_file): + print(f"Error: {openapi_file} does not exist.") + sys.exit(1) + + if openapi_file not in changed_files: + print(f"Error: Router files were updated, but {openapi_file} was not updated.") + sys.exit(1) + else: + print(f"{openapi_file} has been updated along with router changes.") + sys.exit(0) diff --git a/docs/openapi.json b/docs/openapi.json index 79c3b80f87..7dc159a8ec 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.0.1" + "version": "2.1.1-dev0" }, "paths": { "/": { @@ -19,7 +19,6 @@ "Text Generation Inference" ], "summary": "Generate tokens if `stream == false` or a stream of token if `stream == true`", - "description": "Generate tokens if `stream == false` or a stream of token if `stream == true`", "operationId": "compat_generate", "requestBody": { "content": { @@ -108,7 +107,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "generate", "requestBody": { "content": { @@ -192,7 +190,6 @@ "Text Generation Inference" ], "summary": "Generate a stream of token using Server-Sent Events", - "description": "Generate a stream of token using Server-Sent Events", "operationId": "generate_stream", "requestBody": { "content": { @@ -276,7 +273,6 @@ "Text Generation Inference" ], "summary": "Health check method", - "description": "Health check method", "operationId": "health", "responses": { "200": { @@ -305,7 +301,6 @@ "Text Generation Inference" ], "summary": "Text Generation Inference endpoint info", - "description": "Text Generation Inference endpoint info", "operationId": "get_model_info", "responses": { "200": { @@ -327,7 +322,6 @@ "Text Generation Inference" ], "summary": "Prometheus metrics scrape endpoint", - "description": "Prometheus metrics scrape endpoint", "operationId": "metrics", "responses": { "200": { @@ -349,7 +343,6 @@ "Text Generation Inference" ], "summary": "Tokenize inputs", - "description": "Tokenize inputs", "operationId": "tokenize", "requestBody": { "content": { @@ -394,7 +387,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "chat_completions", "requestBody": { "content": { @@ -483,7 +475,6 @@ "Text Generation Inference" ], "summary": "Generate tokens", - "description": "Generate tokens", "operationId": "completions", "requestBody": { "content": { @@ -626,7 +617,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -653,9 +643,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" }, @@ -697,7 +684,6 @@ "type": "object", "required": [ "id", - "object", "created", "model", "system_fingerprint", @@ -723,9 +709,6 @@ "type": "string", "example": "mistralai/Mistral-7B-Instruct-v0.2" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -756,34 +739,19 @@ "nullable": true }, "message": { - "$ref": "#/components/schemas/Message" + "$ref": "#/components/schemas/OutputMessage" } } }, "ChatCompletionDelta": { - "type": "object", - "required": [ - "role" - ], - "properties": { - "content": { - "type": "string", - "example": "What is Deep Learning?", - "nullable": true - }, - "role": { - "type": "string", - "example": "user" + "oneOf": [ + { + "$ref": "#/components/schemas/TextMessage" }, - "tool_calls": { - "allOf": [ - { - "$ref": "#/components/schemas/DeltaToolCall" - } - ], - "nullable": true + { + "$ref": "#/components/schemas/ToolCallDelta" } - } + ] }, "ChatCompletionLogprob": { "type": "object", @@ -903,6 +871,15 @@ "example": 0.1, "nullable": true }, + "response_format": { + "allOf": [ + { + "$ref": "#/components/schemas/GrammarType" + } + ], + "default": "null", + "nullable": true + }, "seed": { "type": "integer", "format": "int64", @@ -1021,7 +998,6 @@ "type": "object", "required": [ "id", - "object", "created", "choices", "model", @@ -1045,9 +1021,6 @@ "model": { "type": "string" }, - "object": { - "type": "string" - }, "system_fingerprint": { "type": "string" } @@ -1081,12 +1054,7 @@ "example": "mistralai/Mistral-7B-Instruct-v0.2" }, "prompt": { - "type": "array", - "items": { - "type": "string" - }, - "description": "The prompt to generate completions for.", - "example": "What is Deep Learning?" + "$ref": "#/components/schemas/Prompt" }, "repetition_penalty": { "type": "number", @@ -1100,6 +1068,15 @@ "nullable": true, "minimum": 0 }, + "stop": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Up to 4 sequences where the API will stop generating further tokens.", + "example": "null", + "nullable": true + }, "stream": { "type": "boolean" }, @@ -1121,15 +1098,6 @@ "description": "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the\ntokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.", "example": 0.95, "nullable": true - }, - "stop": { - "type": "array", - "items": { - "type": "string" - }, - "description": "Up to 4 sequences where the API will stop generating further tokens.", - "example": "null", - "nullable": true } } }, @@ -1272,8 +1240,16 @@ "GenerateParameters": { "type": "object", "properties": { + "adapter_id": { + "type": "string", + "description": "Lora adapter id", + "default": "null", + "example": "null", + "nullable": true + }, "best_of": { "type": "integer", + "description": "Generate best_of sequences and return the one if the highest token logprobs.", "default": "null", "example": 1, "nullable": true, @@ -1282,20 +1258,24 @@ }, "decoder_input_details": { "type": "boolean", + "description": "Whether to return decoder input token logprobs and ids.", "default": "false" }, "details": { "type": "boolean", + "description": "Whether to return generation details.", "default": "true" }, "do_sample": { "type": "boolean", + "description": "Activate logits sampling.", "default": "false", "example": true }, "frequency_penalty": { "type": "number", "format": "float", + "description": "The parameter for frequency penalty. 1.0 means no penalty\nPenalize new tokens based on their existing frequency in the text so far,\ndecreasing the model's likelihood to repeat the same line verbatim.", "default": "null", "example": 0.1, "nullable": true, @@ -1313,6 +1293,7 @@ "max_new_tokens": { "type": "integer", "format": "int32", + "description": "Maximum number of tokens to generate.", "default": "100", "example": "20", "nullable": true, @@ -1321,6 +1302,7 @@ "repetition_penalty": { "type": "number", "format": "float", + "description": "The parameter for repetition penalty. 1.0 means no penalty.\nSee [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.", "default": "null", "example": 1.03, "nullable": true, @@ -1328,6 +1310,7 @@ }, "return_full_text": { "type": "boolean", + "description": "Whether to prepend the prompt to the generated text", "default": "null", "example": false, "nullable": true @@ -1335,6 +1318,7 @@ "seed": { "type": "integer", "format": "int64", + "description": "Random sampling seed.", "default": "null", "example": "null", "nullable": true, @@ -1346,6 +1330,7 @@ "items": { "type": "string" }, + "description": "Stop generating tokens if a member of `stop` is generated.", "example": [ "photographer" ], @@ -1354,6 +1339,7 @@ "temperature": { "type": "number", "format": "float", + "description": "The value used to module the logits distribution.", "default": "null", "example": 0.5, "nullable": true, @@ -1362,6 +1348,7 @@ "top_k": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-k-filtering.", "default": "null", "example": 10, "nullable": true, @@ -1370,6 +1357,7 @@ "top_n_tokens": { "type": "integer", "format": "int32", + "description": "The number of highest probability vocabulary tokens to keep for top-n-filtering.", "default": "null", "example": 5, "nullable": true, @@ -1379,6 +1367,7 @@ "top_p": { "type": "number", "format": "float", + "description": "Top-p value for nucleus sampling.", "default": "null", "example": 0.95, "nullable": true, @@ -1387,6 +1376,7 @@ }, "truncate": { "type": "integer", + "description": "Truncate inputs tokens to the given size.", "default": "null", "example": "null", "nullable": true, @@ -1395,6 +1385,7 @@ "typical_p": { "type": "number", "format": "float", + "description": "Typical Decoding mass\nSee [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.", "default": "null", "example": 0.95, "nullable": true, @@ -1403,6 +1394,7 @@ }, "watermark": { "type": "boolean", + "description": "Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).", "default": "false", "example": true } @@ -1495,13 +1487,14 @@ "max_concurrent_requests", "max_best_of", "max_stop_sequences", - "max_input_length", + "max_input_tokens", "max_total_tokens", "waiting_served_ratio", "max_batch_total_tokens", "max_waiting_tokens", "validation_workers", "max_client_batch_size", + "router", "version" ], "properties": { @@ -1538,7 +1531,7 @@ "example": "128", "minimum": 0 }, - "max_input_length": { + "max_input_tokens": { "type": "integer", "example": "1024", "minimum": 0 @@ -1581,6 +1574,11 @@ "example": "e985a63cdc139290c5f700ff1929f0b5942cced2", "nullable": true }, + "router": { + "type": "string", + "description": "Router Info", + "example": "text-generation-router" + }, "sha": { "type": "string", "example": "null", @@ -1593,7 +1591,6 @@ }, "version": { "type": "string", - "description": "Router Info", "example": "0.5.0" }, "waiting_served_ratio": { @@ -1606,13 +1603,12 @@ "Message": { "type": "object", "required": [ - "role" + "role", + "content" ], "properties": { "content": { - "type": "string", - "example": "My name is David and I", - "nullable": true + "$ref": "#/components/schemas/MessageContent" }, "name": { "type": "string", @@ -1622,13 +1618,6 @@ "role": { "type": "string", "example": "user" - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" - }, - "nullable": true } } }, @@ -1817,9 +1806,7 @@ "$ref": "#/components/schemas/FunctionDefinition" }, "id": { - "type": "integer", - "format": "int32", - "minimum": 0 + "type": "string" }, "type": { "type": "string" @@ -1828,22 +1815,24 @@ }, "ToolType": { "oneOf": [ + { + "type": "object", + "default": null, + "nullable": true + }, + { + "type": "string" + }, { "type": "object", "required": [ - "FunctionName" + "function" ], "properties": { - "FunctionName": { - "type": "string" + "function": { + "$ref": "#/components/schemas/FunctionName" } } - }, - { - "type": "string", - "enum": [ - "OneOf" - ] } ] }, diff --git a/router/src/main.rs b/router/src/main.rs index 8618f57eb3..210cfca8d3 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -83,6 +83,8 @@ struct Args { disable_grammar_support: bool, #[clap(default_value = "4", long, env)] max_client_batch_size: usize, + #[clap(long, env, default_value_t = false)] + update_openapi_schema: bool, } #[tokio::main] @@ -119,6 +121,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + update_openapi_schema, } = args; // Launch Tokio runtime @@ -388,6 +391,7 @@ async fn main() -> Result<(), RouterError> { messages_api_enabled, disable_grammar_support, max_client_batch_size, + update_openapi_schema, ) .await?; Ok(()) diff --git a/router/src/server.rs b/router/src/server.rs index d24774f96c..3268d50b45 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1430,6 +1430,7 @@ pub async fn run( messages_api_enabled: bool, grammar_support: bool, max_client_batch_size: usize, + update_openapi_schema: bool, ) -> Result<(), WebServerError> { // OpenAPI documentation #[derive(OpenApi)] @@ -1499,7 +1500,24 @@ pub async fn run( )] struct ApiDoc; - // Create state + if update_openapi_schema { + use std::io::Write; + let cargo_workspace = + std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string()); + println!("workspace {}", cargo_workspace); + let output_file = format!("{}/../docs/openapi.json", cargo_workspace); + println!("output file {}", output_file); + + let openapi = ApiDoc::openapi(); + let mut file = std::fs::File::create(output_file).expect("Unable to create file"); + file.write_all( + openapi + .to_pretty_json() + .expect("Unable to serialize OpenAPI") + .as_bytes(), + ) + .expect("Unable to write data"); + } // Open connection, get model info and warmup let (scheduler, health_ext, shard_info, max_batch_total_tokens): (