diff --git a/ExampleMCPServer/Sources/main.swift b/ExampleMCPServer/Sources/main.swift new file mode 100644 index 0000000..8f55d1f --- /dev/null +++ b/ExampleMCPServer/Sources/main.swift @@ -0,0 +1,43 @@ +import Foundation +import JSONSchemaBuilder +import MCPServer + +let transport = Transport.stdio() +func proxy(_ transport: Transport) -> Transport { + var sendToDataSequence: AsyncStream.Continuation? + let dataSequence = AsyncStream.init { continuation in + sendToDataSequence = continuation + } + + Task { + for await data in transport.dataSequence { + mcpLogger.info("Reading data from transport: \(String(data: data, encoding: .utf8)!, privacy: .public)") + sendToDataSequence?.yield(data) + } + } + + return Transport( + writeHandler: { data in + mcpLogger.info("Writing data to transport: \(String(data: data, encoding: .utf8)!, privacy: .public)") + try await transport.writeHandler(data) + }, + dataSequence: dataSequence) +} + +// MARK: - RepeatToolInput + +@Schemable +struct RepeatToolInput { + let text: String +} + +let server = try await MCPServer( + info: Implementation(name: "test-server", version: "1.0.0"), + capabilities: ServerCapabilityHandlers(tools: [ + Tool(name: "repeat") { (input: RepeatToolInput) in + [.text(.init(text: input.text))] + }, + ]), + transport: proxy(transport)) + +try await server.waitForDisconnection() diff --git a/ExampleMCPServer/launch.sh b/ExampleMCPServer/launch.sh new file mode 100755 index 0000000..b5952cb --- /dev/null +++ b/ExampleMCPServer/launch.sh @@ -0,0 +1,4 @@ +#!/bin/zsh + +dir=$(dirname "$0") +(cd "$dir/.." && swift run ExampleMCPServer -q) diff --git a/ExampleMCPServer/readme.md b/ExampleMCPServer/readme.md new file mode 100644 index 0000000..2fdaca5 --- /dev/null +++ b/ExampleMCPServer/readme.md @@ -0,0 +1,11 @@ +# Inspect the server in the debugger: + +``` +nvm use 20.18.1 + +npx @modelcontextprotocol/inspector "$(pwd)/ExampleMCPServer/launch.sh" +``` + + +# Observe console logs: +- in Console.app, filter by `com.app.mcp` as the subsystem. diff --git a/MCPClient/Sources/DataChannel+StdioProcess.swift b/MCPClient/Sources/DataChannel+StdioProcess.swift index ec35356..fb9fd4f 100644 --- a/MCPClient/Sources/DataChannel+StdioProcess.swift +++ b/MCPClient/Sources/DataChannel+StdioProcess.swift @@ -63,6 +63,7 @@ extension DataChannel { return path.isEmpty ? executable : path } + // TODO: look at how to use /bin/zsh, at least on MacOS, to avoid needing to specify PATH to locate the executable let process = Process() process.executableURL = URL(fileURLWithPath: try path(for: executable)) process.arguments = args diff --git a/MCPClient/Sources/MCPClient.swift b/MCPClient/Sources/MCPClient.swift index 511c9dc..e910be0 100644 --- a/MCPClient/Sources/MCPClient.swift +++ b/MCPClient/Sources/MCPClient.swift @@ -1,25 +1,24 @@ -//// import Combine import Foundation -import MCPShared - -public typealias SamplingRequestHandler = ((CreateMessageRequest.Params) async throws -> CreateMessageRequest.Result) -public typealias ListRootsRequestHandler = ((ListRootsRequest.Params?) async throws -> ListRootsRequest.Result) +import MCPInterface // MARK: - MCPClient +// TODO: Support cancelling a request + public actor MCPClient: MCPClientInterface { // MARK: Lifecycle + /// Creates a MCP client and connects to the server through the provided transport. + /// The methods completes after connecting to the server. public init( info: Implementation, transport: Transport, capabilities: ClientCapabilityHandlers = .init()) async throws { try await self.init( - samplingRequestHandler: capabilities.sampling?.handler, - listRootRequestHandler: capabilities.roots?.handler, + capabilities: capabilities, connection: try MCPClientConnection( info: info, capabilities: ClientCapabilities( @@ -30,15 +29,18 @@ public actor MCPClient: MCPClientInterface { } init( - samplingRequestHandler: SamplingRequestHandler? = nil, - listRootRequestHandler: ListRootsRequestHandler? = nil, + capabilities: ClientCapabilityHandlers, connection: MCPClientConnectionInterface) async throws { // Initialize the connection, and then update server capabilities. self.connection = connection - self.samplingRequestHandler = samplingRequestHandler - self.listRootRequestHandler = listRootRequestHandler - try await connect() + self.capabilities = capabilities + serverInfo = try await Self.connectToServer(connection: connection) + + await startListeningToNotifications() + await startListeningToRequests() + startPinging() + Task { try await self.updateTools() } Task { try await self.updatePrompts() } Task { try await self.updateResources() } @@ -47,25 +49,27 @@ public actor MCPClient: MCPClientInterface { // MARK: Public - public var tools: ReadOnlyCurrentValueSubject, Never> { + public private(set) var serverInfo: ServerInfo + + public var tools: ReadOnlyCurrentValueSubject, Never> { get async { await .init(_tools.compactMap { $0 }.removeDuplicates().eraseToAnyPublisher()) } } - public var prompts: ReadOnlyCurrentValueSubject, Never> { + public var prompts: ReadOnlyCurrentValueSubject, Never> { get async { await .init(_prompts.compactMap { $0 }.removeDuplicates().eraseToAnyPublisher()) } } - public var resources: ReadOnlyCurrentValueSubject, Never> { + public var resources: ReadOnlyCurrentValueSubject, Never> { get async { await .init(_resources.compactMap { $0 }.removeDuplicates().eraseToAnyPublisher()) } } - public var resourceTemplates: ReadOnlyCurrentValueSubject, Never> { + public var resourceTemplates: ReadOnlyCurrentValueSubject, Never> { get async { await .init(_resourceTemplates.compactMap { $0 }.removeDuplicates().eraseToAnyPublisher()) } @@ -77,9 +81,8 @@ public actor MCPClient: MCPClientInterface { progressHandler: ((Double, Double?) -> Void)? = nil) async throws -> CallToolResult { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.tools != nil else { - throw MCPClientError.notSupported + guard serverInfo.capabilities.tools != nil else { + throw MCPError.notSupported } var progressToken: String? = nil if let progressHandler { @@ -87,7 +90,7 @@ public actor MCPClient: MCPClientInterface { progressHandlers[token] = progressHandler progressToken = token } - let result = try await connectionInfo.connection.call( + let result = try await connection.call( toolName: name, arguments: arguments, progressToken: progressToken.map { .string($0) }) @@ -103,19 +106,17 @@ public actor MCPClient: MCPClientInterface { } public func getPrompt(named name: String, arguments: JSON? = nil) async throws -> GetPromptResult { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.prompts != nil else { - throw MCPClientError.notSupported + guard serverInfo.capabilities.prompts != nil else { + throw MCPError.notSupported } - return try await connectionInfo.connection.getPrompt(.init(name: name, arguments: arguments)) + return try await connection.getPrompt(.init(name: name, arguments: arguments)) } public func readResource(uri: String) async throws -> ReadResourceResult { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.resources != nil else { - throw MCPClientError.notSupported + guard serverInfo.capabilities.resources != nil else { + throw MCPError.notSupported } - return try await connectionInfo.connection.readResource(.init(uri: uri)) + return try await connection.readResource(.init(uri: uri)) } // MARK: Internal @@ -124,34 +125,39 @@ public actor MCPClient: MCPClientInterface { // MARK: Private - private struct ConnectionInfo { - let connection: MCPClientConnectionInterface - let serverInfo: Implementation - let serverCapabilities: ServerCapabilities - } + private let capabilities: ClientCapabilityHandlers - private let samplingRequestHandler: SamplingRequestHandler? - private let listRootRequestHandler: ListRootsRequestHandler? + private let _tools = CurrentValueSubject?, Never>(nil) + private let _prompts = CurrentValueSubject?, Never>(nil) + private let _resources = CurrentValueSubject?, Never>(nil) + private let _resourceTemplates = CurrentValueSubject?, Never>(nil) - private var connectionInfo: ConnectionInfo? + private var progressHandlers = [String: (progress: Double, total: Double?) -> Void]() + + private static func connectToServer(connection: MCPClientConnectionInterface) async throws -> ServerInfo { + let response = try await connection.initialize() + guard response.protocolVersion == MCP.protocolVersion else { + throw MCPClientError.versionMismatch + } - private let _tools = CurrentValueSubject?, Never>(nil) - private let _prompts = CurrentValueSubject?, Never>(nil) - private let _resources = CurrentValueSubject?, Never>(nil) - private let _resourceTemplates = CurrentValueSubject?, Never>(nil) + try await connection.acknowledgeInitialization() - private var progressHandlers = [String: (progress: Double, total: Double?) -> Void]() + return ServerInfo( + info: response.serverInfo, + capabilities: response.capabilities) + } - private func startListeningToNotifications() async throws { - let connectionInfo = try getConnectionInfo() - let notifications = await connectionInfo.connection.notifications + private func startListeningToNotifications() async { + let notifications = await connection.notifications Task { [weak self] in for await notification in notifications { switch notification { case .cancelled: + // TODO: Handle this break case .loggingMessage: + // TODO: Handle this break case .progress(let progressParams): @@ -169,15 +175,15 @@ public actor MCPClient: MCPClientInterface { try await self?.updateResources() case .resourceUpdated: + // TODO: Handle this break } } } } - private func startListeningToRequests() async throws { - let connectionInfo = try getConnectionInfo() - let requests = await connectionInfo.connection.requestsToHandle + private func startListeningToRequests() async { + let requests = await connection.requestsToHandle Task { [weak self] in for await(request, completion) in requests { guard let self else { @@ -188,117 +194,77 @@ public actor MCPClient: MCPClientInterface { } switch request { case .createMessage(let params): - if let handler = await self.samplingRequestHandler { - do { - completion(.success(try await handler(params))) - } catch { - completion(.failure(.init( - code: JRPCErrorCodes.internalError.rawValue, - message: error.localizedDescription))) - } - } else { - completion(.failure(.init( - code: JRPCErrorCodes.invalidRequest.rawValue, - message: "Sampling is not supported by this client"))) - } - + await completion(handle(request: params, with: capabilities.sampling?.handler, "Sampling")) case .listRoots(let params): - if let handler = await self.listRootRequestHandler { - do { - completion(.success(try await handler(params))) - } catch { - completion(.failure(.init( - code: JRPCErrorCodes.internalError.rawValue, - message: error.localizedDescription))) - } - } else { - completion(.failure(.init( - code: JRPCErrorCodes.invalidRequest.rawValue, - message: "Listing roots is not supported by this client"))) - } + await completion(handle(request: params, with: capabilities.roots?.handler, "Listing roots")) } } } } + private func handle( + request params: Params, + with handler: ((Params) async throws -> some Encodable)?, + _ requestName: String) + async -> AnyJRPCResponse + { + if let handler { + do { + return .success(try await handler(params)) + } catch { + return .failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: error.localizedDescription)) + } + } else { + return .failure(.init( + code: JRPCErrorCodes.invalidRequest.rawValue, + message: "\(requestName) is not supported by this server")) + } + } + private func startPinging() { // TODO } private func updateTools() async throws { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.tools != nil else { + guard serverInfo.capabilities.tools != nil else { // Tool calling not supported _tools.send(.notSupported) return } - let tools = try await connectionInfo.connection.listTools() + let tools = try await connection.listTools() _tools.send(.supported(tools)) } private func updatePrompts() async throws { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.prompts != nil else { + guard serverInfo.capabilities.prompts != nil else { // Prompts calling not supported _prompts.send(.notSupported) return } - let prompts = try await connectionInfo.connection.listPrompts() + let prompts = try await connection.listPrompts() _prompts.send(.supported(prompts)) } private func updateResources() async throws { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.resources != nil else { + guard serverInfo.capabilities.resources != nil else { // Resources calling not supported _resources.send(.notSupported) return } - let resources = try await connectionInfo.connection.listResources() + let resources = try await connection.listResources() _resources.send(.supported(resources)) } private func updateResourceTemplates() async throws { - let connectionInfo = try getConnectionInfo() - guard connectionInfo.serverCapabilities.resources != nil else { + guard serverInfo.capabilities.resources != nil else { // Resources calling not supported _resourceTemplates.send(.notSupported) return } - let resourceTemplates = try await connectionInfo.connection.listResourceTemplates() + let resourceTemplates = try await connection.listResourceTemplates() _resourceTemplates.send(.supported(resourceTemplates)) } - private func connect() async throws { - let response = try await connection.initialize() - guard response.protocolVersion == MCP.protocolVersion else { - throw MCPClientError.versionMismatch - } - - connectionInfo = ConnectionInfo( - connection: connection, - serverInfo: response.serverInfo, - serverCapabilities: response.capabilities) - - try await connection.acknowledgeInitialization() - try await startListeningToNotifications() - try await startListeningToRequests() - startPinging() - } - - private func getConnectionInfo() throws -> ConnectionInfo { - guard let connectionInfo else { - throw MCPClientInternalError.internalStateInconsistency - } - return connectionInfo - } - -} - -// MARK: - MCPClientInternalError - -public enum MCPClientInternalError: Error { - case alreadyConnectedOrConnecting - case notConnected - case internalStateInconsistency } diff --git a/MCPClient/Sources/MCPClientConnection.swift b/MCPClient/Sources/MCPClientConnection.swift index da0bb66..66d445a 100644 --- a/MCPClient/Sources/MCPClientConnection.swift +++ b/MCPClient/Sources/MCPClientConnection.swift @@ -1,10 +1,7 @@ import Foundation import JSONRPC -import MCPShared +import MCPInterface import MemberwiseInit -import OSLog - -private let mcpLogger = Logger(subsystem: Bundle.main.bundleIdentifier.map { "\($0).mcp" } ?? "com.app.mcp", category: "mcp") // MARK: - MCPClientConnection @@ -18,43 +15,21 @@ public actor MCPClientConnection: MCPClientConnectionInterface { transport: Transport) throws { + // Note: ideally we would subclass `MCPConnection`. However Swift actors don't support inheritance. + _connection = try MCPConnection(transport: transport) self.info = info self.capabilities = capabilities - jrpcSession = JSONRPCSession(channel: transport) - - var sendNotificationToStream: (ServerNotification) -> Void = { _ in } - notifications = AsyncStream() { continuation in - sendNotificationToStream = { continuation.yield($0) } - } - self.sendNotificationToStream = sendNotificationToStream - - // A bit hard to read... When a server request is received (sent to `askForRequestToBeHandle`), we yield it to the stream - // where we expect someone to be listening and handling the request. The handler then calls the completion `requestContinuation` - // which will be sent back as an async response to `askForRequestToBeHandle`. - var askForRequestToBeHandle: ((ServerRequest) async -> AnyJRPCResponse)? = nil - requestsToHandle = AsyncStream() { streamContinuation in - askForRequestToBeHandle = { request in - await withCheckedContinuation { (requestContinuation: CheckedContinuation) in - streamContinuation.yield((request, { response in - requestContinuation.resume(returning: response) - })) - } - } - } - self.askForRequestToBeHandle = askForRequestToBeHandle - - Task { await listenToIncomingMessages() } } // MARK: Public - public private(set) var notifications: AsyncStream - public private(set) var requestsToHandle: AsyncStream - public let info: Implementation public let capabilities: ClientCapabilities + public var notifications: AsyncStream { _connection.notifications } + public var requestsToHandle: AsyncStream { _connection.requestsToHandle } + public func initialize() async throws -> InitializeRequest.Result { let params = InitializeRequest.Params( protocolVersion: MCP.protocolVersion, @@ -77,7 +52,7 @@ public actor MCPClientConnection: MCPClientConnectionInterface { } public func listPrompts() async throws -> [Prompt] { - try await jrpcSession.send(ListPromptsRequest.Params(), getResults: { $0.prompts }, req: ListPromptsRequest.self) + try await jrpcSession.send(nil, getResults: { $0.prompts }, req: ListPromptsRequest.self) } public func getPrompt(_ params: GetPromptRequest.Params) async throws -> GetPromptRequest.Result { @@ -85,7 +60,7 @@ public actor MCPClientConnection: MCPClientConnectionInterface { } public func listResources() async throws -> [Resource] { - try await jrpcSession.send(ListResourcesRequest.Params(), getResults: { $0.resources }, req: ListResourcesRequest.self) + try await jrpcSession.send(nil, getResults: { $0.resources }, req: ListResourcesRequest.self) } public func readResource(_ params: ReadResourceRequest.Params) async throws -> ReadResourceRequest.Result { @@ -102,13 +77,13 @@ public actor MCPClientConnection: MCPClientConnectionInterface { public func listResourceTemplates() async throws -> [ResourceTemplate] { try await jrpcSession.send( - ListResourceTemplatesRequest.Params(), + nil, getResults: { $0.resourceTemplates }, req: ListResourceTemplatesRequest.self) } public func listTools() async throws -> [Tool] { - try await jrpcSession.send(ListToolsRequest.Params(), getResults: { $0.tools }, req: ListToolsRequest.self) + try await jrpcSession.send(nil, getResults: { $0.tools }, req: ListToolsRequest.self) } public func call( @@ -129,100 +104,15 @@ public actor MCPClientConnection: MCPClientConnectionInterface { try await jrpcSession.send(SetLevelRequest(params: params)) } - public func log(_ params: LoggingMessageNotification.Params) async throws { - try await jrpcSession.send(LoggingMessageNotification(params: params)) + public func notifyRootsListChanged() async throws { + try await jrpcSession.send(RootsListChangedNotification()) } // MARK: Private - private var sendNotificationToStream: ((ServerNotification) -> Void) = { _ in } - - private var askForRequestToBeHandle: ((ServerRequest) async -> AnyJRPCResponse)? = nil - - private let jrpcSession: JSONRPCSession - - private var eventHandlers = [String: (JSONRPCEvent) -> Void]() - - private func listenToIncomingMessages() async { - let events = await jrpcSession.eventSequence - Task { [weak self] in - for await event in events { - await self?.handle(receptionOf: event) - } - } - } - - private func handle(receptionOf event: JSONRPCEvent) { - switch event { - case .notification(_, let data): - do { - let notification = try JSONDecoder().decode(ServerNotification.self, from: data) - sendNotificationToStream(notification) - } catch { - mcpLogger - .error("Failed to decode notification \(String(data: data, encoding: .utf8) ?? "invalid data", privacy: .public)") - } - - case .request(_, let handler, let data): - // Respond to ping from the server - Task { await handler(handle(receptionOf: data)) } - - case .error(let error): - mcpLogger.error("Received error from server: \(error, privacy: .public)") - } - } - - private func handle(receptionOf request: Data) async -> AnyJRPCResponse { - if let serverRequest = try? JSONDecoder().decode(ServerRequest.self, from: request) { - guard let askForRequestToBeHandle else { - mcpLogger.error("Unable to handle request. The client MCP connection has not been set properly") - return .failure(.init( - code: JRPCErrorCodes.methodNotFound.rawValue, - message: "Unable to handle request. The client MCP connection has not been set properly")) - } - return await askForRequestToBeHandle(serverRequest) - } else if (try? JSONDecoder().decode(PingRequest.self, from: request)) != nil { - // Respond to ping from the server - return .success(PingRequest.Result()) - } - mcpLogger - .error( - "Received unknown request from server: \(String(data: request, encoding: .utf8) ?? "invalid data", privacy: .public)") - return .failure(.init( - code: JRPCErrorCodes.methodNotFound.rawValue, - message: "The request could not be decoded to a known type")) - } + private let _connection: MCPConnection -} - -extension JSONRPCSession { - func send(_ request: Req) async throws -> Req.Result { - let response: JSONRPCResponse = try await sendRequest(request.params, method: request.method) - return try response.content.get() + private var jrpcSession: JSONRPCSession { + _connection.jrpcSession } - - func send(_ notification: some MCPShared.Notification) async throws { - try await sendNotification(notification.params, method: notification.method) - } - - func send( - _ params: Req.Params, - getResults: (Req.Result) -> [Result], - req _: Req.Type = Req.self) - async throws -> [Result] - { - var cursor: String? = nil - var results = [Result]() - - while true { - let request = Req(params: params.updatingCursor(to: cursor)) - let response = try await send(request) - results.append(contentsOf: getResults(response)) - cursor = response.nextCursor - if cursor == nil { - return results - } - } - } - } diff --git a/MCPClient/Sources/MCPClientConnectionInterface.swift b/MCPClient/Sources/MCPClientConnectionInterface.swift index 24155ac..c46842c 100644 --- a/MCPClient/Sources/MCPClientConnectionInterface.swift +++ b/MCPClient/Sources/MCPClientConnectionInterface.swift @@ -1,14 +1,10 @@ import JSONRPC -import MCPShared - -public typealias AnyJRPCResponse = Swift.Result - -public typealias HandleServerRequest = (ServerRequest, (AnyJRPCResponse) -> Void) +import MCPInterface // MARK: - MCPClientConnectionInterface -/// The MCP JRPC Bridge is a stateless interface to the MCP server that provides a higher level Swift interface. +/// This is a stateless interface to the MCP server that provides a higher level Swift interface. /// It does not implement any of the stateful behaviors of the MCP server, such as subscribing to changes, detecting connection health, /// ensuring that the connection has been initialized before being used etc. /// @@ -59,6 +55,13 @@ public protocol MCPClientConnectionInterface { func requestCompletion(_ params: CompleteRequest.Params) async throws -> CompleteRequest.Result /// Set the log level that the server should use for this connection. func setLogLevel(_ params: SetLevelRequest.Params) async throws -> SetLevelRequest.Result - /// Log a message to the server. - func log(_ params: LoggingMessageNotification.Params) async throws + /// Send a roots list updated notification to the server + func notifyRootsListChanged() async throws +} + +// MARK: - ServerInfo + +public struct ServerInfo { + public let info: Implementation + public let capabilities: ServerCapabilities } diff --git a/MCPClient/Sources/MCPClientInterface.swift b/MCPClient/Sources/MCPClientInterface.swift index 710011c..4a90886 100644 --- a/MCPClient/Sources/MCPClientInterface.swift +++ b/MCPClient/Sources/MCPClientInterface.swift @@ -1,12 +1,42 @@ import JSONRPC -import MCPShared +import MCPInterface import MemberwiseInit // MARK: - MCPClientInterface -public protocol MCPClientInterface { } - -public typealias Transport = DataChannel +public protocol MCPClientInterface { + var serverInfo: ServerInfo { get async } + + /// The tools supported by the server, if tools are supported. + var tools: ReadOnlyCurrentValueSubject, Never> { get async } + /// The prompts supported by the server, if prompts are supported. + var prompts: ReadOnlyCurrentValueSubject, Never> { get async } + /// The resource provided by the server, if resources are supported. + var resources: ReadOnlyCurrentValueSubject, Never> { get async } + /// The resource templates supported by the server, if resources are supported. + var resourceTemplates: ReadOnlyCurrentValueSubject, Never> { get async } + + /// Invoke a tool provided by the server. + /// - Parameters: + /// - name: The name of the tool to call. + /// - arguments: The arguments to pass to the tool. + /// - progressHandler: A closure that will be called with the progress of the tool execution. The first parameter is the current progress, and the second the total progress to reach if known. + func callTool( + named name: String, + arguments: JSON?, + progressHandler: ((Double, Double?) -> Void)?) async throws -> CallToolResult + + /// Get a prompt provided by the server. + /// - Parameters: + /// - name: The name of the prompt to get. + /// - arguments: Arguments to use for templating the prompt. + func getPrompt(named name: String, arguments: JSON?) async throws -> GetPromptResult + + /// Read a specific resource URI. + /// - Parameters: + /// - uri: The URI of the resource to read. + func readResource(uri: String) async throws -> ReadResourceResult +} // MARK: - ClientCapabilityHandlers @@ -15,44 +45,14 @@ public typealias Transport = DataChannel /// Note: This is similar to `ClientCapabilities`, with the addition of the handler function. @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct ClientCapabilityHandlers { - public let roots: CapabilityHandler? - public let sampling: CapabilityHandler? + public let roots: CapabilityHandler? + public let sampling: CapabilityHandler? // TODO: add experimental } // MARK: - MCPClientError public enum MCPClientError: Error { - case alreadyConnectedOrConnecting - case notConnected - case notSupported case versionMismatch case toolCallError(executionErrors: [CallToolResult.ExecutionError]) } - -// MARK: - ServerCapabilityState - -public enum ServerCapabilityState: Equatable { - case supported(_ capability: Capability) - case notSupported -} - -extension ServerCapabilityState { - public var capability: Capability? { - switch self { - case .supported(let capability): - return capability - case .notSupported: - return nil - } - } - - public func get() throws -> Capability { - switch self { - case .supported(let capability): - return capability - case .notSupported: - throw MCPClientError.notSupported - } - } -} diff --git a/MCPClient/Sources/MockMCPConnection.swift b/MCPClient/Sources/MockMCPConnection.swift index 58257cb..8e6f973 100644 --- a/MCPClient/Sources/MockMCPConnection.swift +++ b/MCPClient/Sources/MockMCPConnection.swift @@ -1,5 +1,5 @@ -import MCPShared +import MCPInterface #if DEBUG // TODO: move to a test helper package @@ -85,6 +85,9 @@ class MockMCPClientConnection: MCPClientConnectionInterface { /// This function is called when `log` is called var logStub: ((LoggingMessageNotification.Params) async throws -> Void)? + /// This function is called when `notifyRootsListChanged` is called + var notifyRootsListChangedStub: (() async throws -> Void)? + func initialize() async throws -> InitializeRequest.Result { if let initializeStub { return try await initializeStub() @@ -189,15 +192,17 @@ class MockMCPClientConnection: MCPClientConnectionInterface { } throw MockMCPClientConnectionError.notImplemented(function: "log") } + + func notifyRootsListChanged() async throws { + if let notifyRootsListChangedStub { + return try await notifyRootsListChangedStub() + } + throw MockMCPClientConnectionError.notImplemented(function: "notifyRootsListChanged") + } + } enum MockMCPClientConnectionError: Error { case notImplemented(function: String) } - -extension Transport { - static var noop: Transport { - .init(writeHandler: { _ in }, dataSequence: DataSequence { _ in }) - } -} #endif diff --git a/MCPClient/Tests/MCPClient/CallTool.swift b/MCPClient/Tests/CallTool.swift similarity index 96% rename from MCPClient/Tests/MCPClient/CallTool.swift rename to MCPClient/Tests/CallTool.swift index 82bba44..7cbba8a 100644 --- a/MCPClient/Tests/MCPClient/CallTool.swift +++ b/MCPClient/Tests/CallTool.swift @@ -1,5 +1,5 @@ -import MCPShared +import MCPInterface import SwiftTestingUtils import Testing @testable import MCPClient @@ -31,7 +31,7 @@ extension MCPClientTestSuite { } let sut = try await createMCPClient() - await #expect(throws: MCPClientError.self) { try await sut.callTool(named: "get_weather") } + await #expect(throws: MCPError.self) { try await sut.callTool(named: "get_weather") } } @Test("call tool fails when there has been an execution error") diff --git a/MCPClient/Tests/MCPClient/Initialization.swift b/MCPClient/Tests/Initialization.swift similarity index 62% rename from MCPClient/Tests/MCPClient/Initialization.swift rename to MCPClient/Tests/Initialization.swift index 6293a1d..41e2721 100644 --- a/MCPClient/Tests/MCPClient/Initialization.swift +++ b/MCPClient/Tests/Initialization.swift @@ -1,6 +1,7 @@ import Foundation -import MCPShared +import MCPInterface +import MCPTestingUtils import SwiftTestingUtils import Testing @testable import MCPClient @@ -34,58 +35,65 @@ extension MCPClientTestSuite { @Test("initialization with capabilities") func test_initializationWithCapabilities_sendsCorrectCapabilities() async throws { let transport = MockTransport() - let client = try await MCPClientConnectionTest.assert(executing: { - try await MCPClient( - info: Implementation(name: "test-client", version: "1.0.0"), - transport: transport.dataChannel, - capabilities: ClientCapabilityHandlers( - roots: .init(info: .init(listChanged: true), handler: { _ in .init(roots: []) }), - sampling: .init(handler: { _ in .init(role: .user, content: .text(.init(text: "hello")), model: "claude") }))) - }, triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "initialize", - "params" : { - "capabilities" : { - "roots" : { - "listChanged" : true + let client = try await MCPTestingUtils.assert( + clientTransport: transport, + serverTransport: nil, + serverRequestsHandler: connection.requestsToHandle, + clientRequestsHandler: nil, + serverNotifications: connection.notifications, + clientNotifications: nil, + executing: { + try await MCPClient( + info: Implementation(name: "test-client", version: "1.0.0"), + transport: transport.dataChannel, + capabilities: ClientCapabilityHandlers( + roots: .init(info: .init(listChanged: true), handler: { _ in .init(roots: []) }), + sampling: .init(handler: { _ in .init(role: .user, content: .text(.init(text: "hello")), model: "claude") }))) + }, triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "initialize", + "params" : { + "capabilities" : { + "roots" : { + "listChanged" : true + }, + "sampling" : { + + } }, - "sampling" : { - - } - }, - "clientInfo" : { - "name" : "test-client", - "version" : "1.0.0" - }, - "protocolVersion" : "2024-11-05" + "clientInfo" : { + "name" : "test-client", + "version" : "1.0.0" + }, + "protocolVersion" : "2024-11-05" + } } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "protocolVersion": "2024-11-05", - "capabilities": {}, - "serverInfo": { - "name": "ExampleServer", - "version": "1.0.0" + """), + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": { + "name": "ExampleServer", + "version": "1.0.0" + } } } + """), + .clientSendsJrpc(""" + { + "jsonrpc" : "2.0", + "method" : "notifications/initialized", + "params" : null } - """), - .request(""" - { - "jsonrpc" : "2.0", - "method" : "notifications/initialized", - "params" : null - } - """), - ], with: transport) + """), + ]) let clientCapabilities = await(client.connection as? MCPClientConnection)?.capabilities #expect(clientCapabilities?.roots?.listChanged == true) diff --git a/MCPClient/Tests/MCPClient/MCPClientTests.swift b/MCPClient/Tests/MCPClientTests.swift similarity index 79% rename from MCPClient/Tests/MCPClient/MCPClientTests.swift rename to MCPClient/Tests/MCPClientTests.swift index 42c774b..0955000 100644 --- a/MCPClient/Tests/MCPClient/MCPClientTests.swift +++ b/MCPClient/Tests/MCPClientTests.swift @@ -1,5 +1,5 @@ -import MCPShared +import MCPInterface import Testing @testable import MCPClient @@ -50,14 +50,17 @@ class MCPClientTest { let connection: MockMCPClientConnection func createMCPClient( - samplingRequestHandler: SamplingRequestHandler? = nil, - listRootRequestHandler: ListRootsRequestHandler? = nil, + samplingRequestHandler: CreateSamplingMessageRequest.Handler? = nil, + listRootRequestHandler: ListRootsRequest.Handler? = nil, connection: MCPClientConnectionInterface? = nil) async throws -> MCPClient { try await MCPClient( - samplingRequestHandler: samplingRequestHandler, - listRootRequestHandler: listRootRequestHandler, + capabilities: ClientCapabilityHandlers( + roots: listRootRequestHandler.map { .init( + info: .init(listChanged: true), + handler: $0) }, + sampling: samplingRequestHandler.map { .init(handler: $0) }), connection: connection ?? self.connection) } diff --git a/MCPClient/Tests/MCPConnection/CallToolTests.swift b/MCPClient/Tests/MCPConnection/CallToolTests.swift deleted file mode 100644 index 748cb9b..0000000 --- a/MCPClient/Tests/MCPConnection/CallToolTests.swift +++ /dev/null @@ -1,172 +0,0 @@ - -import JSONRPC -import MCPShared -import Testing -@testable import MCPClient - -// MARK: - MCPClientConnectionTestSuite.CallToolTests - -extension MCPClientConnectionTestSuite { - final class CallToolTests: MCPClientConnectionTest { - - // MARK: Internal - - @Test("call tool") - func test_callTool() async throws { - let weathers = try await assert(executing: { - try await self.sut.call( - toolName: self.tool.name, - arguments: .object([ - "location": .string("New York"), - ]), - progressToken: .string("toolCallId")) - .content - .map { $0.text } - }, sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": { - "_meta" : { - "progressToken" : "toolCallId" - }, - "name": "get_weather", - "arguments": { - "location": "New York" - } - } - } - """, receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "content": [{ - "type": "text", - "text": "Current weather in New York:\\nTemperature: 72°F\\nConditions: Partly cloudy" - }] - } - } - """) - - #expect(weathers.map { $0?.text } == ["Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"]) - } - - @Test("protocol error") - func test_protocolError() async throws { - await assert( - executing: { - _ = try await self.sut.call(toolName: self.tool.name, arguments: .object([ - "location": .string("New York"), - ])) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": { - "name": "get_weather", - "arguments": { - "location": "New York" - } - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32602, - "message": "Unknown tool: invalid_tool_name" - } - } - """) { error in - guard let error = error as? JSONRPCResponseError else { - Issue.record("Unexpected error type: \(error)") - return - } - - #expect(error.code == -32602) - #expect(error.message == "Unknown tool: invalid_tool_name") - #expect(error.data == nil) - } - } - - @Test("tool call error") - func test_toolCallError() async throws { - let response = try await assert( - executing: { - try await self.sut.call( - toolName: self.tool.name, - arguments: .object([ - "location": .string("New York"), - ])) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "tools/call", - "params": { - "name": "get_weather", - "arguments": { - "location": "New York" - } - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "content": [{ - "type": "text", - "text": "Failed to fetch weather data: API rate limit exceeded" - }], - "isError": true - } - } - """) - #expect(response.isError == true) - #expect(response.content.map { $0.text?.text } == ["Failed to fetch weather data: API rate limit exceeded"]) -// { error in -// guard let error = error as? MCPClientError else { -// Issue.record("Unexpected error type: \(error)") -// return -// } -// -// switch error { -// case .toolCallError(let errors): -// #expect(errors.map { $0.text } == ["Failed to fetch weather data: API rate limit exceeded"]) -// default: -// Issue.record("Unexpected error type: \(error)") -// } -// } - } - - // MARK: Private - - private let tool = Tool( - name: "get_weather", - description: "Get current weather information for a location", - inputSchema: .array([])) - - } -} - -// MARK: - ToolArguments - -private struct ToolArguments: Encodable { - let location: String -} - -// MARK: - ToolResponse - -private struct ToolResponse: Decodable { - let type: String - let text: String -} diff --git a/MCPClient/Tests/MCPConnection/ClientInitializationTests.swift b/MCPClient/Tests/MCPConnection/ClientInitializationTests.swift deleted file mode 100644 index 5021f56..0000000 --- a/MCPClient/Tests/MCPConnection/ClientInitializationTests.swift +++ /dev/null @@ -1,159 +0,0 @@ - -import JSONRPC -import MCPShared -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ClientInitializationTests: MCPClientConnectionTest { - - @Test("initialize connection") - func test_initializeConnection() async throws { - let initializationResult = try await assert( - executing: { - try await self.sut.initialize() - }, - sends: """ - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "initialize", - "params" : { - "capabilities" : { - "roots" : { - "listChanged" : true - }, - "sampling" : {} - }, - "clientInfo" : { - "name" : "TestClient", - "version" : "1.0.0" - }, - "protocolVersion" : "\(MCP.protocolVersion)" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "protocolVersion": "2024-11-05", - "capabilities": { - "logging": {}, - "prompts": { - "listChanged": true - }, - "resources": { - "subscribe": true, - "listChanged": true - }, - "tools": { - "listChanged": true - } - }, - "serverInfo": { - "name": "ExampleServer", - "version": "1.0.0" - } - } - } - """) - #expect(initializationResult.serverInfo.name == "ExampleServer") - } - - @Test("initialize with error") - func test_initializeWithError() async throws { - await assert( - executing: { _ = try await self.sut.initialize() }, - sends: """ - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "initialize", - "params" : { - "capabilities" : { - "roots" : { - "listChanged" : true - }, - "sampling" : {} - }, - "clientInfo" : { - "name" : "TestClient", - "version" : "1.0.0" - }, - "protocolVersion" : "\(MCP.protocolVersion)" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32602, - "message": "Unsupported protocol version", - "data": { - "supported": ["2024-11-05"], - "requested": "1.0.0" - } - } - } - """) { error in - guard let notificationError = try? #require(error as? JSONRPCResponseError) else { - Issue.record("unexpected error type \(error)") - return - } - - #expect(notificationError.code == -32602) - #expect(notificationError.message == "Unsupported protocol version") - #expect(notificationError.data == [ - "supported": ["2024-11-05"], - "requested": "1.0.0", - ]) - } - } - - @Test("initialization acknowledgement") - func test_initializationAcknowledgement() async throws { - let notificationReceived = expectation(description: "notification received") - - transport.expect(messages: [ - { sendMessage in - sendMessage(""" - { - "jsonrpc" : "2.0", - "method" : "notifications/initialized", - "params" : null - } - """) - notificationReceived.fulfill() - }, - ]) - - try await sut.acknowledgeInitialization() - try await fulfillment(of: [notificationReceived]) - } - - @Test("deinitialization") - func test_deinitializationReleasesReferencedObjects() async throws { - // initialize the MCP connection. This will create a JRPC session. - try await test_initializeConnection() - - // Get pointers to values that we want to see dereferenced when MCPClientConnection is dereferenced - weak var weakTransport = transport - #expect(weakTransport != nil) - - // Replace the values referenced by this test class. - transport = MockTransport() - sut = try await MCPClientConnection( - info: sut.info, - capabilities: sut.capabilities, - transport: transport.dataChannel) - - // Verifies that the referenced objects are released. - #expect(weakTransport == nil) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/CompletionTests.swift b/MCPClient/Tests/MCPConnection/CompletionTests.swift deleted file mode 100644 index 169a3db..0000000 --- a/MCPClient/Tests/MCPConnection/CompletionTests.swift +++ /dev/null @@ -1,51 +0,0 @@ - -import JSONRPC -import MCPShared -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class CompletionTests: MCPClientConnectionTest { - @Test("request completion") - func test_requestCompletion() async throws { - let resources = try await assert( - executing: { - try await self.sut.requestCompletion(CompleteRequest.Params( - ref: .prompt(PromptReference(name: "code_review")), - argument: .init(name: "language", value: "py"))) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "completion/complete", - "params": { - "ref": { - "type": "ref/prompt", - "name": "code_review" - }, - "argument": { - "name": "language", - "value": "py" - } - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "completion": { - "values": ["python", "pytorch", "pyside"], - "total": 10, - "hasMore": true - } - } - } - """) - #expect(resources.completion.values == ["python", "pytorch", "pyside"]) - #expect(resources.completion.hasMore == true) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/GetPromptTests.swift b/MCPClient/Tests/MCPConnection/GetPromptTests.swift deleted file mode 100644 index 79ca134..0000000 --- a/MCPClient/Tests/MCPConnection/GetPromptTests.swift +++ /dev/null @@ -1,160 +0,0 @@ - -import JSONRPC -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class GetPromptTests: MCPClientConnectionTest { - - @Test("get one prompt") - func test_getOnePrompt() async throws { - let prompts = try await assert( - executing: { - try await self.sut.getPrompt(.init(name: "code_review", arguments: .object([ - "code": .string("def hello():\n print('world')"), - ]))) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "prompts/get", - "params": { - "name": "code_review", - "arguments": { - "code": "def hello():\\n print('world')" - } - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "description": "Code review prompt", - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "Please review this Python code:\\ndef hello():\\n print('world')" - } - } - ] - } - } - """) - #expect( - prompts.messages - .map { $0.content.text?.text } == ["Please review this Python code:\ndef hello():\n print('world')"]) - } - - @Test("get prompts of different types") - func test_getPromptsOfDifferentTypes() async throws { - let prompts = try await assert( - executing: { - try await self.sut.getPrompt(.init(name: "code_review", arguments: .object([ - "code": .string("def hello():\n print('world')"), - ]))) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "prompts/get", - "params": { - "name": "code_review", - "arguments": { - "code": "def hello():\\n print('world')" - } - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "description": "Code review prompt", - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "Please review this Python code:\\ndef hello():\\n print('world')" - } - }, - { - "role": "user", - "content": { - "type": "image", - "data": "base64-encoded-image-data", - "mimeType": "image/png" - } - }, - { - "role": "user", - "content": { - "type": "resource", - "resource": { - "uri": "resource://example", - "mimeType": "text/plain", - "text": "Resource content" - } - } - } - ] - } - } - """) - #expect( - prompts.messages.map { $0.content.text?.text } == - ["Please review this Python code:\ndef hello():\n print('world')", nil, nil]) - #expect(prompts.messages.map { $0.content.image?.data } == [nil, "base64-encoded-image-data", nil]) - #expect(prompts.messages.map { $0.content.embeddedResource?.resource.text?.text } == [nil, nil, "Resource content"]) - } - - @Test("error when getting prompt") - func test_errorWhenGettingPrompt() async throws { - await assert( - executing: { try await self.sut.getPrompt(.init(name: "non_existent_code_review")) }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "prompts/get", - "params": { - "name": "non_existent_code_review" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32002, - "message": "Prompt not found", - "data": { - "name": "non_existent_code_review" - } - } - } - """, - andFailsWith: { error in - guard let error = error as? JSONRPCResponseError else { - Issue.record("Unexpected error type: \(error)") - return - } - - #expect(error.code == -32002) - #expect(error.message == "Prompt not found") - #expect(error.data == .hash([ - "name": .string("non_existent_code_review"), - ])) - }) - } - - } -} diff --git a/MCPClient/Tests/MCPConnection/HandleServerRequestTests.swift b/MCPClient/Tests/MCPConnection/HandleServerRequestTests.swift deleted file mode 100644 index 4bb55ae..0000000 --- a/MCPClient/Tests/MCPConnection/HandleServerRequestTests.swift +++ /dev/null @@ -1,117 +0,0 @@ - -import JSONRPC -import MCPShared -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class HandleServerRequestTests: MCPClientConnectionTest { - - @Test("list roots") - func test_listRoots() async throws { - try await assert( - executing: { - // Handle the first incoming request - for await(request, completion) in await self.sut.requestsToHandle { - switch request { - case .listRoots(let params): - #expect(params == nil) - default: - Issue.record("Unexpected server request: \(request)") - } - completion(.success(ListRootsResult(roots: [.init(uri: "file:///home/user/projects/myproject", name: "My Project")]))) - break - } - }, triggers: [ - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "method": "roots/list" - } - """), - .request(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "roots": [ - { - "uri": "file:///home/user/projects/myproject", - "name": "My Project" - } - ] - } - } - """), - ]) - } - - @Test("create sampling message") - func test_createMessage() async throws { - try await assert( - executing: { - // Handle the first incoming request - for await(request, completion) in await self.sut.requestsToHandle { - switch request { - case .createMessage(let params): - #expect(params.maxTokens == 100) - default: - Issue.record("Unexpected server request: \(request)") - } - completion(.success(CreateMessageResult( - role: .assistant, - content: .text(.init(text: "The capital of France is Paris.")), - model: "claude-3-sonnet-20240307", - stopReason: "endTurn"))) - break - } - }, triggers: [ - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "method": "sampling/createMessage", - "params": { - "messages": [ - { - "role": "user", - "content": { - "type": "text", - "text": "What is the capital of France?" - } - } - ], - "modelPreferences": { - "hints": [ - { - "name": "claude-3-sonnet" - } - ], - "intelligencePriority": 0.8, - "speedPriority": 0.5 - }, - "systemPrompt": "You are a helpful assistant.", - "maxTokens": 100 - } - } - """), - .request(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "role": "assistant", - "content": { - "type": "text", - "text": "The capital of France is Paris." - }, - "model": "claude-3-sonnet-20240307", - "stopReason": "endTurn" - } - } - """), - ]) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/ListPromptTests.swift b/MCPClient/Tests/MCPConnection/ListPromptTests.swift deleted file mode 100644 index 9f69afd..0000000 --- a/MCPClient/Tests/MCPConnection/ListPromptTests.swift +++ /dev/null @@ -1,141 +0,0 @@ - -import JSONRPC -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ListPromptsTests: MCPClientConnectionTest { - - @Test("list prompts") - func test_listPrompts() async throws { - let prompts = try await assert( - executing: { - try await self.sut.listPrompts() - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "prompts/list", - "params": {} - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "prompts": [ - { - "name": "code_review", - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - } - ] - } - } - """) - #expect(prompts.map { $0.name } == ["code_review"]) - } - - @Test("list prompts with pagination") - func test_listPrompts_withPagination() async throws { - let prompts = try await assert( - executing: { try await self.sut.listPrompts() }, - triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "prompts/list", - "params" : {} - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "prompts": [ - { - "name": "code_review", - "description": "Asks the LLM to analyze code quality and suggest improvements", - "arguments": [ - { - "name": "code", - "description": "The code to review", - "required": true - } - ] - } - ], - "nextCursor": "next-page-cursor" - } - } - """), - .request(""" - { - "id" : 2, - "jsonrpc" : "2.0", - "method" : "prompts/list", - "params" : { - "cursor": "next-page-cursor" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 2, - "result": { - "prompts": [ - { - "name": "test_code", - "description": "Asks the LLM to write a unit test for the code", - "arguments": [ - { - "name": "code", - "description": "The code to test", - "required": true - } - ] - } - ] - } - } - """), - ]) - #expect(prompts.map { $0.name } == ["code_review", "test_code"]) - } - - @Test("receiving prompts list changed notification") - func test_receivingPromptsListChangedNotification() async throws { - let notificationReceived = expectation(description: "Notification received") - Task { - for await notification in await sut.notifications { - switch notification { - case .promptListChanged: - notificationReceived.fulfill() - default: - Issue.record("Unexpected notification: \(notification)") - } - } - } - - transport.receive(message: """ - { - "jsonrpc": "2.0", - "method": "notifications/prompts/list_changed" - } - """) - try await fulfillment(of: [notificationReceived]) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/ListResourceTemplates.swift b/MCPClient/Tests/MCPConnection/ListResourceTemplates.swift deleted file mode 100644 index dac74ec..0000000 --- a/MCPClient/Tests/MCPConnection/ListResourceTemplates.swift +++ /dev/null @@ -1,103 +0,0 @@ - -import JSONRPC -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ListResourceTemplatesTests: MCPClientConnectionTest { - - @Test("list resource templates") - func test_listResourceTemplates() async throws { - let resources = try await assert( - executing: { - try await self.sut.listResourceTemplates() - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/templates/list", - "params": {} - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "resourceTemplates": [ - { - "uriTemplate": "file:///{path}", - "name": "Project Files", - "description": "Access files in the project directory", - "mimeType": "application/octet-stream" - } - ] - } - } - """) - #expect(resources.map { $0.uriTemplate } == ["file:///{path}"]) - } - - @Test("list resource templates with pagination") - func test_listResourceTemplates_withPagination() async throws { - let resources = try await assert( - executing: { try await self.sut.listResourceTemplates() }, - triggers: [ - .request(""" - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/templates/list", - "params": {} - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "resourceTemplates": [ - { - "uriTemplate": "file:///{path}", - "name": "Project Files", - "description": "Access files in the project directory", - "mimeType": "application/octet-stream" - } - ], - "nextCursor": "next-page-cursor" - } - } - """), - .request(""" - { - "id" : 2, - "jsonrpc" : "2.0", - "method" : "resources/templates/list", - "params" : { - "cursor": "next-page-cursor" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 2, - "result": { - "resourceTemplates": [ - { - "uriTemplate": "images:///{path}", - "name": "Project Images", - "description": "Access images in the project directory", - "mimeType": "image/jpeg" - } - ] - } - } - """), - ]) - #expect(resources.map { $0.uriTemplate } == ["file:///{path}", "images:///{path}"]) - } - - } -} diff --git a/MCPClient/Tests/MCPConnection/ListResourcesTest.swift b/MCPClient/Tests/MCPConnection/ListResourcesTest.swift deleted file mode 100644 index 6dedb7c..0000000 --- a/MCPClient/Tests/MCPConnection/ListResourcesTest.swift +++ /dev/null @@ -1,126 +0,0 @@ - -import JSONRPC -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ListResourcesTests: MCPClientConnectionTest { - - @Test("list resources") - func test_listResources() async throws { - let resources = try await assert( - executing: { - try await self.sut.listResources() - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/list", - "params": {} - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "resources": [ - { - "uri": "file:///project/src/main.rs", - "name": "main.rs", - "description": "Primary application entry point", - "mimeType": "text/x-rust" - } - ] - } - } - """) - #expect(resources.map { $0.name } == ["main.rs"]) - } - - @Test("list resources with pagination") - func test_listResources_withPagination() async throws { - let resources = try await assert( - executing: { try await self.sut.listResources() }, - triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "resources/list", - "params" : {} - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "resources": [ - { - "uri": "file:///project/src/main.rs", - "name": "main.rs", - "description": "Primary application entry point", - "mimeType": "text/x-rust" - } - ], - "nextCursor": "next-page-cursor" - } - } - """), - .request(""" - { - "id" : 2, - "jsonrpc" : "2.0", - "method" : "resources/list", - "params" : { - "cursor": "next-page-cursor" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 2, - "result": { - "resources": [ - { - "uri": "file:///project/src/utils.rs", - "name": "utils.rs", - "description": "Some utils functions application entry point", - "mimeType": "text/x-rust" - } - ] - } - } - """), - ]) - #expect(resources.map { $0.name } == ["main.rs", "utils.rs"]) - } - - @Test("receiving resources list changed notification") - func test_receivingResourcesListChangedNotification() async throws { - let notificationReceived = expectation(description: "Notification received") - Task { - for await notification in await sut.notifications { - switch notification { - case .resourceListChanged: - notificationReceived.fulfill() - default: - Issue.record("Unexpected notification: \(notification)") - } - } - } - - transport.receive(message: """ - { - "jsonrpc": "2.0", - "method": "notifications/resources/list_changed" - } - """) - try await fulfillment(of: [notificationReceived]) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/ListToolsTests.swift b/MCPClient/Tests/MCPConnection/ListToolsTests.swift deleted file mode 100644 index 8cc20e0..0000000 --- a/MCPClient/Tests/MCPConnection/ListToolsTests.swift +++ /dev/null @@ -1,153 +0,0 @@ - -import JSONRPC -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ListToolsTests: MCPClientConnectionTest { - - @Test("list tools") - func test_listTools() async throws { - let tools = try await assert( - executing: { - try await self.sut.listTools() - }, - sends: """ - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "tools/list", - "params" : {} - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "tools": [ - { - "name": "get_weather", - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - } - ], - "nextCursor": null - } - } - """) - #expect(tools.map { $0.name } == ["get_weather"]) - } - - @Test("list tools with pagination") - func test_listTools_withPagination() async throws { - let tools = try await assert( - executing: { try await self.sut.listTools() }, - triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "tools/list", - "params" : {} - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "tools": [ - { - "name": "get_weather", - "description": "Get current weather information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - } - ], - "nextCursor": "next-page-cursor" - } - } - """), - .request(""" - { - "id" : 2, - "jsonrpc" : "2.0", - "method" : "tools/list", - "params" : { - "cursor": "next-page-cursor" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 2, - "result": { - "tools": [ - { - "name": "get_time", - "description": "Get current time information for a location", - "inputSchema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City name or zip code" - } - }, - "required": ["location"] - } - } - ], - "nextCursor": null - } - } - """), - ]) - - #expect(tools.map { $0.name } == ["get_weather", "get_time"]) - } - - @Test("receiving list tool changed notification") - func test_receivingListToolChangedNotification() async throws { - let notificationReceived = expectation(description: "Notification received") - Task { - for await notification in await sut.notifications { - switch notification { - case .toolListChanged: - notificationReceived.fulfill() - default: - Issue.record("Unexpected notification: \(notification)") - } - } - } - - transport.receive(message: """ - { - "jsonrpc": "2.0", - "method": "notifications/tools/list_changed" - } - """) - try await fulfillment(of: [notificationReceived]) - } - } -} diff --git a/MCPClient/Tests/MCPConnection/LoggingTests.swift b/MCPClient/Tests/MCPConnection/LoggingTests.swift deleted file mode 100644 index 71c9037..0000000 --- a/MCPClient/Tests/MCPConnection/LoggingTests.swift +++ /dev/null @@ -1,97 +0,0 @@ - -import JSONRPC -import MCPShared -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class LoggingTests: MCPClientConnectionTest { - @Test("setting log level") - func test_settingLogLevel() async throws { - _ = try await assert( - executing: { - try await self.sut.setLogLevel(SetLevelRequest.Params(level: .debug)) - }, - sends: """ - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "logging/setLevel", - "params" : { - "level" : "debug" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": {} - } - """) - } - - @Test("send a log") - func test_sendALog() async throws { - try await assert( - executing: { - try await self.sut.log(LoggingMessageNotification.Params(level: .debug, data: .string("Up and running!"))) - }, - triggers: [.request(""" - { - "jsonrpc" : "2.0", - "method" : "notifications/message", - "params" : { - "data" : "Up and running!", - "level" : "debug" - } - } - """)]) - } - - @Test("receive a server log notification") - func test_receivesServerLogNotification() async throws { - let notificationReceived = expectation(description: "Notification received") - Task { - for await notification in await sut.notifications { - switch notification { - case .loggingMessage(let message): - #expect(message.level == .error) - #expect(message.data == .object([ - "error": .string("Connection failed"), - "details": .object([ - "host": .string("localhost"), - "port": .number(5432), - ]), - ])) - notificationReceived.fulfill() - - default: - Issue.record("Unexpected notification: \(notification)") - } - } - } - - transport.receive(message: """ - { - "jsonrpc": "2.0", - "method": "notifications/message", - "params": { - "level": "error", - "logger": "database", - "data": { - "error": "Connection failed", - "details": { - "host": "localhost", - "port": 5432 - } - } - } - } - """) - try await fulfillment(of: [notificationReceived]) - } - - } -} diff --git a/MCPClient/Tests/MCPConnection/MCPConnectionTests.swift b/MCPClient/Tests/MCPConnection/MCPConnectionTests.swift deleted file mode 100644 index 8798aee..0000000 --- a/MCPClient/Tests/MCPConnection/MCPConnectionTests.swift +++ /dev/null @@ -1,169 +0,0 @@ - -import MCPShared -import SwiftTestingUtils -import Testing -@testable import MCPClient - -// MARK: - MCPClientConnectionTestSuite - -/// All the tests about `MCPClientConnection` -@Suite("MCP Connection") -class MCPClientConnectionTestSuite { } - -// MARK: - MCPClientConnectionTest - -/// A parent test class that provides a few util functions to assert that the interactions with the transport are as expected. -class MCPClientConnectionTest { - - // MARK: Lifecycle - - init() { - transport = MockTransport() - clientCapabilities = ClientCapabilities( - roots: .init(listChanged: true), - sampling: .init()) - sut = try! MCPClientConnection( - info: .init(name: "TestClient", version: "1.0.0"), - capabilities: clientCapabilities, - transport: transport.dataChannel) - } - - // MARK: Internal - - var transport: MockTransport - let clientCapabilities: ClientCapabilities - var sut: MCPClientConnection - - /// Asserts that the given task sends the expected requests and receives the expected responses. - /// - Parameters: - /// - task: The task to execute. - /// - messages: The sequence of messages relevant to the task. All responses are dequeued as soon as possible, and each request is awaited for until continuing to dequeue messages. - /// - transport: The transport to use. - static func assert( - executing task: @escaping () async throws -> Result, - triggers messages: [Message], - with transport: MockTransport) - async throws -> Result - { - var result: Result? = nil - var err: Error? = nil - - /// The next message that the system is expected to send. - var nextMessageToSent: (exp: SwiftTestingUtils.Expectation, message: String)? - - transport.sendMessage = { data in - if let (exp, message) = nextMessageToSent { - assertEqual(received: data, expected: message) - exp.fulfill() - } else { - Issue.record("Unexpected message sent: \(String(data: data, encoding: .utf8) ?? "Invalid data")") - } - } - - var i = 0 - let prepareNextExpectedMessage = { - if let request = messages[i..( - executing task: @escaping () async throws -> Result, - sends request: String, - receives response: String) - async throws -> Result - { - try await assert(executing: task, triggers: [ - .request(request), - .response(response), - ]) - } - - /// Asserts that the given task sends the expected requests and receives the expected responses. - /// - Parameters: - /// - task: The task to execute. - /// - messages: The sequence of messages relevant to the task. All responses are dequeued as soon as possible, and each request is awaited for until continuing to dequeue messages. - func assert( - executing task: @escaping () async throws -> Result, - triggers messages: [Message]) - async throws -> Result - { - try await Self.assert(executing: task, triggers: messages, with: transport) - } - - func assert( - executing task: @escaping () async throws -> Result, - sends request: String, - receives response: String, - andFailsWith errorHandler: (Error) -> Void) - async - { - do { - _ = try await assert(executing: task, sends: request, receives: response) - Issue.record("Expected the task to fail") - } catch { - // Expected - errorHandler(error) - } - } - -} diff --git a/MCPClient/Tests/MCPConnection/MockTransport.swift b/MCPClient/Tests/MCPConnection/MockTransport.swift deleted file mode 100644 index 0d91df4..0000000 --- a/MCPClient/Tests/MCPConnection/MockTransport.swift +++ /dev/null @@ -1,39 +0,0 @@ - -import Foundation -import JSONRPC -@testable import MCPClient - -final class MockTransport { - - // MARK: Lifecycle - - init() { - let dataSequence = AsyncStream() { continuation in - self.continuation = continuation - } - - dataChannel = DataChannel( - writeHandler: { [weak self] data in self?.handleWrite(data: data) }, - dataSequence: dataSequence) - } - - // MARK: Internal - - private(set) var dataChannel: DataChannel = .noop - - var sendMessage: (Data) -> Void = { _ in } - - func receive(message: String) { - let data = Data(message.utf8) - continuation?.yield(data) - } - - // MARK: Private - - private var continuation: AsyncStream.Continuation? - - private func handleWrite(data: Data) { - sendMessage(data) - } - -} diff --git a/MCPClient/Tests/MCPConnection/PingTests.swift b/MCPClient/Tests/MCPConnection/PingTests.swift deleted file mode 100644 index 268eb92..0000000 --- a/MCPClient/Tests/MCPConnection/PingTests.swift +++ /dev/null @@ -1,52 +0,0 @@ - -import JSONRPC -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class PingTests: MCPClientConnectionTest { - - @Test("sending ping") - func sendingPing() async throws { - try await assert( - executing: { - try await self.sut.ping() - }, - sends: """ - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "ping", - "params" : null - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": {} - } - """) - } - - @Test("receiving ping") - func receivingPing() async throws { - try await assert( - receiving: """ - { - "jsonrpc" : "2.0", - "id" : 1, - "method" : "ping" - } - """, - respondsWith: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": {} - } - """) - } - - } -} diff --git a/MCPClient/Tests/MCPConnection/ReadResourceTests.swift b/MCPClient/Tests/MCPConnection/ReadResourceTests.swift deleted file mode 100644 index 9097aac..0000000 --- a/MCPClient/Tests/MCPConnection/ReadResourceTests.swift +++ /dev/null @@ -1,203 +0,0 @@ - -import JSONRPC -import SwiftTestingUtils -import Testing -@testable import MCPClient - -extension MCPClientConnectionTestSuite { - final class ReadResourceTests: MCPClientConnectionTest { - - @Test("read one resource") - func test_readOneResource() async throws { - let resources = try await assert( - executing: { - try await self.sut.readResource(.init(uri: "file:///project/src/main.rs")) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/read", - "params": { - "uri": "file:///project/src/main.rs" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "contents": [ - { - "uri": "file:///project/src/main.rs", - "mimeType": "text/x-rust", - "text": "fn main() {\\n println!(\\"Hello world!\\");\\n}" - } - ] - } - } - """) - #expect(resources.contents.map { $0.text?.text } == ["fn main() {\n println!(\"Hello world!\");\n}"]) - } - - @Test("read resources of different types") - func test_readResourcesOfDifferentTypes() async throws { - let resources = try await assert( - executing: { - try await self.sut.readResource(.init(uri: "file:///project/src/*")) - }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/read", - "params": { - "uri": "file:///project/src/*" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "result": { - "contents": [ - { - "uri": "file:///project/src/main.rs", - "mimeType": "text/x-rust", - "text": "fn main() {\\n println!(\\"Hello world!\\");\\n}" - }, - { - "uri": "file:///project/src/main.rs", - "mimeType": "image/jpeg", - "blob": "base64-encoded-image-data" - } - ] - } - } - """) - #expect(resources.contents.map { $0.text?.text } == ["fn main() {\n println!(\"Hello world!\");\n}", nil]) - #expect(resources.contents.map { $0.blob?.mimeType } == [nil, "image/jpeg"]) - } - - @Test("error when reading resource") - func test_errorWhenReadingResource() async throws { - await assert( - executing: { try await self.sut.readResource(.init(uri: "file:///nonexistent.txt")) }, - sends: """ - { - "jsonrpc": "2.0", - "id": 1, - "method": "resources/read", - "params": { - "uri": "file:///nonexistent.txt" - } - } - """, - receives: """ - { - "jsonrpc": "2.0", - "id": 1, - "error": { - "code": -32002, - "message": "Resource not found", - "data": { - "uri": "file:///nonexistent.txt" - } - } - } - """, - andFailsWith: { error in - guard let error = error as? JSONRPCResponseError else { - Issue.record("Unexpected error type: \(error)") - return - } - - #expect(error.code == -32002) - #expect(error.message == "Resource not found") - #expect(error.data == .hash([ - "uri": .string("file:///nonexistent.txt"), - ])) - }) - } - - @Test("subscribing to resource updates") - func test_subscribingToResourceUpdates() async throws { - try await assert( - executing: { try await self.sut.subscribeToUpdateToResource(.init(uri: "file:///project/src/main.rs")) }, - triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "resources/subscribe", - "params" : { - "uri" : "file:///project/src/main.rs" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": {} - } - """), - ]) - } - - @Test("unsubscribing to resource updates") - func test_unsubscribingToResourceUpdates() async throws { - try await assert( - executing: { try await self.sut.unsubscribeToUpdateToResource(.init(uri: "file:///project/src/main.rs")) }, - triggers: [ - .request(""" - { - "id" : 1, - "jsonrpc" : "2.0", - "method" : "resources/unsubscribe", - "params" : { - "uri" : "file:///project/src/main.rs" - } - } - """), - .response(""" - { - "jsonrpc": "2.0", - "id": 1, - "result": {} - } - """), - ]) - } - - @Test("receiving resource update notification") - func test_receivingResourceUpdateNotification() async throws { - let notificationReceived = expectation(description: "Notification received") - Task { - for await notification in await self.sut.notifications { - switch notification { - case .resourceUpdated(let updateNotification): - #expect(updateNotification.uri == "file:///project/src/main.rs") - notificationReceived.fulfill() - - default: - Issue.record("Unexpected notification: \(notification)") - } - } - } - transport.receive(message: """ - { - "jsonrpc": "2.0", - "method": "notifications/resources/updated", - "params": { - "uri": "file:///project/src/main.rs" - } - } - """) - try await fulfillment(of: [notificationReceived]) - } - - } -} diff --git a/MCPClient/Tests/MCPConnection/TestUtils.swift b/MCPClient/Tests/MCPConnection/TestUtils.swift deleted file mode 100644 index 30e2754..0000000 --- a/MCPClient/Tests/MCPConnection/TestUtils.swift +++ /dev/null @@ -1,78 +0,0 @@ - -import Foundation -import Testing -@testable import MCPClient - -/// Asserts that the received JSON is equal to the expected JSON, allowing for any order of keys or spacing. -func assertEqual(received jsonData: Data, expected: String) { - do { - let received = try JSONSerialization.jsonObject(with: jsonData) - let receivedPrettyPrinted = try JSONSerialization.data(withJSONObject: received, options: [.sortedKeys, .prettyPrinted]) - - let expected = try JSONSerialization.jsonObject(with: expected.data(using: .utf8)!) - let expectedPrettyPrinted = try JSONSerialization.data(withJSONObject: expected, options: [.sortedKeys, .prettyPrinted]) - - #expect(String(data: receivedPrettyPrinted, encoding: .utf8)! == String(data: expectedPrettyPrinted, encoding: .utf8)!) - } catch { - Issue.record("Failed to compare JSON: \(error)") - } -} - -extension MockTransport { - - /// Expects the given messages to be sent. - /// Examples: - /// expect([ - /// "{ \"jsonrpc\": \"2.0\", \"id\": 1, \"result\": null }", - /// ]) - func expect(messages: [String]) { - expect(messages: messages.map { m in { $0(m) } }) - } - - /// Expects the given messages to be sent, calling the corresponding closure when needed. - /// Examples: - /// expect([ - /// { - /// firstMessageReceived.fulfill() - /// return "{ \"jsonrpc\": \"2.0\", \"id\": 1, \"result\": null }" - /// }, - /// ]) - func expect(messages: [((String) -> Void) -> Void]) { - var messagesCount = 0 - - sendMessage = { message in - defer { messagesCount += 1 } - guard messagesCount < messages.count else { - Issue.record(""" - Too many messages sent. Expected \(messages.count). Last message received: - \(String(data: message, encoding: .utf8) ?? "Invalid data") - """) - return - } - messages[messagesCount]() { expected in - assertEqual(received: message, expected: expected) - } - } - } -} - -// MARK: - TestError - -enum TestError: Error { - case expectationUnfulfilled - case internalError -} - -// MARK: - Message - -enum Message { - case request(_ value: String) - case response(_ value: String) - - var request: String? { - if case .request(let value) = self { - return value - } - return nil - } -} diff --git a/MCPClient/Tests/MCPClient/ServerRequests.swift b/MCPClient/Tests/ServerRequests.swift similarity index 93% rename from MCPClient/Tests/MCPClient/ServerRequests.swift rename to MCPClient/Tests/ServerRequests.swift index c25b99d..efaa211 100644 --- a/MCPClient/Tests/MCPClient/ServerRequests.swift +++ b/MCPClient/Tests/ServerRequests.swift @@ -1,6 +1,6 @@ import JSONRPC -import MCPShared +import MCPInterface import SwiftTestingUtils import Testing @testable import MCPClient @@ -13,7 +13,7 @@ extension MCPClientTestSuite { func test_listRootsIsSuccessfulWhenHandled() async throws { let expectation = expectation(description: "The request is responded to") let root = Root(uri: "//root", name: "root") - let listRootsRequestHandler: ListRootsRequestHandler = { _ in + let listRootsRequestHandler: ListRootsRequest.Handler = { _ in .init(roots: [root]) } @@ -48,7 +48,7 @@ extension MCPClientTestSuite { Issue.record("unexpected success") case .failure(let error): - #expect(error.message == "Listing roots is not supported by this client") + #expect(error.message == "Listing roots is not supported by this server") } expectation.fulfill() })) diff --git a/MCPInterface/Sources/Interfaces.swift b/MCPInterface/Sources/Interfaces.swift new file mode 100644 index 0000000..78dc9f6 --- /dev/null +++ b/MCPInterface/Sources/Interfaces.swift @@ -0,0 +1,66 @@ +import JSONRPC +import MemberwiseInit + +public typealias Transport = DataChannel + +extension Transport { + public static var noop: Transport { + .init(writeHandler: { _ in }, dataSequence: DataSequence { _ in }) + } +} + +public typealias AnyJRPCResponse = Swift.Result + +// MARK: - CapabilityHandler + +/// Describes a capability of a client/server (see `ClientCapabilities` and `ServerCapabilities`), as well as how it is handled. +@MemberwiseInit(.public, _optionalsDefaultNil: true) +public struct CapabilityHandler { + public let info: Info + public let handler: Handler +} + +extension CapabilityHandler where Info == EmptyObject { + public init(handler: Handler) { + self.init(info: .init(), handler: handler) + } +} + +// MARK: - CapabilityStatus + +/// Describes whether a given capability is supported by the other peer, +/// and if so provide details about which functionalities (subscription, listing changes) are supported. +public enum CapabilityStatus: Equatable { + case supported(_ capability: Capability) + case notSupported +} + +extension CapabilityStatus { + public var capability: Capability? { + switch self { + case .supported(let capability): + capability + case .notSupported: + nil + } + } + + public func get() throws -> Capability { + switch self { + case .supported(let capability): + return capability + case .notSupported: + throw MCPError.notSupported + } + } +} + +// MARK: - MCPError + +public enum MCPError: Error { + case notSupported +} + +public typealias HandleServerRequest = (ServerRequest, (AnyJRPCResponse) -> Void) + +public typealias HandleClientRequest = (ClientRequest, (AnyJRPCResponse) -> Void) diff --git a/MCPInterface/Sources/JRPC+helpers.swift b/MCPInterface/Sources/JRPC+helpers.swift new file mode 100644 index 0000000..9e2b784 --- /dev/null +++ b/MCPInterface/Sources/JRPC+helpers.swift @@ -0,0 +1,33 @@ +import JSONRPC + +extension JSONRPCSession { + public func send(_ request: Req) async throws -> Req.Result { + let response: JSONRPCResponse = try await sendRequest(request.params, method: request.method) + return try response.content.get() + } + + public func send(_ notification: some MCPInterface.Notification) async throws { + try await sendNotification(notification.params, method: notification.method) + } + + public func send( + _ params: Req.Params, + getResults: (Req.Result) -> [Result], + req _: Req.Type = Req.self) + async throws -> [Result] + { + var cursor: String? = nil + var results = [Result]() + + while true { + let request = Req(params: Req.Params.updating(cursor: cursor, from: params)) + let response = try await send(request) + results.append(contentsOf: getResults(response)) + cursor = response.nextCursor + if cursor == nil { + return results + } + } + } + +} diff --git a/MCPInterface/Sources/MCPConnection.swift b/MCPInterface/Sources/MCPConnection.swift new file mode 100644 index 0000000..2379c74 --- /dev/null +++ b/MCPInterface/Sources/MCPConnection.swift @@ -0,0 +1,110 @@ +import Foundation +import JSONRPC + +public typealias HandleRequest = (Request, (AnyJRPCResponse) -> Void) + +// MARK: - MCPConnection + +/// Note: this class is not thread safe, and should be used in a thread safe context (like within an actor). +package class MCPConnection { + + // MARK: Lifecycle + + public init( + transport: Transport) + throws + { + jrpcSession = JSONRPCSession(channel: transport) + + var sendNotificationToStream: (Notification) -> Void = { _ in } + notifications = AsyncStream() { continuation in + sendNotificationToStream = { continuation.yield($0) } + } + self.sendNotificationToStream = sendNotificationToStream + + // A bit hard to read... When a request is received (sent to `askForRequestToBeHandle`), we yield it to the stream + // where we expect someone to be listening and handling the request. The handler then calls the completion `requestContinuation` + // which will be sent back as an async response to `askForRequestToBeHandle`. + var askForRequestToBeHandle: ((Request) async -> AnyJRPCResponse)? = nil + requestsToHandle = AsyncStream>() { streamContinuation in + askForRequestToBeHandle = { request in + await withCheckedContinuation { (requestContinuation: CheckedContinuation) in + streamContinuation.yield((request, { response in + requestContinuation.resume(returning: response) + })) + } + } + } + self.askForRequestToBeHandle = askForRequestToBeHandle + + Task { await listenToIncomingMessages() } + } + + // MARK: Public + + public private(set) var notifications: AsyncStream + public private(set) var requestsToHandle: AsyncStream> + + // MARK: Package + + package let jrpcSession: JSONRPCSession + + // MARK: Private + + private var sendNotificationToStream: ((Notification) -> Void) = { _ in } + + private var askForRequestToBeHandle: ((Request) async -> AnyJRPCResponse)? = nil + + private var eventHandlers = [String: (JSONRPCEvent) -> Void]() + + private func listenToIncomingMessages() async { + let events = await jrpcSession.eventSequence + Task { [weak self] in + for await event in events { + self?.handle(receptionOf: event) + } + } + } + + private func handle(receptionOf event: JSONRPCEvent) { + switch event { + case .notification(_, let data): + do { + let notification = try JSONDecoder().decode(Notification.self, from: data) + sendNotificationToStream(notification) + } catch { + mcpLogger + .error("Failed to decode notification \(String(data: data, encoding: .utf8) ?? "invalid data", privacy: .public)") + } + + case .request(_, let handler, let data): + // Respond to ping from the other side + Task { await handler(handle(receptionOf: data)) } + + case .error(let error): + mcpLogger.error("Received error: \(error, privacy: .public)") + } + } + + private func handle(receptionOf request: Data) async -> AnyJRPCResponse { + if let decodedRequest = try? JSONDecoder().decode(Request.self, from: request) { + guard let askForRequestToBeHandle else { + mcpLogger.error("Unable to handle request. The MCP connection has not been set properly") + return .failure(.init( + code: JRPCErrorCodes.methodNotFound.rawValue, + message: "Unable to handle request. The MCP connection has not been set properly")) + } + return await askForRequestToBeHandle(decodedRequest) + } else if (try? JSONDecoder().decode(PingRequest.self, from: request)) != nil { + // Respond to ping + return .success(PingRequest.Result()) + } + mcpLogger + .error( + "Received unknown request: \(String(data: request, encoding: .utf8) ?? "invalid data", privacy: .public)") + return .failure(.init( + code: JRPCErrorCodes.methodNotFound.rawValue, + message: "The request could not be decoded to a known type")) + } + +} diff --git a/MCPClient/Sources/ReadOnlyCurrentValueSubject.swift b/MCPInterface/Sources/ReadOnlyCurrentValueSubject.swift similarity index 100% rename from MCPClient/Sources/ReadOnlyCurrentValueSubject.swift rename to MCPInterface/Sources/ReadOnlyCurrentValueSubject.swift diff --git a/MCPInterface/Sources/logger.swift b/MCPInterface/Sources/logger.swift new file mode 100644 index 0000000..ebbdc83 --- /dev/null +++ b/MCPInterface/Sources/logger.swift @@ -0,0 +1,3 @@ +import OSLog + +package let mcpLogger = Logger(subsystem: Bundle.main.bundleIdentifier.map { "\($0).mcp" } ?? "com.app.mcp", category: "mcp") diff --git a/MCPShared/Sources/mcp_interfaces/Constants.swift b/MCPInterface/Sources/mcp_interfaces/Constants.swift similarity index 86% rename from MCPShared/Sources/mcp_interfaces/Constants.swift rename to MCPInterface/Sources/mcp_interfaces/Constants.swift index 04a68e6..799d918 100644 --- a/MCPShared/Sources/mcp_interfaces/Constants.swift +++ b/MCPInterface/Sources/mcp_interfaces/Constants.swift @@ -42,3 +42,13 @@ enum Requests { static var setLoggingLevel: String { "logging/setLevel" } static var autocomplete: String { "completion/complete" } } + +// MARK: - ResourceTypes + +enum ResourceTypes { + static var resource: String { "resource" } + static var text: String { "text" } + static var image: String { "image" } + static var resourceReference: String { "ref/resource" } + static var promptReference: String { "ref/prompt" } +} diff --git a/MCPShared/Sources/mcp_interfaces/EmptyObject.swift b/MCPInterface/Sources/mcp_interfaces/EmptyObject.swift similarity index 100% rename from MCPShared/Sources/mcp_interfaces/EmptyObject.swift rename to MCPInterface/Sources/mcp_interfaces/EmptyObject.swift diff --git a/MCPShared/Sources/mcp_interfaces/Interface+extensions.swift b/MCPInterface/Sources/mcp_interfaces/Interface+extensions.swift similarity index 58% rename from MCPShared/Sources/mcp_interfaces/Interface+extensions.swift rename to MCPInterface/Sources/mcp_interfaces/Interface+extensions.swift index 18bad04..6889fac 100644 --- a/MCPShared/Sources/mcp_interfaces/Interface+extensions.swift +++ b/MCPInterface/Sources/mcp_interfaces/Interface+extensions.swift @@ -1,3 +1,4 @@ +import Foundation // MARK: AnyMeta + Codable extension AnyMeta { @@ -82,7 +83,7 @@ extension AnyParams { if values.isEmpty { value = nil } else { - value = JSON.object(values) + value = values } } @@ -100,9 +101,42 @@ extension AnyParams { } -// MARK: AnyParamsWithProgressToken + Encodable +// MARK: AnyParamsWithProgressToken + Codable extension AnyParamsWithProgressToken { + + // MARK: Lifecycle + + public init(from decoder: Decoder) throws { + let hash = try JSON(from: decoder) + switch hash { + case .array: + throw DecodingError.dataCorruptedError(in: try decoder.unkeyedContainer(), debugDescription: "Unexpected array") + case .object(var object): + if case .object(let metaHash) = object["_meta"] { + let data = try JSONEncoder().encode(metaHash) + _meta = try JSONDecoder().decode(MetaProgress.self, from: data) + } else { + _meta = nil + } + object.removeValue(forKey: "_meta") + if object.isEmpty { + value = nil + } else { + let data = try JSONEncoder().encode(object) + let json = try JSONDecoder().decode(JSON.self, from: data) + switch json { + case .array: + throw DecodingError.dataCorruptedError(in: try decoder.unkeyedContainer(), debugDescription: "Unexpected array") + case .object(let val): + value = val + } + } + } + } + + // MARK: Public + public func encode(to encoder: Encoder) throws { var container = encoder.container(keyedBy: String.self) if let _meta { @@ -114,8 +148,11 @@ extension AnyParamsWithProgressToken { } } -// MARK: TextContentOrImageContentOrEmbeddedResource + Decodable +// MARK: TextContentOrImageContentOrEmbeddedResource + Codable extension TextContentOrImageContentOrEmbeddedResource { + + // MARK: Lifecycle + public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: String.self) let type = try container.decode(String.self, forKey: "type") @@ -130,6 +167,19 @@ extension TextContentOrImageContentOrEmbeddedResource { throw DecodingError.dataCorruptedError(forKey: "type", in: container, debugDescription: "Invalid content. Got type \(type)") } } + + // MARK: Public + + public func encode(to encoder: any Encoder) throws { + switch self { + case .text(let value): + try value.encode(to: encoder) + case .image(let value): + try value.encode(to: encoder) + case .embeddedResource(let value): + try value.encode(to: encoder) + } + } } extension TextContentOrImageContentOrEmbeddedResource { @@ -201,8 +251,11 @@ extension TextOrImageContent { } } -// MARK: TextOrBlobResourceContents + Decodable +// MARK: TextOrBlobResourceContents + Codable extension TextOrBlobResourceContents { + + // MARK: Lifecycle + public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: String.self) if container.contains("text") { @@ -213,6 +266,17 @@ extension TextOrBlobResourceContents { self = .blob(try .init(from: decoder)) } } + + // MARK: Public + + public func encode(to encoder: Encoder) throws { + switch self { + case .text(let value): + try value.encode(to: encoder) + case .blob(let value): + try value.encode(to: encoder) + } + } } extension TextOrBlobResourceContents { @@ -231,26 +295,23 @@ extension TextOrBlobResourceContents { } } -// MARK: CallToolResult + Decodable -extension CallToolResult { - public init(from decoder: any Decoder) throws { - let container = try decoder.container(keyedBy: String.self) -// let isError = try container.decodeIfPresent(Bool.self, forKey: "isError") - isError = try container.decodeIfPresent(Bool.self, forKey: "isError") -// if isError == true { -// // If the response is an error, decode the messages using the error format -// let errors = try container.decode([ExecutionError].self, forKey: "content") -// throw MCPClientError.toolCallError(executionErrors: errors) -// } +// MARK: PromptOrResourceReference + Codable +extension PromptOrResourceReference { - content = try container.decode([TextContentOrImageContentOrEmbeddedResource].self, forKey: "content") - _meta = try container.decodeIfPresent(AnyMeta.self, forKey: "_meta") -// self.isError = false + // MARK: Lifecycle + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: String.self) + let type = try container.decode(String.self, forKey: "type") + if type == "ref/prompt" { + self = .prompt(try .init(from: decoder)) + } else { + self = .resource(try .init(from: decoder)) + } } -} -// MARK: PromptOrResourceReference + Encodable -extension PromptOrResourceReference { + // MARK: Public + public func encode(to encoder: Encoder) throws { switch self { case .prompt(let value): @@ -259,10 +320,68 @@ extension PromptOrResourceReference { try value.encode(to: encoder) } } + } -// MARK: ServerNotification + Decodable -extension ServerNotification { +extension PromptOrResourceReference { + public var prompt: PromptReference? { + guard case .prompt(let value) = self else { + return nil + } + return value + } + + public var resource: ResourceReference? { + guard case .resource(let value) = self else { + return nil + } + return value + } +} + +// MARK: ClientRequest + Decodable +extension ClientRequest { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: String.self) + let method = try container.decode(String.self, forKey: "method") + + switch method { + case Requests.initialize: + self = .initialize(try InitializeRequest(from: decoder).params) + case Requests.listPrompts: + self = .listPrompts(try ListPromptsRequest(from: decoder).params) + case Requests.getPrompt: + self = .getPrompt(try GetPromptRequest(from: decoder).params) + case Requests.listResources: + self = .listResources(try ListResourcesRequest(from: decoder).params) + case Requests.readResource: + self = .readResource(try ReadResourceRequest(from: decoder).params) + case Requests.subscribeToResource: + self = .subscribeToResource(try SubscribeRequest(from: decoder).params) + case Requests.unsubscribeToResource: + self = .unsubscribeToResource(try UnsubscribeRequest(from: decoder).params) + case Requests.listResourceTemplates: + self = .listResourceTemplates(try ListResourceTemplatesRequest(from: decoder).params) + case Requests.listTools: + self = .listTools(try ListToolsRequest(from: decoder).params) + case Requests.callTool: + self = .callTool(try CallToolRequest(from: decoder).params) + case Requests.autocomplete: + self = .complete(try CompleteRequest(from: decoder).params) + case Requests.setLoggingLevel: + self = .setLogLevel(try SetLevelRequest(from: decoder).params) + default: + throw DecodingError.dataCorruptedError( + forKey: "method", + in: container, + debugDescription: "Invalid client request. Got method \(method)") + } + } +} + +// MARK: ClientNotification + Decodable + +extension ClientNotification { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: String.self) let method = try container.decode(String.self, forKey: "method") @@ -271,34 +390,58 @@ extension ServerNotification { self = .cancelled(try CancelledNotification(from: decoder).params) case Notifications.progress: self = .progress(try ProgressNotification(from: decoder).params) - case Notifications.loggingMessage: - self = .loggingMessage(try LoggingMessageNotification(from: decoder).params) - case Notifications.resourceUpdated: - self = .resourceUpdated(try ResourceUpdatedNotification(from: decoder).params) - case Notifications.resourceListChanged: - self = .resourceListChanged(try ResourceListChangedNotification(from: decoder).params ?? .init()) - case Notifications.toolListChanged: - self = .toolListChanged(try ToolListChangedNotification(from: decoder).params ?? .init()) - case Notifications.promptListChanged: - self = .promptListChanged(try PromptListChangedNotification(from: decoder).params ?? .init()) + case Notifications.initialized: + self = .initialized(try InitializedNotification(from: decoder).params ?? .init()) + case Notifications.rootsListChanged: + self = .rootsListChanged(try RootsListChangedNotification(from: decoder).params ?? .init()) default: throw DecodingError.dataCorruptedError( forKey: "method", in: container, - debugDescription: "Invalid server notification. Got method \(method)") + debugDescription: "Invalid client notification. Got method \(method)") } } } +// MARK: ServerRequest + Decodable extension ServerRequest { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: String.self) let method = try container.decode(String.self, forKey: "method") switch method { case Requests.createMessage: - self = .createMessage(try CreateMessageRequest(from: decoder).params) + self = .createMessage(try CreateSamplingMessageRequest(from: decoder).params) case Requests.listRoots: self = .listRoots(try ListRootsRequest(from: decoder).params) + default: + throw DecodingError.dataCorruptedError( + forKey: "method", + in: container, + debugDescription: "Invalid server request. Got method \(method)") + } + } +} + +// MARK: ServerNotification + Decodable +extension ServerNotification { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: String.self) + let method = try container.decode(String.self, forKey: "method") + switch method { + case Notifications.cancelled: + self = .cancelled(try CancelledNotification(from: decoder).params) + case Notifications.progress: + self = .progress(try ProgressNotification(from: decoder).params) + case Notifications.loggingMessage: + self = .loggingMessage(try LoggingMessageNotification(from: decoder).params) + case Notifications.resourceUpdated: + self = .resourceUpdated(try ResourceUpdatedNotification(from: decoder).params) + case Notifications.resourceListChanged: + self = .resourceListChanged(try ResourceListChangedNotification(from: decoder).params ?? .init()) + case Notifications.toolListChanged: + self = .toolListChanged(try ToolListChangedNotification(from: decoder).params ?? .init()) + case Notifications.promptListChanged: + self = .promptListChanged(try PromptListChangedNotification(from: decoder).params ?? .init()) default: throw DecodingError.dataCorruptedError( forKey: "method", @@ -371,9 +514,9 @@ extension LoggingMessageNotification { public typealias CodingKeys = JRPCMessageCodingKeys } -// MARK: - CreateMessageRequest.CodingKeys +// MARK: - CreateSamplingMessageRequest.CodingKeys -extension CreateMessageRequest { +extension CreateSamplingMessageRequest { public typealias CodingKeys = JRPCMessageCodingKeys } @@ -389,6 +532,78 @@ extension RootsListChangedNotification { public typealias CodingKeys = JRPCMessageCodingKeys } +// MARK: - InitializeRequest.CodingKeys + +extension InitializeRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - ListResourcesRequest.CodingKeys + +extension ListResourcesRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - ListResourceTemplatesRequest.CodingKeys + +extension ListResourceTemplatesRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - ReadResourceRequest.CodingKeys + +extension ReadResourceRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - SubscribeRequest.CodingKeys + +extension SubscribeRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - UnsubscribeRequest.CodingKeys + +extension UnsubscribeRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - ListPromptsRequest.CodingKeys + +extension ListPromptsRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - GetPromptRequest.CodingKeys + +extension GetPromptRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - ListToolsRequest.CodingKeys + +extension ListToolsRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - CallToolRequest.CodingKeys + +extension CallToolRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - SetLevelRequest.CodingKeys + +extension SetLevelRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + +// MARK: - CompleteRequest.CodingKeys + +extension CompleteRequest { + public typealias CodingKeys = JRPCMessageCodingKeys +} + extension TextContent { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: String.self) @@ -413,3 +628,17 @@ extension EmbeddedResource { resource = try container.decode(TextOrBlobResourceContents.self, forKey: "resource") } } + +extension ResourceReference { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: String.self) + uri = try container.decode(String.self, forKey: "uri") + } +} + +extension PromptReference { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: String.self) + name = try container.decode(String.self, forKey: "name") + } +} diff --git a/MCPShared/Sources/mcp_interfaces/Interface.swift b/MCPInterface/Sources/mcp_interfaces/Interface.swift similarity index 89% rename from MCPShared/Sources/mcp_interfaces/Interface.swift rename to MCPInterface/Sources/mcp_interfaces/Interface.swift index 1137785..49dfe1a 100644 --- a/MCPShared/Sources/mcp_interfaces/Interface.swift +++ b/MCPInterface/Sources/mcp_interfaces/Interface.swift @@ -21,6 +21,12 @@ public protocol HasMetaValue { var _meta: Meta? { get } } +// MARK: - Optional + HasMetaValue + +extension Optional: HasMetaValue where Wrapped: HasMetaValue { + public var _meta: Wrapped.Meta? { self?._meta } +} + // MARK: - MetaProgress @MemberwiseInit(.public, _optionalsDefaultNil: true) @@ -43,7 +49,7 @@ public struct AnyMeta: MetaType, Codable, Equatable, Sendable { @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct AnyParams: HasMetaValue, Equatable, Codable { public let _meta: AnyMeta? - public let value: JSON? + public let value: [String: JSON.Value]? } // MARK: - AnyParamsWithProgressToken @@ -51,17 +57,19 @@ public struct AnyParams: HasMetaValue, Equatable, Codable { @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct AnyParamsWithProgressToken: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? - public let value: JSON? + public let value: [String: JSON.Value]? } // MARK: - Request /// The _meta parameter is reserved by MCP to allow clients and servers to attach additional metadata to their notifications. -public protocol Request: Encodable where Params.Meta == MetaProgress { - associatedtype Params: Encodable & HasMetaValue - associatedtype Result: Decodable +public protocol Request: Codable where Params.Meta == MetaProgress { + associatedtype Params: Codable & HasMetaValue + associatedtype Result: Codable var method: String { get } - var params: Params? { get } + var params: Params { get } + + typealias Handler = (Params) async throws -> Result } // MARK: - Notification @@ -78,7 +86,7 @@ public protocol Notification: Codable, Equatable where Params.Meta == AnyMeta { /// This little boilerplate helps make Swift happy when a type fulfills with `params: Params` a protocol that requires `params: Params?`... public protocol HasParams { - associatedtype Params: Encodable & HasMetaValue + associatedtype Params: Codable & HasMetaValue var params: Params { get } } @@ -97,7 +105,7 @@ extension Notification where Self: HasParams { // MARK: - Result /// The _meta property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses. -public protocol Result: Decodable, HasMetaValue where Meta == AnyMeta { } +public protocol Result: Codable, HasMetaValue where Meta == AnyMeta { } public typealias RequestId = StringOrNumber @@ -151,7 +159,7 @@ public struct InitializeRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well. public let protocolVersion: String @@ -206,7 +214,7 @@ public struct ListChangedCapability: Codable, Equatable { /// Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ClientCapabilities: Encodable { +public struct ClientCapabilities: Codable, Equatable { /// Experimental, non-standard capabilities that the client supports. public let experimental: JSON? /// Present if the client supports listing roots. @@ -219,7 +227,7 @@ public struct ClientCapabilities: Encodable { /// Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ServerCapabilities: Decodable, Equatable { +public struct ServerCapabilities: Codable, Equatable { // MARK: Public /// Experimental, non-standard capabilities that the server supports @@ -248,7 +256,7 @@ public struct Implementation: Codable, Equatable { /// A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct PingRequest: Request, Decodable { +public struct PingRequest: Request, Codable { public typealias Result = EmptyResult public let method = Requests.ping public let params: AnyParamsWithProgressToken? @@ -288,7 +296,17 @@ public protocol PaginationParams { /// If provided, the server should return results starting after this cursor. var cursor: Cursor? { get } - func updatingCursor(to cursor: Cursor?) -> Self + static func updating(cursor: Cursor?, from params: Self?) -> Self +} + +// MARK: - Optional + PaginationParams + +extension Optional: PaginationParams where Wrapped: PaginationParams { + public var cursor: Cursor? { self?.cursor } + + public static func updating(cursor: Cursor?, from params: Self?) -> Self { + Wrapped.updating(cursor: cursor, from: params ?? nil) + } } // MARK: - SharedPaginationParams @@ -298,8 +316,8 @@ public protocol PaginationParams { /// to make it easier to evolve the code if the protocol evolves in the future. @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct SharedPaginationParams: PaginationParams, HasMetaValue, Codable, Equatable { - public func updatingCursor(to cursor: Cursor?) -> SharedPaginationParams { - .init(_meta: _meta, cursor: cursor) + public static func updating(cursor: Cursor?, from params: SharedPaginationParams?) -> SharedPaginationParams { + .init(_meta: params?._meta, cursor: cursor) } public let _meta: MetaProgress? @@ -309,7 +327,7 @@ public struct SharedPaginationParams: PaginationParams, HasMetaValue, Codable, E // MARK: - PaginatedRequest public protocol PaginatedRequest: Request where Result: PaginatedResult, Params: PaginationParams { - var params: Params? { get } + var params: Params { get } init(params: Params) } @@ -324,10 +342,6 @@ public protocol PaginatedResult: Result { /// Sent from the client to request a list of resources the server has. @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct ListResourcesRequest: PaginatedRequest { - public init(params: SharedPaginationParams) { - self.params = params - } - public typealias Result = ListResourcesResult public let method = Requests.listResources @@ -380,7 +394,7 @@ public struct ReadResourceRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it. /// @format uri @@ -418,7 +432,7 @@ public struct SubscribeRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it. /// @format uri @@ -437,7 +451,7 @@ public struct UnsubscribeRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The URI of the resource to unsubscribe from. /// @format uri @@ -467,7 +481,7 @@ public struct ResourceUpdatedNotification: Notification, HasParams { /// A known resource that the server is capable of reading. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct Resource: Annotated, Decodable, Equatable { +public struct Resource: Annotated, Codable, Equatable { public let annotations: Annotations? /// The URI of this resource. /// @format uri @@ -486,7 +500,7 @@ public struct Resource: Annotated, Decodable, Equatable { /// A template description for resources available on the server. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ResourceTemplate: Annotated, Decodable, Equatable { +public struct ResourceTemplate: Annotated, Codable, Equatable { public let annotations: Annotations? /// A URI template (according to RFC 6570) that can be used to construct resource URIs. /// @format uri @@ -515,7 +529,7 @@ public protocol ResourceContents { // MARK: - TextResourceContents @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct TextResourceContents: Decodable, ResourceContents { +public struct TextResourceContents: Codable, Equatable, ResourceContents { /// The URI of this resource. public let uri: String /// The MIME type of this resource, if known. @@ -527,7 +541,7 @@ public struct TextResourceContents: Decodable, ResourceContents { // MARK: - BlobResourceContents @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct BlobResourceContents: Decodable, ResourceContents { +public struct BlobResourceContents: Codable, Equatable, ResourceContents { /// The URI of this resource. public let uri: String /// The MIME type of this resource, if known. @@ -542,9 +556,6 @@ public struct BlobResourceContents: Decodable, ResourceContents { /// Sent from the client to request a list of prompts and prompt templates the server has. @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct ListPromptsRequest: PaginatedRequest { - public init(params: SharedPaginationParams) { - self.params = params - } public typealias Result = ListPromptsResult @@ -573,7 +584,7 @@ public struct GetPromptRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The name of the prompt or prompt template. public let name: String @@ -597,7 +608,7 @@ public struct GetPromptResult: Result { /// A prompt or prompt template that the server offers. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct Prompt: Decodable, Equatable { +public struct Prompt: Codable, Equatable { /// The name of the prompt or prompt template. public let name: String /// An optional description of what this prompt provides @@ -610,7 +621,7 @@ public struct Prompt: Decodable, Equatable { /// Describes an argument that a prompt can accept. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct PromptArgument: Decodable, Equatable { +public struct PromptArgument: Codable, Equatable { /// The name of the argument. public let name: String /// A human-readable description of the argument. @@ -634,7 +645,7 @@ public enum Role: String, Codable { /// This is similar to `SamplingMessage`, but also supports the embedding of /// resources from the MCP server. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct PromptMessage: Decodable { +public struct PromptMessage: Codable { public let role: Role public let content: TextContentOrImageContentOrEmbeddedResource } @@ -646,9 +657,9 @@ public struct PromptMessage: Decodable { /// It is up to the client how best to render embedded resources for the benefit /// of the LLM and/or the user. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct EmbeddedResource: Annotated, Decodable { +public struct EmbeddedResource: Annotated, Codable, Equatable { public let annotations: Annotations? - public let type = "resource" + public let type = ResourceTypes.resource public let resource: TextOrBlobResourceContents } @@ -709,7 +720,7 @@ public struct CallToolResult: Result { /// An error that occurred during the execution of the tool. @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct ExecutionError: Error, Decodable { + public struct ExecutionError: Error, Codable { public let text: String } } @@ -725,7 +736,7 @@ public struct CallToolRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? public let name: String public let arguments: JSON? @@ -747,11 +758,11 @@ public struct ToolListChangedNotification: Notification { // TODO: add the ability to cast this to a Tool while validating the schema /// Definition for a tool the client can call. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct Tool: Decodable, Equatable { +public struct Tool: Codable, Equatable { /// The name of the tool. public let name: String /// A human-readable description of the tool. - public let description: String + public let description: String? // TODO: Use a more specific type to represent the JSON schema type? /// A JSON Schema object defining the expected parameters for the tool. public let inputSchema: JSON @@ -768,7 +779,7 @@ public struct SetLevelRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? /// The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/logging/message. public let level: LoggingLevel @@ -812,11 +823,11 @@ public enum LoggingLevel: String, Codable { case emergency } -// MARK: - CreateMessageRequest +// MARK: - CreateSamplingMessageRequest /// A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct CreateMessageRequest: Request, HasParams, Codable { +public struct CreateSamplingMessageRequest: Request, HasParams, Codable { public typealias Result = CreateMessageResult @MemberwiseInit(.public, _optionalsDefaultNil: true) @@ -847,7 +858,7 @@ public struct CreateMessageRequest: Request, HasParams, Codable { /// The client's response to a sampling/create_message request from the server. The client should inform the user before returning the sampled message, to allow them to inspect the response (human in the loop) and decide whether to allow the server to see it. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct CreateMessageResult: Result, SamplingMessageInterface, Encodable { +public struct CreateMessageResult: Result, SamplingMessageInterface, Codable { public let role: Role public let content: TextOrImageContent public let _meta: AnyMeta? @@ -894,7 +905,7 @@ public struct Annotations: Codable, Equatable { @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct TextContent: Annotated, Codable, Equatable { public let annotations: Annotations? - public let type = "text" + public let type = ResourceTypes.text /// The text content of the message. public let text: String } @@ -905,7 +916,7 @@ public struct TextContent: Annotated, Codable, Equatable { @MemberwiseInit(.public, _optionalsDefaultNil: true) public struct ImageContent: Annotated, Codable, Equatable { public let annotations: Annotations? - public let type = "image" + public let type = ResourceTypes.image /// The base64-encoded image data. public let data: String /// The MIME type of the image. Different providers may support different image types. @@ -991,13 +1002,13 @@ public struct CompleteRequest: Request, HasParams { public let params: Params @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Encodable { + public struct Params: HasMetaValue, Codable, Equatable { public let _meta: MetaProgress? public let ref: PromptOrResourceReference public let argument: Argument @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Argument: Encodable { + public struct Argument: Codable, Equatable { /// The name of the argument public let name: String /// The value of the argument to use for completion matching. @@ -1015,7 +1026,7 @@ public struct CompleteResult: Result { public let completion: Completion @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Completion: Decodable { + public struct Completion: Codable { /// An array of completion values. Must not exceed 100 items. public let values: [String] /// The total number of completion options available. This can exceed the number of values actually sent in the response. @@ -1029,8 +1040,8 @@ public struct CompleteResult: Result { /// A reference to a resource or resource template definition. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ResourceReference: Encodable { - public let type = "ref/resource" +public struct ResourceReference: Codable, Equatable { + public let type = ResourceTypes.resourceReference /// The URI or URI template of the resource. /// @format uri-template public let uri: String @@ -1040,8 +1051,8 @@ public struct ResourceReference: Encodable { /// Identifies a prompt. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct PromptReference: Encodable { - public let type = "ref/prompt" +public struct PromptReference: Codable, Equatable { + public let type = ResourceTypes.promptReference /// The name of the prompt or prompt template public let name: String } @@ -1056,7 +1067,7 @@ public struct PromptReference: Encodable { /// This request is typically used when the server needs to understand the file system /// structure or access specific locations that the client has permission to read from. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ListRootsRequest: Request, Decodable { +public struct ListRootsRequest: Request, Codable { public typealias Result = ListRootsResult @@ -1070,7 +1081,7 @@ public struct ListRootsRequest: Request, Decodable { /// This result contains an array of Root objects, each representing a root directory /// or file that the server can operate on. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct ListRootsResult: PaginatedResult, Encodable { +public struct ListRootsResult: PaginatedResult, Codable { public let _meta: AnyMeta? public let nextCursor: Cursor? public let roots: [Root] @@ -1080,7 +1091,7 @@ public struct ListRootsResult: PaginatedResult, Encodable { /// Represents a root directory or file that the server can operate on. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct Root: Codable { +public struct Root: Codable, Equatable { /// The URI identifying the root. This *must* start with file:// for now. /// This restriction may be relaxed in future versions of the protocol to allow /// other URI schemes. @@ -1099,14 +1110,9 @@ public struct Root: Codable { /// This notification should be sent whenever the client adds, removes, or modifies any root. /// The server should then request an updated list of roots using the ListRootsRequest. @MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct RootsListChangedNotification: Notification, HasParams { +public struct RootsListChangedNotification: Notification { public let method = Notifications.rootsListChanged - public let params: Params - - @MemberwiseInit(.public, _optionalsDefaultNil: true) - public struct Params: HasMetaValue, Codable, Equatable { - public let _meta: AnyMeta? - } + public let params: AnyParams? } // MARK: - MessageContext @@ -1129,7 +1135,7 @@ public struct SamplingMessage: SamplingMessageInterface, Codable, Equatable { // MARK: - TextContentOrImageContentOrEmbeddedResource -public enum TextContentOrImageContentOrEmbeddedResource: Decodable { +public enum TextContentOrImageContentOrEmbeddedResource: Codable, Equatable { case text(TextContent) case image(ImageContent) case embeddedResource(EmbeddedResource) @@ -1151,25 +1157,53 @@ public enum StringOrNumber: Codable, Equatable { // MARK: - TextOrBlobResourceContents -public enum TextOrBlobResourceContents: Decodable { +public enum TextOrBlobResourceContents: Codable, Equatable { case text(TextResourceContents) case blob(BlobResourceContents) } // MARK: - PromptOrResourceReference -public enum PromptOrResourceReference: Encodable { +public enum PromptOrResourceReference: Codable, Equatable { case prompt(PromptReference) case resource(ResourceReference) } +// MARK: - ClientRequest + +/// Requests that can be received by the client. +/// Note: the ping request is omitted since it is responded to by the connection layer. +public enum ClientRequest: Decodable, Equatable { + case initialize(InitializeRequest.Params) + case listPrompts(ListPromptsRequest.Params = nil) + case getPrompt(GetPromptRequest.Params) + case listResources(ListResourcesRequest.Params = nil) + case readResource(ReadResourceRequest.Params) + case subscribeToResource(SubscribeRequest.Params) + case unsubscribeToResource(UnsubscribeRequest.Params) + case listResourceTemplates(ListResourceTemplatesRequest.Params = nil) + case listTools(ListToolsRequest.Params = nil) + case callTool(CallToolRequest.Params) + case complete(CompleteRequest.Params) + case setLogLevel(SetLevelRequest.Params) +} + // MARK: - ClientNotification -public enum ClientNotification: Encodable, Equatable { - case cancelled(CancelledNotification) - case progress(ProgressNotification) - case initialized(InitializedNotification) - case rootsListChanged(RootsListChangedNotification) +public enum ClientNotification: Decodable, Equatable { + case cancelled(CancelledNotification.Params) + case progress(ProgressNotification.Params) + case initialized(InitializedNotification.Params) + case rootsListChanged(RootsListChangedNotification.Params) +} + +// MARK: - ServerRequest + +/// Requests that can be received by the server. +/// Note: the ping request is omitted since it is responded to by the connection layer. +public enum ServerRequest: Decodable, Equatable { + case createMessage(CreateSamplingMessageRequest.Params) + case listRoots(ListRootsRequest.Params = nil) } // MARK: - ServerNotification @@ -1183,12 +1217,3 @@ public enum ServerNotification: Decodable, Equatable { case toolListChanged(ToolListChangedNotification.Params) case promptListChanged(PromptListChangedNotification.Params) } - -// MARK: - ServerRequest - -/// Requests that can be received by the server. -/// Note: the ping request is omitted since it is responded to by the connection layer. -public enum ServerRequest: Decodable, Equatable { - case createMessage(CreateMessageRequest.Params) - case listRoots(ListRootsRequest.Params? = nil) -} diff --git a/MCPShared/Sources/mcp_interfaces/JSON+extensions.swift b/MCPInterface/Sources/mcp_interfaces/JSON+extensions.swift similarity index 100% rename from MCPShared/Sources/mcp_interfaces/JSON+extensions.swift rename to MCPInterface/Sources/mcp_interfaces/JSON+extensions.swift diff --git a/MCPInterface/Sources/mcp_interfaces/JSON.swift b/MCPInterface/Sources/mcp_interfaces/JSON.swift new file mode 100644 index 0000000..47538bc --- /dev/null +++ b/MCPInterface/Sources/mcp_interfaces/JSON.swift @@ -0,0 +1,115 @@ +import Foundation + +// MARK: - JSON + +public enum JSON: Codable, Equatable, Sendable { + case object(_ value: [String: JSON.Value]) + case array(_ value: [JSON.Value]) + + // MARK: - JSONValue + + public enum Value: Codable, Equatable, Sendable { + case string(_ value: String) + case object(_ value: [String: JSON.Value]) + case array(_ value: [JSON.Value]) + case bool(_ value: Bool) + case number(_ value: Double) + case null + } +} + +// MARK: ExpressibleByDictionaryLiteral + +extension JSON: ExpressibleByDictionaryLiteral { + public init(dictionaryLiteral elements: (String, JSON.Value)...) { + var object = [String: JSON.Value]() + + for element in elements { + object[element.0] = element.1 + } + + self = .object(object) + } +} + +// MARK: ExpressibleByArrayLiteral + +extension JSON: ExpressibleByArrayLiteral { + public init(arrayLiteral elements: JSON.Value...) { + var array = [JSON.Value]() + + for element in elements { + array.append(element) + } + + self = .array(array) + } +} + +// MARK: - JSON.Value + ExpressibleByNilLiteral + +extension JSON.Value: ExpressibleByNilLiteral { + public init(nilLiteral _: ()) { + self = .null + } +} + +// MARK: - JSON.Value + ExpressibleByDictionaryLiteral + +extension JSON.Value: ExpressibleByDictionaryLiteral { + public init(dictionaryLiteral elements: (String, JSON.Value)...) { + var object = [String: JSON.Value]() + + for element in elements { + object[element.0] = element.1 + } + + self = .object(object) + } +} + +// MARK: - JSON.Value + ExpressibleByStringLiteral + +extension JSON.Value: ExpressibleByStringLiteral { + public init(stringLiteral: String) { + self = .string(stringLiteral) + } +} + +// MARK: - JSON.Value + ExpressibleByIntegerLiteral + +extension JSON.Value: ExpressibleByIntegerLiteral { + public init(integerLiteral value: IntegerLiteralType) { + self = .number(Double(value)) + } +} + +// MARK: - JSON.Value + ExpressibleByFloatLiteral + +extension JSON.Value: ExpressibleByFloatLiteral { + public init(floatLiteral value: FloatLiteralType) { + self = .number(value) + } +} + +// MARK: - JSON.Value + ExpressibleByArrayLiteral + +extension JSON.Value: ExpressibleByArrayLiteral { + public init(arrayLiteral elements: JSON.Value...) { + var array = [JSON.Value]() + + for element in elements { + array.append(element) + } + + self = .array(array) + } +} + +// MARK: - JSON.Value + ExpressibleByBooleanLiteral + +extension JSON.Value: ExpressibleByBooleanLiteral { + public init(booleanLiteral value: BooleanLiteralType) { + self = .bool(value) + } +} diff --git a/MCPShared/Tests/interface/AnyMetaTests.swift b/MCPInterface/Tests/interface/AnyMetaTests.swift similarity index 98% rename from MCPShared/Tests/interface/AnyMetaTests.swift rename to MCPInterface/Tests/interface/AnyMetaTests.swift index d77872a..d3046c1 100644 --- a/MCPShared/Tests/interface/AnyMetaTests.swift +++ b/MCPInterface/Tests/interface/AnyMetaTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { diff --git a/MCPShared/Tests/interface/AnyParamsTests.swift b/MCPInterface/Tests/interface/AnyParamsTests.swift similarity index 91% rename from MCPShared/Tests/interface/AnyParamsTests.swift rename to MCPInterface/Tests/interface/AnyParamsTests.swift index e279d1d..0fc13d4 100644 --- a/MCPShared/Tests/interface/AnyParamsTests.swift +++ b/MCPInterface/Tests/interface/AnyParamsTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { @@ -31,9 +31,9 @@ extension MCPInterfaceTests { @Test func encodeWithNonMetaValues() throws { - try testEncoding(of: AnyParams(value: .object([ + try testEncoding(of: AnyParams(value: [ "key": .string("value"), - ])), """ + ]), """ { "key" : "value" } @@ -46,9 +46,9 @@ extension MCPInterfaceTests { _meta: .init(value: .object([ "meta_key": .string("meta_value"), ])), - value: .object([ + value: [ "key": .string("value"), - ])), """ + ]), """ { "_meta" : { "meta_key" : "meta_value" @@ -95,9 +95,9 @@ extension MCPInterfaceTests { { "key" : "value" } - """, AnyParams(value: .object([ + """, AnyParams(value: [ "key": .string("value"), - ]))) + ])) } @Test @@ -113,9 +113,9 @@ extension MCPInterfaceTests { _meta: .init(value: .object([ "meta_key": .string("meta_value"), ])), - value: .object([ + value: [ "key": .string("value"), - ]))) + ])) } // MARK: Private diff --git a/MCPInterface/Tests/interface/AnyParamsWithProgressTokenTests.swift b/MCPInterface/Tests/interface/AnyParamsWithProgressTokenTests.swift new file mode 100644 index 0000000..d9b17ef --- /dev/null +++ b/MCPInterface/Tests/interface/AnyParamsWithProgressTokenTests.swift @@ -0,0 +1,65 @@ + +import Foundation +import MCPInterface +import Testing + +extension MCPInterfaceTests { + struct AnyParamsWithProgressTokenTest { + + @Test + func encodeWithNoValues() throws { + try testEncodingDecoding(of: AnyParamsWithProgressToken(), """ + {} + """) + } + + @Test + func encodeWithStringProgressToken() throws { + try testEncodingDecoding(of: AnyParamsWithProgressToken(_meta: .init(progressToken: .string("123abc"))), """ + { + "_meta" : { + "progressToken" : "123abc" + } + } + """) + } + + @Test + func encodeWithNumberProgressToken() throws { + try testEncodingDecoding(of: AnyParamsWithProgressToken(_meta: .init(progressToken: .number(123456))), """ + { + "_meta" : { + "progressToken" : 123456 + } + } + """) + } + + @Test + func encodeWithNonMetaValues() throws { + try testEncodingDecoding(of: AnyParamsWithProgressToken(value: [ + "key": .string("value"), + ]), """ + { + "key" : "value" + } + """) + } + + @Test + func encodeWithMetaAndOtherParameters() throws { + try testEncodingDecoding(of: AnyParamsWithProgressToken( + _meta: .init(progressToken: .string("123abc")), + value: [ + "key": .string("value"), + ]), """ + { + "_meta" : { + "progressToken" : "123abc" + }, + "key" : "value" + } + """) + } + } +} diff --git a/MCPInterface/Tests/interface/ClientNotificationTests.swift b/MCPInterface/Tests/interface/ClientNotificationTests.swift new file mode 100644 index 0000000..f287d9e --- /dev/null +++ b/MCPInterface/Tests/interface/ClientNotificationTests.swift @@ -0,0 +1,89 @@ + +import Foundation +import MCPInterface +import Testing + +extension MCPInterfaceTests { + enum ClientNotificationTest { + + struct Deserialization { + + // MARK: Internal + + @Test + func decodeCancelledNotification() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "method": "notifications/cancelled", + "params": { + "requestId": "123", + "reason": "User requested cancellation" + } + } + """, + .cancelled(.init(requestId: .string("123"), reason: "User requested cancellation"))) + } + + @Test + func decodeProgressNotification() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": { + "progressToken": "abc123", + "progress": 50, + "total": 100 + } + } + """, + .progress(.init(progressToken: .string("abc123"), progress: 50, total: 100))) + } + + @Test + func decodeInitializedNotification() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + """, + .initialized(.init())) + } + + @Test + func decodeRootsListChangedNotification() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "method": "notifications/roots/list_changed" + } + """, + .rootsListChanged(.init())) + } + + @Test + func failsToDecodeBadValue() throws { + let data = """ + { + "jsonrpc": "2.0", + "method": "notifications/llm/unknown" + } + """.data(using: .utf8)! + #expect(throws: DecodingError.self) { try JSONDecoder().decode(ClientNotification.self, from: data) } + } + + // MARK: Private + + private func testDecoding(of json: String, _ value: ClientNotification) throws { + let data = json.data(using: .utf8)! + #expect(try JSONDecoder().decode(ClientNotification.self, from: data) == value) + } + } + } +} diff --git a/MCPInterface/Tests/interface/ClientRequestTests.swift b/MCPInterface/Tests/interface/ClientRequestTests.swift new file mode 100644 index 0000000..085352b --- /dev/null +++ b/MCPInterface/Tests/interface/ClientRequestTests.swift @@ -0,0 +1,231 @@ +import Foundation +import MCPInterface +import Testing + +extension MCPInterfaceTests { + enum ClientRequestTest { + + struct Deserialization { + + // MARK: Internal + + @Test + func decodeInitializeRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "1.0", + "capabilities": { + "roots": { + "listChanged": true + } + }, + "clientInfo": { + "name": "TestClient", + "version": "1.0.0" + } + } + } + """, + .initialize(.init( + protocolVersion: "1.0", + capabilities: .init(roots: .init(listChanged: true)), + clientInfo: .init(name: "TestClient", version: "1.0.0")))) + } + + @Test + func decodeListPromptsRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/list" + } + """, + .listPrompts(nil)) + } + + @Test + func decodeGetPromptRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": { + "name": "code_review" + } + } + """, + .getPrompt(.init(name: "code_review"))) + } + + @Test + func decodeListResourcesRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list" + } + """, + .listResources(nil)) + } + + @Test + func decodeReadResourceRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { + "uri": "file:///example.txt" + } + } + """, + .readResource(.init(uri: "file:///example.txt"))) + } + + @Test + func decodeSubscribeToResourceRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/subscribe", + "params": { + "uri": "file:///example.txt" + } + } + """, + .subscribeToResource(.init(uri: "file:///example.txt"))) + } + + @Test + func decodeUnsubscribeToResourceRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/unsubscribe", + "params": { + "uri": "file:///example.txt" + } + } + """, + .unsubscribeToResource(.init(uri: "file:///example.txt"))) + } + + @Test + func decodeListResourceTemplatesRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/templates/list" + } + """, + .listResourceTemplates(nil)) + } + + @Test + func decodeListToolsRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list" + } + """, + .listTools(nil)) + } + + @Test + func decodeCallToolRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "format_code", + "arguments": { + "language": "swift", + "style": "default" + } + } + } + """, + .callTool(.init( + name: "format_code", + arguments: .object([ + "language": .string("swift"), + "style": .string("default"), + ])))) + } + + @Test + func decodeCompleteRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "code_completion" + }, + "argument": { + "name": "prefix", + "value": "func test" + } + } + } + """, + .complete(.init( + ref: .prompt(.init(name: "code_completion")), + argument: .init(name: "prefix", value: "func test")))) + } + + @Test + func decodeSetLogLevelRequest() throws { + try testDecoding( + of: """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": { + "level": "debug" + } + } + """, + .setLogLevel(.init(level: .debug))) + } + + // MARK: Private + + private func testDecoding(of json: String, _ value: ClientRequest) throws { + let data = json.data(using: .utf8)! + let decodedValue = try JSONDecoder().decode(ClientRequest.self, from: data) + #expect(decodedValue == value) + } + } + } +} diff --git a/MCPShared/Tests/interface/InterfaceTests.swift b/MCPInterface/Tests/interface/InterfaceTests.swift similarity index 100% rename from MCPShared/Tests/interface/InterfaceTests.swift rename to MCPInterface/Tests/interface/InterfaceTests.swift diff --git a/MCPShared/Tests/interface/JSONTests.swift b/MCPInterface/Tests/interface/JSONTests.swift similarity index 98% rename from MCPShared/Tests/interface/JSONTests.swift rename to MCPInterface/Tests/interface/JSONTests.swift index fbd4990..2c6c3e2 100644 --- a/MCPShared/Tests/interface/JSONTests.swift +++ b/MCPInterface/Tests/interface/JSONTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { diff --git a/MCPShared/Tests/interface/LoggingLevelTests.swift b/MCPInterface/Tests/interface/LoggingLevelTests.swift similarity index 99% rename from MCPShared/Tests/interface/LoggingLevelTests.swift rename to MCPInterface/Tests/interface/LoggingLevelTests.swift index c5f8c53..2a17071 100644 --- a/MCPShared/Tests/interface/LoggingLevelTests.swift +++ b/MCPInterface/Tests/interface/LoggingLevelTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { enum LoggingLevelTest { diff --git a/MCPInterface/Tests/interface/PromptOrResourceReferenceTests.swift b/MCPInterface/Tests/interface/PromptOrResourceReferenceTests.swift new file mode 100644 index 0000000..99e04da --- /dev/null +++ b/MCPInterface/Tests/interface/PromptOrResourceReferenceTests.swift @@ -0,0 +1,38 @@ + +import Foundation +import MCPInterface +import Testing +extension MCPInterfaceTests { + struct PromptOrResourceReferenceTest { + + // MARK: Internal + + @Test + func encodePrompt() throws { + let value = PromptOrResourceReference.prompt(.init(name: "code_review")) + try testEncodingDecoding(of: value, """ + { + "name" : "code_review", + "type" : "ref/prompt" + } + """) + } + + @Test + func encodeResource() throws { + let value = PromptOrResourceReference.resource(.init(uri: "file:///foo_path")) + try testEncodingDecoding(of: value, """ + { + "type" : "ref/resource", + "uri" : "file:///foo_path" + } + """) + } + + // MARK: Private + + private func testEncoding(of value: PromptOrResourceReference, _ json: String) throws { + try testEncodingOf(value, json) + } + } +} diff --git a/MCPShared/Tests/interface/RoleTests.swift b/MCPInterface/Tests/interface/RoleTests.swift similarity index 97% rename from MCPShared/Tests/interface/RoleTests.swift rename to MCPInterface/Tests/interface/RoleTests.swift index a9374bd..b10fe00 100644 --- a/MCPShared/Tests/interface/RoleTests.swift +++ b/MCPInterface/Tests/interface/RoleTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { enum RoleTest { diff --git a/MCPShared/Tests/interface/SerializationDeserializationTestUtils.swift b/MCPInterface/Tests/interface/SerializationDeserializationTestUtils.swift similarity index 60% rename from MCPShared/Tests/interface/SerializationDeserializationTestUtils.swift rename to MCPInterface/Tests/interface/SerializationDeserializationTestUtils.swift index b77d454..c5606b5 100644 --- a/MCPShared/Tests/interface/SerializationDeserializationTestUtils.swift +++ b/MCPInterface/Tests/interface/SerializationDeserializationTestUtils.swift @@ -25,15 +25,28 @@ func testDecodingEncodingOf(_ json: String, with _: T.Type) throws { #expect(expected == value) } -/// Test encoding the value to Json, and comparing it to the expectation. +/// Test that encoding the value gives the expected json. func testEncodingOf(_ value: some Encodable, _ json: String) throws { - let encoder = JSONEncoder() - let encoded = try encoder.encode(value) + let encoded = try JSONEncoder().encode(value) + let encodedString = try encoded.jsonString() + // Reformat the json expectation (pretty print, sort keys) let jsonData = json.data(using: .utf8)! - - let encodedString = try encoded.jsonString() let expected = try jsonData.jsonString() #expect(expected == encodedString) } + +// TODO: remove the 'Of': + +/// Test that decoding the json gives the expected value. +func testDecodingOf(_ value: T, _ json: String) throws { + let decoded = try JSONDecoder().decode(T.self, from: json.data(using: .utf8)!) + #expect(decoded == value) +} + +/// Test that encoding the value gives the expected json, and that decoding the json gives the expected value. +func testEncodingDecoding(of value: T, _ json: String) throws { + try testEncodingOf(value, json) + try testDecodingOf(value, json) +} diff --git a/MCPShared/Tests/interface/ServerNotificationTests.swift b/MCPInterface/Tests/interface/ServerNotificationTests.swift similarity index 99% rename from MCPShared/Tests/interface/ServerNotificationTests.swift rename to MCPInterface/Tests/interface/ServerNotificationTests.swift index 69ac9e9..962001c 100644 --- a/MCPShared/Tests/interface/ServerNotificationTests.swift +++ b/MCPInterface/Tests/interface/ServerNotificationTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { diff --git a/MCPShared/Tests/interface/ServerRequestTests.swift b/MCPInterface/Tests/interface/ServerRequestTests.swift similarity index 99% rename from MCPShared/Tests/interface/ServerRequestTests.swift rename to MCPInterface/Tests/interface/ServerRequestTests.swift index 0e4a181..74b674a 100644 --- a/MCPShared/Tests/interface/ServerRequestTests.swift +++ b/MCPInterface/Tests/interface/ServerRequestTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { diff --git a/MCPInterface/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift b/MCPInterface/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift new file mode 100644 index 0000000..b0a1c9a --- /dev/null +++ b/MCPInterface/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift @@ -0,0 +1,65 @@ + +import Foundation +import MCPInterface +import Testing + +extension MCPInterfaceTests { + struct TextContentOrImageContentOrEmbeddedResourceTest { + + // MARK: Internal + + @Test + func decodeText() throws { + let json = """ + { + "type": "text", + "text": "Tool result text" + } + """ + let value = try decode(json) + + #expect(value.text?.text == "Tool result text") + try testEncodingDecoding(of: value, json) + } + + @Test + func decodeImage() throws { + let json = """ + { + "type": "image", + "data": "base64-encoded-data", + "mimeType": "image/png" + } + """ + let value = try decode(json) + + #expect(value.image?.data == "base64-encoded-data") + try testEncodingDecoding(of: value, json) + } + + @Test + func decodeResource() throws { + let json = """ + { + "type": "resource", + "resource": { + "uri": "resource://example", + "mimeType": "text/plain", + "text": "Resource content" + } + } + """ + let value = try decode(json) + + #expect(value.embeddedResource?.resource.text?.uri == "resource://example") + try testEncodingDecoding(of: value, json) + } + + // MARK: Private + + private func decode(_ jsonString: String) throws -> TextContentOrImageContentOrEmbeddedResource { + let data = jsonString.data(using: .utf8)! + return try JSONDecoder().decode(TextContentOrImageContentOrEmbeddedResource.self, from: data) + } + } +} diff --git a/MCPInterface/Tests/interface/TextOrBlobResourceContentsTests.swift b/MCPInterface/Tests/interface/TextOrBlobResourceContentsTests.swift new file mode 100644 index 0000000..33919e5 --- /dev/null +++ b/MCPInterface/Tests/interface/TextOrBlobResourceContentsTests.swift @@ -0,0 +1,48 @@ + +import Foundation +import MCPInterface +import Testing + +extension MCPInterfaceTests { + struct TextOrBlobResourceContentsTest { + + // MARK: Internal + + @Test + func decodeText() throws { + let json = """ + { + "uri": "file:///example.txt", + "mimeType": "text/plain", + "text": "Resource content" + } + """ + let value = try decode(json) + + #expect(value.text?.text == "Resource content") + try testEncodingDecoding(of: value, json) + } + + @Test + func decodeImage() throws { + let json = """ + { + "uri": "file:///example.png", + "mimeType": "image/png", + "blob": "base64-encoded-data" + } + """ + let value = try decode(json) + + #expect(value.blob?.blob == "base64-encoded-data") + try testEncodingDecoding(of: value, json) + } + + // MARK: Private + + private func decode(_ jsonString: String) throws -> TextOrBlobResourceContents { + let data = jsonString.data(using: .utf8)! + return try JSONDecoder().decode(TextOrBlobResourceContents.self, from: data) + } + } +} diff --git a/MCPShared/Tests/interface/TextOrImageContentTests.swift b/MCPInterface/Tests/interface/TextOrImageContentTests.swift similarity index 98% rename from MCPShared/Tests/interface/TextOrImageContentTests.swift rename to MCPInterface/Tests/interface/TextOrImageContentTests.swift index df309f0..26f0e67 100644 --- a/MCPShared/Tests/interface/TextOrImageContentTests.swift +++ b/MCPInterface/Tests/interface/TextOrImageContentTests.swift @@ -1,6 +1,6 @@ import Foundation -import MCPShared +import MCPInterface import Testing extension MCPInterfaceTests { diff --git a/MCPServer/Sources/Convenience/Schemable+extensions.swift b/MCPServer/Sources/Convenience/Schemable+extensions.swift new file mode 100644 index 0000000..ef509d6 --- /dev/null +++ b/MCPServer/Sources/Convenience/Schemable+extensions.swift @@ -0,0 +1,190 @@ +import Foundation +import JSONSchema +import JSONSchemaBuilder +import MCPInterface + +// MARK: - CallableTool + +// Allows to use @Schemable from `JSONSchemaBuilder` to define the input for tools as: +// +// @Schemable +// struct RepeatToolInput { +// let text: String +// } +// +// print(RepeatToolInput.schema.schemaValue) +// -> {"required":["text"],"properties":{"text":{"type":"string"}},"type":"object"} + +/// Definition for a tool the client can call. +public protocol CallableTool { + associatedtype Input: Decodable + /// A JSON Schema object defining the expected parameters for the tool. + var inputSchema: JSON { get } + /// The name of the tool. + var name: String { get } + /// A human-readable description of the tool. + var description: String? { get } + + func decodeInput(_ input: JSON?) throws -> Input + + func call(_ input: Input) async throws -> [TextContentOrImageContentOrEmbeddedResource] +} + +// MARK: - Tool + +public struct Tool: CallableTool { + + // MARK: Lifecycle + + public init( + name: String, + description: String? = nil, + inputSchema: JSON, + decodeInput: @escaping (Data) throws -> Input, + call: @escaping (Input) async throws -> [TextContentOrImageContentOrEmbeddedResource]) + { + self.name = name + self.description = description + self.inputSchema = inputSchema + _decodeInput = decodeInput + _call = call + } + + // MARK: Public + + public let name: String + + public let description: String? + + public let inputSchema: JSON + + public func call(_ input: Input) async throws -> [TextContentOrImageContentOrEmbeddedResource] { + try await _call(input) + } + + public func decodeInput(_ input: MCPInterface.JSON?) throws -> Input { + let data = try JSONEncoder().encode(input) + return try _decodeInput(data) + } + + // MARK: Private + + private let _call: (Input) async throws -> [TextContentOrImageContentOrEmbeddedResource] + + private let _decodeInput: (Data) throws -> Input +} + +extension Tool where Input: Schemable { + public init( + name: String, + description: String? = nil, + call: @escaping (Input) async throws -> [TextContentOrImageContentOrEmbeddedResource]) where Input.Schema.Output == Input + { + self.init( + name: name, + description: description, + inputSchema: Input.schema.schemaValue.json, + decodeInput: { data in + let json = try JSONDecoder().decode(JSONValue.self, from: data) + switch Input.schema.parse(json) { + case .valid(let value): + return value + case .invalid(let errors): + throw errors.first ?? MCPServerError.toolCallError(errors) + } + }, + call: call) + } +} + +extension Tool where Input: Decodable { + public init( + name: String, + description: String? = nil, + inputSchema: JSON, + call: @escaping (Input) async throws -> [TextContentOrImageContentOrEmbeddedResource]) + { + self.init( + name: name, + description: description, + inputSchema: inputSchema, + decodeInput: { data in + try JSONDecoder().decode(Input.self, from: data) + }, + call: call) + } +} + +extension CallableTool { + public func decodeInput(_ input: JSON?) throws -> Input { + let data = try JSONEncoder().encode(input) + return try JSONDecoder().decode(Input.self, from: data) + } + + public func call(_ input: JSON?) async throws -> [TextContentOrImageContentOrEmbeddedResource] { + let input = try decodeInput(input) + return try await call(input) + } +} + +extension Array where Element == any CallableTool { + func asRequestHandler(listToolChanged: Bool) + -> ListedCapabilityHandler + { + let toolsByName = [String: any CallableTool](uniqueKeysWithValues: map { ($0.name, $0) }) + + return .init( + info: .init(listChanged: listToolChanged), + handler: { request in + let name = request.name + guard let tool = toolsByName[name] else { + throw MCPError.notSupported + } + let arguments = request.arguments + do { + let content = try await tool.call(arguments) + return CallToolResult(content: content) + } catch { + return CallToolResult(content: [.text(.init(text: error.localizedDescription))], isError: true) + } + }, + listHandler: { _ in + ListToolsResult(tools: self.map { tool in MCPInterface.Tool( + name: tool.name, + description: tool.description, + inputSchema: tool.inputSchema) }) + }) + } +} + +/// Convert between the JSON representation from `JSONSchema` and ours +extension [KeywordIdentifier: JSONValue] { + fileprivate var json: JSON { + .object(mapValues { $0.value }) + } +} + +extension JSONValue { + fileprivate var value: JSON.Value { + switch self { + case .null: + .null + case .boolean(let value): + .bool(value) + case .number(let value): + .number(value) + case .string(let value): + .string(value) + case .array(let value): + .array(value.map { $0.value }) + case .object(let value): + .object(value.mapValues { $0.value }) + case .integer(let value): + .number(Double(value)) + } + } +} + +// MARK: - ParseIssue + Error + +extension ParseIssue: @retroactive Error { } diff --git a/MCPServer/Sources/Convenience/ServerCapabilityHandler.swift b/MCPServer/Sources/Convenience/ServerCapabilityHandler.swift new file mode 100644 index 0000000..27402aa --- /dev/null +++ b/MCPServer/Sources/Convenience/ServerCapabilityHandler.swift @@ -0,0 +1,20 @@ +extension ServerCapabilityHandlers { + /// Initialize a new `ServerCapabilityHandlers` with the given handlers. + /// - Parameters: + /// - logging: The logging handler. + /// - prompts: The prompts handler. + /// - tools: The list of supported tools (the request handlers will be created automatically). + /// - resources: The resources handler. + public init( + logging: SetLevelRequest.Handler? = nil, + prompts: ListedCapabilityHandler? = nil, + tools: [any CallableTool], + resources: ResourcesCapabilityHandler? = nil) + { + self.init( + logging: logging, + prompts: prompts, + tools: tools.asRequestHandler(listToolChanged: false), + resources: resources) + } +} diff --git a/MCPServer/Sources/DataChannel+stdio.swift b/MCPServer/Sources/DataChannel+stdio.swift new file mode 100644 index 0000000..f1be9f3 --- /dev/null +++ b/MCPServer/Sources/DataChannel+stdio.swift @@ -0,0 +1,19 @@ +import Foundation +import JSONRPC + +extension DataChannel { + + /// A `DataChannel` that uses the process stdio. + /// + /// Note: this is similar to `DataChannel.stdioPipe` but it ensures the data is flushed when written. + public static func stdio() -> DataChannel { + let writeHandler: DataChannel.WriteHandler = { data in + // TODO: upstream this to JSONRPC + var data = data + data.append(contentsOf: [UInt8(ascii: "\n")]) + FileHandle.standardOutput.write(data) + } + + return DataChannel(writeHandler: writeHandler, dataSequence: FileHandle.standardInput.dataStream) + } +} diff --git a/MCPServer/Sources/Exports.swift b/MCPServer/Sources/Exports.swift new file mode 100644 index 0000000..c107e00 --- /dev/null +++ b/MCPServer/Sources/Exports.swift @@ -0,0 +1,2 @@ + +@_exported import MCPInterface diff --git a/MCPServer/Sources/MCPServer.swift b/MCPServer/Sources/MCPServer.swift index 41b2ee1..e398aad 100644 --- a/MCPServer/Sources/MCPServer.swift +++ b/MCPServer/Sources/MCPServer.swift @@ -1,2 +1,310 @@ -//// -public actor MCPServer { } +import Combine +import JSONRPC +import MCPInterface + +// MARK: - MCPServer + +// TODO: Support cancelling request +// TODO: Support sending progress +// TODO: test MCPServer + +public actor MCPServer: MCPServerInterface { + + // MARK: Lifecycle + + /// Creates a MCP server and connects to the client through the provided transport. + /// The methods completes after connecting to the client. + public init( + info: Implementation, + capabilities: ServerCapabilityHandlers, + transport: Transport, + initializeRequestHook: @escaping InitializeRequestHook = { _ in }) + async throws { + connection = try MCPServerConnection( + info: info, + capabilities: capabilities.description, + transport: transport) + self.info = info + self.capabilities = capabilities + + clientInfo = try await Self.connectToClient( + connection: connection, + initializeRequestHook: initializeRequestHook, + capabilities: capabilities, + info: info) + + Task { + for await notification in await connection.notifications { + mcpLogger.log("Received notification: \(String(describing: notification), privacy: .public)") + } + } + await startListeningToNotifications() + await startListeningToRequests() + startPinging() + + Task { try await self.updateRoots() } + } + + // MARK: Public + + public private(set) var clientInfo: ClientInfo + + public var roots: ReadOnlyCurrentValueSubject, Never> { + get async { + await .init(_roots.compactMap { $0 }.removeDuplicates().eraseToAnyPublisher()) + } + } + + public func waitForDisconnection() async throws { + await withCheckedContinuation { (_ continuation: CheckedContinuation) in + // keep running forever + // TODO: handle disconnection from the transport. From ping? + didDisconnect = { + continuation.resume() + } + } + } + + public func getSampling(params: CreateSamplingMessageRequest.Params) async throws -> CreateSamplingMessageRequest.Result { + guard clientInfo.capabilities.sampling != nil else { + throw MCPError.notSupported + } + return try await connection.requestCreateMessage(params) + } + + public func log(params: LoggingMessageNotification.Params) async throws { + try await connection.log(params) + } + + public func notifyResourceUpdated(params: ResourceUpdatedNotification.Params) async throws { + guard capabilities.resources != nil else { + throw MCPError.notSupported + } + try await connection.notifyResourceUpdated(params) + } + + public func notifyResourceListChanged(params _: ResourceListChangedNotification.Params? = nil) async throws { + guard capabilities.resources != nil else { + throw MCPError.notSupported + } + try await connection.notifyResourceListChanged() + } + + public func notifyToolListChanged(params _: ToolListChangedNotification.Params? = nil) async throws { + guard capabilities.tools != nil else { + throw MCPError.notSupported + } + try await connection.notifyResourceListChanged() + } + + public func notifyPromptListChanged(params _: PromptListChangedNotification.Params? = nil) async throws { + guard capabilities.prompts != nil else { + throw MCPError.notSupported + } + try await connection.notifyResourceListChanged() + } + + public func update(tools: [any CallableTool]) async throws { + guard capabilities.tools?.info.listChanged == true else { + throw MCPError.notSupported + } + capabilities = .init( + logging: capabilities.logging, + prompts: capabilities.prompts, + tools: tools.asRequestHandler(listToolChanged: true), + resources: capabilities.resources) + + try await connection.notifyToolListChanged() + } + + // MARK: Private + + private let _roots = CurrentValueSubject?, Never>(nil) + + private let info: Implementation + + private var capabilities: ServerCapabilityHandlers + + /// Called once the client has disconnected. The closure should only be called once. + private var didDisconnect: () -> Void = { } + + private let connection: MCPServerConnection + + private static func connectToClient( + connection: MCPServerConnectionInterface, + initializeRequestHook: @escaping InitializeRequestHook, + capabilities: ServerCapabilityHandlers, + info: Implementation) + async throws -> ClientInfo + { + try await withCheckedThrowingContinuation { (_ continuation: CheckedContinuation) in + Task { + for await(request, completion) in await connection.requestsToHandle { + if case .initialize(let params) = request { + do { + try await initializeRequestHook(params) + completion(.success(InitializeRequest.Result( + protocolVersion: MCP.protocolVersion, + capabilities: capabilities.description, + serverInfo: info))) + + let clientInfo = ClientInfo( + info: params.clientInfo, + capabilities: params.capabilities) + continuation.resume(returning: clientInfo) + } catch { + completion(.failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: error.localizedDescription))) + continuation.resume(throwing: error) + } + break + } else { + mcpLogger.error("Unexpected request received before initialization") + completion(.failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: "Unexpected request received before initialization"))) + } + } + } + } + } + + private func updateRoots() async throws { + guard clientInfo.capabilities.roots != nil else { + // Tool calling not supported + _roots.send(.notSupported) + return + } + let roots = try await connection.listRoots() + _roots.send(.supported(roots.roots)) + } + + private func startPinging() { + // TODO + } + + private func handle( + request params: Params, + with handler: ((Params) async throws -> some Encodable)?, + _ requestName: String) + async -> AnyJRPCResponse + { + if let handler { + do { + return .success(try await handler(params)) + } catch { + return .failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: error.localizedDescription)) + } + } else { + return .failure(.init( + code: JRPCErrorCodes.invalidRequest.rawValue, + message: "\(requestName) is not supported by this server")) + } + } + + private func startListeningToNotifications() async { + let notifications = await connection.notifications + Task { [weak self] in + for await notification in notifications { + switch notification { + case .cancelled: + // TODO: Handle this + break + + case .progress(let progressParams): + // TODO: Handle this + break + + case .initialized: + break + + case .rootsListChanged: + try await self?.updateRoots() + } + } + } + } + + private func startListeningToRequests() async { + let requests = await connection.requestsToHandle + Task { [weak self] in + for await(request, completion) in requests { + mcpLogger.log("Received request: \(String(describing: request), privacy: .public)") + + guard let self else { + completion(.failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: "The server is gone"))) + return + } + + switch request { + case .initialize: + mcpLogger.error("initialization received twice") + completion(.failure(.init( + code: JRPCErrorCodes.internalError.rawValue, + message: "initialization received twice"))) + + case .listPrompts(let params): + await completion(handle(request: params, with: capabilities.prompts?.listHandler, "Listing prompts")) + + case .getPrompt(let params): + await completion(handle(request: params, with: capabilities.prompts?.handler, "Getting prompt")) + + case .listResources(let params): + await completion(handle(request: params, with: capabilities.resources?.listResource, "Listing resources")) + + case .readResource(let params): + await completion(handle(request: params, with: capabilities.resources?.readResource, "Reading resource")) + + case .subscribeToResource(let params): + await completion(handle(request: params, with: capabilities.resources?.subscribeToResource, "Subscribing to resource")) + + case .unsubscribeToResource(let params): + await completion(handle( + request: params, + with: capabilities.resources?.unsubscribeToResource, + "Unsubscribing to resource")) + + case .listResourceTemplates(let params): + await completion(handle( + request: params, + with: capabilities.resources?.listResourceTemplates, + "Listing resource templates")) + + case .listTools(let params): + await completion(handle(request: params, with: capabilities.tools?.listHandler, "Listing tools")) + + case .callTool(let params): + await completion(handle(request: params, with: capabilities.tools?.handler, "Tool calling")) + + case .complete(let params): + await completion(handle(request: params, with: capabilities.resources?.complete, "Resource completion")) + + case .setLogLevel(let params): + await completion(handle(request: params, with: capabilities.logging, "Setting log level")) + } + } + } + } + +} + +extension ServerCapabilityHandlers { + /// The MCP description of the supported server capabilities, inferred from which ones have handlers. + var description: ServerCapabilities { + ServerCapabilities( + experimental: nil, // TODO: support experimental requests + logging: logging != nil ? EmptyObject() : nil, + prompts: prompts?.info, + resources: resources.map { capability in + CapabilityInfo( + subscribe: capability.subscribeToResource != nil, + listChanged: capability.listChanged) + }, + tools: tools?.info) + } +} diff --git a/MCPServer/Sources/MCPServerConnection.swift b/MCPServer/Sources/MCPServerConnection.swift new file mode 100644 index 0000000..3190b97 --- /dev/null +++ b/MCPServer/Sources/MCPServerConnection.swift @@ -0,0 +1,83 @@ +import Foundation +import JSONRPC +import MCPInterface + +// MARK: - MCPClientConnection + +public actor MCPServerConnection: MCPServerConnectionInterface { + + // MARK: Lifecycle + + public init( + info: Implementation, + capabilities: ServerCapabilities, + transport: Transport) + throws + { + // Note: ideally we would subclass `MCPConnection`. However Swift actors don't support inheritance. + _connection = try MCPConnection(transport: transport) + self.info = info + self.capabilities = capabilities + } + + // MARK: Public + + public let info: Implementation + + public let capabilities: ServerCapabilities + + public var notifications: AsyncStream { _connection.notifications } + public var requestsToHandle: AsyncStream { _connection.requestsToHandle } + + public func ping() async throws { + // TODO: add timeout + _ = try await jrpcSession.send(PingRequest()) + } + + public func requestCreateMessage(_ params: CreateSamplingMessageRequest.Params) async throws -> CreateSamplingMessageRequest + .Result + { + try await jrpcSession.send(CreateSamplingMessageRequest(params: params)) + } + + public func listRoots() async throws -> ListRootsResult { + try await jrpcSession.send(ListRootsRequest()) + } + + public func notifyProgress(_ params: ProgressNotification.Params) async throws { + try await jrpcSession.send(ProgressNotification(params: params)) + } + + public func notifyResourceUpdated(_ params: ResourceUpdatedNotification.Params) async throws { + try await jrpcSession.send(ResourceUpdatedNotification(params: params)) + } + + public func notifyResourceListChanged(_ params: ResourceListChangedNotification.Params? = nil) async throws { + try await jrpcSession.send(ResourceListChangedNotification(params: params)) + } + + public func notifyToolListChanged(_ params: ToolListChangedNotification.Params? = nil) async throws { + try await jrpcSession.send(ToolListChangedNotification(params: params)) + } + + public func notifyPromptListChanged(_ params: PromptListChangedNotification.Params? = nil) async throws { + try await jrpcSession.send(PromptListChangedNotification(params: params)) + } + + public func notifyCancelled(_ params: CancelledNotification.Params) async throws { + try await jrpcSession.send(CancelledNotification(params: params)) + } + + public func log(_ params: LoggingMessageNotification.Params) async throws { + try await jrpcSession.send(LoggingMessageNotification(params: params)) + } + + // MARK: Private + + private let _connection: MCPConnection + + private var jrpcSession: JSONRPCSession { + _connection.jrpcSession + } + +} diff --git a/MCPServer/Sources/MCPServerConnectionInterface.swift b/MCPServer/Sources/MCPServerConnectionInterface.swift new file mode 100644 index 0000000..ad8f9c5 --- /dev/null +++ b/MCPServer/Sources/MCPServerConnectionInterface.swift @@ -0,0 +1,54 @@ +import JSONRPC +import MCPInterface + +// MARK: - MCPServerConnectionInterface + +/// The MCP JRPC Bridge is a stateless interface to the MCP server that provides a higher level Swift interface. +/// It does not implement any of the stateful behaviors of the MCP client, such as handling subscriptions, detecting connection health, +/// ensuring that the connection has been initialized before being used etc. +/// +/// For most use cases, `MCPServer` should be a preferred interface. +public protocol MCPServerConnectionInterface { + /// The notifications received by the client. + var notifications: AsyncStream { get async } + /// The requests received by the client that need to be responded to. + var requestsToHandle: AsyncStream { get async } + + /// Creates a new MCP JRPC Bridge. + /// This will create a new connection with the transport corresponding to the MCP client, but it will not handle the initialization request as specified by the MCP protocol. + /// The connection will be closed when this object is de-initialized. + init( + info: Implementation, + capabilities: ServerCapabilities, + transport: Transport) throws + + /// Send a ping to the client + func ping() async throws + + /// Request the client to create a message (LLM sampling) + func requestCreateMessage(_ params: CreateSamplingMessageRequest.Params) async throws -> CreateSamplingMessageRequest.Result + + /// Request the list of roots from the client + func listRoots() async throws -> ListRootsResult + + /// Send a progress notification to the client + func notifyProgress(_ params: ProgressNotification.Params) async throws + + /// Send a resource updated notification to the client + func notifyResourceUpdated(_ params: ResourceUpdatedNotification.Params) async throws + + /// Send a resource list changed notification to the client + func notifyResourceListChanged(_ params: ResourceListChangedNotification.Params?) async throws + + /// Send a tool list changed notification to the client + func notifyToolListChanged(_ params: ToolListChangedNotification.Params?) async throws + + /// Send a prompt list changed notification to the client + func notifyPromptListChanged(_ params: PromptListChangedNotification.Params?) async throws + + /// Send a logging message to the client + func log(_ params: LoggingMessageNotification.Params) async throws + + /// Send a cancellation notification to the client + func notifyCancelled(_ params: CancelledNotification.Params) async throws +} diff --git a/MCPServer/Sources/MCPServerInterface.swift b/MCPServer/Sources/MCPServerInterface.swift new file mode 100644 index 0000000..f9860e1 --- /dev/null +++ b/MCPServer/Sources/MCPServerInterface.swift @@ -0,0 +1,112 @@ +import Foundation +import JSONRPC +import MCPInterface +import MemberwiseInit + +// MARK: - MCPServerInterface + +public protocol MCPServerInterface { + var clientInfo: ClientInfo { get async } + + /// The client's roots. This will be update if the client changes them. + var roots: ReadOnlyCurrentValueSubject, Never> { get async } + /// A method that completes once the client has disconnected. + func waitForDisconnection() async throws + /// Ask the client to sample an LLM for the given parameters. + func getSampling(params: CreateSamplingMessageRequest.Params) async throws -> CreateSamplingMessageRequest.Result + /// Ask the client to log an event. + func log(params: LoggingMessageNotification.Params) async throws + /// Update the list of available tools, and notify the client that the list has changed. + func update(tools: [any CallableTool]) async throws + /// Notify the client that a specific resource has been updated. + func notifyResourceUpdated(params: ResourceUpdatedNotification.Params) async throws + /// Notify the client that the list of available resources has been updated. + func notifyResourceListChanged(params: ResourceListChangedNotification.Params?) async throws + /// Notify the client that the list of available tools has been updated. + func notifyToolListChanged(params: ToolListChangedNotification.Params?) async throws + /// Notify the client that the list of available prompts has been updated. + func notifyPromptListChanged(params: PromptListChangedNotification.Params?) async throws +} + +// MARK: - ServerCapabilityHandlers + +/// Capabilities that the server supports. +/// Each supported capability provides the handlers required to respond to the relevant requests from the client. +/// +/// Note: This is similar to `ServerCapabilities`, with the addition of the handler function. +@MemberwiseInit(.public, _optionalsDefaultNil: true) +public struct ServerCapabilityHandlers { + /// Present if the server supports sending log messages to the client. + public let logging: SetLevelRequest.Handler? + /// Present if the server offers any prompt templates. + public let prompts: ListedCapabilityHandler? + /// Present if the server offers any tools to call. + public let tools: ListedCapabilityHandler? + /// Present if the server offers any resources to read. + public let resources: ResourcesCapabilityHandler? +} + +// MARK: - ListedCapabilityHandler + +/// A capability that has a list of options (ex: prompts, tools, resources) +@MemberwiseInit(.public, _optionalsDefaultNil: true) +public struct ListedCapabilityHandler { + public let info: Info + public let handler: Handler + public let listHandler: ListHandler +} + +// MARK: - ResourcesCapabilityHandler + +/// All the handler functions required to support the `resources` capability. +public struct ResourcesCapabilityHandler { + + // MARK: Lifecycle + + public init( + listChanged: Bool = false, + readResource: @escaping ReadResourceRequest.Handler, + listResource: @escaping ListResourcesRequest.Handler, + listResourceTemplates: @escaping ListResourceTemplatesRequest.Handler, + subscribeToResource: SubscribeRequest.Handler? = nil, + unsubscribeToResource: UnsubscribeRequest.Handler? = nil, + complete: CompleteRequest.Handler? = nil) + { + self.listChanged = listChanged + self.readResource = readResource + self.listResource = listResource + self.listResourceTemplates = listResourceTemplates + self.subscribeToResource = subscribeToResource + self.unsubscribeToResource = unsubscribeToResource + self.complete = complete + } + + // MARK: Public + + /// Whether this server supports notifications for changes to the resource list. + public let listChanged: Bool + public let readResource: ReadResourceRequest.Handler + public let listResource: ListResourcesRequest.Handler + public let listResourceTemplates: ListResourceTemplatesRequest.Handler + public let subscribeToResource: SubscribeRequest.Handler? + public let unsubscribeToResource: UnsubscribeRequest.Handler? + public let complete: CompleteRequest.Handler? + +} + +public typealias InitializeRequestHook = (InitializeRequest.Params) async throws -> Void + +// MARK: - ClientInfo + +/// Information about the client the server is connected to. +public struct ClientInfo { + public let info: Implementation + public let capabilities: ClientCapabilities +} + +// MARK: - MCPServerError + +public enum MCPServerError: Error { + /// An error that occurred while calling a tool. + case toolCallError(_ errors: [Error]) +} diff --git a/MCPShared/Sources/Interfaces.swift b/MCPShared/Sources/Interfaces.swift deleted file mode 100644 index 166e0c8..0000000 --- a/MCPShared/Sources/Interfaces.swift +++ /dev/null @@ -1,19 +0,0 @@ -import JSONRPC -import MemberwiseInit - -public typealias Transport = DataChannel - -// MARK: - CapabilityHandler - -/// Describes a capability of a client/server (see `ClientCapabilities` and `ServerCapabilities`), as well as how it is handled. -@MemberwiseInit(.public, _optionalsDefaultNil: true) -public struct CapabilityHandler { - public let info: Info - public let handler: Handler -} - -extension CapabilityHandler where Info == EmptyObject { - public init(handler: Handler) { - self.init(info: .init(), handler: handler) - } -} diff --git a/MCPShared/Sources/mcp_interfaces/JSON.swift b/MCPShared/Sources/mcp_interfaces/JSON.swift deleted file mode 100644 index 09411a3..0000000 --- a/MCPShared/Sources/mcp_interfaces/JSON.swift +++ /dev/null @@ -1,20 +0,0 @@ - -// MARK: - JSON - -public enum JSON: Codable, Equatable, Sendable { - case object(_ value: [String: JSON.Value]) - case array(_ value: [JSON.Value]) - - // MARK: - JSONValue - - // TODO: look at instead aliasing JSONRPC.JSONValue which seems to have an error in its encoding for arrays/objects - public enum Value: Codable, Equatable, Sendable { - case string(_ value: String) - case object(_ value: [String: JSON.Value]) - case array(_ value: [JSON.Value]) - case bool(_ value: Bool) - case number(_ value: Double) - case null - } - -} diff --git a/MCPShared/Tests/interface/AnyParamsWithProgressTokenTests.swift b/MCPShared/Tests/interface/AnyParamsWithProgressTokenTests.swift deleted file mode 100644 index 03ade44..0000000 --- a/MCPShared/Tests/interface/AnyParamsWithProgressTokenTests.swift +++ /dev/null @@ -1,75 +0,0 @@ - -import Foundation -import MCPShared -import Testing - -extension MCPInterfaceTests { - enum AnyParamsWithProgressTokenTest { - struct Serialization { - - // MARK: Internal - - @Test - func encodeWithNoValues() throws { - try testEncoding(of: AnyParamsWithProgressToken(), """ - {} - """) - } - - @Test - func encodeWithStringProgressToken() throws { - try testEncoding(of: AnyParamsWithProgressToken(_meta: .init(progressToken: .string("123abc"))), """ - { - "_meta" : { - "progressToken" : "123abc" - } - } - """) - } - - @Test - func encodeWithNumberProgressToken() throws { - try testEncoding(of: AnyParamsWithProgressToken(_meta: .init(progressToken: .number(123456))), """ - { - "_meta" : { - "progressToken" : 123456 - } - } - """) - } - - @Test - func encodeWithNonMetaValues() throws { - try testEncoding(of: AnyParamsWithProgressToken(value: .object([ - "key": .string("value"), - ])), """ - { - "key" : "value" - } - """) - } - - @Test - func encodeWithMetaAndOtherParameters() throws { - try testEncoding(of: AnyParamsWithProgressToken( - _meta: .init(progressToken: .string("123abc")), - value: .object([ - "key": .string("value"), - ])), """ - { - "_meta" : { - "progressToken" : "123abc" - }, - "key" : "value" - } - """) - } - - // MARK: Private - - private func testEncoding(of value: AnyParamsWithProgressToken, _ json: String) throws { - try testEncodingOf(value, json) - } - } - } -} diff --git a/MCPShared/Tests/interface/PromptOrResourceReferenceTests.swift b/MCPShared/Tests/interface/PromptOrResourceReferenceTests.swift deleted file mode 100644 index 7177f3b..0000000 --- a/MCPShared/Tests/interface/PromptOrResourceReferenceTests.swift +++ /dev/null @@ -1,38 +0,0 @@ - -import Foundation -import MCPShared -import Testing -extension MCPInterfaceTests { - enum PromptOrResourceReferenceTest { - struct Serialization { - - // MARK: Internal - - @Test - func encodePrompt() throws { - try testEncoding(of: .prompt(.init(name: "code_review")), """ - { - "name" : "code_review", - "type" : "ref/prompt" - } - """) - } - - @Test - func encodeResource() throws { - try testEncoding(of: .resource(.init(uri: "file:///foo_path")), """ - { - "type" : "ref/resource", - "uri" : "file:///foo_path" - } - """) - } - - // MARK: Private - - private func testEncoding(of value: PromptOrResourceReference, _ json: String) throws { - try testEncodingOf(value, json) - } - } - } -} diff --git a/MCPShared/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift b/MCPShared/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift deleted file mode 100644 index f83273d..0000000 --- a/MCPShared/Tests/interface/TextContentOrImageContentOrEmbeddedResourceTests.swift +++ /dev/null @@ -1,62 +0,0 @@ - -import Foundation -import MCPShared -import Testing - -extension MCPInterfaceTests { - enum TextContentOrImageContentOrEmbeddedResourceTest { - - struct Deserialization { - - // MARK: Internal - - @Test - func decodeText() throws { - let value = try decode(""" - { - "type": "text", - "text": "Tool result text" - } - """) - - #expect(value.text?.text == "Tool result text") - } - - @Test - func decodeImage() throws { - let value = try decode(""" - { - "type": "image", - "data": "base64-encoded-data", - "mimeType": "image/png" - } - """) - - #expect(value.image?.data == "base64-encoded-data") - } - - @Test - func decodeResource() throws { - let value = try decode(""" - { - "type": "resource", - "resource": { - "uri": "resource://example", - "mimeType": "text/plain", - "text": "Resource content" - } - } - """) - - #expect(value.embeddedResource?.resource.text?.uri == "resource://example") - } - - // MARK: Private - - private func decode(_ jsonString: String) throws -> TextContentOrImageContentOrEmbeddedResource { - let data = jsonString.data(using: .utf8)! - return try JSONDecoder().decode(TextContentOrImageContentOrEmbeddedResource.self, from: data) - } - } - } -} diff --git a/MCPShared/Tests/interface/TextOrBlobResourceContentsTests.swift b/MCPShared/Tests/interface/TextOrBlobResourceContentsTests.swift deleted file mode 100644 index 1e68a44..0000000 --- a/MCPShared/Tests/interface/TextOrBlobResourceContentsTests.swift +++ /dev/null @@ -1,46 +0,0 @@ - -import Foundation -import MCPShared -import Testing - -extension MCPInterfaceTests { - enum TextOrBlobResourceContentsTest { - struct Deserialization { - - // MARK: Internal - - @Test - func decodeText() throws { - let value = try decode(""" - { - "uri": "file:///example.txt", - "mimeType": "text/plain", - "text": "Resource content" - } - """) - - #expect(value.text?.text == "Resource content") - } - - @Test - func decodeImage() throws { - let value = try decode(""" - { - "uri": "file:///example.png", - "mimeType": "image/png", - "blob": "base64-encoded-data" - } - """) - - #expect(value.blob?.blob == "base64-encoded-data") - } - - // MARK: Private - - private func decode(_ jsonString: String) throws -> TextOrBlobResourceContents { - let data = jsonString.data(using: .utf8)! - return try JSONDecoder().decode(TextOrBlobResourceContents.self, from: data) - } - } - } -} diff --git a/MCPSharedTesting/Tests/CallToolTests.swift b/MCPSharedTesting/Tests/CallToolTests.swift new file mode 100644 index 0000000..5c5ee08 --- /dev/null +++ b/MCPSharedTesting/Tests/CallToolTests.swift @@ -0,0 +1,188 @@ + +import JSONRPC +import MCPInterface +import Testing + +// MARK: - MCPConnectionTestSuite.CallToolTests + +extension MCPConnectionTestSuite { + final class CallToolTests: MCPConnectionTest { + + // MARK: Internal + + @Test("call tool") + func test_callTool() async throws { + let weathers = try await assert(executing: { + try await self.clientConnection.call( + toolName: self.tool.name, + arguments: .object([ + "location": .string("New York"), + ]), + progressToken: .string("toolCallId")) + .content + .map { $0.text } + }, triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "_meta" : { + "progressToken" : "toolCallId" + }, + "name": "get_weather", + "arguments": { + "location": "New York" + } + } + } + """), + .serverResponding { request in + guard case .callTool(let callToolRequest) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(callToolRequest.name == self.tool.name) + + return .success(CallToolResult( + content: [ + .text(.init( + text: "Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy")), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "content": [{ + "type": "text", + "text": "Current weather in New York:\\nTemperature: 72°F\\nConditions: Partly cloudy" + }] + } + } + """), + ]) + + #expect(weathers.map { $0?.text } == ["Current weather in New York:\nTemperature: 72°F\nConditions: Partly cloudy"]) + } + + @Test("protocol error") + func test_protocolError() async throws { + await assert( + executing: { + _ = try await self.clientConnection.call(toolName: self.tool.name, arguments: .object([ + "location": .string("New York"), + ])) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get_weather", + "arguments": { + "location": "New York" + } + } + } + """), + .serverResponding { _ in + .failure(.init(code: -32602, message: "Unknown tool: invalid_tool_name")) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32602, + "message": "Unknown tool: invalid_tool_name" + } + } + """), + ]) { error in + guard let error = error as? JSONRPCResponseError else { + Issue.record("Unexpected error type: \(error)") + return + } + + #expect(error.code == -32602) + #expect(error.message == "Unknown tool: invalid_tool_name") + #expect(error.data == nil) + } + } + + @Test("tool call error") + func test_toolCallError() async throws { + let response = try await assert( + executing: { + try await self.clientConnection.call( + toolName: self.tool.name, + arguments: .object([ + "location": .string("New York"), + ])) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "get_weather", + "arguments": { + "location": "New York" + } + } + } + """), + .serverResponding { _ in + .success(CallToolResult( + content: [ + .text(.init( + text: "Failed to fetch weather data: API rate limit exceeded")), + ], + isError: true)) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "content": [{ + "type": "text", + "text": "Failed to fetch weather data: API rate limit exceeded" + }], + "isError": true + } + } + """), + ]) + #expect(response.isError == true) + #expect(response.content.map { $0.text?.text } == ["Failed to fetch weather data: API rate limit exceeded"]) + } + + // MARK: Private + + private let tool = Tool( + name: "get_weather", + description: "Get current weather information for a location", + inputSchema: .array([])) + + } +} + +// MARK: - ToolArguments + +private struct ToolArguments: Encodable { + let location: String +} + +// MARK: - ToolResponse + +private struct ToolResponse: Decodable { + let type: String + let text: String +} diff --git a/MCPSharedTesting/Tests/CompletionTests.swift b/MCPSharedTesting/Tests/CompletionTests.swift new file mode 100644 index 0000000..8113df9 --- /dev/null +++ b/MCPSharedTesting/Tests/CompletionTests.swift @@ -0,0 +1,60 @@ + +import JSONRPC +import MCPInterface +import Testing + +extension MCPConnectionTestSuite { + final class CompletionTests: MCPConnectionTest { + @Test("request completion") + func test_requestCompletion() async throws { + let resources = try await assert( + executing: { + try await self.clientConnection.requestCompletion(CompleteRequest.Params( + ref: .prompt(PromptReference(name: "code_review")), + argument: .init(name: "language", value: "py"))) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "completion/complete", + "params": { + "ref": { + "type": "ref/prompt", + "name": "code_review" + }, + "argument": { + "name": "language", + "value": "py" + } + } + } + """), + .serverResponding { request in + guard case .complete(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.ref.prompt?.name == "code_review") + + return .success(CompleteResult(completion: .init(values: ["python", "pytorch", "pyside"], total: 10, hasMore: true))) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "completion": { + "values": ["python", "pytorch", "pyside"], + "total": 10, + "hasMore": true + } + } + } + """), + ]) + #expect(resources.completion.values == ["python", "pytorch", "pyside"]) + #expect(resources.completion.hasMore == true) + } + } +} diff --git a/MCPSharedTesting/Tests/CreateSamplingMessageTests.swift b/MCPSharedTesting/Tests/CreateSamplingMessageTests.swift new file mode 100644 index 0000000..06a4375 --- /dev/null +++ b/MCPSharedTesting/Tests/CreateSamplingMessageTests.swift @@ -0,0 +1,93 @@ + +import JSONRPC +import MCPInterface +import Testing + +extension MCPConnectionTestSuite { + final class CreateSamplingMessageTests: MCPConnectionTest { + + @Test("create sampling message") + func test_createSamplingMessage() async throws { + let sampledMessage = try await assert( + executing: { + try await self.serverConnection.requestCreateMessage( + .init( + messages: [ + .init( + role: .user, + content: .text(.init(text: "What is the capital of France?"))), + ], + modelPreferences: .init( + hints: [ + .init(name: "claude-3-sonnet"), + ], + speedPriority: 0.5, + intelligencePriority: 0.8), + systemPrompt: "You are a helpful assistant.", + maxTokens: 100)) + + }, + triggers: [ + .serverSendsJrpc( + """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "sampling/createMessage", + "params": { + "messages": [ + { + "role": "user", + "content": { + "type": "text", + "text": "What is the capital of France?" + } + } + ], + "modelPreferences": { + "hints": [ + { + "name": "claude-3-sonnet" + } + ], + "intelligencePriority": 0.8, + "speedPriority": 0.5 + }, + "systemPrompt": "You are a helpful assistant.", + "maxTokens": 100 + } + } + """), + .clientResponding { request in + guard case .createMessage(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.modelPreferences?.hints?.first?.name == "claude-3-sonnet") + + return .success(CreateMessageResult( + role: .assistant, + content: .text(.init(text: "The capital of France is Paris.")), + model: "claude-3-sonnet-20240307", + stopReason: "endTurn")) + }, + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "role": "assistant", + "content": { + "type": "text", + "text": "The capital of France is Paris." + }, + "model": "claude-3-sonnet-20240307", + "stopReason": "endTurn" + } + } + """), + ]) + #expect(sampledMessage.content.text?.text == "The capital of France is Paris.") + } + + } +} diff --git a/MCPSharedTesting/Tests/GetPromptTests.swift b/MCPSharedTesting/Tests/GetPromptTests.swift new file mode 100644 index 0000000..5cb747f --- /dev/null +++ b/MCPSharedTesting/Tests/GetPromptTests.swift @@ -0,0 +1,204 @@ + +import JSONRPC +import MCPInterface +import Testing + +extension MCPConnectionTestSuite { + final class GetPromptTests: MCPConnectionTest { + + @Test("get one prompt") + func test_getOnePrompt() async throws { + let prompts = try await assert( + executing: { + try await self.clientConnection.getPrompt(.init(name: "code_review", arguments: .object([ + "code": .string("def hello():\n print('world')"), + ]))) + }, + triggers: [ + .clientSendsJrpc( + """ + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": { + "name": "code_review", + "arguments": { + "code": "def hello():\\n print('world')" + } + } + } + """), + .serverResponding { request in + guard case .getPrompt(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.name == "code_review") + + return .success(GetPromptResult( + description: "Code review prompt", + messages: [ + .init( + role: .user, + content: .text(.init(text: "Please review this Python code:\ndef hello():\n print('world')"))), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "description": "Code review prompt", + "messages": [ + { + "role": "user", + "content": { + "type": "text", + "text": "Please review this Python code:\\ndef hello():\\n print('world')" + } + } + ] + } + } + """), + ]) + #expect( + prompts.messages + .map { $0.content.text?.text } == ["Please review this Python code:\ndef hello():\n print('world')"]) + } + + @Test("get prompts of different types") + func test_getPromptsOfDifferentTypes() async throws { + let prompts = try await assert( + executing: { + try await self.clientConnection.getPrompt(.init(name: "code_review", arguments: .object([ + "code": .string("def hello():\n print('world')"), + ]))) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": { + "name": "code_review", + "arguments": { + "code": "def hello():\\n print('world')" + } + } + } + """), + .serverResponding { _ in + .success(GetPromptResult( + description: "Code review prompt", + messages: [ + .init( + role: .user, + content: .text(.init(text: "Please review this Python code:\ndef hello():\n print('world')"))), + .init( + role: .user, + content: .image(.init(data: "base64-encoded-image-data", mimeType: "image/png"))), + .init( + role: .user, + content: .embeddedResource(.init(resource: .text(.init( + uri: "resource://example", + mimeType: "text/plain", + text: "Resource content"))))), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "description": "Code review prompt", + "messages": [ + { + "role": "user", + "content": { + "type": "text", + "text": "Please review this Python code:\\ndef hello():\\n print('world')" + } + }, + { + "role": "user", + "content": { + "type": "image", + "data": "base64-encoded-image-data", + "mimeType": "image/png" + } + }, + { + "role": "user", + "content": { + "type": "resource", + "resource": { + "uri": "resource://example", + "mimeType": "text/plain", + "text": "Resource content" + } + } + } + ] + } + } + """), + ]) + #expect( + prompts.messages.map { $0.content.text?.text } == + ["Please review this Python code:\ndef hello():\n print('world')", nil, nil]) + #expect(prompts.messages.map { $0.content.image?.data } == [nil, "base64-encoded-image-data", nil]) + #expect(prompts.messages.map { $0.content.embeddedResource?.resource.text?.text } == [nil, nil, "Resource content"]) + } + + @Test("error when getting prompt") + func test_errorWhenGettingPrompt() async throws { + await assert( + executing: { try await self.clientConnection.getPrompt(.init(name: "non_existent_code_review")) }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/get", + "params": { + "name": "non_existent_code_review" + } + } + """), + .serverResponding { _ in + .failure(.init(code: -32002, message: "Prompt not found", data: [ + "name": "non_existent_code_review", + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32002, + "message": "Prompt not found", + "data": { + "name": "non_existent_code_review" + } + } + } + """), + ], + andFailsWith: { error in + guard let error = error as? JSONRPCResponseError else { + Issue.record("Unexpected error type: \(error)") + return + } + + #expect(error.code == -32002) + #expect(error.message == "Prompt not found") + #expect(error.data == .hash([ + "name": .string("non_existent_code_review"), + ])) + }) + } + + } +} diff --git a/MCPSharedTesting/Tests/InitializationTests.swift b/MCPSharedTesting/Tests/InitializationTests.swift new file mode 100644 index 0000000..1e03d85 --- /dev/null +++ b/MCPSharedTesting/Tests/InitializationTests.swift @@ -0,0 +1,197 @@ + +import JSONRPC +import MCPInterface +import MCPTestingUtils +import SwiftTestingUtils +import Testing +@testable import MCPClient +@testable import MCPServer + +extension MCPConnectionTestSuite { + final class InitializationTests: MCPConnectionTest { + + @Test("initialize connection") + func test_initializeConnection() async throws { + let initializationResult = try await assert( + executing: { + try await self.clientConnection.initialize() + }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "initialize", + "params" : { + "capabilities" : { + "roots" : { + "listChanged" : true + }, + "sampling" : {} + }, + "clientInfo" : { + "name" : "TestClient", + "version" : "1.0.0" + }, + "protocolVersion" : "\(MCP.protocolVersion)" + } + } + """), + .serverResponding { request in + guard case .initialize(let params) = request else { + throw Issue.record("Unexpected client request: \(request)") + } + #expect(params.capabilities.roots?.listChanged == true) + + return .success(InitializeResult( + protocolVersion: "2024-11-05", + capabilities: ServerCapabilities( + logging: .init(), + prompts: .init(listChanged: true), + resources: .init(subscribe: true, listChanged: true), + tools: .init(listChanged: true)), + serverInfo: .init( + name: "ExampleServer", + version: "1.0.0"))) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": { + "logging": {}, + "prompts": { + "listChanged": true + }, + "resources": { + "subscribe": true, + "listChanged": true + }, + "tools": { + "listChanged": true + } + }, + "serverInfo": { + "name": "ExampleServer", + "version": "1.0.0" + } + } + } + """), + ]) + #expect(initializationResult.serverInfo.name == "ExampleServer") + } + + @Test("initialize with error") + func test_initializeWithError() async throws { + await assert( + executing: { _ = try await self.clientConnection.initialize() }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "initialize", + "params" : { + "capabilities" : { + "roots" : { + "listChanged" : true + }, + "sampling" : {} + }, + "clientInfo" : { + "name" : "TestClient", + "version" : "1.0.0" + }, + "protocolVersion" : "\(MCP.protocolVersion)" + } + } + """), + .serverResponding { request in + guard case .initialize(let params) = request else { + throw Issue.record("Unexpected client request: \(request)") + } + #expect(params.capabilities.roots?.listChanged == true) + + return .failure(.init( + code: -32602, + message: "Unsupported protocol version", + data: .hash([ + "supported": ["2024-11-05"], + "requested": "1.0.0", + ]))) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32602, + "message": "Unsupported protocol version", + "data": { + "supported": ["2024-11-05"], + "requested": "1.0.0" + } + } + } + """), + ]) { error in + guard let responseError = try? #require(error as? JSONRPCResponseError) else { + Issue.record("unexpected error type \(error)") + return + } + + #expect(responseError.code == -32602) + #expect(responseError.message == "Unsupported protocol version") + #expect(responseError.data == [ + "supported": ["2024-11-05"], + "requested": "1.0.0", + ]) + } + } + + @Test("initialization acknowledgement") + func test_initializationAcknowledgement() async throws { + try await assert(executing: { + try await self.clientConnection.acknowledgeInitialization() + }, triggers: [ + .clientSendsJrpc( + """ + { + "jsonrpc" : "2.0", + "method" : "notifications/initialized", + "params" : null + } + """), + .serverReceiving { notification in + guard case .initialized(let params) = notification else { + throw Issue.record("Unexpected client notification: \(notification)") + } + #expect(params.value == nil) + }, + ]) + } + + @Test("deinitialization") + func test_deinitializationReleasesReferencedObjects() async throws { + // initialize the MCP connection. This will create a JRPC session. + try await test_initializeConnection() + + // Get pointers to values that we want to see dereferenced when MCPClientConnection is dereferenced + weak var weakTransport = clientTransport + #expect(weakTransport != nil) + + // Replace the values referenced by this test class. + clientTransport = MockTransport() + clientConnection = try await MCPClientConnection( + info: clientConnection.info, + capabilities: clientCapabilities, + transport: clientTransport.dataChannel) + + // Verifies that the referenced objects are released. + #expect(weakTransport == nil) + } + } +} diff --git a/MCPSharedTesting/Tests/ListPromptTests.swift b/MCPSharedTesting/Tests/ListPromptTests.swift new file mode 100644 index 0000000..1b103bd --- /dev/null +++ b/MCPSharedTesting/Tests/ListPromptTests.swift @@ -0,0 +1,192 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class ListPromptsTests: MCPConnectionTest { + + @Test("list prompts") + func test_listPrompts() async throws { + let prompts = try await assert( + executing: { + try await self.clientConnection.listPrompts() + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "prompts/list", + "params": {} + } + """), + .serverResponding { request in + guard case .listPrompts = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListPromptsResult(prompts: [ + .init( + name: "code_review", + description: "Asks the LLM to analyze code quality and suggest improvements", + arguments: [ + .init( + name: "code", + description: "The code to review", + required: true), + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "prompts": [ + { + "name": "code_review", + "description": "Asks the LLM to analyze code quality and suggest improvements", + "arguments": [ + { + "name": "code", + "description": "The code to review", + "required": true + } + ] + } + ] + } + } + """), + ]) + #expect(prompts.map { $0.name } == ["code_review"]) + } + + @Test("list prompts with pagination") + func test_listPrompts_withPagination() async throws { + let prompts = try await assert( + executing: { try await self.clientConnection.listPrompts() }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "prompts/list", + "params" : {} + } + """), + .serverResponding { request in + guard case .listPrompts = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListPromptsResult( + nextCursor: "next-page-cursor", + prompts: [ + .init( + name: "code_review", + description: "Asks the LLM to analyze code quality and suggest improvements", + arguments: [ + .init( + name: "code", + description: "The code to review", + required: true), + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "prompts": [ + { + "name": "code_review", + "description": "Asks the LLM to analyze code quality and suggest improvements", + "arguments": [ + { + "name": "code", + "description": "The code to review", + "required": true + } + ] + } + ], + "nextCursor": "next-page-cursor" + } + } + """), + .clientSendsJrpc(""" + { + "id" : 2, + "jsonrpc" : "2.0", + "method" : "prompts/list", + "params" : { + "cursor": "next-page-cursor" + } + } + """), + .serverResponding { request in + guard case .listPrompts = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListPromptsResult(prompts: [ + .init( + name: "test_code", + description: "Asks the LLM to write a unit test for the code", + arguments: [ + .init( + name: "code", + description: "The code to test", + required: true), + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 2, + "result": { + "prompts": [ + { + "name": "test_code", + "description": "Asks the LLM to write a unit test for the code", + "arguments": [ + { + "name": "code", + "description": "The code to test", + "required": true + } + ] + } + ] + } + } + """), + ]) + #expect(prompts.map { $0.name } == ["code_review", "test_code"]) + } + + @Test("receiving prompts list changed notification") + func test_receivingPromptsListChangedNotification() async throws { + try await assert( + executing: { + try await self.serverConnection.notifyPromptListChanged() + }, + triggers: [ + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "method": "notifications/prompts/list_changed", + "params": null + } + """), + .clientReceiving { notification in + guard case .promptListChanged = notification else { + throw Issue.record("Unexpected notification: \(notification)") + } + }, + ]) + } + } +} diff --git a/MCPSharedTesting/Tests/ListResourceTemplates.swift b/MCPSharedTesting/Tests/ListResourceTemplates.swift new file mode 100644 index 0000000..6d6b911 --- /dev/null +++ b/MCPSharedTesting/Tests/ListResourceTemplates.swift @@ -0,0 +1,149 @@ + +import JSONRPC +import MCPInterface +import Testing + +extension MCPConnectionTestSuite { + final class ListResourceTemplatesTests: MCPConnectionTest { + + @Test("list resource templates") + func test_listResourceTemplates() async throws { + let resources = try await assert( + executing: { + try await self.clientConnection.listResourceTemplates() + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/templates/list", + "params": {} + } + """), + .serverResponding { request in + guard case .listResourceTemplates = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListResourceTemplatesResult( + resourceTemplates: [ + .init( + uriTemplate: "file:///{path}", + name: "Project Files", + description: "Access files in the project directory", + mimeType: "application/octet-stream"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "resourceTemplates": [ + { + "uriTemplate": "file:///{path}", + "name": "Project Files", + "description": "Access files in the project directory", + "mimeType": "application/octet-stream" + } + ] + } + } + """), + ]) + #expect(resources.map { $0.uriTemplate } == ["file:///{path}"]) + } + + @Test("list resource templates with pagination") + func test_listResourceTemplates_withPagination() async throws { + let resources = try await assert( + executing: { try await self.clientConnection.listResourceTemplates() }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/templates/list", + "params": {} + } + """), + .serverResponding { request in + guard case .listResourceTemplates(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == nil) + + return .success(ListResourceTemplatesResult( + nextCursor: "next-page-cursor", + resourceTemplates: [ + .init( + uriTemplate: "file:///{path}", + name: "Project Files", + description: "Access files in the project directory", + mimeType: "application/octet-stream"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "resourceTemplates": [ + { + "uriTemplate": "file:///{path}", + "name": "Project Files", + "description": "Access files in the project directory", + "mimeType": "application/octet-stream" + } + ], + "nextCursor": "next-page-cursor" + } + } + """), + .clientSendsJrpc(""" + { + "id" : 2, + "jsonrpc" : "2.0", + "method" : "resources/templates/list", + "params" : { + "cursor": "next-page-cursor" + } + } + """), + .serverResponding { request in + guard case .listResourceTemplates(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == "next-page-cursor") + + return .success(ListResourceTemplatesResult( + resourceTemplates: [ + .init( + uriTemplate: "images:///{path}", + name: "Project Images", + description: "Access images in the project directory", + mimeType: "image/jpeg"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 2, + "result": { + "resourceTemplates": [ + { + "uriTemplate": "images:///{path}", + "name": "Project Images", + "description": "Access images in the project directory", + "mimeType": "image/jpeg" + } + ] + } + } + """), + ]) + #expect(resources.map { $0.uriTemplate } == ["file:///{path}", "images:///{path}"]) + } + + } +} diff --git a/MCPSharedTesting/Tests/ListResourcesTest.swift b/MCPSharedTesting/Tests/ListResourcesTest.swift new file mode 100644 index 0000000..31a43d4 --- /dev/null +++ b/MCPSharedTesting/Tests/ListResourcesTest.swift @@ -0,0 +1,169 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class ListResourcesTests: MCPConnectionTest { + + @Test("list resources") + func test_listResources() async throws { + let resources = try await assert( + executing: { + try await self.clientConnection.listResources() + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/list", + "params": {} + } + """), + .serverResponding { request in + guard case .listResources = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListResourcesResult( + resources: [ + Resource( + uri: "file:///project/src/main.rs", + name: "main.rs", + description: "Primary application entry point", + mimeType: "text/x-rust"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "resources": [ + { + "uri": "file:///project/src/main.rs", + "name": "main.rs", + "description": "Primary application entry point", + "mimeType": "text/x-rust" + } + ] + } + } + """), + ]) + #expect(resources.map { $0.name } == ["main.rs"]) + } + + @Test("list resources with pagination") + func test_listResources_withPagination() async throws { + let resources = try await assert( + executing: { try await self.clientConnection.listResources() }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "resources/list", + "params" : {} + } + """), + .serverResponding { request in + guard case .listResources(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == nil) + + return .success(ListResourcesResult( + nextCursor: "next-page-cursor", + resources: [ + Resource( + uri: "file:///project/src/main.rs", + name: "main.rs", + description: "Primary application entry point", + mimeType: "text/x-rust"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "resources": [ + { + "uri": "file:///project/src/main.rs", + "name": "main.rs", + "description": "Primary application entry point", + "mimeType": "text/x-rust" + } + ], + "nextCursor": "next-page-cursor" + } + } + """), + .clientSendsJrpc(""" + { + "id" : 2, + "jsonrpc" : "2.0", + "method" : "resources/list", + "params" : { + "cursor": "next-page-cursor" + } + } + """), + .serverResponding { request in + guard case .listResources(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == "next-page-cursor") + + return .success(ListResourcesResult( + resources: [ + Resource( + uri: "file:///project/src/utils.rs", + name: "utils.rs", + description: "Some utils functions application entry point", + mimeType: "text/x-rust"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 2, + "result": { + "resources": [ + { + "uri": "file:///project/src/utils.rs", + "name": "utils.rs", + "description": "Some utils functions application entry point", + "mimeType": "text/x-rust" + } + ] + } + } + """), + ]) + #expect(resources.map { $0.name } == ["main.rs", "utils.rs"]) + } + + @Test("receiving resources list changed notification") + func test_receivingResourcesListChangedNotification() async throws { + try await assert(executing: { + try await self.serverConnection.notifyResourceListChanged() + }, triggers: [ + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "method": "notifications/resources/list_changed", + "params" : null + } + """), + .clientReceiving { notification in + guard case .resourceListChanged = notification else { + throw Issue.record("Unexpected notification: \(notification)") + } + }, + ]) + } + } +} diff --git a/MCPSharedTesting/Tests/ListRootsTests.swift b/MCPSharedTesting/Tests/ListRootsTests.swift new file mode 100644 index 0000000..4c70b0c --- /dev/null +++ b/MCPSharedTesting/Tests/ListRootsTests.swift @@ -0,0 +1,71 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class ListRootsTests: MCPConnectionTest { + + @Test("list roots") + func test_listRoots() async throws { + let roots = try await assert( + executing: { + try await self.serverConnection.listRoots() + }, + triggers: [ + .serverSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "roots/list", + "params" : null + } + """), + .clientResponding { request in + guard case .listRoots = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListRootsResult(roots: [ + .init(uri: "file:///home/user/projects/myproject", name: "My Project"), + ])) + }, + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "roots": [ + { + "uri": "file:///home/user/projects/myproject", + "name": "My Project" + } + ] + } + } + """), + ]) + #expect(roots.roots.map { $0.name } == ["My Project"]) + } + + @Test("receiving roots list changed notification") + func test_receivingRootsListChangedNotification() async throws { + try await assert(executing: { + try await self.clientConnection.notifyRootsListChanged() + }, triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "method": "notifications/roots/list_changed", + "params": null + } + """), + .serverReceiving { notification in + guard case .rootsListChanged = notification else { + throw Issue.record("Unexpected notification: \(notification)") + } + }, + ]) + } + } +} diff --git a/MCPSharedTesting/Tests/ListToolsTests.swift b/MCPSharedTesting/Tests/ListToolsTests.swift new file mode 100644 index 0000000..e756eb6 --- /dev/null +++ b/MCPSharedTesting/Tests/ListToolsTests.swift @@ -0,0 +1,217 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class ListToolsTests: MCPConnectionTest { + + @Test("list tools") + func test_listTools() async throws { + let tools = try await assert( + executing: { + try await self.clientConnection.listTools() + }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "tools/list", + "params" : {} + } + """), + .serverResponding { request in + guard case .listTools = request else { + throw Issue.record("Unexpected request: \(request)") + } + return .success(ListToolsResult(tools: [ + .init( + name: "get_weather", + description: "Get current weather information for a location", + inputSchema: [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "City name or zip code", + ], + ], + "required": ["location"], + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + } + ] + } + } + """), + ]) + #expect(tools.map { $0.name } == ["get_weather"]) + } + + @Test("list tools with pagination") + func test_listTools_withPagination() async throws { + let tools = try await assert( + executing: { try await self.clientConnection.listTools() }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "tools/list", + "params" : {} + } + """), + .serverResponding { request in + guard case .listTools(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == nil) + + return .success(ListToolsResult(nextCursor: "next-page-cursor", tools: [ + .init( + name: "get_weather", + description: "Get current weather information for a location", + inputSchema: [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "City name or zip code", + ], + ], + "required": ["location"], + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + } + ], + "nextCursor": "next-page-cursor" + } + } + """), + .clientSendsJrpc(""" + { + "id" : 2, + "jsonrpc" : "2.0", + "method" : "tools/list", + "params" : { + "cursor": "next-page-cursor" + } + } + """), + .serverResponding { request in + guard case .listTools(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.cursor == "next-page-cursor") + + return .success(ListToolsResult(tools: [ + .init( + name: "get_time", + description: "Get current time information for a location", + inputSchema: [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "City name or zip code", + ], + ], + "required": ["location"], + ]), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 2, + "result": { + "tools": [ + { + "name": "get_time", + "description": "Get current time information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + } + ] + } + } + """), + ]) + + #expect(tools.map { $0.name } == ["get_weather", "get_time"]) + } + + @Test("receiving list tool changed notification") + func test_receivingListToolChangedNotification() async throws { + let notificationReceived = expectation(description: "Notification received") + Task { + for await notification in await clientConnection.notifications { + switch notification { + case .toolListChanged: + notificationReceived.fulfill() + default: + Issue.record("Unexpected notification: \(notification)") + } + } + } + + clientTransport.receive(message: """ + { + "jsonrpc": "2.0", + "method": "notifications/tools/list_changed" + } + """) + try await fulfillment(of: [notificationReceived]) + } + } +} diff --git a/MCPSharedTesting/Tests/LoggingTests.swift b/MCPSharedTesting/Tests/LoggingTests.swift new file mode 100644 index 0000000..b4d6694 --- /dev/null +++ b/MCPSharedTesting/Tests/LoggingTests.swift @@ -0,0 +1,115 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class LoggingTests: MCPConnectionTest { + @Test("setting log level") + func test_settingLogLevel() async throws { + _ = try await assert( + executing: { + try await self.clientConnection.setLogLevel(SetLevelRequest.Params(level: .debug)) + }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "logging/setLevel", + "params" : { + "level" : "debug" + } + } + """), + .serverResponding { request in + guard case .setLogLevel(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.level == .debug) + + return .success(SetLevelRequest.Result()) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": {} + } + """), + ]) + } + + @Test("send a log") + func test_sendALog() async throws { + try await assert( + executing: { + try await self.serverConnection.log(LoggingMessageNotification.Params(level: .debug, data: .string("Up and running!"))) + }, + triggers: [ + .serverSendsJrpc(""" + { + "jsonrpc" : "2.0", + "method" : "notifications/message", + "params" : { + "data" : "Up and running!", + "level" : "debug" + } + } + """), + .clientReceiving { notification in + guard case .loggingMessage(let message) = notification else { + throw Issue.record("Unexpected notification: \(notification)") + } + #expect(message.level == .debug) + #expect(message.data == "Up and running!") + }, + ]) + } + + @Test("receive a server log notification") + func test_receivesServerLogNotification() async throws { +// let notificationReceived = expectation(description: "Notification received") +// Task { +// for await notification in await sut.notifications { +// switch notification { +// case .loggingMessage(let message): +// #expect(message.level == .error) +// #expect(message.data == .object([ +// "error": .string("Connection failed"), +// "details": .object([ +// "host": .string("localhost"), +// "port": .number(5432), +// ]), +// ])) +// notificationReceived.fulfill() +// +// default: +// Issue.record("Unexpected notification: \(notification)") +// } +// } +// } +// +// transport.receive(message: """ +// { +// "jsonrpc": "2.0", +// "method": "notifications/message", +// "params": { +// "level": "error", +// "logger": "database", +// "data": { +// "error": "Connection failed", +// "details": { +// "host": "localhost", +// "port": 5432 +// } +// } +// } +// } +// """) +// try await fulfillment(of: [notificationReceived]) + } + + } +} diff --git a/MCPSharedTesting/Tests/PingTests.swift b/MCPSharedTesting/Tests/PingTests.swift new file mode 100644 index 0000000..6ef1577 --- /dev/null +++ b/MCPSharedTesting/Tests/PingTests.swift @@ -0,0 +1,58 @@ + +import JSONRPC +import Testing + +extension MCPConnectionTestSuite { + final class PingTests: MCPConnectionTest { + + @Test("client sending ping") + func clientSendingPing() async throws { + try await assert( + executing: { + try await self.clientConnection.ping() + }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "ping", + "params" : null + } + """), + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": {} + } + """), + ]) + } + + @Test("server sending ping") + func serverSendingPing() async throws { + try await assert( + executing: { + try await self.serverConnection.ping() + }, + triggers: [ + .serverSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "ping", + "params" : null + } + """), + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": {} + } + """), + ]) + } + } +} diff --git a/MCPSharedTesting/Tests/ReadResourceTests.swift b/MCPSharedTesting/Tests/ReadResourceTests.swift new file mode 100644 index 0000000..1581040 --- /dev/null +++ b/MCPSharedTesting/Tests/ReadResourceTests.swift @@ -0,0 +1,253 @@ + +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +extension MCPConnectionTestSuite { + final class ReadResourceTests: MCPConnectionTest { + + @Test("read one resource") + func test_readOneResource() async throws { + let resources = try await assert( + executing: { + try await self.clientConnection.readResource(.init(uri: "file:///project/src/main.rs")) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { + "uri": "file:///project/src/main.rs" + } + } + """), + .serverResponding { request in + guard case .readResource(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.uri == "file:///project/src/main.rs") + + return .success(ReadResourceResult( + contents: [ + .text(.init( + uri: "file:///project/src/main.rs", + mimeType: "text/x-rust", + text: "fn main() {\n println!(\"Hello world!\");\n}")), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "contents": [ + { + "uri": "file:///project/src/main.rs", + "mimeType": "text/x-rust", + "text": "fn main() {\\n println!(\\"Hello world!\\");\\n}" + } + ] + } + } + """), + ]) + #expect(resources.contents.map { $0.text?.text } == ["fn main() {\n println!(\"Hello world!\");\n}"]) + } + + @Test("read resources of different types") + func test_readResourcesOfDifferentTypes() async throws { + let resources = try await assert( + executing: { + try await self.clientConnection.readResource(.init(uri: "file:///project/src/*")) + }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { + "uri": "file:///project/src/*" + } + } + """), + .serverResponding { _ in + .success(ReadResourceResult( + contents: [ + .text(.init( + uri: "file:///project/src/main.rs", + mimeType: "text/x-rust", + text: "fn main() {\n println!(\"Hello world!\");\n}")), + .blob(.init( + uri: "file:///project/src/main.rs", + mimeType: "image/jpeg", + blob: "base64-encoded-image-data")), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": { + "contents": [ + { + "uri": "file:///project/src/main.rs", + "mimeType": "text/x-rust", + "text": "fn main() {\\n println!(\\"Hello world!\\");\\n}" + }, + { + "uri": "file:///project/src/main.rs", + "mimeType": "image/jpeg", + "blob": "base64-encoded-image-data" + } + ] + } + } + """), + ]) + #expect(resources.contents.map { $0.text?.text } == ["fn main() {\n println!(\"Hello world!\");\n}", nil]) + #expect(resources.contents.map { $0.blob?.mimeType } == [nil, "image/jpeg"]) + } + + @Test("error when reading resource") + func test_errorWhenReadingResource() async throws { + await assert( + executing: { try await self.clientConnection.readResource(.init(uri: "file:///nonexistent.txt")) }, + triggers: [ + .clientSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "method": "resources/read", + "params": { + "uri": "file:///nonexistent.txt" + } + } + """), + .serverResponding { _ in + .failure(.init(code: -32002, message: "Resource not found", data: [ + "uri": .string("file:///nonexistent.txt"), + ])) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32002, + "message": "Resource not found", + "data": { + "uri": "file:///nonexistent.txt" + } + } + } + """), + ], + andFailsWith: { error in + guard let error = error as? JSONRPCResponseError else { + Issue.record("Unexpected error type: \(error)") + return + } + + #expect(error.code == -32002) + #expect(error.message == "Resource not found") + #expect(error.data == .hash([ + "uri": .string("file:///nonexistent.txt"), + ])) + }) + } + + @Test("subscribing to resource updates") + func test_subscribingToResourceUpdates() async throws { + try await assert( + executing: { try await self.clientConnection.subscribeToUpdateToResource(.init(uri: "file:///project/src/main.rs")) }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "resources/subscribe", + "params" : { + "uri" : "file:///project/src/main.rs" + } + } + """), + .serverResponding { request in + guard case .subscribeToResource(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.uri == "file:///project/src/main.rs") + + return .success(EmptyObject()) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": {} + } + """), + ]) + } + + @Test("unsubscribing to resource updates") + func test_unsubscribingToResourceUpdates() async throws { + try await assert( + executing: { try await self.clientConnection.unsubscribeToUpdateToResource(.init(uri: "file:///project/src/main.rs")) }, + triggers: [ + .clientSendsJrpc(""" + { + "id" : 1, + "jsonrpc" : "2.0", + "method" : "resources/unsubscribe", + "params" : { + "uri" : "file:///project/src/main.rs" + } + } + """), + .serverResponding { request in + guard case .unsubscribeToResource(let params) = request else { + throw Issue.record("Unexpected request: \(request)") + } + #expect(params.uri == "file:///project/src/main.rs") + + return .success(EmptyObject()) + }, + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "id": 1, + "result": {} + } + """), + ]) + } + + @Test("receiving resource update notification") + func test_receivingResourceUpdateNotification() async throws { + try await assert(executing: { + try await self.serverConnection.notifyResourceUpdated(.init(uri: "file:///project/src/main.rs")) + }, triggers: [ + .serverSendsJrpc(""" + { + "jsonrpc": "2.0", + "method": "notifications/resources/updated", + "params": { + "uri": "file:///project/src/main.rs" + } + } + """), + .clientReceiving { notification in + guard case .resourceUpdated(let params) = notification else { + throw Issue.record("Unexpected notification: \(notification)") + } + #expect(params.uri == "file:///project/src/main.rs") + }, + ]) + } + + } +} diff --git a/MCPSharedTesting/Tests/TestSuite.swift b/MCPSharedTesting/Tests/TestSuite.swift new file mode 100644 index 0000000..cd40717 --- /dev/null +++ b/MCPSharedTesting/Tests/TestSuite.swift @@ -0,0 +1,96 @@ +import MCPInterface +import MCPTestingUtils +import SwiftTestingUtils +import Testing +@testable import MCPClient +@testable import MCPServer + +// MARK: - MCPConnectionTestSuite + +/// All the tests about `MCPClientConnection` +@Suite("MCP Connection") +class MCPConnectionTestSuite { } + +// MARK: - MCPConnectionTest + +/// A parent test class that provides a few util functions to assert that the interactions with the transport are as expected. +class MCPConnectionTest { + + // MARK: Lifecycle + + init() { + clientTransport = MockTransport() + serverTransport = MockTransport() + + clientCapabilities = ClientCapabilities(roots: .init(listChanged: true), sampling: .init()) + serverCapabilities = ServerCapabilities( + logging: .init(), + prompts: .init(listChanged: true), + resources: .init(subscribe: true, listChanged: true), + tools: .init(listChanged: true)) + + clientConnection = try! MCPClientConnection( + info: .init(name: "TestClient", version: "1.0.0"), + capabilities: clientCapabilities, + transport: clientTransport.dataChannel) + + serverConnection = try! MCPServerConnection( + info: .init(name: "TestServer", version: "1.0.0"), + capabilities: serverCapabilities, + transport: serverTransport.dataChannel) + + clientTransport.onSendMessage { [weak self] message in + self?.serverTransport.receive(message: String(data: message, encoding: .utf8)!) + } + serverTransport.onSendMessage { [weak self] message in + self?.clientTransport.receive(message: String(data: message, encoding: .utf8)!) + } + } + + // MARK: Internal + + var clientTransport: MockTransport + var serverTransport: MockTransport + + let clientCapabilities: ClientCapabilities + let serverCapabilities: ServerCapabilities + + var clientConnection: MCPClientConnection + var serverConnection: MCPServerConnection +} + +// MARK: MCPConnectionsProvider + +extension MCPConnectionTest { + func assert( + executing task: @escaping () async throws -> Result, + triggers events: [Event]) + async throws -> Result + { + try await MCPTestingUtils.assert( + clientTransport: clientTransport, + serverTransport: serverTransport, + serverRequestsHandler: await clientConnection.requestsToHandle, + clientRequestsHandler: await serverConnection.requestsToHandle, + serverNotifications: await clientConnection.notifications, + clientNotifications: await serverConnection.notifications, + executing: task, + triggers: events) + } + + func assert( + executing task: @escaping () async throws -> Result, + triggers events: [Event], + andFailsWith errorHandler: (Error) -> Void) + async + { + do { + _ = try await assert(executing: task, triggers: events) + Issue.record("Expected the task to fail") + } catch { + // Expected + errorHandler(error) + } + } + +} diff --git a/MCPTestingUtils/Sources/MockTransport.swift b/MCPTestingUtils/Sources/MockTransport.swift new file mode 100644 index 0000000..60e6199 --- /dev/null +++ b/MCPTestingUtils/Sources/MockTransport.swift @@ -0,0 +1,92 @@ + +import Foundation +import JSONRPC +import MCPInterface +import Testing + +// MARK: - MockTransport + +public final class MockTransport { + + // MARK: Lifecycle + + public init() { + let dataSequence = AsyncStream() { continuation in + self.continuation = continuation + } + + dataChannel = DataChannel( + writeHandler: { [weak self] data in self?.handleWrite(data: data) }, + dataSequence: dataSequence) + } + + // MARK: Public + + public private(set) var dataChannel: DataChannel = .noop + + public func onSendMessage(_ hook: @escaping (Data) -> Void) { + let previousSendMessage = sendMessage + sendMessage = { message in + previousSendMessage(message) + hook(message) + } + } + + public func receive(message: String) { + let data = Data(message.utf8) + continuation?.yield(data) + } + + // MARK: Private + + private var sendMessage: (Data) -> Void = { _ in } + + private var continuation: AsyncStream.Continuation? + + private func handleWrite(data: Data) { + sendMessage(data) + } + +} + +extension MockTransport { + + /// Expects the given messages to be sent. + /// Examples: + /// expect([ + /// "{ \"jsonrpc\": \"2.0\", \"id\": 1, \"result\": null }", + /// ]) + public func expect(messages: [String]) { + expect(messages: messages.map { m in { $0(m) } }) + } + + /// Expects the given messages to be sent, calling the corresponding closure when needed. + /// Examples: + /// expect([ + /// { + /// firstMessageReceived.fulfill() + /// return "{ \"jsonrpc\": \"2.0\", \"id\": 1, \"result\": null }" + /// }, + /// ]) + public func expect(messages: [((String) -> Void) -> Void]) { + var messagesCount = 0 + + let previousSendMessage = sendMessage + + sendMessage = { message in + defer { messagesCount += 1 } + guard messagesCount < messages.count else { + Issue.record(""" + Too many messages sent. Expected \(messages.count). Last message received: + \(String(data: message, encoding: .utf8) ?? "Invalid data") + """) + return + } + messages[messagesCount]() { expected in + assertEqual(received: message, expected: expected) + } + + previousSendMessage(message) + } + } +} diff --git a/MCPTestingUtils/Sources/TestUtils.swift b/MCPTestingUtils/Sources/TestUtils.swift new file mode 100644 index 0000000..8a02df1 --- /dev/null +++ b/MCPTestingUtils/Sources/TestUtils.swift @@ -0,0 +1,237 @@ + +import Foundation +import JSONRPC +import MCPInterface +import SwiftTestingUtils +import Testing + +/// Asserts that the received JSON is equal to the expected JSON, allowing for any order of keys or spacing. +public func assertEqual(received jsonData: Data, expected: String) { + do { + let received = try JSONSerialization.jsonObject(with: jsonData) + let receivedPrettyPrinted = try JSONSerialization.data(withJSONObject: received, options: [.sortedKeys, .prettyPrinted]) + + let expected = try JSONSerialization.jsonObject(with: expected.data(using: .utf8)!) + let expectedPrettyPrinted = try JSONSerialization.data(withJSONObject: expected, options: [.sortedKeys, .prettyPrinted]) + + #expect(String(data: receivedPrettyPrinted, encoding: .utf8)! == String(data: expectedPrettyPrinted, encoding: .utf8)!) + } catch { + Issue.record("Failed to compare JSON: \(error)") + } +} + +// MARK: - TestError + +public enum TestError: Error { + case expectationUnfulfilled + case internalError +} + +// MARK: - Event + +public enum Event { + case clientSendsJrpc(_ value: String) + case serverSendsJrpc(_ value: String) + case serverResponding(_ request: (ClientRequest) async throws -> AnyJRPCResponse) + case clientResponding(_ request: (ServerRequest) async throws -> AnyJRPCResponse) + case serverReceiving(_ notification: (ClientNotification) async throws -> Void) + case clientReceiving(_ notification: (ServerNotification) async throws -> Void) +} + +/// Asserts that the given task sends the expected requests and receives the expected responses. +/// - Parameters: +/// - clientTransport: The transport to use to send messages to the client. +/// - serverTransport: The transport to use to send messages to the server. +/// - serverRequestsHandler: The client's handler that receives server requests. If nil, the client's requests (`clientSendsJrpc`) will be sent immediately. +/// - clientRequestsHandler: The server's handler that receives client requests. If nil, the server's requests (`serverSendsJrpc`) will be sent immediately. +/// - serverNotifications: The client's stream of notifications received from the server. +/// - clientNotifications: The server's stream of notifications received from the client. +/// - task: The task to execute. +/// - events: The sequence of events relevant to the task. +public func assert( + clientTransport: MockTransport?, + serverTransport: MockTransport?, + serverRequestsHandler: AsyncStream?, + clientRequestsHandler: AsyncStream?, + serverNotifications: AsyncStream?, + clientNotifications: AsyncStream?, + executing task: @escaping () async throws -> Result, + triggers events: [Event]) + async throws -> Result +{ + var result: Result? = nil + var err: Error? = nil + + /// The next JRPC message that is expected to be sent + var nextMessageToSent: (exp: SwiftTestingUtils.Expectation, clientMessage: String?, serverMessage: String?)? + + clientTransport?.onSendMessage { data in + if let (exp, message, _) = nextMessageToSent, let message { + assertEqual(received: data, expected: message) + exp.fulfill() + } else { + Issue.record("Unexpected message sent: \(String(data: data, encoding: .utf8) ?? "Invalid data")") + } + } + + serverTransport?.onSendMessage { data in + if let (exp, _, message) = nextMessageToSent, let message { + assertEqual(received: data, expected: message) + exp.fulfill() + } else { + Issue.record("Unexpected message sent: \(String(data: data, encoding: .utf8) ?? "Invalid data")") + } + } + + var i = 0 + let prepareNextExpectedMessage = { + loop: for j in i..