From 9546bc3d33956dfb43c52e9381f3005918d51dc3 Mon Sep 17 00:00:00 2001 From: Javier Torres Date: Mon, 24 Nov 2025 10:33:32 +0100 Subject: [PATCH 1/5] Add preliminary mcp tests --- Cargo.lock | 199 +++++++++++++++++++++++++++++ encoderfile-core/Cargo.toml | 4 + encoderfile-core/tests/test_mcp.rs | 70 ++++++++++ 3 files changed, 273 insertions(+) create mode 100644 encoderfile-core/tests/test_mcp.rs diff --git a/Cargo.lock b/Cargo.lock index e8a8fa0..2a34a56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1051,8 +1051,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1062,9 +1064,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1206,6 +1210,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + [[package]] name = "hyper-timeout" version = "0.5.2" @@ -1534,6 +1555,12 @@ version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lua-src" version = "548.1.2" @@ -2362,6 +2389,61 @@ dependencies = [ "pulldown-cmark", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + [[package]] name = "quote" version = "1.0.42" @@ -2563,23 +2645,45 @@ dependencies = [ "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", ] [[package]] @@ -2598,6 +2702,7 @@ dependencies = [ "paste", "pin-project-lite", "rand 0.9.2", + "reqwest", "rmcp-macros", "schemars", "serde", @@ -2644,15 +2749,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" dependencies = [ + "web-time", "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.103.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -2977,6 +3108,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.110" @@ -3103,6 +3240,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokenizers" version = "0.22.1" @@ -3165,6 +3317,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -3519,6 +3681,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "ureq" version = "3.1.4" @@ -3736,6 +3904,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.82" @@ -3765,6 +3946,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "8.0.0" @@ -3866,6 +4056,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" diff --git a/encoderfile-core/Cargo.toml b/encoderfile-core/Cargo.toml index eccef9d..36802aa 100644 --- a/encoderfile-core/Cargo.toml +++ b/encoderfile-core/Cargo.toml @@ -53,6 +53,10 @@ tonic-prost-build = "0.14.2" rand = "0.9.2" tower = "0.5.2" +[dev-dependencies.rmcp] +version = "0.8.0" +features = ["client", "transport-streamable-http-client-reqwest"] + [dependencies.mlua] version = "0.11.4" features = [ "lua54", "vendored",] diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs new file mode 100644 index 0000000..4338b41 --- /dev/null +++ b/encoderfile-core/tests/test_mcp.rs @@ -0,0 +1,70 @@ +const LOCALHOST: &str = "localhost"; + +use encoderfile_core::{ + AppState, test_utils::embedding_state, transport::mcp, +}; +use tower_http::trace::DefaultOnResponse; +use tokio::net::TcpListener; +use anyhow::Result; +use rmcp::{ + ServiceExt, + transport::StreamableHttpClientTransport, + model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation} +}; + +async fn run_mcp(addr: String, state: AppState) -> Result<()>{ + let model_type = state.model_type.clone(); + let router = mcp::make_router(state).layer( + tower_http::trace::TraceLayer::new_for_http() + // TODO check if otel is enabled + // .make_span_with(crate::middleware::format_span) + .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), + ); + tracing::info!("Running {:?} MCP server on {}", model_type, &addr); + let listener = TcpListener::bind(addr).await?; + axum::serve(listener, router).await?; + Ok(()) +} + +#[tokio::test] +async fn test_mcp() { + let port = 9100; + let addr = format!("{}:{}", LOCALHOST, port); + let dummy_state = embedding_state(); + let mcp_server = tokio::spawn(run_mcp(addr, dummy_state)); + // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs + let client_transport = StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, port)); + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "test sse client".to_string(), + title: None, + version: "0.0.1".to_string(), + website_url: None, + icons: None, + }, + }; + let client = client_info.serve(client_transport).await.inspect_err(|e| { + tracing::error!("client error: {:?}", e); + }).unwrap(); + // Initialize + let server_info = client.peer_info(); + tracing::info!("Connected to server: {server_info:#?}"); + + // List tools + let tools = client.list_tools(Default::default()).await.expect("list tools failed"); + tracing::info!("Available tools: {tools:#?}"); + + let tool_result = client + .call_tool(CallToolRequestParam { + name: "increment".into(), + arguments: serde_json::json!({}).as_object().cloned(), + }) + .await.expect("call tool failed"); + tracing::info!("Tool result: {tool_result:#?}"); + client.cancel().await.unwrap(); + mcp_server.abort(); +} + + From a1d39f071707a7a2e85b94901696466553305e34 Mon Sep 17 00:00:00 2001 From: Javier Torres Date: Mon, 24 Nov 2025 13:43:16 +0100 Subject: [PATCH 2/5] Add single test for embedding (macro pending) --- Cargo.lock | 46 +++++++++++++++++++++++- encoderfile-core/Cargo.toml | 5 +++ encoderfile-core/src/common/embedding.rs | 6 ++-- encoderfile-core/tests/test_mcp.rs | 20 +++++++++-- 4 files changed, 70 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3436869..489d935 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -802,6 +802,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "test-log", "thiserror", "tokenizers", "tokio", @@ -830,12 +831,33 @@ dependencies = [ "serde_json", ] +[[package]] +name = "env_filter" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf3c259d255ca70051b30e2e95b5446cdb8949ac4cd22c0d7fd634d89f568e2" +dependencies = [ + "log", +] + [[package]] name = "env_home" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -2441,7 +2463,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -3201,6 +3223,28 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thiserror" version = "2.0.17" diff --git a/encoderfile-core/Cargo.toml b/encoderfile-core/Cargo.toml index d6ec5a2..69b1988 100644 --- a/encoderfile-core/Cargo.toml +++ b/encoderfile-core/Cargo.toml @@ -25,6 +25,10 @@ required-features = [ "dev-utils", "transport",] name = "test_http" required-features = [ "dev-utils", "transport",] +[[test]] +name = "test_mcp" +required-features = [ "dev-utils", "transport",] + [[test]] name = "test_models" required-features = [ "dev-utils",] @@ -66,6 +70,7 @@ tonic-prost-build = "0.14.2" [dev-dependencies] rand = "0.9.2" tower = "0.5.2" +test-log = "0.2.18" [dev-dependencies.rmcp] version = "0.8.0" diff --git a/encoderfile-core/src/common/embedding.rs b/encoderfile-core/src/common/embedding.rs index 485e837..fc6f6b6 100644 --- a/encoderfile-core/src/common/embedding.rs +++ b/encoderfile-core/src/common/embedding.rs @@ -10,7 +10,7 @@ pub struct EmbeddingRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct EmbeddingResponse { pub results: Vec, pub model_id: String, @@ -18,12 +18,12 @@ pub struct EmbeddingResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenEmbeddingSequence { pub embeddings: Vec, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenEmbedding { pub embedding: Vec, pub token_info: Option, diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs index 4338b41..32fb6f2 100644 --- a/encoderfile-core/tests/test_mcp.rs +++ b/encoderfile-core/tests/test_mcp.rs @@ -1,7 +1,8 @@ const LOCALHOST: &str = "localhost"; use encoderfile_core::{ - AppState, test_utils::embedding_state, transport::mcp, + AppState, dev_utils::embedding_state, transport::mcp, + common::{EmbeddingRequest, EmbeddingResponse, ModelType}, }; use tower_http::trace::DefaultOnResponse; use tokio::net::TcpListener; @@ -27,6 +28,7 @@ async fn run_mcp(addr: String, state: AppState) -> Result<()>{ } #[tokio::test] +#[test_log::test] async fn test_mcp() { let port = 9100; let addr = format!("{}:{}", LOCALHOST, port); @@ -56,13 +58,25 @@ async fn test_mcp() { let tools = client.list_tools(Default::default()).await.expect("list tools failed"); tracing::info!("Available tools: {tools:#?}"); + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "run_encoder"); + + let test_params = EmbeddingRequest { + inputs: vec!["This is a test.".to_string(), "This is another test.".to_string()], + metadata: None + }; let tool_result = client .call_tool(CallToolRequestParam { - name: "increment".into(), - arguments: serde_json::json!({}).as_object().cloned(), + name: "run_encoder".into(), + arguments: serde_json::json!(test_params).as_object().cloned(), }) .await.expect("call tool failed"); tracing::info!("Tool result: {tool_result:#?}"); + let embeddings_response: EmbeddingResponse = serde_json::from_value( + tool_result.structured_content + .expect("No structured content found")) + .expect("failed to parse tool result"); + assert_eq!(embeddings_response.results.len(), 2); client.cancel().await.unwrap(); mcp_server.abort(); } From 66c45afa35ac2ed73dc75cb47805454ca598f918 Mon Sep 17 00:00:00 2001 From: Javier Torres Date: Mon, 24 Nov 2025 14:18:43 +0100 Subject: [PATCH 3/5] Fix fmt --- encoderfile-core/tests/test_mcp.rs | 52 +++++++++++++++++++----------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs index 32fb6f2..43fd33e 100644 --- a/encoderfile-core/tests/test_mcp.rs +++ b/encoderfile-core/tests/test_mcp.rs @@ -1,19 +1,21 @@ const LOCALHOST: &str = "localhost"; +use anyhow::Result; use encoderfile_core::{ - AppState, dev_utils::embedding_state, transport::mcp, + AppState, common::{EmbeddingRequest, EmbeddingResponse, ModelType}, + dev_utils::embedding_state, + transport::mcp, }; -use tower_http::trace::DefaultOnResponse; -use tokio::net::TcpListener; -use anyhow::Result; use rmcp::{ ServiceExt, + model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, transport::StreamableHttpClientTransport, - model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation} }; +use tokio::net::TcpListener; +use tower_http::trace::DefaultOnResponse; -async fn run_mcp(addr: String, state: AppState) -> Result<()>{ +async fn run_mcp(addr: String, state: AppState) -> Result<()> { let model_type = state.model_type.clone(); let router = mcp::make_router(state).layer( tower_http::trace::TraceLayer::new_for_http() @@ -35,7 +37,8 @@ async fn test_mcp() { let dummy_state = embedding_state(); let mcp_server = tokio::spawn(run_mcp(addr, dummy_state)); // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs - let client_transport = StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, port)); + let client_transport = + StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, port)); let client_info = ClientInfo { protocol_version: Default::default(), capabilities: ClientCapabilities::default(), @@ -47,38 +50,49 @@ async fn test_mcp() { icons: None, }, }; - let client = client_info.serve(client_transport).await.inspect_err(|e| { - tracing::error!("client error: {:?}", e); - }).unwrap(); + let client = client_info + .serve(client_transport) + .await + .inspect_err(|e| { + tracing::error!("client error: {:?}", e); + }) + .unwrap(); // Initialize let server_info = client.peer_info(); tracing::info!("Connected to server: {server_info:#?}"); // List tools - let tools = client.list_tools(Default::default()).await.expect("list tools failed"); + let tools = client + .list_tools(Default::default()) + .await + .expect("list tools failed"); tracing::info!("Available tools: {tools:#?}"); assert_eq!(tools.tools.len(), 1); assert_eq!(tools.tools[0].name, "run_encoder"); let test_params = EmbeddingRequest { - inputs: vec!["This is a test.".to_string(), "This is another test.".to_string()], - metadata: None + inputs: vec![ + "This is a test.".to_string(), + "This is another test.".to_string(), + ], + metadata: None, }; let tool_result = client .call_tool(CallToolRequestParam { name: "run_encoder".into(), arguments: serde_json::json!(test_params).as_object().cloned(), }) - .await.expect("call tool failed"); + .await + .expect("call tool failed"); tracing::info!("Tool result: {tool_result:#?}"); let embeddings_response: EmbeddingResponse = serde_json::from_value( - tool_result.structured_content - .expect("No structured content found")) - .expect("failed to parse tool result"); + tool_result + .structured_content + .expect("No structured content found"), + ) + .expect("failed to parse tool result"); assert_eq!(embeddings_response.results.len(), 2); client.cancel().await.unwrap(); mcp_server.abort(); } - - From 050919d191e1fca33305e6d3ca4a0fba72a89beb Mon Sep 17 00:00:00 2001 From: Javier Torres Date: Tue, 25 Nov 2025 18:44:46 +0100 Subject: [PATCH 4/5] Add templates for all 4 servers (fails now) --- Cargo.lock | 6 +- encoderfile-core/Cargo.toml | 8 + .../src/common/sentence_embedding.rs | 2 +- .../src/common/sequence_classification.rs | 4 +- .../src/common/token_classification.rs | 6 +- encoderfile-core/tests/test_mcp.rs | 190 ++++++++++++------ 6 files changed, 142 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 489d935..811469f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -785,6 +785,8 @@ dependencies = [ "clap_derive", "codspeed-divan-compat", "dotenv", + "hyper", + "hyper-util", "mlua", "ndarray", "ndarray-stats", @@ -1211,9 +1213,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744436df46f0bde35af3eda22aeaba453aada65d8f1c171cd8a5f59030bd69f" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ "atomic-waker", "bytes", diff --git a/encoderfile-core/Cargo.toml b/encoderfile-core/Cargo.toml index 69b1988..469b0aa 100644 --- a/encoderfile-core/Cargo.toml +++ b/encoderfile-core/Cargo.toml @@ -72,6 +72,14 @@ rand = "0.9.2" tower = "0.5.2" test-log = "0.2.18" +[dev-dependencies.hyper-util] +version = "0.1.18" +features = ["server-graceful"] + +[dev-dependencies.hyper] +version = "1.8.1" +features = ["http1"] + [dev-dependencies.rmcp] version = "0.8.0" features = ["client", "transport-streamable-http-client-reqwest"] diff --git a/encoderfile-core/src/common/sentence_embedding.rs b/encoderfile-core/src/common/sentence_embedding.rs index 2ed348c..d2e1a3f 100644 --- a/encoderfile-core/src/common/sentence_embedding.rs +++ b/encoderfile-core/src/common/sentence_embedding.rs @@ -10,7 +10,7 @@ pub struct SentenceEmbeddingRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct SentenceEmbeddingResponse { pub results: Vec, pub model_id: String, diff --git a/encoderfile-core/src/common/sequence_classification.rs b/encoderfile-core/src/common/sequence_classification.rs index 63bf0d1..cdb901a 100644 --- a/encoderfile-core/src/common/sequence_classification.rs +++ b/encoderfile-core/src/common/sequence_classification.rs @@ -10,7 +10,7 @@ pub struct SequenceClassificationRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct SequenceClassificationResponse { pub results: Vec, pub model_id: String, @@ -18,7 +18,7 @@ pub struct SequenceClassificationResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct SequenceClassificationResult { pub logits: Vec, pub scores: Vec, diff --git a/encoderfile-core/src/common/token_classification.rs b/encoderfile-core/src/common/token_classification.rs index 900ca86..7d9b124 100644 --- a/encoderfile-core/src/common/token_classification.rs +++ b/encoderfile-core/src/common/token_classification.rs @@ -10,7 +10,7 @@ pub struct TokenClassificationRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct TokenClassificationResponse { pub results: Vec, pub model_id: String, @@ -18,12 +18,12 @@ pub struct TokenClassificationResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenClassificationResult { pub tokens: Vec, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenClassification { pub token_info: super::token::TokenInfo, pub scores: Vec, diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs index 43fd33e..6b2d490 100644 --- a/encoderfile-core/tests/test_mcp.rs +++ b/encoderfile-core/tests/test_mcp.rs @@ -3,7 +3,6 @@ const LOCALHOST: &str = "localhost"; use anyhow::Result; use encoderfile_core::{ AppState, - common::{EmbeddingRequest, EmbeddingResponse, ModelType}, dev_utils::embedding_state, transport::mcp, }; @@ -14,8 +13,9 @@ use rmcp::{ }; use tokio::net::TcpListener; use tower_http::trace::DefaultOnResponse; +use tokio::sync::oneshot; -async fn run_mcp(addr: String, state: AppState) -> Result<()> { +async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, done_sender: oneshot::Sender<()>) -> Result<()> { let model_type = state.model_type.clone(); let router = mcp::make_router(state).layer( tower_http::trace::TraceLayer::new_for_http() @@ -25,74 +25,132 @@ async fn run_mcp(addr: String, state: AppState) -> Result<()> { ); tracing::info!("Running {:?} MCP server on {}", model_type, &addr); let listener = TcpListener::bind(addr).await?; - axum::serve(listener, router).await?; + axum::serve(listener, router) + .with_graceful_shutdown( + async { + receiver.await; + tracing::info!("Received shutdown signal, shutting down"); + done_sender.send(()); + () + }) + .await; Ok(()) } -#[tokio::test] -#[test_log::test] -async fn test_mcp() { - let port = 9100; - let addr = format!("{}:{}", LOCALHOST, port); - let dummy_state = embedding_state(); - let mcp_server = tokio::spawn(run_mcp(addr, dummy_state)); - // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs - let client_transport = - StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, port)); - let client_info = ClientInfo { - protocol_version: Default::default(), - capabilities: ClientCapabilities::default(), - client_info: Implementation { - name: "test sse client".to_string(), - title: None, - version: "0.0.1".to_string(), - website_url: None, - icons: None, - }, - }; - let client = client_info - .serve(client_transport) - .await - .inspect_err(|e| { - tracing::error!("client error: {:?}", e); - }) - .unwrap(); - // Initialize - let server_info = client.peer_info(); - tracing::info!("Connected to server: {server_info:#?}"); +macro_rules! test_mcp_server_impl { + ($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => { + mod $mod_name { + use encoderfile_core::{ + common::{$req_type, $resp_type}, + dev_utils::$state_func, + }; + use rmcp::{ + ServiceExt, + transport::StreamableHttpClientTransport, + model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, + }; + use tokio::sync::oneshot; - // List tools - let tools = client - .list_tools(Default::default()) - .await - .expect("list tools failed"); - tracing::info!("Available tools: {tools:#?}"); + const LOCALHOST: &str = "localhost"; + const PORT: i32 = 9100; - assert_eq!(tools.tools.len(), 1); - assert_eq!(tools.tools[0].name, "run_encoder"); + #[tokio::test] + #[test_log::test] + async fn $mod_name() { + let addr = format!("{}:{}", LOCALHOST, PORT); + let dummy_state = $state_func(); + let (sender, receiver) = oneshot::channel(); + let (done_sender, done_receiver) = oneshot::channel(); + let mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver, done_sender)); + // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs + let client_transport = + StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, PORT)); + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "test sse client".to_string(), + title: None, + version: "0.0.1".to_string(), + website_url: None, + icons: None, + }, + }; + let client = client_info + .serve(client_transport) + .await + .inspect_err(|e| { + tracing::error!("client error: {:?}", e); + }) + .unwrap(); + // Initialize + let server_info = client.peer_info(); + tracing::info!("Connected to server: {server_info:#?}"); - let test_params = EmbeddingRequest { - inputs: vec![ - "This is a test.".to_string(), - "This is another test.".to_string(), - ], - metadata: None, - }; - let tool_result = client - .call_tool(CallToolRequestParam { - name: "run_encoder".into(), - arguments: serde_json::json!(test_params).as_object().cloned(), - }) - .await - .expect("call tool failed"); - tracing::info!("Tool result: {tool_result:#?}"); - let embeddings_response: EmbeddingResponse = serde_json::from_value( - tool_result - .structured_content - .expect("No structured content found"), - ) - .expect("failed to parse tool result"); - assert_eq!(embeddings_response.results.len(), 2); - client.cancel().await.unwrap(); - mcp_server.abort(); + // List tools + let tools = client + .list_tools(Default::default()) + .await + .expect("list tools failed"); + tracing::info!("Available tools: {tools:#?}"); + + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "run_encoder"); + + let test_params = $req_type { + inputs: vec![ + "This is a test.".to_string(), + "This is another test.".to_string(), + ], + metadata: None, + }; + let tool_result = client + .call_tool(CallToolRequestParam { + name: "run_encoder".into(), + arguments: serde_json::json!(test_params).as_object().cloned(), + }) + .await + .expect("call tool failed"); + tracing::info!("Tool result: {tool_result:#?}"); + let embeddings_response: $resp_type = serde_json::from_value( + tool_result + .structured_content + .expect("No structured content found"), + ) + .expect("failed to parse tool result"); + assert_eq!(embeddings_response.results.len(), 2); + client.cancel().await; + sender.send(()); + done_receiver.await; + } + } + } } + + +test_mcp_server_impl!( + test_mcp_embedding, + embedding_state, + EmbeddingRequest, + EmbeddingResponse +); + +test_mcp_server_impl!( + test_mcp_sentence_embedding, + sentence_embedding_state, + SentenceEmbeddingRequest, + SentenceEmbeddingResponse +); + +test_mcp_server_impl!( + test_mcp_token_classification, + token_classification_state, + TokenClassificationRequest, + TokenClassificationResponse +); +test_mcp_server_impl!( + test_mcp_sequence_classification, + sequence_classification_state, + SequenceClassificationRequest, + SequenceClassificationResponse +); \ No newline at end of file From 1bf388273a822bc3f98a6702c0fbcc6fd339d25a Mon Sep 17 00:00:00 2001 From: Javier Torres Date: Wed, 26 Nov 2025 12:28:30 +0100 Subject: [PATCH 5/5] Serialize mcp tests --- encoderfile-core/tests/test_mcp.rs | 72 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs index 6b2d490..7fd4e1f 100644 --- a/encoderfile-core/tests/test_mcp.rs +++ b/encoderfile-core/tests/test_mcp.rs @@ -1,21 +1,10 @@ -const LOCALHOST: &str = "localhost"; - use anyhow::Result; -use encoderfile_core::{ - AppState, - dev_utils::embedding_state, - transport::mcp, -}; -use rmcp::{ - ServiceExt, - model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, - transport::StreamableHttpClientTransport, -}; +use encoderfile_core::{AppState, transport::mcp}; use tokio::net::TcpListener; -use tower_http::trace::DefaultOnResponse; use tokio::sync::oneshot; +use tower_http::trace::DefaultOnResponse; -async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, done_sender: oneshot::Sender<()>) -> Result<()> { +async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>) -> Result<()> { let model_type = state.model_type.clone(); let router = mcp::make_router(state).layer( tower_http::trace::TraceLayer::new_for_http() @@ -26,45 +15,42 @@ async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>, tracing::info!("Running {:?} MCP server on {}", model_type, &addr); let listener = TcpListener::bind(addr).await?; axum::serve(listener, router) - .with_graceful_shutdown( - async { - receiver.await; - tracing::info!("Received shutdown signal, shutting down"); - done_sender.send(()); - () - }) - .await; + .with_graceful_shutdown(async { + receiver.await.ok(); + tracing::info!("Received shutdown signal, shutting down"); + }) + .await + .expect("Error while shutting down server"); Ok(()) } macro_rules! test_mcp_server_impl { ($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => { - mod $mod_name { + pub mod $mod_name { use encoderfile_core::{ common::{$req_type, $resp_type}, dev_utils::$state_func, }; use rmcp::{ ServiceExt, - transport::StreamableHttpClientTransport, model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, + transport::StreamableHttpClientTransport, }; use tokio::sync::oneshot; const LOCALHOST: &str = "localhost"; const PORT: i32 = 9100; - #[tokio::test] - #[test_log::test] - async fn $mod_name() { + pub async fn $mod_name() { let addr = format!("{}:{}", LOCALHOST, PORT); let dummy_state = $state_func(); let (sender, receiver) = oneshot::channel(); - let (done_sender, done_receiver) = oneshot::channel(); - let mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver, done_sender)); + let _mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver)); // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs - let client_transport = - StreamableHttpClientTransport::from_uri(format!("http://{}:{}/mcp", LOCALHOST, PORT)); + let client_transport = StreamableHttpClientTransport::from_uri(format!( + "http://{}:{}/mcp", + LOCALHOST, PORT + )); let client_info = ClientInfo { protocol_version: Default::default(), capabilities: ClientCapabilities::default(), @@ -119,15 +105,13 @@ macro_rules! test_mcp_server_impl { ) .expect("failed to parse tool result"); assert_eq!(embeddings_response.results.len(), 2); - client.cancel().await; - sender.send(()); - done_receiver.await; + client.cancel().await.expect("Error cancelling the agent"); + sender.send(()).expect("Error sending end of test signal"); } } - } + }; } - test_mcp_server_impl!( test_mcp_embedding, embedding_state, @@ -148,9 +132,23 @@ test_mcp_server_impl!( TokenClassificationRequest, TokenClassificationResponse ); + test_mcp_server_impl!( test_mcp_sequence_classification, sequence_classification_state, SequenceClassificationRequest, SequenceClassificationResponse -); \ No newline at end of file +); + +#[tokio::test] +#[test_log::test] +async fn test_mcp_servers() { + self::test_mcp_embedding::test_mcp_embedding().await; + tracing::info!("Testing embedding"); + self::test_mcp_sentence_embedding::test_mcp_sentence_embedding().await; + tracing::info!("Testing sentence embedding"); + self::test_mcp_token_classification::test_mcp_token_classification().await; + tracing::info!("Testing token classification"); + self::test_mcp_sequence_classification::test_mcp_sequence_classification().await; + tracing::info!("Testing sequence classification"); +}