From 93194a5075a0f1d86fcf5b57d98439ce58087dba Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Aug 2024 16:00:48 +0000 Subject: [PATCH 1/6] feat: add /v1/models endpoint --- router/src/lib.rs | 28 ++++++++++++++++++++++++++++ router/src/server.rs | 25 +++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index 1b2ff153c97..6cefe5e267f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1243,6 +1243,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index 8ec7a8716ed..31242c2ada6 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -24,6 +24,7 @@ use crate::{ VertexResponse, }; use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,25 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses((status = 200, description = "Served model info", body = ModelInfo)) +)] +#[instrument] +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -2206,7 +2226,7 @@ async fn start( // Define base and health routes let mut base_routes = Router::new() - .route("/", post(compat_generate)) + .route("/", post(openai_get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) @@ -2244,7 +2264,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled { From 8398d4f436cffe142586484ab0623ead063e8803 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 19 Aug 2024 16:00:48 +0000 Subject: [PATCH 2/6] feat: add /v1/models endpoint --- router/src/lib.rs | 28 ++++++++++++++++++++++++++++ router/src/server.rs | 27 ++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/router/src/lib.rs b/router/src/lib.rs index ce4f7c46754..d874c38bd6a 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1240,6 +1240,34 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelInfo { + #[schema(example = "gpt2")] + pub id: String, + #[schema(example = "model")] + pub object: String, + #[schema(example = 1686935002)] + pub created: u64, + #[schema(example = "openai")] + pub owned_by: String, +} + +#[derive(Serialize, Deserialize, ToSchema)] +pub(crate) struct ModelsInfo { + #[schema(example = "list")] + pub object: String, + pub data: Vec, +} + +impl Default for ModelsInfo { + fn default() -> Self { + ModelsInfo { + object: "list".to_string(), + data: Vec::new(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/router/src/server.rs b/router/src/server.rs index 8ebd1a3316d..88a2e9ce31b 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,8 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; use axum::http::{HeaderMap, HeaderValue, Method, StatusCode}; @@ -116,6 +117,25 @@ async fn get_model_info(info: Extension) -> Json { Json(info.0) } +#[utoipa::path( +get, +tag = "Text Generation Inference", +path = "/v1/models", +responses((status = 200, description = "Served model info", body = ModelInfo)) +)] +#[instrument] +async fn openai_get_model_info(info: Extension) -> Json { + Json(ModelsInfo { + data: vec![ModelInfo { + id: info.0.model_id.clone(), + object: "model".to_string(), + created: 0, // TODO: determine how to get this + owned_by: info.0.model_id.clone(), + }], + ..Default::default() + }) +} + #[utoipa::path( post, tag = "Text Generation Inference", @@ -2208,7 +2228,7 @@ async fn start( // Define base and health routes let mut base_routes = Router::new() - .route("/", post(compat_generate)) + .route("/", post(openai_get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) @@ -2246,7 +2266,8 @@ async fn start( .route("/info", get(get_model_info)) .route("/health", get(health)) .route("/ping", get(health)) - .route("/metrics", get(metrics)); + .route("/metrics", get(metrics)) + .route("/v1/models", get(openai_get_model_info)); // Conditional AWS Sagemaker route let aws_sagemaker_route = if messages_api_enabled { From 997d7a102aa71a97c4c8aa7f10d0f8aef1397b91 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 27 Aug 2024 16:33:40 +0000 Subject: [PATCH 3/6] fix: remove unused type import --- router/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index 88a2e9ce31b..0db053fcff3 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -23,7 +23,7 @@ use crate::{ CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, VertexRequest, VertexResponse, }; -use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType, Tools}; +use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice, ToolType}; use crate::{ModelInfo, ModelsInfo}; use async_stream::__private::AsyncStream; use axum::extract::Extension; From a76bd78486f4bedca7826922290916c4c154ed17 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 27 Aug 2024 16:34:37 +0000 Subject: [PATCH 4/6] fix: revert route typo --- router/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/server.rs b/router/src/server.rs index 0db053fcff3..518b50e1355 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -2228,7 +2228,7 @@ async fn start( // Define base and health routes let mut base_routes = Router::new() - .route("/", post(openai_get_model_info)) + .route("/", post(compat_generate)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) .route("/v1/chat/completions", post(chat_completions)) From 8bfa11f636399647d12896183e1ccb2b91d0239a Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 27 Aug 2024 16:59:33 +0000 Subject: [PATCH 5/6] fix: update docs with new endpoint --- docs/openapi.json | 62 +++++++++++++++++++++++++++++++++++++++++++- router/src/server.rs | 10 +++++-- 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/docs/openapi.json b/docs/openapi.json index fd64a3ab714..e61091885df 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -556,6 +556,37 @@ } } } + }, + "/v1/models": { + "get": { + "tags": [ + "Text Generation Inference" + ], + "summary": "Get model info", + "operationId": "openai_get_model_info", + "responses": { + "200": { + "description": "Served model info", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModelInfo" + } + } + } + }, + "404": { + "description": "Model not found", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ErrorResponse" + } + } + } + } + } + } } }, "components": { @@ -1747,6 +1778,35 @@ } ] }, + "ModelInfo": { + "type": "object", + "required": [ + "id", + "object", + "created", + "owned_by" + ], + "properties": { + "created": { + "type": "integer", + "format": "int64", + "example": 1686935002, + "minimum": 0 + }, + "id": { + "type": "string", + "example": "gpt2" + }, + "object": { + "type": "string", + "example": "model" + }, + "owned_by": { + "type": "string", + "example": "openai" + } + } + }, "OutputMessage": { "oneOf": [ { @@ -2094,4 +2154,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} +} \ No newline at end of file diff --git a/router/src/server.rs b/router/src/server.rs index 518b50e1355..24003d9586c 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -121,9 +121,13 @@ async fn get_model_info(info: Extension) -> Json { get, tag = "Text Generation Inference", path = "/v1/models", -responses((status = 200, description = "Served model info", body = ModelInfo)) +responses( +(status = 200, description = "Served model info", body = ModelInfo), +(status = 404, description = "Model not found", body = ErrorResponse), +) )] -#[instrument] +#[instrument(skip(info))] +/// Get model info async fn openai_get_model_info(info: Extension) -> Json { Json(ModelsInfo { data: vec![ModelInfo { @@ -1521,6 +1525,7 @@ chat_completions, completions, tokenize, metrics, +openai_get_model_info, ), components( schemas( @@ -1573,6 +1578,7 @@ ToolCall, Function, FunctionDefinition, ToolChoice, +ModelInfo, ) ), tags( From 5e14f5bed759cbc16e07ef828f85670857b41ed1 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 27 Aug 2024 17:01:15 +0000 Subject: [PATCH 6/6] fix: add to redocly ignore and lint --- .redocly.lint-ignore.yaml | 1 + docs/openapi.json | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.redocly.lint-ignore.yaml b/.redocly.lint-ignore.yaml index 382c9ab6447..13b80497ea0 100644 --- a/.redocly.lint-ignore.yaml +++ b/.redocly.lint-ignore.yaml @@ -77,3 +77,4 @@ docs/openapi.json: - '#/paths/~1tokenize/post' - '#/paths/~1v1~1chat~1completions/post' - '#/paths/~1v1~1completions/post' + - '#/paths/~1v1~1models/get' diff --git a/docs/openapi.json b/docs/openapi.json index e61091885df..691705f28ba 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -2154,4 +2154,4 @@ "description": "Hugging Face Text Generation Inference API" } ] -} \ No newline at end of file +}