diff --git a/Cargo.lock b/Cargo.lock index 5a056fe..811469f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -785,6 +785,8 @@ dependencies = [ "clap_derive", "codspeed-divan-compat", "dotenv", + "hyper", + "hyper-util", "mlua", "ndarray", "ndarray-stats", @@ -802,6 +804,7 @@ dependencies = [ "schemars", "serde", "serde_json", + "test-log", "thiserror", "tokenizers", "tokio", @@ -830,12 +833,33 @@ dependencies = [ "serde_json", ] +[[package]] +name = "env_filter" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bf3c259d255ca70051b30e2e95b5446cdb8949ac4cd22c0d7fd634d89f568e2" +dependencies = [ + "log", +] + [[package]] name = "env_home" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1051,8 +1075,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1062,9 +1088,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1185,9 +1213,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744436df46f0bde35af3eda22aeaba453aada65d8f1c171cd8a5f59030bd69f" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" dependencies = [ "atomic-waker", "bytes", @@ -1206,6 +1234,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + [[package]] name = "hyper-timeout" version = "0.5.2" @@ -1534,6 +1579,12 @@ version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "lua-src" version = "548.1.2" @@ -2362,6 +2413,61 @@ dependencies = [ "pulldown-cmark", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.42" @@ -2563,23 +2669,45 @@ dependencies = [ "http-body", "http-body-util", "hyper", + "hyper-rustls", "hyper-util", "js-sys", "log", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", ] [[package]] @@ -2598,6 +2726,7 @@ dependencies = [ "paste", "pin-project-lite", "rand 0.9.2", + "reqwest", "rmcp-macros", "schemars", "serde", @@ -2644,15 +2773,41 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" dependencies = [ + "web-time", "zeroize", ] +[[package]] +name = "rustls-webpki" +version = "0.103.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -2977,6 +3132,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.110" @@ -3064,6 +3225,28 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "test-log" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e33b98a582ea0be1168eba097538ee8dd4bbe0f2b01b22ac92ea30054e5be7b" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -3103,6 +3286,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokenizers" version = "0.22.1" @@ -3165,6 +3363,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -3519,6 +3727,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "ureq" version = "3.1.4" @@ -3736,6 +3950,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.82" @@ -3765,6 +3992,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8983c3ab33d6fb807cfcdad2491c4ea8cbc8ed839181c7dfd9c67c83e261b2" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "8.0.0" @@ -3866,6 +4102,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" diff --git a/encoderfile-core/Cargo.toml b/encoderfile-core/Cargo.toml index c2d5d04..469b0aa 100644 --- a/encoderfile-core/Cargo.toml +++ b/encoderfile-core/Cargo.toml @@ -25,6 +25,10 @@ required-features = [ "dev-utils", "transport",] name = "test_http" required-features = [ "dev-utils", "transport",] +[[test]] +name = "test_mcp" +required-features = [ "dev-utils", "transport",] + [[test]] name = "test_models" required-features = [ "dev-utils",] @@ -66,6 +70,19 @@ tonic-prost-build = "0.14.2" [dev-dependencies] rand = "0.9.2" tower = "0.5.2" +test-log = "0.2.18" + +[dev-dependencies.hyper-util] +version = "0.1.18" +features = ["server-graceful"] + +[dev-dependencies.hyper] +version = "1.8.1" +features = ["http1"] + +[dev-dependencies.rmcp] +version = "0.8.0" +features = ["client", "transport-streamable-http-client-reqwest"] [dependencies.axum] version = "0.8.6" diff --git a/encoderfile-core/src/common/embedding.rs b/encoderfile-core/src/common/embedding.rs index 485e837..fc6f6b6 100644 --- a/encoderfile-core/src/common/embedding.rs +++ b/encoderfile-core/src/common/embedding.rs @@ -10,7 +10,7 @@ pub struct EmbeddingRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct EmbeddingResponse { pub results: Vec, pub model_id: String, @@ -18,12 +18,12 @@ pub struct EmbeddingResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenEmbeddingSequence { pub embeddings: Vec, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenEmbedding { pub embedding: Vec, pub token_info: Option, diff --git a/encoderfile-core/src/common/sentence_embedding.rs b/encoderfile-core/src/common/sentence_embedding.rs index 2ed348c..d2e1a3f 100644 --- a/encoderfile-core/src/common/sentence_embedding.rs +++ b/encoderfile-core/src/common/sentence_embedding.rs @@ -10,7 +10,7 @@ pub struct SentenceEmbeddingRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct SentenceEmbeddingResponse { pub results: Vec, pub model_id: String, diff --git a/encoderfile-core/src/common/sequence_classification.rs b/encoderfile-core/src/common/sequence_classification.rs index 63bf0d1..cdb901a 100644 --- a/encoderfile-core/src/common/sequence_classification.rs +++ b/encoderfile-core/src/common/sequence_classification.rs @@ -10,7 +10,7 @@ pub struct SequenceClassificationRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct SequenceClassificationResponse { pub results: Vec, pub model_id: String, @@ -18,7 +18,7 @@ pub struct SequenceClassificationResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct SequenceClassificationResult { pub logits: Vec, pub scores: Vec, diff --git a/encoderfile-core/src/common/token_classification.rs b/encoderfile-core/src/common/token_classification.rs index 900ca86..7d9b124 100644 --- a/encoderfile-core/src/common/token_classification.rs +++ b/encoderfile-core/src/common/token_classification.rs @@ -10,7 +10,7 @@ pub struct TokenClassificationRequest { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema, utoipa::ToResponse)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema, utoipa::ToResponse)] pub struct TokenClassificationResponse { pub results: Vec, pub model_id: String, @@ -18,12 +18,12 @@ pub struct TokenClassificationResponse { pub metadata: Option>, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenClassificationResult { pub tokens: Vec, } -#[derive(Debug, Serialize, ToSchema, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, ToSchema, JsonSchema)] pub struct TokenClassification { pub token_info: super::token::TokenInfo, pub scores: Vec, diff --git a/encoderfile-core/tests/test_mcp.rs b/encoderfile-core/tests/test_mcp.rs new file mode 100644 index 0000000..7fd4e1f --- /dev/null +++ b/encoderfile-core/tests/test_mcp.rs @@ -0,0 +1,154 @@ +use anyhow::Result; +use encoderfile_core::{AppState, transport::mcp}; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tower_http::trace::DefaultOnResponse; + +async fn run_mcp(addr: String, state: AppState, receiver: oneshot::Receiver<()>) -> Result<()> { + let model_type = state.model_type.clone(); + let router = mcp::make_router(state).layer( + tower_http::trace::TraceLayer::new_for_http() + // TODO check if otel is enabled + // .make_span_with(crate::middleware::format_span) + .on_response(DefaultOnResponse::new().level(tracing::Level::INFO)), + ); + tracing::info!("Running {:?} MCP server on {}", model_type, &addr); + let listener = TcpListener::bind(addr).await?; + axum::serve(listener, router) + .with_graceful_shutdown(async { + receiver.await.ok(); + tracing::info!("Received shutdown signal, shutting down"); + }) + .await + .expect("Error while shutting down server"); + Ok(()) +} + +macro_rules! test_mcp_server_impl { + ($mod_name:ident, $state_func:ident, $req_type:ident, $resp_type:ident) => { + pub mod $mod_name { + use encoderfile_core::{ + common::{$req_type, $resp_type}, + dev_utils::$state_func, + }; + use rmcp::{ + ServiceExt, + model::{CallToolRequestParam, ClientCapabilities, ClientInfo, Implementation}, + transport::StreamableHttpClientTransport, + }; + use tokio::sync::oneshot; + + const LOCALHOST: &str = "localhost"; + const PORT: i32 = 9100; + + pub async fn $mod_name() { + let addr = format!("{}:{}", LOCALHOST, PORT); + let dummy_state = $state_func(); + let (sender, receiver) = oneshot::channel(); + let _mcp_server = tokio::spawn(super::run_mcp(addr, dummy_state, receiver)); + // Client usage copied over from https://github.com/modelcontextprotocol/rust-sdk/blob/main/examples/clients/src/streamable_http.rs + let client_transport = StreamableHttpClientTransport::from_uri(format!( + "http://{}:{}/mcp", + LOCALHOST, PORT + )); + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "test sse client".to_string(), + title: None, + version: "0.0.1".to_string(), + website_url: None, + icons: None, + }, + }; + let client = client_info + .serve(client_transport) + .await + .inspect_err(|e| { + tracing::error!("client error: {:?}", e); + }) + .unwrap(); + // Initialize + let server_info = client.peer_info(); + tracing::info!("Connected to server: {server_info:#?}"); + + // List tools + let tools = client + .list_tools(Default::default()) + .await + .expect("list tools failed"); + tracing::info!("Available tools: {tools:#?}"); + + assert_eq!(tools.tools.len(), 1); + assert_eq!(tools.tools[0].name, "run_encoder"); + + let test_params = $req_type { + inputs: vec![ + "This is a test.".to_string(), + "This is another test.".to_string(), + ], + metadata: None, + }; + let tool_result = client + .call_tool(CallToolRequestParam { + name: "run_encoder".into(), + arguments: serde_json::json!(test_params).as_object().cloned(), + }) + .await + .expect("call tool failed"); + tracing::info!("Tool result: {tool_result:#?}"); + let embeddings_response: $resp_type = serde_json::from_value( + tool_result + .structured_content + .expect("No structured content found"), + ) + .expect("failed to parse tool result"); + assert_eq!(embeddings_response.results.len(), 2); + client.cancel().await.expect("Error cancelling the agent"); + sender.send(()).expect("Error sending end of test signal"); + } + } + }; +} + +test_mcp_server_impl!( + test_mcp_embedding, + embedding_state, + EmbeddingRequest, + EmbeddingResponse +); + +test_mcp_server_impl!( + test_mcp_sentence_embedding, + sentence_embedding_state, + SentenceEmbeddingRequest, + SentenceEmbeddingResponse +); + +test_mcp_server_impl!( + test_mcp_token_classification, + token_classification_state, + TokenClassificationRequest, + TokenClassificationResponse +); + +test_mcp_server_impl!( + test_mcp_sequence_classification, + sequence_classification_state, + SequenceClassificationRequest, + SequenceClassificationResponse +); + +#[tokio::test] +#[test_log::test] +async fn test_mcp_servers() { + self::test_mcp_embedding::test_mcp_embedding().await; + tracing::info!("Testing embedding"); + self::test_mcp_sentence_embedding::test_mcp_sentence_embedding().await; + tracing::info!("Testing sentence embedding"); + self::test_mcp_token_classification::test_mcp_token_classification().await; + tracing::info!("Testing token classification"); + self::test_mcp_sequence_classification::test_mcp_sequence_classification().await; + tracing::info!("Testing sequence classification"); +}