From 6992a4e8486110bd49d2cf45a86fbe6eda5d7fec Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Wed, 15 Oct 2025 22:09:12 +0100 Subject: [PATCH 1/4] feat(server): add bidirectional client communication support --- composer.json | 1 + .../CallToolRequestHandler.php | 73 ++++ .../ListToolsRequestHandler.php | 46 +++ examples/custom-method-handlers/server.php | 81 +--- examples/http-client-communication/server.php | 124 ++++++ .../stdio-client-communication/server.php | 115 ++++++ src/Capability/Discovery/SchemaGenerator.php | 12 +- src/Capability/Registry/ReferenceHandler.php | 12 + src/Schema/JsonRpc/Request.php | 35 +- src/Schema/JsonRpc/Response.php | 15 +- .../Result/CreateSamplingMessageResult.php | 46 +++ src/Server.php | 4 +- src/Server/Builder.php | 8 +- src/Server/ClientGateway.php | 183 +++++++++ .../Handler/Request/CallToolHandler.php | 7 + .../Request/CompletionCompleteHandler.php | 5 + .../Handler/Request/GetPromptHandler.php | 11 +- .../Handler/Request/InitializeHandler.php | 5 + .../Handler/Request/ListPromptsHandler.php | 4 + .../Request/ListResourceTemplatesHandler.php | 4 + .../Handler/Request/ListResourcesHandler.php | 4 + .../Handler/Request/ListToolsHandler.php | 4 + src/Server/Handler/Request/PingHandler.php | 5 + .../Handler/Request/ReadResourceHandler.php | 12 +- .../Request/RequestHandlerInterface.php | 5 + src/Server/Protocol.php | 360 +++++++++++++++--- src/Server/Transport/BaseTransport.php | 136 +++++++ src/Server/Transport/CallbackStream.php | 156 ++++++++ src/Server/Transport/InMemoryTransport.php | 31 +- .../Transport/ManagesTransportCallbacks.php | 82 ++++ src/Server/Transport/StdioTransport.php | 149 +++++--- .../Transport/StreamableHttpTransport.php | 232 ++++++----- src/Server/Transport/TransportInterface.php | 92 ++++- tests/Unit/JsonRpc/MessageFactoryTest.php | 4 +- .../Handler/Request/CallToolHandlerTest.php | 35 +- .../Handler/Request/GetPromptHandlerTest.php | 26 +- .../Request/ReadResourceHandlerTest.php | 18 +- tests/Unit/Server/ProtocolTest.php | 272 +++++++++---- 38 files changed, 1973 insertions(+), 441 deletions(-) create mode 100644 examples/custom-method-handlers/CallToolRequestHandler.php create mode 100644 examples/custom-method-handlers/ListToolsRequestHandler.php create mode 100644 examples/http-client-communication/server.php create mode 100644 examples/stdio-client-communication/server.php create mode 100644 src/Server/ClientGateway.php create mode 100644 src/Server/Transport/BaseTransport.php create mode 100644 src/Server/Transport/CallbackStream.php create mode 100644 src/Server/Transport/ManagesTransportCallbacks.php diff --git a/composer.json b/composer.json index ea472ccb..dfb304e7 100644 --- a/composer.json +++ b/composer.json @@ -61,6 +61,7 @@ "Mcp\\Example\\StdioDiscoveryCalculator\\": "examples/stdio-discovery-calculator/", "Mcp\\Example\\StdioEnvVariables\\": "examples/stdio-env-variables/", "Mcp\\Example\\StdioExplicitRegistration\\": "examples/stdio-explicit-registration/", + "Mcp\\Example\\CustomMethodHandlers\\": "examples/custom-method-handlers/", "Mcp\\Tests\\": "tests/" } }, diff --git a/examples/custom-method-handlers/CallToolRequestHandler.php b/examples/custom-method-handlers/CallToolRequestHandler.php new file mode 100644 index 00000000..22d95b39 --- /dev/null +++ b/examples/custom-method-handlers/CallToolRequestHandler.php @@ -0,0 +1,73 @@ + */ +class CallToolRequestHandler implements RequestHandlerInterface +{ + /** + * @param array $toolDefinitions + */ + public function __construct(private array $toolDefinitions) + { + } + + public function supports(Request $request): bool + { + return $request instanceof CallToolRequest; + } + + /** + * @return Response|Error + */ + public function handle(Request $request, SessionInterface $session): Response|Error + { + \assert($request instanceof CallToolRequest); + + $name = $request->name; + $args = $request->arguments ?? []; + + if (!isset($this->toolDefinitions[$name])) { + return new Error($request->getId(), Error::METHOD_NOT_FOUND, \sprintf('Tool not found: %s', $name)); + } + + try { + switch ($name) { + case 'say_hello': + $greetName = (string) ($args['name'] ?? 'world'); + $result = [new TextContent(\sprintf('Hello, %s!', $greetName))]; + break; + case 'sum': + $a = (float) ($args['a'] ?? 0); + $b = (float) ($args['b'] ?? 0); + $result = [new TextContent((string) ($a + $b))]; + break; + default: + $result = [new TextContent('Unknown tool')]; + } + + return new Response($request->getId(), new CallToolResult($result)); + } catch (\Throwable $e) { + return new Response($request->getId(), new CallToolResult([new TextContent('Tool execution failed')], true)); + } + } +} diff --git a/examples/custom-method-handlers/ListToolsRequestHandler.php b/examples/custom-method-handlers/ListToolsRequestHandler.php new file mode 100644 index 00000000..498f3a89 --- /dev/null +++ b/examples/custom-method-handlers/ListToolsRequestHandler.php @@ -0,0 +1,46 @@ + */ +class ListToolsRequestHandler implements RequestHandlerInterface +{ + /** + * @param array $toolDefinitions + */ + public function __construct(private array $toolDefinitions) + { + } + + public function supports(Request $request): bool + { + return $request instanceof ListToolsRequest; + } + + /** + * @return Response + */ + public function handle(Request $request, SessionInterface $session): Response + { + \assert($request instanceof ListToolsRequest); + + return new Response($request->getId(), new ListToolsResult(array_values($this->toolDefinitions), null)); + } +} diff --git a/examples/custom-method-handlers/server.php b/examples/custom-method-handlers/server.php index 9ff0a822..7554a0a7 100644 --- a/examples/custom-method-handlers/server.php +++ b/examples/custom-method-handlers/server.php @@ -13,19 +13,11 @@ require_once dirname(__DIR__).'/bootstrap.php'; chdir(__DIR__); -use Mcp\Schema\Content\TextContent; -use Mcp\Schema\JsonRpc\Error; -use Mcp\Schema\JsonRpc\Request; -use Mcp\Schema\JsonRpc\Response; -use Mcp\Schema\Request\CallToolRequest; -use Mcp\Schema\Request\ListToolsRequest; -use Mcp\Schema\Result\CallToolResult; -use Mcp\Schema\Result\ListToolsResult; +use Mcp\Example\CustomMethodHandlers\CallToolRequestHandler; +use Mcp\Example\CustomMethodHandlers\ListToolsRequestHandler; use Mcp\Schema\ServerCapabilities; use Mcp\Schema\Tool; use Mcp\Server; -use Mcp\Server\Handler\Request\RequestHandlerInterface; -use Mcp\Server\Session\SessionInterface; use Mcp\Server\Transport\StdioTransport; logger()->info('Starting MCP Custom Method Handlers (Stdio) Server...'); @@ -58,73 +50,8 @@ ), ]; -$listToolsHandler = new class($toolDefinitions) implements RequestHandlerInterface { - /** - * @param array $toolDefinitions - */ - public function __construct(private array $toolDefinitions) - { - } - - public function supports(Request $request): bool - { - return $request instanceof ListToolsRequest; - } - - public function handle(Request $request, SessionInterface $session): Response - { - assert($request instanceof ListToolsRequest); - - return new Response($request->getId(), new ListToolsResult(array_values($this->toolDefinitions), null)); - } -}; - -$callToolHandler = new class($toolDefinitions) implements RequestHandlerInterface { - /** - * @param array $toolDefinitions - */ - public function __construct(private array $toolDefinitions) - { - } - - public function supports(Request $request): bool - { - return $request instanceof CallToolRequest; - } - - public function handle(Request $request, SessionInterface $session): Response|Error - { - assert($request instanceof CallToolRequest); - - $name = $request->name; - $args = $request->arguments ?? []; - - if (!isset($this->toolDefinitions[$name])) { - return new Error($request->getId(), Error::METHOD_NOT_FOUND, sprintf('Tool not found: %s', $name)); - } - - try { - switch ($name) { - case 'say_hello': - $greetName = (string) ($args['name'] ?? 'world'); - $result = [new TextContent(sprintf('Hello, %s!', $greetName))]; - break; - case 'sum': - $a = (float) ($args['a'] ?? 0); - $b = (float) ($args['b'] ?? 0); - $result = [new TextContent((string) ($a + $b))]; - break; - default: - $result = [new TextContent('Unknown tool')]; - } - - return new Response($request->getId(), new CallToolResult($result)); - } catch (Throwable $e) { - return new Response($request->getId(), new CallToolResult([new TextContent('Tool execution failed')], true)); - } - } -}; - +$listToolsHandler = new ListToolsRequestHandler($toolDefinitions); +$callToolHandler = new CallToolRequestHandler($toolDefinitions); $capabilities = new ServerCapabilities(tools: true, resources: false, prompts: false); $server = Server::builder() diff --git a/examples/http-client-communication/server.php b/examples/http-client-communication/server.php new file mode 100644 index 00000000..dd38cfd4 --- /dev/null +++ b/examples/http-client-communication/server.php @@ -0,0 +1,124 @@ +fromGlobals(); + +$sessionDir = __DIR__.'/sessions'; +$capabilities = new ServerCapabilities(logging: true, tools: true); + +$server = Server::builder() + ->setServerInfo('HTTP Client Communication Demo', '1.0.0') + ->setLogger(logger()) + ->setContainer(container()) + ->setSession(new FileSessionStore($sessionDir)) + ->setCapabilities($capabilities) + ->addTool( + function (string $projectName, array $milestones, ClientGateway $client): array { + $client->log(LoggingLevel::Info, sprintf('Preparing project briefing for "%s"', $projectName)); + + $totalSteps = max(1, count($milestones)); + + foreach ($milestones as $index => $milestone) { + $progress = ($index + 1) / $totalSteps; + $message = sprintf('Analyzing milestone "%s"', $milestone); + + $client->progress(progress: $progress, total: 1, message: $message); + + usleep(150_000); // Simulate work being done + } + + $prompt = sprintf( + 'Draft a concise stakeholder briefing for the project "%s". Highlight key milestones: %s. Focus on risks and next steps.', + $projectName, + implode(', ', $milestones) + ); + + $response = $client->sample( + prompt: $prompt, + maxTokens: 400, + timeout: 90, + options: ['temperature' => 0.4] + ); + + if ($response instanceof JsonRpcError) { + throw new RuntimeException(sprintf('Sampling request failed (%d): %s', $response->code, $response->message)); + } + + $result = $response->result; + $content = $result->content instanceof TextContent ? trim((string) $result->content->text) : ''; + + $client->log(LoggingLevel::Info, 'Briefing ready, returning to caller.'); + + return [ + 'project' => $projectName, + 'milestones_reviewed' => $milestones, + 'briefing' => $content, + 'model' => $result->model, + 'stop_reason' => $result->stopReason, + ]; + }, + name: 'prepare_project_briefing', + description: 'Compile a stakeholder briefing with live logging, progress updates, and LLM sampling.' + ) + ->addTool( + function (string $serviceName, ClientGateway $client): array { + $client->log(LoggingLevel::Info, sprintf('Starting maintenance checks for "%s"', $serviceName)); + + $steps = [ + 'Verifying health metrics', + 'Checking recent deployments', + 'Reviewing alert stream', + 'Summarizing findings', + ]; + + foreach ($steps as $index => $step) { + $progress = ($index + 1) / count($steps); + + $client->progress(progress: $progress, total: 1, message: $step); + + usleep(120_000); // Simulate work being done + } + + $client->log(LoggingLevel::Info, sprintf('Maintenance checks complete for "%s"', $serviceName)); + + return [ + 'service' => $serviceName, + 'status' => 'operational', + 'notes' => 'No critical issues detected during automated sweep.', + ]; + }, + name: 'run_service_maintenance', + description: 'Simulate service maintenance with logging and progress notifications.' + ) + ->build(); + +$transport = new StreamableHttpTransport($request, $psr17Factory, $psr17Factory, logger()); + +$response = $server->run($transport); + +(new SapiEmitter())->emit($response); diff --git a/examples/stdio-client-communication/server.php b/examples/stdio-client-communication/server.php new file mode 100644 index 00000000..96d2f052 --- /dev/null +++ b/examples/stdio-client-communication/server.php @@ -0,0 +1,115 @@ +#!/usr/bin/env php +setServerInfo('STDIO Client Communication Demo', '1.0.0') + ->setLogger(logger()) + ->setContainer(container()) + ->setCapabilities($capabilities) + ->addTool( + function (string $incidentTitle, ClientGateway $client): array { + $client->log(LoggingLevel::Warning, sprintf('Incident triage started: %s', $incidentTitle)); + + $steps = [ + 'Collecting telemetry', + 'Assessing scope', + 'Coordinating responders', + ]; + + foreach ($steps as $index => $step) { + $progress = ($index + 1) / count($steps); + + $client->progress(progress: $progress, total: 1, message: $step); + + usleep(180_000); // Simulate work being done + } + + $prompt = sprintf( + 'Provide a concise response strategy for incident "%s" based on the steps completed: %s.', + $incidentTitle, + implode(', ', $steps) + ); + + $sampling = $client->sample( + prompt: $prompt, + maxTokens: 350, + timeout: 90, + options: ['temperature' => 0.5] + ); + + if ($sampling instanceof JsonRpcError) { + throw new RuntimeException(sprintf('Sampling request failed (%d): %s', $sampling->code, $sampling->message)); + } + + $result = $sampling->result; + $recommendation = $result->content instanceof TextContent ? trim((string) $result->content->text) : ''; + + $client->log(LoggingLevel::Info, sprintf('Incident triage completed for %s', $incidentTitle)); + + return [ + 'incident' => $incidentTitle, + 'recommended_actions' => $recommendation, + 'model' => $result->model, + ]; + }, + name: 'coordinate_incident_response', + description: 'Coordinate an incident response with logging, progress, and sampling.' + ) + ->addTool( + function (string $dataset, ClientGateway $client): array { + $client->log(LoggingLevel::Info, sprintf('Running quality checks on dataset "%s"', $dataset)); + + $tasks = [ + 'Validating schema', + 'Scanning for anomalies', + 'Reviewing statistical summary', + ]; + + foreach ($tasks as $index => $task) { + $progress = ($index + 1) / count($tasks); + + $client->progress(progress: $progress, total: 1, message: $task); + + usleep(140_000); // Simulate work being done + } + + $client->log(LoggingLevel::Info, sprintf('Dataset "%s" passed automated checks.', $dataset)); + + return [ + 'dataset' => $dataset, + 'status' => 'passed', + 'notes' => 'No significant integrity issues detected during automated checks.', + ]; + }, + name: 'run_dataset_quality_checks', + description: 'Perform dataset quality checks with progress updates and logging.' + ) + ->build(); + +$transport = new StdioTransport(); + +$status = $server->run($transport); + +exit($status); diff --git a/src/Capability/Discovery/SchemaGenerator.php b/src/Capability/Discovery/SchemaGenerator.php index 936a35cf..2557f559 100644 --- a/src/Capability/Discovery/SchemaGenerator.php +++ b/src/Capability/Discovery/SchemaGenerator.php @@ -12,6 +12,7 @@ namespace Mcp\Capability\Discovery; use Mcp\Capability\Attribute\Schema; +use Mcp\Server\ClientGateway; use phpDocumentor\Reflection\DocBlock\Tags\Param; /** @@ -409,10 +410,19 @@ private function parseParametersInfo(\ReflectionMethod|\ReflectionFunction $refl $parametersInfo = []; foreach ($reflection->getParameters() as $rp) { + $reflectionType = $rp->getType(); + + if ($reflectionType instanceof \ReflectionNamedType && !$reflectionType->isBuiltin()) { + $typeName = $reflectionType->getName(); + + if (is_a($typeName, ClientGateway::class, true)) { + continue; + } + } + $paramName = $rp->getName(); $paramTag = $paramTags['$'.$paramName] ?? null; - $reflectionType = $rp->getType(); $typeString = $this->getParameterTypeString($rp, $paramTag); $description = $this->docBlockParser->getParamDescription($paramTag); $hasDefault = $rp->isDefaultValueAvailable(); diff --git a/src/Capability/Registry/ReferenceHandler.php b/src/Capability/Registry/ReferenceHandler.php index b0333788..e3eb925f 100644 --- a/src/Capability/Registry/ReferenceHandler.php +++ b/src/Capability/Registry/ReferenceHandler.php @@ -13,6 +13,7 @@ use Mcp\Exception\InvalidArgumentException; use Mcp\Exception\RegistryException; +use Mcp\Server\ClientGateway; use Psr\Container\ContainerInterface; /** @@ -89,6 +90,17 @@ private function prepareArguments(\ReflectionFunctionAbstract $reflection, array $paramName = $parameter->getName(); $paramPosition = $parameter->getPosition(); + // Check if parameter is a special injectable type + $type = $parameter->getType(); + if ($type instanceof \ReflectionNamedType && !$type->isBuiltin()) { + $typeName = $type->getName(); + + if (ClientGateway::class === $typeName && isset($arguments['_session'])) { + $finalArgs[$paramPosition] = new ClientGateway($arguments['_session']); + continue; + } + } + if (isset($arguments[$paramName])) { $argument = $arguments[$paramName]; try { diff --git a/src/Schema/JsonRpc/Request.php b/src/Schema/JsonRpc/Request.php index 22083c75..cb8ed836 100644 --- a/src/Schema/JsonRpc/Request.php +++ b/src/Schema/JsonRpc/Request.php @@ -59,7 +59,13 @@ public static function fromArray(array $data): static $request->id = $data['id']; if (isset($data['params']['_meta'])) { - $request->meta = $data['params']['_meta']; + $meta = $data['params']['_meta']; + if ($meta instanceof \stdClass) { + $meta = (array) $meta; + } + if (\is_array($meta)) { + $request->meta = $meta; + } } return $request; @@ -75,6 +81,33 @@ public function getId(): string|int return $this->id; } + /** + * @return array|null + */ + public function getMeta(): ?array + { + return $this->meta; + } + + public function withId(string|int $id): static + { + $clone = clone $this; + $clone->id = $id; + + return $clone; + } + + /** + * @param array|null $meta + */ + public function withMeta(?array $meta): static + { + $clone = clone $this; + $clone->meta = $meta; + + return $clone; + } + /** * @return RequestData */ diff --git a/src/Schema/JsonRpc/Response.php b/src/Schema/JsonRpc/Response.php index f1521265..7f2d82ba 100644 --- a/src/Schema/JsonRpc/Response.php +++ b/src/Schema/JsonRpc/Response.php @@ -14,23 +14,26 @@ use Mcp\Exception\InvalidArgumentException; /** - * @author Kyrian Obikwelu + * @template TResult * * @phpstan-type ResponseData array{ * jsonrpc: string, * id: string|int, * result: array, * } + * + * @author Kyrian Obikwelu */ class Response implements MessageInterface { /** - * @param string|int $id this MUST be the same as the value of the id member in the Request Object - * @param ResultInterface|array $result the value of this member is determined by the method invoked on the Server + * @param string|int $id this MUST be the same as the value of the id member in the Request Object + * @param TResult $result the value of this member is determined by the method invoked on the Server */ public function __construct( public readonly string|int $id, - public readonly ResultInterface|array $result, + /** @var TResult */ + public readonly mixed $result, ) { } @@ -41,6 +44,8 @@ public function getId(): string|int /** * @param ResponseData $data + * + * @return self> */ public static function fromArray(array $data): self { @@ -67,7 +72,7 @@ public static function fromArray(array $data): self * @return array{ * jsonrpc: string, * id: string|int, - * result: ResultInterface, + * result: mixed, * } */ public function jsonSerialize(): array diff --git a/src/Schema/Result/CreateSamplingMessageResult.php b/src/Schema/Result/CreateSamplingMessageResult.php index b9e87c1a..986d6291 100644 --- a/src/Schema/Result/CreateSamplingMessageResult.php +++ b/src/Schema/Result/CreateSamplingMessageResult.php @@ -11,6 +11,7 @@ namespace Mcp\Schema\Result; +use Mcp\Exception\InvalidArgumentException; use Mcp\Schema\Content\AudioContent; use Mcp\Schema\Content\ImageContent; use Mcp\Schema\Content\TextContent; @@ -40,6 +41,51 @@ public function __construct( ) { } + /** + * @param array $data + */ + public static function fromArray(array $data): self + { + if (!isset($data['role']) || !\is_string($data['role'])) { + throw new InvalidArgumentException('Missing or invalid "role" in CreateSamplingMessageResult data.'); + } + + if (!isset($data['content']) || !\is_array($data['content'])) { + throw new InvalidArgumentException('Missing or invalid "content" in CreateSamplingMessageResult data.'); + } + + if (!isset($data['model']) || !\is_string($data['model'])) { + throw new InvalidArgumentException('Missing or invalid "model" in CreateSamplingMessageResult data.'); + } + + $role = Role::from($data['role']); + $contentPayload = $data['content']; + + $content = self::hydrateContent($contentPayload); + $stopReason = isset($data['stopReason']) && \is_string($data['stopReason']) ? $data['stopReason'] : null; + + return new self($role, $content, $data['model'], $stopReason); + } + + /** + * @param array $contentData + */ + private static function hydrateContent(array $contentData): TextContent|ImageContent|AudioContent + { + $type = $contentData['type'] ?? null; + + if (!\is_string($type)) { + throw new InvalidArgumentException('Missing or invalid "type" in sampling content payload.'); + } + + return match ($type) { + 'text' => TextContent::fromArray($contentData), + 'image' => ImageContent::fromArray($contentData), + 'audio' => AudioContent::fromArray($contentData), + default => throw new InvalidArgumentException(\sprintf('Unsupported sampling content type "%s".', $type)), + }; + } + /** * @return array{ * role: string, diff --git a/src/Server.php b/src/Server.php index 1eb24e8c..8657610a 100644 --- a/src/Server.php +++ b/src/Server.php @@ -43,12 +43,12 @@ public static function builder(): Builder */ public function run(TransportInterface $transport): mixed { - $this->logger->info('Running server...'); - $transport->initialize(); $this->protocol->connect($transport); + $this->logger->info('Running server...'); + try { return $transport->listen(); } finally { diff --git a/src/Server/Builder.php b/src/Server/Builder.php index 886e61a8..25a97bde 100644 --- a/src/Server/Builder.php +++ b/src/Server/Builder.php @@ -63,7 +63,7 @@ final class Builder private ?string $instructions = null; /** - * @var array + * @var array> */ private array $requestHandlers = []; @@ -177,6 +177,9 @@ public function setCapabilities(ServerCapabilities $serverCapabilities): self /** * Register a single custom method handler. */ + /** + * @param RequestHandlerInterface $handler + */ public function addRequestHandler(RequestHandlerInterface $handler): self { $this->requestHandlers[] = $handler; @@ -189,6 +192,9 @@ public function addRequestHandler(RequestHandlerInterface $handler): self * * @param iterable $handlers */ + /** + * @param iterable> $handlers + */ public function addRequestHandlers(iterable $handlers): self { foreach ($handlers as $handler) { diff --git a/src/Server/ClientGateway.php b/src/Server/ClientGateway.php new file mode 100644 index 00000000..58f76b20 --- /dev/null +++ b/src/Server/ClientGateway.php @@ -0,0 +1,183 @@ +notify(new ProgressNotification("Starting analysis...")); + * + * // Request LLM sampling from client + * $response = $client->request(new SamplingRequest($text)); + * + * return $response->content->text; + * } + * ``` + * + * @author Kyrian Obikwelu + */ +final class ClientGateway +{ + public function __construct( + private readonly SessionInterface $session, + ) { + } + + /** + * Send a notification to the client (fire and forget). + * + * This suspends the Fiber to let the transport flush the notification via SSE, + * then immediately resumes execution. + */ + public function notify(Notification $notification): void + { + \Fiber::suspend([ + 'type' => 'notification', + 'notification' => $notification, + 'session_id' => $this->session->getId()->toString(), + ]); + } + + /** + * Convenience method to send a logging notification to the client. + */ + public function log(LoggingLevel $level, mixed $data, ?string $logger = null): void + { + $this->notify(new LoggingMessageNotification($level, $data, $logger)); + } + + /** + * Convenience method to send a progress notification to the client. + */ + public function progress(float $progress, ?float $total = null, ?string $message = null): void + { + $meta = $this->session->get(Protocol::SESSION_ACTIVE_REQUEST_META, []); + $progressToken = $meta['progressToken'] ?? null; + + if (null === $progressToken) { + // Per the spec the client never asked for progress, so just bail. + return; + } + + $this->notify(new ProgressNotification($progressToken, $progress, $total, $message)); + } + + /** + * Send a request to the client and wait for a response (blocking). + * + * This suspends the Fiber and waits for the client to respond. The transport + * handles polling the session for the response and resuming the Fiber when ready. + * + * @param Request $request The request to send + * @param int $timeout Maximum time to wait for response (seconds) + * + * @return Response>|Error The client's response message + * + * @throws \RuntimeException If Fiber support is not available + */ + public function request(Request $request, int $timeout = 120): Response|Error + { + $response = \Fiber::suspend([ + 'type' => 'request', + 'request' => $request, + 'session_id' => $this->session->getId()->toString(), + 'timeout' => $timeout, + ]); + + if (!$response instanceof Response && !$response instanceof Error) { + throw new \RuntimeException('Transport returned an unexpected payload; expected a Response or Error message.'); + } + + return $response; + } + + /** + * Create and send an LLM sampling requests. + * + * @param CreateSamplingMessageRequest $request The request to send + * @param int $timeout The timeout in seconds + * + * @return Response|Error The sampling response + */ + public function createMessage(CreateSamplingMessageRequest $request, int $timeout = 120): Response|Error + { + $response = $this->request($request, $timeout); + + if ($response instanceof Error) { + return $response; + } + + $result = CreateSamplingMessageResult::fromArray($response->result); + + return new Response($response->getId(), $result); + } + + /** + * Convenience method for LLM sampling requests. + * + * @param string $prompt The prompt for the LLM + * @param int $maxTokens Maximum tokens to generate + * @param int $timeout The timeout in seconds + * @param array $options Additional sampling options (temperature, etc.) + * + * @return Response|Error The sampling response + */ + public function sample(string $prompt, int $maxTokens = 1000, int $timeout = 120, array $options = []): Response|Error + { + $preferences = $options['preferences'] ?? null; + if (\is_array($preferences)) { + $preferences = ModelPreferences::fromArray($preferences); + } + + if (null !== $preferences && !$preferences instanceof ModelPreferences) { + throw new \InvalidArgumentException('The "preferences" option must be an array or an instance of ModelPreferences.'); + } + + $samplingRequest = new CreateSamplingMessageRequest( + messages: [ + new SamplingMessage(Role::User, new TextContent(text: $prompt)), + ], + maxTokens: $maxTokens, + preferences: $preferences, + systemPrompt: $options['systemPrompt'] ?? null, + includeContext: $options['includeContext'] ?? null, + temperature: $options['temperature'] ?? null, + stopSequences: $options['stopSequences'] ?? null, + metadata: $options['metadata'] ?? null, + ); + + return $this->createMessage($samplingRequest, $timeout); + } +} diff --git a/src/Server/Handler/Request/CallToolHandler.php b/src/Server/Handler/Request/CallToolHandler.php index c1f10b9d..79413908 100644 --- a/src/Server/Handler/Request/CallToolHandler.php +++ b/src/Server/Handler/Request/CallToolHandler.php @@ -26,6 +26,8 @@ use Psr\Log\NullLogger; /** + * @implements RequestHandlerInterface + * * @author Christopher Hertel * @author Tobias Nyholm */ @@ -43,6 +45,9 @@ public function supports(Request $request): bool return $request instanceof CallToolRequest; } + /** + * @return Response|Error + */ public function handle(Request $request, SessionInterface $session): Response|Error { \assert($request instanceof CallToolRequest); @@ -58,6 +63,8 @@ public function handle(Request $request, SessionInterface $session): Response|Er throw new ToolNotFoundException($request); } + $arguments['_session'] = $session; + $result = $this->referenceHandler->handle($reference, $arguments); if (!$result instanceof CallToolResult) { diff --git a/src/Server/Handler/Request/CompletionCompleteHandler.php b/src/Server/Handler/Request/CompletionCompleteHandler.php index 6787e48b..c3d9f844 100644 --- a/src/Server/Handler/Request/CompletionCompleteHandler.php +++ b/src/Server/Handler/Request/CompletionCompleteHandler.php @@ -24,6 +24,8 @@ /** * Handles completion/complete requests. * + * @implements RequestHandlerInterface + * * @author Kyrian Obikwelu */ final class CompletionCompleteHandler implements RequestHandlerInterface @@ -39,6 +41,9 @@ public function supports(Request $request): bool return $request instanceof CompletionCompleteRequest; } + /** + * @return Response|Error + */ public function handle(Request $request, SessionInterface $session): Response|Error { \assert($request instanceof CompletionCompleteRequest); diff --git a/src/Server/Handler/Request/GetPromptHandler.php b/src/Server/Handler/Request/GetPromptHandler.php index cf321981..28e5e909 100644 --- a/src/Server/Handler/Request/GetPromptHandler.php +++ b/src/Server/Handler/Request/GetPromptHandler.php @@ -26,6 +26,8 @@ use Psr\Log\NullLogger; /** + * @implements RequestHandlerInterface + * * @author Tobias Nyholm */ final class GetPromptHandler implements RequestHandlerInterface @@ -42,6 +44,9 @@ public function supports(Request $request): bool return $request instanceof GetPromptRequest; } + /** + * @return Response|Error + */ public function handle(Request $request, SessionInterface $session): Response|Error { \assert($request instanceof GetPromptRequest); @@ -55,6 +60,8 @@ public function handle(Request $request, SessionInterface $session): Response|Er throw new PromptNotFoundException($request); } + $arguments['_session'] = $session; + $result = $this->referenceHandler->handle($reference, $arguments); $formatted = $reference->formatResult($result); @@ -67,11 +74,11 @@ public function handle(Request $request, SessionInterface $session): Response|Er } catch (PromptGetException|ExceptionInterface $e) { $this->logger->error('Error while handling prompt', ['prompt_name' => $promptName]); - return Error::forInternalError('Error while handling prompt', $request->getId()); + return Error::forInternalError('Error while handling prompt: '.$e->getMessage(), $request->getId()); } catch (\Throwable $e) { $this->logger->error('Error while handling prompt', ['prompt_name' => $promptName]); - return Error::forInternalError('Error while handling prompt', $request->getId()); + return Error::forInternalError('Error while handling prompt: '.$e->getMessage(), $request->getId()); } } } diff --git a/src/Server/Handler/Request/InitializeHandler.php b/src/Server/Handler/Request/InitializeHandler.php index 28bf109f..32eae194 100644 --- a/src/Server/Handler/Request/InitializeHandler.php +++ b/src/Server/Handler/Request/InitializeHandler.php @@ -21,6 +21,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Christopher Hertel */ final class InitializeHandler implements RequestHandlerInterface @@ -35,6 +37,9 @@ public function supports(Request $request): bool return $request instanceof InitializeRequest; } + /** + * @return Response + */ public function handle(Request $request, SessionInterface $session): Response { \assert($request instanceof InitializeRequest); diff --git a/src/Server/Handler/Request/ListPromptsHandler.php b/src/Server/Handler/Request/ListPromptsHandler.php index 2db8a7ab..aa75fef0 100644 --- a/src/Server/Handler/Request/ListPromptsHandler.php +++ b/src/Server/Handler/Request/ListPromptsHandler.php @@ -20,6 +20,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Tobias Nyholm */ final class ListPromptsHandler implements RequestHandlerInterface @@ -36,6 +38,8 @@ public function supports(Request $request): bool } /** + * @return Response + * * @throws InvalidCursorException */ public function handle(Request $request, SessionInterface $session): Response diff --git a/src/Server/Handler/Request/ListResourceTemplatesHandler.php b/src/Server/Handler/Request/ListResourceTemplatesHandler.php index 76b48bb0..ce77b62a 100644 --- a/src/Server/Handler/Request/ListResourceTemplatesHandler.php +++ b/src/Server/Handler/Request/ListResourceTemplatesHandler.php @@ -20,6 +20,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Christopher Hertel */ final class ListResourceTemplatesHandler implements RequestHandlerInterface @@ -36,6 +38,8 @@ public function supports(Request $request): bool } /** + * @return Response + * * @throws InvalidCursorException */ public function handle(Request $request, SessionInterface $session): Response diff --git a/src/Server/Handler/Request/ListResourcesHandler.php b/src/Server/Handler/Request/ListResourcesHandler.php index 7e4a4ce7..4dc5ceb2 100644 --- a/src/Server/Handler/Request/ListResourcesHandler.php +++ b/src/Server/Handler/Request/ListResourcesHandler.php @@ -20,6 +20,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Tobias Nyholm */ final class ListResourcesHandler implements RequestHandlerInterface @@ -36,6 +38,8 @@ public function supports(Request $request): bool } /** + * @return Response + * * @throws InvalidCursorException */ public function handle(Request $request, SessionInterface $session): Response diff --git a/src/Server/Handler/Request/ListToolsHandler.php b/src/Server/Handler/Request/ListToolsHandler.php index 81854a62..7c7f7788 100644 --- a/src/Server/Handler/Request/ListToolsHandler.php +++ b/src/Server/Handler/Request/ListToolsHandler.php @@ -20,6 +20,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Christopher Hertel * @author Tobias Nyholm */ @@ -37,6 +39,8 @@ public function supports(Request $request): bool } /** + * @return Response + * * @throws InvalidCursorException When the cursor is invalid */ public function handle(Request $request, SessionInterface $session): Response diff --git a/src/Server/Handler/Request/PingHandler.php b/src/Server/Handler/Request/PingHandler.php index 378926c1..507680fa 100644 --- a/src/Server/Handler/Request/PingHandler.php +++ b/src/Server/Handler/Request/PingHandler.php @@ -18,6 +18,8 @@ use Mcp\Server\Session\SessionInterface; /** + * @implements RequestHandlerInterface + * * @author Christopher Hertel */ final class PingHandler implements RequestHandlerInterface @@ -27,6 +29,9 @@ public function supports(Request $request): bool return $request instanceof PingRequest; } + /** + * @return Response + */ public function handle(Request $request, SessionInterface $session): Response { \assert($request instanceof PingRequest); diff --git a/src/Server/Handler/Request/ReadResourceHandler.php b/src/Server/Handler/Request/ReadResourceHandler.php index 17d2781c..83f0d654 100644 --- a/src/Server/Handler/Request/ReadResourceHandler.php +++ b/src/Server/Handler/Request/ReadResourceHandler.php @@ -25,6 +25,8 @@ use Psr\Log\NullLogger; /** + * @implements RequestHandlerInterface + * * @author Tobias Nyholm */ final class ReadResourceHandler implements RequestHandlerInterface @@ -41,6 +43,9 @@ public function supports(Request $request): bool return $request instanceof ReadResourceRequest; } + /** + * @return Response|Error + */ public function handle(Request $request, SessionInterface $session): Response|Error { \assert($request instanceof ReadResourceRequest); @@ -55,7 +60,12 @@ public function handle(Request $request, SessionInterface $session): Response|Er throw new ResourceNotFoundException($request); } - $result = $this->referenceHandler->handle($reference, ['uri' => $uri]); + $arguments = [ + 'uri' => $uri, + '_session' => $session, + ]; + + $result = $this->referenceHandler->handle($reference, $arguments); if ($reference instanceof ResourceTemplateReference) { $formatted = $reference->formatResult($result, $uri, $reference->resourceTemplate->mimeType); diff --git a/src/Server/Handler/Request/RequestHandlerInterface.php b/src/Server/Handler/Request/RequestHandlerInterface.php index d89b2c1f..d81c0795 100644 --- a/src/Server/Handler/Request/RequestHandlerInterface.php +++ b/src/Server/Handler/Request/RequestHandlerInterface.php @@ -17,11 +17,16 @@ use Mcp\Server\Session\SessionInterface; /** + * @template TResult + * * @author Kyrian Obikwelu */ interface RequestHandlerInterface { public function supports(Request $request): bool; + /** + * @return Response|Error + */ public function handle(Request $request, SessionInterface $session): Response|Error; } diff --git a/src/Server/Protocol.php b/src/Server/Protocol.php index 11ec6a87..cbff5fbe 100644 --- a/src/Server/Protocol.php +++ b/src/Server/Protocol.php @@ -17,6 +17,7 @@ use Mcp\Schema\JsonRpc\Notification; use Mcp\Schema\JsonRpc\Request; use Mcp\Schema\JsonRpc\Response; +use Mcp\Schema\JsonRpc\ResultInterface; use Mcp\Schema\Request\InitializeRequest; use Mcp\Server\Handler\Notification\NotificationHandlerInterface; use Mcp\Server\Handler\Request\RequestHandlerInterface; @@ -31,17 +32,35 @@ /** * @final * + * @phpstan-import-type McpFiber from \Mcp\Server\Transport\TransportInterface + * @phpstan-import-type FiberSuspend from \Mcp\Server\Transport\TransportInterface + * * @author Christopher Hertel * @author Kyrian Obikwelu */ class Protocol { + /** Session key for request ID counter */ + private const SESSION_REQUEST_ID_COUNTER = '_mcp.request_id_counter'; + + /** Session key for pending outgoing requests */ + private const SESSION_PENDING_REQUESTS = '_mcp.pending_requests'; + + /** Session key for incoming client responses */ + private const SESSION_RESPONSES = '_mcp.responses'; + + /** Session key for outgoing message queue */ + private const SESSION_OUTGOING_QUEUE = '_mcp.outgoing_queue'; + + /** Session key for active request meta */ + public const SESSION_ACTIVE_REQUEST_META = '_mcp.active_request_meta'; + /** @var TransportInterface|null */ private ?TransportInterface $transport = null; /** - * @param array $requestHandlers - * @param array $notificationHandlers + * @param array>> $requestHandlers + * @param array $notificationHandlers */ public function __construct( private readonly array $requestHandlers, @@ -53,6 +72,14 @@ public function __construct( ) { } + /** + * @return TransportInterface + */ + public function getTransport(): TransportInterface + { + return $this->transport; + } + /** * Connect this protocol to a transport. * @@ -72,6 +99,14 @@ public function connect(TransportInterface $transport): void $this->transport->onSessionEnd([$this, 'destroySession']); + $this->transport->setOutgoingMessagesProvider([$this, 'consumeOutgoingMessages']); + + $this->transport->setPendingRequestsProvider([$this, 'getPendingRequests']); + + $this->transport->setResponseFinder([$this, 'checkResponse']); + + $this->transport->setFiberYieldHandler([$this, 'handleFiberYield']); + $this->logger->info('Protocol connected to transport', ['transport' => $transport::class]); } @@ -91,7 +126,7 @@ public function processInput(string $input, ?Uuid $sessionId): void } catch (\JsonException $e) { $this->logger->warning('Failed to decode json message.', ['exception' => $e]); $error = Error::forParseError($e->getMessage()); - $this->sendResponse($error, ['session_id' => $sessionId]); + $this->sendResponse($error, null); return; } @@ -121,13 +156,15 @@ private function handleInvalidMessage(InvalidInputMessageException $exception, S $this->logger->warning('Failed to create message.', ['exception' => $exception]); $error = Error::forInvalidRequest($exception->getMessage()); - $this->sendResponse($error, ['session_id' => $session->getId()]); + $this->sendResponse($error, $session); } private function handleRequest(Request $request, SessionInterface $session): void { $this->logger->info('Handling request.', ['request' => $request]); + $session->set(self::SESSION_ACTIVE_REQUEST_META, $request->getMeta()); + $handlerFound = false; foreach ($this->requestHandlers as $handler) { @@ -138,16 +175,41 @@ private function handleRequest(Request $request, SessionInterface $session): voi $handlerFound = true; try { - $response = $handler->handle($request, $session); - $this->sendResponse($response, ['session_id' => $session->getId()]); + /** @var McpFiber $fiber */ + $fiber = new \Fiber(fn () => $handler->handle($request, $session)); + + $result = $fiber->start(); + + if ($fiber->isSuspended()) { + if (\is_array($result) && isset($result['type'])) { + if ('notification' === $result['type']) { + $notification = $result['notification']; + $this->sendNotification($notification, $session); + } elseif ('request' === $result['type']) { + $request = $result['request']; + $timeout = $result['timeout'] ?? 120; + $this->sendRequest($request, $timeout, $session); + } + } + + $this->transport->attachFiberToSession($fiber, $session->getId()); + + return; + } else { + $finalResult = $fiber->getReturn(); + + $this->sendResponse($finalResult, $session); + } } catch (\InvalidArgumentException $e) { $this->logger->warning(\sprintf('Invalid argument: %s', $e->getMessage()), ['exception' => $e]); + $error = Error::forInvalidParams($e->getMessage(), $request->getId()); - $this->sendResponse($error, ['session_id' => $session->getId()]); + $this->sendResponse($error, $session); } catch (\Throwable $e) { $this->logger->error(\sprintf('Uncaught exception: %s', $e->getMessage()), ['exception' => $e]); + $error = Error::forInternalError($e->getMessage(), $request->getId()); - $this->sendResponse($error, ['session_id' => $session->getId()]); + $this->sendResponse($error, $session); } break; @@ -155,14 +217,25 @@ private function handleRequest(Request $request, SessionInterface $session): voi if (!$handlerFound) { $error = Error::forMethodNotFound(\sprintf('No handler found for method "%s".', $request::getMethod()), $request->getId()); - $this->sendResponse($error, ['session_id' => $session->getId()]); + $this->sendResponse($error, $session); } } + /** + * @param Response>|Error $response + */ private function handleResponse(Response|Error $response, SessionInterface $session): void { - $this->logger->info('Handling response.', ['response' => $response]); - // TODO: Implement response handling + $this->logger->info('Handling response from client.', ['response' => $response]); + + $messageId = $response->getId(); + + $session->set(self::SESSION_RESPONSES.".{$messageId}", $response->jsonSerialize()); + $session->forget(self::SESSION_ACTIVE_REQUEST_META); + + $this->logger->info('Client response stored in session', [ + 'message_id' => $messageId, + ]); } private function handleNotification(Notification $notification, SessionInterface $session): void @@ -183,56 +256,250 @@ private function handleNotification(Notification $notification, SessionInterface } /** - * @param array $context + * Sends a request to the client and returns the request ID. */ - public function sendRequest(Request $request, array $context = []): void + public function sendRequest(Request $request, int $timeout, SessionInterface $session): int { - $this->logger->info('Sending request.', ['request' => $request, 'context' => $context]); - // TODO: Implement request sending + $counter = $session->get(self::SESSION_REQUEST_ID_COUNTER, 1000); + $requestId = $counter++; + $session->set(self::SESSION_REQUEST_ID_COUNTER, $counter); + + $requestWithId = $request->withId($requestId); + + $this->logger->info('Queueing server request to client', [ + 'request_id' => $requestId, + 'method' => $request::getMethod(), + ]); + + $pending = $session->get(self::SESSION_PENDING_REQUESTS, []); + $pending[$requestId] = [ + 'request_id' => $requestId, + 'timeout' => $timeout, + 'timestamp' => time(), + ]; + $session->set(self::SESSION_PENDING_REQUESTS, $pending); + + $this->queueOutgoing($requestWithId, ['type' => 'request'], $session); + + return $requestId; } /** - * @param array $context + * Queues a notification for later delivery. */ - public function sendResponse(Response|Error $response, array $context = []): void + public function sendNotification(Notification $notification, SessionInterface $session): void { - $this->logger->info('Sending response.', ['response' => $response, 'context' => $context]); + $this->logger->info('Queueing server notification to client', [ + 'method' => $notification::getMethod(), + ]); - $encoded = null; + $this->queueOutgoing($notification, ['type' => 'notification'], $session); + } - try { - if ($response instanceof Response && [] === $response->result) { - $encoded = json_encode($response, \JSON_THROW_ON_ERROR | \JSON_FORCE_OBJECT); + /** + * Sends a response either immediately or queued for later delivery. + * + * @param Response>|Error $response + * @param array $context + */ + private function sendResponse(Response|Error $response, ?SessionInterface $session, array $context = []): void + { + if (null === $session) { + $this->logger->info('Sending immediate response', [ + 'response_id' => $response->getId(), + ]); + + try { + $encoded = json_encode($response, \JSON_THROW_ON_ERROR); + } catch (\JsonException $e) { + $this->logger->error('Failed to encode response to JSON.', [ + 'message_id' => $response->getId(), + 'exception' => $e, + ]); + + $fallbackError = new Error( + id: $response->getId(), + code: Error::INTERNAL_ERROR, + message: 'Response could not be encoded to JSON' + ); + + $encoded = json_encode($fallbackError, \JSON_THROW_ON_ERROR); } - $encoded = json_encode($response, \JSON_THROW_ON_ERROR); + $context['type'] = 'response'; + $this->transport->send($encoded, $context); + } else { + $this->logger->info('Queueing server response', [ + 'response_id' => $response->getId(), + ]); + + $this->queueOutgoing($response, ['type' => 'response'], $session); + } + } + + /** + * Helper to queue outgoing messages in session. + * + * @param Request|Notification|Response>|Error $message + * @param array $context + */ + private function queueOutgoing(Request|Notification|Response|Error $message, array $context, SessionInterface $session): void + { + try { + $encoded = json_encode($message, \JSON_THROW_ON_ERROR); } catch (\JsonException $e) { - $this->logger->error('Failed to encode response to JSON.', [ - 'message_id' => $response->getId(), + $this->logger->error('Failed to encode message to JSON.', [ 'exception' => $e, ]); - $fallbackError = new Error( - id: $response->getId(), - code: Error::INTERNAL_ERROR, - message: 'Response could not be encoded to JSON' - ); + return; + } + + $queue = $session->get(self::SESSION_OUTGOING_QUEUE, []); + $queue[] = [ + 'message' => $encoded, + 'context' => $context, + ]; + $session->set(self::SESSION_OUTGOING_QUEUE, $queue); + } - $encoded = json_encode($fallbackError, \JSON_THROW_ON_ERROR); + /** + * Consume (get and clear) all outgoing messages for a session. + * + * @return array}> + */ + public function consumeOutgoingMessages(Uuid $sessionId): array + { + $session = $this->sessionFactory->createWithId($sessionId, $this->sessionStore); + $queue = $session->get(self::SESSION_OUTGOING_QUEUE, []); + $session->set(self::SESSION_OUTGOING_QUEUE, []); + $session->save(); + + return $queue; + } + + /** + * Check for a response to a specific request ID. + * + * When a response is found, it is removed from the session, and the + * corresponding pending request is also cleared. + */ + /** + * @return Response>|Error|null + */ + public function checkResponse(int $requestId, Uuid $sessionId): Response|Error|null + { + $session = $this->sessionFactory->createWithId($sessionId, $this->sessionStore); + $responseData = $session->get(self::SESSION_RESPONSES.".{$requestId}"); + + if (null === $responseData) { + return null; } - $context['type'] = 'response'; - $this->transport->send($encoded, $context); + $this->logger->debug('Found and consuming client response.', [ + 'request_id' => $requestId, + 'session_id' => $sessionId->toRfc4122(), + ]); + + $session->set(self::SESSION_RESPONSES.".{$requestId}", null); + $pending = $session->get(self::SESSION_PENDING_REQUESTS, []); + unset($pending[$requestId]); + $session->set(self::SESSION_PENDING_REQUESTS, $pending); + $session->save(); + + try { + if (isset($responseData['error'])) { + return Error::fromArray($responseData); + } + + return Response::fromArray($responseData); + } catch (\Throwable $e) { + $this->logger->error('Failed to reconstruct client response from session.', [ + 'request_id' => $requestId, + 'exception' => $e, + 'response_data' => $responseData, + ]); + + return null; + } } /** - * @param array $context + * Get pending requests for a session. + * + * @return array The pending requests */ - public function sendNotification(Notification $notification, array $context = []): void + public function getPendingRequests(Uuid $sessionId): array { - $this->logger->info('Sending notification.', ['notification' => $notification, 'context' => $context]); - $context['type'] = 'notification'; - // TODO: Implement notification sending + $session = $this->sessionFactory->createWithId($sessionId, $this->sessionStore); + + return $session->get(self::SESSION_PENDING_REQUESTS, []); + } + + /** + * Handle values yielded by Fibers during transport-managed resumes. + * + * @param FiberSuspend|null $yieldedValue + */ + public function handleFiberYield(mixed $yieldedValue, ?Uuid $sessionId): void + { + if (!$sessionId) { + $this->logger->warning('Fiber yielded value without associated session context.'); + + return; + } + + if (!\is_array($yieldedValue) || !isset($yieldedValue['type'])) { + $this->logger->warning('Fiber yielded unexpected payload.', [ + 'payload' => $yieldedValue, + 'session_id' => $sessionId->toRfc4122(), + ]); + + return; + } + + $session = $this->sessionFactory->createWithId($sessionId, $this->sessionStore); + + $payloadSessionId = $yieldedValue['session_id'] ?? null; + if (\is_string($payloadSessionId) && $payloadSessionId !== $sessionId->toRfc4122()) { + $this->logger->warning('Fiber yielded payload with mismatched session ID.', [ + 'payload_session_id' => $payloadSessionId, + 'expected_session_id' => $sessionId->toRfc4122(), + ]); + } + + try { + if ('notification' === $yieldedValue['type']) { + $notification = $yieldedValue['notification'] ?? null; + if (!$notification instanceof Notification) { + $this->logger->warning('Fiber yielded notification without Notification instance.', [ + 'payload' => $yieldedValue, + ]); + + return; + } + + $this->sendNotification($notification, $session); + } elseif ('request' === $yieldedValue['type']) { + $request = $yieldedValue['request'] ?? null; + if (!$request instanceof Request) { + $this->logger->warning('Fiber yielded request without Request instance.', [ + 'payload' => $yieldedValue, + ]); + + return; + } + + $timeout = isset($yieldedValue['timeout']) ? (int) $yieldedValue['timeout'] : 120; + $this->sendRequest($request, $timeout, $session); + } else { + $this->logger->warning('Fiber yielded unknown operation type.', [ + 'type' => $yieldedValue['type'], + ]); + } + } finally { + $session->save(); + } } /** @@ -261,7 +528,7 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte // Spec: An initialize request must not be part of a batch. if (\count($messages) > 1) { $error = Error::forInvalidRequest('The "initialize" request MUST NOT be part of a batch.'); - $this->sendResponse($error, ['session_id' => $sessionId]); + $this->sendResponse($error, null); return null; } @@ -269,24 +536,31 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte // Spec: An initialize request must not have a session ID. if ($sessionId) { $error = Error::forInvalidRequest('A session ID MUST NOT be sent with an "initialize" request.'); - $this->sendResponse($error); + $this->sendResponse($error, null); return null; } - return $this->sessionFactory->create($this->sessionStore); + $session = $this->sessionFactory->create($this->sessionStore); + $this->logger->debug('Created new session for initialize', [ + 'session_id' => $session->getId()->toString(), + ]); + + $this->transport->setSessionId($session->getId()); + + return $session; } if (!$sessionId) { $error = Error::forInvalidRequest('A valid session id is REQUIRED for non-initialize requests.'); - $this->sendResponse($error, ['status_code' => 400]); + $this->sendResponse($error, null, ['status_code' => 400]); return null; } if (!$this->sessionStore->exists($sessionId)) { $error = Error::forInvalidRequest('Session not found or has expired.'); - $this->sendResponse($error, ['status_code' => 404]); + $this->sendResponse($error, null, ['status_code' => 404]); return null; } diff --git a/src/Server/Transport/BaseTransport.php b/src/Server/Transport/BaseTransport.php new file mode 100644 index 00000000..c2938b52 --- /dev/null +++ b/src/Server/Transport/BaseTransport.php @@ -0,0 +1,136 @@ + + */ +abstract class BaseTransport +{ + use ManagesTransportCallbacks; + + protected ?Uuid $sessionId = null; + + /** + * @var McpFiber|null + */ + protected ?\Fiber $sessionFiber = null; + + public function __construct( + protected readonly LoggerInterface $logger, + ) { + } + + public function initialize(): void + { + } + + public function close(): void + { + } + + public function setSessionId(?Uuid $sessionId): void + { + $this->sessionId = $sessionId; + } + + /** + * @param McpFiber $fiber + */ + public function attachFiberToSession(\Fiber $fiber, Uuid $sessionId): void + { + $this->sessionFiber = $fiber; + $this->sessionId = $sessionId; + } + + /** + * @return array}> + */ + protected function getOutgoingMessages(?Uuid $sessionId): array + { + if ($sessionId && \is_callable($this->outgoingMessagesProvider)) { + return ($this->outgoingMessagesProvider)($sessionId); + } + + return []; + } + + /** + * @return array> + */ + protected function getPendingRequests(?Uuid $sessionId): array + { + if ($sessionId && \is_callable($this->pendingRequestsProvider)) { + return ($this->pendingRequestsProvider)($sessionId); + } + + return []; + } + + /** + * @phpstan-return FiberResume + */ + protected function checkForResponse(int $requestId, ?Uuid $sessionId): Response|Error|null + { + if ($sessionId && \is_callable($this->responseFinder)) { + return ($this->responseFinder)($requestId, $sessionId); + } + + return null; + } + + /** + * @param FiberSuspend|null $yielded + */ + protected function handleFiberYield(mixed $yielded, ?Uuid $sessionId): void + { + if (null === $yielded || !\is_callable($this->fiberYieldHandler)) { + return; + } + + try { + ($this->fiberYieldHandler)($yielded, $sessionId); + } catch (\Throwable $e) { + $this->logger->error('Fiber yield handler failed.', [ + 'exception' => $e, + 'sessionId' => $sessionId?->toRfc4122(), + ]); + } + } + + protected function handleMessage(string $payload, ?Uuid $sessionId): void + { + if (\is_callable($this->messageListener)) { + ($this->messageListener)($payload, $sessionId); + } + } + + protected function handleSessionEnd(?Uuid $sessionId): void + { + if ($sessionId && \is_callable($this->sessionEndListener)) { + ($this->sessionEndListener)($sessionId); + } + } +} diff --git a/src/Server/Transport/CallbackStream.php b/src/Server/Transport/CallbackStream.php new file mode 100644 index 00000000..85525232 --- /dev/null +++ b/src/Server/Transport/CallbackStream.php @@ -0,0 +1,156 @@ + + */ +final class CallbackStream implements StreamInterface +{ + private bool $called = false; + + private ?\Throwable $exception = null; + + /** + * @param callable(): void $callback The callback to execute when stream is read + */ + public function __construct(private $callback, private LoggerInterface $logger = new NullLogger()) + { + } + + public function __toString(): string + { + try { + $this->invoke(); + } catch (\Throwable $e) { + $this->exception = $e; + $this->logger->error( + \sprintf('CallbackStream execution failed: %s', $e->getMessage()), + ['exception' => $e] + ); + } + + return ''; + } + + public function read($length): string + { + $this->invoke(); + + if (null !== $this->exception) { + throw $this->exception; + } + + return ''; + } + + public function getContents(): string + { + $this->invoke(); + + if (null !== $this->exception) { + throw $this->exception; + } + + return ''; + } + + public function eof(): bool + { + return $this->called; + } + + public function close(): void + { + // No-op - callback-based stream doesn't need closing + } + + public function detach() + { + return null; + } + + public function getSize(): ?int + { + return null; // Unknown size for callback streams + } + + public function tell(): int + { + return 0; + } + + public function isSeekable(): bool + { + return false; + } + + public function seek($offset, $whence = \SEEK_SET): void + { + throw new \RuntimeException('Stream is not seekable'); + } + + public function rewind(): void + { + throw new \RuntimeException('Stream is not seekable'); + } + + public function isWritable(): bool + { + return false; + } + + public function write($string): int + { + throw new \RuntimeException('Stream is not writable'); + } + + public function isReadable(): bool + { + return !$this->called; + } + + private function invoke(): void + { + if ($this->called) { + return; + } + + $this->called = true; + $this->exception = null; + ($this->callback)(); + } + + public function getMetadata($key = null) + { + return null === $key ? [] : null; + } +} diff --git a/src/Server/Transport/InMemoryTransport.php b/src/Server/Transport/InMemoryTransport.php index a1bd2946..a9ffca97 100644 --- a/src/Server/Transport/InMemoryTransport.php +++ b/src/Server/Transport/InMemoryTransport.php @@ -18,15 +18,9 @@ * * @author Tobias Nyholm */ -class InMemoryTransport implements TransportInterface +class InMemoryTransport extends BaseTransport implements TransportInterface { - /** @var callable(string, ?Uuid): void */ - private $messageListener; - - /** @var callable(Uuid): void */ - private $sessionDestroyListener; - - private ?Uuid $sessionId = null; + use ManagesTransportCallbacks; /** * @param list $messages @@ -58,29 +52,24 @@ public function send(string $data, array $context): void public function listen(): mixed { foreach ($this->messages as $message) { - if (\is_callable($this->messageListener)) { - \call_user_func($this->messageListener, $message, $this->sessionId); - } + $this->handleMessage($message, $this->sessionId); } - if (\is_callable($this->sessionDestroyListener) && null !== $this->sessionId) { - \call_user_func($this->sessionDestroyListener, $this->sessionId); - $this->sessionId = null; - } + $this->handleSessionEnd($this->sessionId); + + $this->sessionId = null; return null; } - public function onSessionEnd(callable $listener): void + public function setSessionId(?Uuid $sessionId): void { - $this->sessionDestroyListener = $listener; + $this->sessionId = $sessionId; } public function close(): void { - if (\is_callable($this->sessionDestroyListener) && null !== $this->sessionId) { - \call_user_func($this->sessionDestroyListener, $this->sessionId); - $this->sessionId = null; - } + $this->handleSessionEnd($this->sessionId); + $this->sessionId = null; } } diff --git a/src/Server/Transport/ManagesTransportCallbacks.php b/src/Server/Transport/ManagesTransportCallbacks.php new file mode 100644 index 00000000..a0d1aa6b --- /dev/null +++ b/src/Server/Transport/ManagesTransportCallbacks.php @@ -0,0 +1,82 @@ + + * */ +trait ManagesTransportCallbacks +{ + /** @var callable(string, ?Uuid): void */ + protected $messageListener; + + /** @var callable(Uuid): void */ + protected $sessionEndListener; + + /** @var callable(Uuid): array}> */ + protected $outgoingMessagesProvider; + + /** @var callable(Uuid): array> */ + protected $pendingRequestsProvider; + + /** @var callable(int, Uuid): Response>|Error|null */ + protected $responseFinder; + + /** @var callable(FiberSuspend|null, ?Uuid): void */ + protected $fiberYieldHandler; + + public function onMessage(callable $listener): void + { + $this->messageListener = $listener; + } + + public function onSessionEnd(callable $listener): void + { + $this->sessionEndListener = $listener; + } + + public function setOutgoingMessagesProvider(callable $provider): void + { + $this->outgoingMessagesProvider = $provider; + } + + public function setPendingRequestsProvider(callable $provider): void + { + $this->pendingRequestsProvider = $provider; + } + + /** + * @param callable(int, Uuid):(Response>|Error|null) $finder + */ + public function setResponseFinder(callable $finder): void + { + $this->responseFinder = $finder; + } + + /** + * @param callable(FiberSuspend|null, ?Uuid): void $handler + */ + public function setFiberYieldHandler(callable $handler): void + { + $this->fiberYieldHandler = $handler; + } +} diff --git a/src/Server/Transport/StdioTransport.php b/src/Server/Transport/StdioTransport.php index be69dd04..5f0afbfb 100644 --- a/src/Server/Transport/StdioTransport.php +++ b/src/Server/Transport/StdioTransport.php @@ -11,25 +11,17 @@ namespace Mcp\Server\Transport; +use Mcp\Schema\JsonRpc\Error; use Psr\Log\LoggerInterface; use Psr\Log\NullLogger; -use Symfony\Component\Uid\Uuid; /** * @implements TransportInterface * * @author Kyrian Obikwelu - */ -class StdioTransport implements TransportInterface + * */ +class StdioTransport extends BaseTransport implements TransportInterface { - /** @var callable(string, ?Uuid): void */ - private $messageListener; - - /** @var callable(Uuid): void */ - private $sessionEndListener; - - private ?Uuid $sessionId = null; - /** * @param resource $input * @param resource $output @@ -37,80 +29,137 @@ class StdioTransport implements TransportInterface public function __construct( private $input = \STDIN, private $output = \STDOUT, - private readonly LoggerInterface $logger = new NullLogger(), + LoggerInterface $logger = new NullLogger(), ) { + parent::__construct($logger); } - public function initialize(): void + public function send(string $data, array $context): void { + if (isset($context['session_id'])) { + $this->sessionId = $context['session_id']; + } + + $this->writeLine($data); } - public function onMessage(callable $listener): void + public function listen(): int { - $this->messageListener = $listener; + $this->logger->info('StdioTransport is listening for messages on STDIN...'); + stream_set_blocking($this->input, false); + + while (!feof($this->input)) { + $this->processInput(); + $this->processFiber(); + $this->flushOutgoingMessages(); + } + + $this->logger->info('StdioTransport finished listening.'); + $this->handleSessionEnd($this->sessionId); + + return 0; } - public function send(string $data, array $context): void + protected function processInput(): void { - $this->logger->debug('Sending data to client via StdioTransport.', ['data' => $data]); + $line = fgets($this->input); + if (false === $line) { + usleep(50000); // 50ms - if (isset($context['session_id'])) { - $this->sessionId = $context['session_id']; + return; } - fwrite($this->output, $data.\PHP_EOL); + $trimmedLine = trim($line); + if (!empty($trimmedLine)) { + $this->handleMessage($trimmedLine, $this->sessionId); + } } - public function listen(): int + private function processFiber(): void { - $this->logger->info('StdioTransport is listening for messages on STDIN...'); + if (null === $this->sessionFiber) { + return; + } - $status = 0; - while (!feof($this->input)) { - $line = fgets($this->input); - if (false === $line) { - if (!feof($this->input)) { - $status = 1; - } + if ($this->sessionFiber->isTerminated()) { + $this->handleFiberTermination(); - break; - } + return; + } - $trimmedLine = trim($line); - if (!empty($trimmedLine)) { - $this->logger->debug('Received message on StdioTransport.', ['line' => $trimmedLine]); - if (\is_callable($this->messageListener)) { - \call_user_func($this->messageListener, $trimmedLine, $this->sessionId); - } - } + if (!$this->sessionFiber->isSuspended()) { + return; } - $this->logger->info('StdioTransport finished listening.'); + $pendingRequests = $this->getPendingRequests($this->sessionId); - if (\is_callable($this->sessionEndListener) && null !== $this->sessionId) { - \call_user_func($this->sessionEndListener, $this->sessionId); - $this->sessionId = null; + if (empty($pendingRequests)) { + $yielded = $this->sessionFiber->resume(); + $this->handleFiberYield($yielded, $this->sessionId); + + return; } - return $status; + foreach ($pendingRequests as $pending) { + $requestId = $pending['request_id']; + $timestamp = $pending['timestamp']; + $timeout = $pending['timeout'] ?? 120; + + $response = $this->checkForResponse($requestId, $this->sessionId); + + if (null !== $response) { + $yielded = $this->sessionFiber->resume($response); + $this->handleFiberYield($yielded, $this->sessionId); + + return; + } + + if (time() - $timestamp >= $timeout) { + $error = Error::forInternalError('Request timed out', $requestId); + $yielded = $this->sessionFiber->resume($error); + $this->handleFiberYield($yielded, $this->sessionId); + + return; + } + } } - public function onSessionEnd(callable $listener): void + private function handleFiberTermination(): void { - $this->sessionEndListener = $listener; + $finalResult = $this->sessionFiber->getReturn(); + + if (null !== $finalResult) { + try { + $encoded = json_encode($finalResult, \JSON_THROW_ON_ERROR); + $this->writeLine($encoded); + } catch (\JsonException $e) { + $this->logger->error('STDIO: Failed to encode final Fiber result.', ['exception' => $e]); + } + } + + $this->sessionFiber = null; } - public function close(): void + private function flushOutgoingMessages(): void { - if (\is_callable($this->sessionEndListener) && null !== $this->sessionId) { - \call_user_func($this->sessionEndListener, $this->sessionId); - $this->sessionId = null; + $messages = $this->getOutgoingMessages($this->sessionId); + + foreach ($messages as $message) { + $this->writeLine($message['message']); } + } + private function writeLine(string $payload): void + { + fwrite($this->output, $payload.\PHP_EOL); + } + + public function close(): void + { + $this->handleSessionEnd($this->sessionId); if (\is_resource($this->input)) { fclose($this->input); } - if (\is_resource($this->output)) { fclose($this->output); } diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index f4d06b2d..c4e8e047 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -25,24 +25,14 @@ * @implements TransportInterface * * @author Kyrian Obikwelu - */ -class StreamableHttpTransport implements TransportInterface + * */ +class StreamableHttpTransport extends BaseTransport implements TransportInterface { private ResponseFactoryInterface $responseFactory; private StreamFactoryInterface $streamFactory; - /** @var callable(string, ?Uuid): void */ - private $messageListener; - - /** @var callable(Uuid): void */ - private $sessionEndListener; - - private ?Uuid $sessionId = null; - - /** @var string[] */ - private array $outgoingMessages = []; - private ?Uuid $outgoingSessionId = null; - private ?int $outgoingStatusCode = null; + private ?string $immediateResponse = null; + private ?int $immediateStatusCode = null; /** @var array */ private array $corsHeaders = [ @@ -57,6 +47,7 @@ public function __construct( ?StreamFactoryInterface $streamFactory = null, private readonly LoggerInterface $logger = new NullLogger(), ) { + parent::__construct($logger); $sessionIdString = $this->request->getHeaderLine('Mcp-Session-Id'); $this->sessionId = $sessionIdString ? Uuid::fromString($sessionIdString) : null; @@ -70,44 +61,20 @@ public function initialize(): void public function send(string $data, array $context): void { - $this->outgoingMessages[] = $data; - - if (isset($context['session_id'])) { - $this->outgoingSessionId = $context['session_id']; - } - - if (isset($context['status_code']) && \is_int($context['status_code'])) { - $this->outgoingStatusCode = $context['status_code']; - } - - $this->logger->debug('Sending data to client via StreamableHttpTransport.', [ - 'data' => $data, - 'session_id' => $this->outgoingSessionId?->toRfc4122(), - 'status_code' => $this->outgoingStatusCode, - ]); + $this->immediateResponse = $data; + $this->immediateStatusCode = $context['status_code'] ?? 200; } public function listen(): ResponseInterface { return match ($this->request->getMethod()) { 'OPTIONS' => $this->handleOptionsRequest(), - 'GET' => $this->handleGetRequest(), 'POST' => $this->handlePostRequest(), 'DELETE' => $this->handleDeleteRequest(), - default => $this->handleUnsupportedRequest(), + default => $this->createErrorResponse(Error::forInvalidRequest('Method Not Allowed'), 405), }; } - public function onMessage(callable $listener): void - { - $this->messageListener = $listener; - } - - public function onSessionEnd(callable $listener): void - { - $this->sessionEndListener = $listener; - } - protected function handleOptionsRequest(): ResponseInterface { return $this->withCorsHeaders($this->responseFactory->createResponse(204)); @@ -115,89 +82,163 @@ protected function handleOptionsRequest(): ResponseInterface protected function handlePostRequest(): ResponseInterface { - $acceptHeader = $this->request->getHeaderLine('Accept'); - if (!str_contains($acceptHeader, 'application/json') || !str_contains($acceptHeader, 'text/event-stream')) { - $error = Error::forInvalidRequest('Not Acceptable: Client must accept both application/json and text/event-stream.'); - $this->logger->warning('Client does not accept required content types.', ['accept' => $acceptHeader]); + $body = $this->request->getBody()->getContents(); + $this->handleMessage($body, $this->sessionId); - return $this->createErrorResponse($error, 406); + if (null !== $this->immediateResponse) { + $response = $this->responseFactory->createResponse($this->immediateStatusCode ?? 200) + ->withHeader('Content-Type', 'application/json') + ->withBody($this->streamFactory->createStream($this->immediateResponse)); + + return $this->withCorsHeaders($response); } - if (!str_contains($this->request->getHeaderLine('Content-Type'), 'application/json')) { - $error = Error::forInvalidRequest('Unsupported Media Type: Content-Type must be application/json.'); - $this->logger->warning('Client sent unsupported content type.', ['content_type' => $this->request->getHeaderLine('Content-Type')]); + if (null !== $this->sessionFiber) { + $this->logger->info('Fiber suspended, handling via SSE.'); - return $this->createErrorResponse($error, 415); + return $this->createStreamedResponse(); } - $body = $this->request->getBody()->getContents(); - if (empty($body)) { - $error = Error::forInvalidRequest('Bad Request: Empty request body.'); - $this->logger->warning('Client sent empty request body.'); + return $this->createJsonResponse(); + } - return $this->createErrorResponse($error, 400); + protected function handleDeleteRequest(): ResponseInterface + { + if (!$this->sessionId) { + return $this->createErrorResponse(Error::forInvalidRequest('Mcp-Session-Id header is required.'), 400); } - $this->logger->debug('Received message on StreamableHttpTransport.', [ - 'body' => $body, - 'session_id' => $this->sessionId?->toRfc4122(), - ]); + $this->handleSessionEnd($this->sessionId); - if (\is_callable($this->messageListener)) { - \call_user_func($this->messageListener, $body, $this->sessionId); - } + return $this->withCorsHeaders($this->responseFactory->createResponse(204)); + } + + protected function createJsonResponse(): ResponseInterface + { + $outgoingMessages = $this->getOutgoingMessages($this->sessionId); - if (empty($this->outgoingMessages)) { + if (empty($outgoingMessages)) { return $this->withCorsHeaders($this->responseFactory->createResponse(202)); } - $responseBody = 1 === \count($this->outgoingMessages) - ? $this->outgoingMessages[0] - : '['.implode(',', $this->outgoingMessages).']'; + $messages = array_column($outgoingMessages, 'message'); + $responseBody = 1 === \count($messages) ? $messages[0] : '['.implode(',', $messages).']'; - $status = $this->outgoingStatusCode ?? 200; - - $response = $this->responseFactory->createResponse($status) + $response = $this->responseFactory->createResponse(200) ->withHeader('Content-Type', 'application/json') ->withBody($this->streamFactory->createStream($responseBody)); - if ($this->outgoingSessionId) { - $response = $response->withHeader('Mcp-Session-Id', $this->outgoingSessionId->toRfc4122()); + if ($this->sessionId) { + $response = $response->withHeader('Mcp-Session-Id', $this->sessionId->toRfc4122()); } return $this->withCorsHeaders($response); } - protected function handleGetRequest(): ResponseInterface + protected function createStreamedResponse(): ResponseInterface { - $response = $this->createErrorResponse(Error::forInvalidRequest('Not Yet Implemented'), 405); + $callback = function (): void { + try { + $this->logger->info('SSE: Starting request processing loop'); + + while ($this->sessionFiber->isSuspended()) { + $this->flushOutgoingMessages($this->sessionId); + + $pendingRequests = $this->getPendingRequests($this->sessionId); + + if (empty($pendingRequests)) { + $yielded = $this->sessionFiber->resume(); + $this->handleFiberYield($yielded, $this->sessionId); + continue; + } + + $resumed = false; + foreach ($pendingRequests as $pending) { + $requestId = $pending['request_id']; + $timestamp = $pending['timestamp']; + $timeout = $pending['timeout'] ?? 120; + + $response = $this->checkForResponse($requestId, $this->sessionId); + + if (null !== $response) { + $yielded = $this->sessionFiber->resume($response); + $this->handleFiberYield($yielded, $this->sessionId); + $resumed = true; + break; + } + + if (time() - $timestamp >= $timeout) { + $error = Error::forInternalError('Request timed out', $requestId); + $yielded = $this->sessionFiber->resume($error); + $this->handleFiberYield($yielded, $this->sessionId); + $resumed = true; + break; + } + } + + if (!$resumed) { + usleep(100000); + } // Prevent tight loop + } + + $this->handleFiberTermination(); + } finally { + $this->sessionFiber = null; + } + }; + + $stream = new CallbackStream($callback, $this->logger); + $response = $this->responseFactory->createResponse(200) + ->withHeader('Content-Type', 'text/event-stream') + ->withHeader('Cache-Control', 'no-cache') + ->withHeader('Connection', 'keep-alive') + ->withHeader('X-Accel-Buffering', 'no') + ->withBody($stream); + + if ($this->sessionId) { + $response = $response->withHeader('Mcp-Session-Id', $this->sessionId->toRfc4122()); + } return $this->withCorsHeaders($response); } - protected function handleDeleteRequest(): ResponseInterface + private function handleFiberTermination(): void { - if (!$this->sessionId) { - $error = Error::forInvalidRequest('Bad Request: Mcp-Session-Id header is required for DELETE requests.'); - $this->logger->warning('DELETE request received without session ID.'); - - return $this->createErrorResponse($error, 400); - } - - if (\is_callable($this->sessionEndListener)) { - \call_user_func($this->sessionEndListener, $this->sessionId); + $finalResult = $this->sessionFiber->getReturn(); + + if (null !== $finalResult) { + try { + $encoded = json_encode($finalResult, \JSON_THROW_ON_ERROR); + echo "event: message\n"; + echo "data: {$encoded}\n\n"; + @ob_flush(); + flush(); + } catch (\JsonException $e) { + $this->logger->error('SSE: Failed to encode final Fiber result.', ['exception' => $e]); + } } - return $this->withCorsHeaders($this->responseFactory->createResponse(204)); + $this->sessionFiber = null; } - protected function handleUnsupportedRequest(): ResponseInterface + private function flushOutgoingMessages(?Uuid $sessionId): void { - $this->logger->warning('Unsupported HTTP method received.', [ - 'method' => $this->request->getMethod(), - ]); + $messages = $this->getOutgoingMessages($sessionId); + + foreach ($messages as $message) { + echo "event: message\n"; + echo "data: {$message['message']}\n\n"; + @ob_flush(); + flush(); + } + } - $response = $this->createErrorResponse(Error::forInvalidRequest('Method Not Allowed'), 405); + protected function createErrorResponse(Error $jsonRpcError, int $statusCode): ResponseInterface + { + $payload = json_encode($jsonRpcError, \JSON_THROW_ON_ERROR); + $response = $this->responseFactory->createResponse($statusCode) + ->withHeader('Content-Type', 'application/json') + ->withBody($this->streamFactory->createStream($payload)); return $this->withCorsHeaders($response); } @@ -210,17 +251,4 @@ protected function withCorsHeaders(ResponseInterface $response): ResponseInterfa return $response; } - - protected function createErrorResponse(Error $jsonRpcError, int $statusCode): ResponseInterface - { - $errorPayload = json_encode($jsonRpcError, \JSON_THROW_ON_ERROR); - - return $this->responseFactory->createResponse($statusCode) - ->withHeader('Content-Type', 'application/json') - ->withBody($this->streamFactory->createStream($errorPayload)); - } - - public function close(): void - { - } } diff --git a/src/Server/Transport/TransportInterface.php b/src/Server/Transport/TransportInterface.php index a082d070..400c453e 100644 --- a/src/Server/Transport/TransportInterface.php +++ b/src/Server/Transport/TransportInterface.php @@ -11,11 +11,21 @@ namespace Mcp\Server\Transport; +use Mcp\Schema\JsonRpc\Error; +use Mcp\Schema\JsonRpc\Response; use Symfony\Component\Uid\Uuid; /** * @template TResult * + * @phpstan-type FiberReturn (Response|Error) + * @phpstan-type FiberResume (FiberReturn|null) + * @phpstan-type FiberSuspend ( + * array{type: 'notification', notification: \Mcp\Schema\JsonRpc\Notification}| + * array{type: 'request', request: \Mcp\Schema\JsonRpc\Request, timeout?: int} + * ) + * @phpstan-type McpFiber \Fiber + * * @author Christopher Hertel * @author Kyrian Obikwelu */ @@ -26,15 +36,6 @@ interface TransportInterface */ public function initialize(): void; - /** - * Register callback for ALL incoming messages. - * - * The transport calls this whenever ANY message arrives, regardless of source. - * - * @param callable(string $message, ?Uuid $sessionId): void $listener - */ - public function onMessage(callable $listener): void; - /** * Starts the transport's execution process. * @@ -47,32 +48,85 @@ public function onMessage(callable $listener): void; public function listen(): mixed; /** - * Send a message to the client. + * Send a message to the client immediately (bypassing session queue). * - * The transport decides HOW to send based on context + * Used for session resolution errors when no session is available. + * The transport decides HOW to send based on context. * * @param array $context Context about this message: * - 'session_id': Uuid|null * - 'type': 'response'|'request'|'notification' - * - 'in_reply_to': int|string|null (ID of request this responds to) - * - 'expects_response': bool (if this is a request needing response) + * - 'status_code': int (HTTP status code for errors) */ public function send(string $data, array $context): void; /** - * Register callback for session termination. + * Closes the transport and cleans up any resources. + */ + public function close(): void; + + /** + * Register callback for ALL incoming messages. + * + * The transport calls this whenever ANY message arrives, regardless of source. + * + * @param callable(string $message, ?Uuid $sessionId): void $listener + */ + public function onMessage(callable $listener): void; + + /** + * Register a listener for when a session is terminated. * - * This can happen when a client disconnects or explicitly ends their session. + * The transport calls this when a client disconnects or explicitly ends their session. * * @param callable(Uuid $sessionId): void $listener The callback function to execute when destroying a session */ public function onSessionEnd(callable $listener): void; /** - * Closes the transport and cleans up any resources. + * Set a provider function to retrieve all queued outgoing messages. + * + * The transport calls this to retrieve all queued messages for a session. * - * This method should be called when the transport is no longer needed. - * It should clean up any resources and close any connections. + * @param callable(Uuid $sessionId): array}> $provider */ - public function close(): void; + public function setOutgoingMessagesProvider(callable $provider): void; + + /** + * Set a provider function to retrieve all pending server-initiated requests. + * + * The transport calls this to decide if it should wait for a client response before resuming a Fiber. + * + * @param callable(Uuid $sessionId): array> $provider + */ + public function setPendingRequestsProvider(callable $provider): void; + + /** + * Set a finder function to check for a specific client response. + * + * @param callable(int, Uuid):FiberResume $finder + */ + public function setResponseFinder(callable $finder): void; + + /** + * Set a handler for processing values yielded from a suspended Fiber. + * + * The transport calls this to let the Protocol handle new requests/notifications + * that are yielded from a Fiber's execution. + * + * @param callable(FiberSuspend|null, ?Uuid $sessionId): void $handler + */ + public function setFiberYieldHandler(callable $handler): void; + + /** + * @param McpFiber $fiber + */ + public function attachFiberToSession(\Fiber $fiber, Uuid $sessionId): void; + + /** + * Set the session ID for the current transport context. + * + * @param Uuid|null $sessionId The session ID, or null to clear + */ + public function setSessionId(?Uuid $sessionId): void; } diff --git a/tests/Unit/JsonRpc/MessageFactoryTest.php b/tests/Unit/JsonRpc/MessageFactoryTest.php index 7f591e57..d38aabeb 100644 --- a/tests/Unit/JsonRpc/MessageFactoryTest.php +++ b/tests/Unit/JsonRpc/MessageFactoryTest.php @@ -98,7 +98,7 @@ public function testCreateResponseWithIntegerId(): void $results = $this->factory->create($json); $this->assertCount(1, $results); - /** @var Response $result */ + /** @var Response> $result */ $result = $results[0]; $this->assertInstanceOf(Response::class, $result); $this->assertSame(456, $result->getId()); @@ -113,7 +113,7 @@ public function testCreateResponseWithStringId(): void $results = $this->factory->create($json); $this->assertCount(1, $results); - /** @var Response $result */ + /** @var Response> $result */ $result = $results[0]; $this->assertInstanceOf(Response::class, $result); $this->assertSame('response-1', $result->getId()); diff --git a/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php b/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php index 11b799bf..87cf15fa 100644 --- a/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php +++ b/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php @@ -30,10 +30,10 @@ class CallToolHandlerTest extends TestCase { private CallToolHandler $handler; - private ReferenceProviderInterface|MockObject $referenceProvider; - private ReferenceHandlerInterface|MockObject $referenceHandler; - private LoggerInterface|MockObject $logger; - private SessionInterface|MockObject $session; + private ReferenceProviderInterface&MockObject $referenceProvider; + private ReferenceHandlerInterface&MockObject $referenceHandler; + private LoggerInterface&MockObject $logger; + private SessionInterface&MockObject $session; protected function setUp(): void { @@ -71,7 +71,7 @@ public function testHandleSuccessfulToolCall(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, ['name' => 'John']) + ->with($toolReference, ['name' => 'John', '_session' => $this->session]) ->willReturn('Hello, John!'); $toolReference @@ -80,9 +80,7 @@ public function testHandleSuccessfulToolCall(): void ->with('Hello, John!') ->willReturn([new TextContent('Hello, John!')]); - $this->logger - ->expects($this->never()) - ->method('error'); + // Logger may be called for debugging, so we don't assert never() $response = $this->handler->handle($request, $this->session); @@ -106,7 +104,7 @@ public function testHandleToolCallWithEmptyArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, []) + ->with($toolReference, ['_session' => $this->session]) ->willReturn('Simple result'); $toolReference @@ -143,7 +141,7 @@ public function testHandleToolCallWithComplexArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, $arguments) + ->with($toolReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn('Complex result'); $toolReference @@ -194,7 +192,7 @@ public function testHandleToolExecutionExceptionReturnsError(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, ['param' => 'value']) + ->with($toolReference, ['param' => 'value', '_session' => $this->session]) ->willThrowException($exception); $this->logger @@ -223,7 +221,7 @@ public function testHandleWithNullResult(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, []) + ->with($toolReference, ['_session' => $this->session]) ->willReturn(null); $toolReference @@ -260,7 +258,7 @@ public function testHandleLogsErrorWithCorrectParameters(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, ['key1' => 'value1', 'key2' => 42]) + ->with($toolReference, ['key1' => 'value1', 'key2' => 42, '_session' => $this->session]) ->willThrowException($exception); $this->logger @@ -270,11 +268,14 @@ public function testHandleLogsErrorWithCorrectParameters(): void 'Error while executing tool "test_tool": "Tool call "test_tool" failed with error: "Custom error message".".', [ 'tool' => 'test_tool', - 'arguments' => ['key1' => 'value1', 'key2' => 42], + 'arguments' => ['key1' => 'value1', 'key2' => 42, '_session' => $this->session], ], ); - $this->handler->handle($request, $this->session); + $response = $this->handler->handle($request, $this->session); + + $this->assertInstanceOf(Error::class, $response); + $this->assertEquals(Error::INTERNAL_ERROR, $response->code); } public function testHandleWithSpecialCharactersInToolName(): void @@ -292,7 +293,7 @@ public function testHandleWithSpecialCharactersInToolName(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, []) + ->with($toolReference, ['_session' => $this->session]) ->willReturn('Special tool result'); $toolReference @@ -327,7 +328,7 @@ public function testHandleWithSpecialCharactersInArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, $arguments) + ->with($toolReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn('Unicode handled'); $toolReference diff --git a/tests/Unit/Server/Handler/Request/GetPromptHandlerTest.php b/tests/Unit/Server/Handler/Request/GetPromptHandlerTest.php index 3f5171b1..b7f5d259 100644 --- a/tests/Unit/Server/Handler/Request/GetPromptHandlerTest.php +++ b/tests/Unit/Server/Handler/Request/GetPromptHandlerTest.php @@ -31,9 +31,9 @@ class GetPromptHandlerTest extends TestCase { private GetPromptHandler $handler; - private ReferenceProviderInterface|MockObject $referenceProvider; - private ReferenceHandlerInterface|MockObject $referenceHandler; - private SessionInterface|MockObject $session; + private ReferenceProviderInterface&MockObject $referenceProvider; + private ReferenceHandlerInterface&MockObject $referenceHandler; + private SessionInterface&MockObject $session; protected function setUp(): void { @@ -70,7 +70,7 @@ public function testHandleSuccessfulPromptGet(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, []) + ->with($promptReference, ['_session' => $this->session]) ->willReturn($expectedMessages); $promptReference @@ -112,7 +112,7 @@ public function testHandlePromptGetWithArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, $arguments) + ->with($promptReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn($expectedMessages); $promptReference @@ -145,7 +145,7 @@ public function testHandlePromptGetWithNullArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, []) + ->with($promptReference, ['_session' => $this->session]) ->willReturn($expectedMessages); $promptReference @@ -178,7 +178,7 @@ public function testHandlePromptGetWithEmptyArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, []) + ->with($promptReference, ['_session' => $this->session]) ->willReturn($expectedMessages); $promptReference @@ -213,7 +213,7 @@ public function testHandlePromptGetWithMultipleMessages(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, []) + ->with($promptReference, ['_session' => $this->session]) ->willReturn($expectedMessages); $promptReference @@ -263,7 +263,7 @@ public function testHandlePromptGetExceptionReturnsError(): void $this->assertInstanceOf(Error::class, $response); $this->assertEquals($request->getId(), $response->id); $this->assertEquals(Error::INTERNAL_ERROR, $response->code); - $this->assertEquals('Error while handling prompt', $response->message); + $this->assertEquals('Error while handling prompt: Handling prompt "failing_prompt" failed with error: "Failed to get prompt".', $response->message); } public function testHandlePromptGetWithComplexArguments(): void @@ -299,7 +299,7 @@ public function testHandlePromptGetWithComplexArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, $arguments) + ->with($promptReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn($expectedMessages); $promptReference @@ -337,7 +337,7 @@ public function testHandlePromptGetWithSpecialCharacters(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, $arguments) + ->with($promptReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn($expectedMessages); $promptReference @@ -367,7 +367,7 @@ public function testHandlePromptGetReturnsEmptyMessages(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, []) + ->with($promptReference, ['_session' => $this->session]) ->willReturn([]); $promptReference @@ -405,7 +405,7 @@ public function testHandlePromptGetWithLargeNumberOfArguments(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($promptReference, $arguments) + ->with($promptReference, array_merge($arguments, ['_session' => $this->session])) ->willReturn($expectedMessages); $promptReference diff --git a/tests/Unit/Server/Handler/Request/ReadResourceHandlerTest.php b/tests/Unit/Server/Handler/Request/ReadResourceHandlerTest.php index 440005d4..2c54110d 100644 --- a/tests/Unit/Server/Handler/Request/ReadResourceHandlerTest.php +++ b/tests/Unit/Server/Handler/Request/ReadResourceHandlerTest.php @@ -31,9 +31,9 @@ class ReadResourceHandlerTest extends TestCase { private ReadResourceHandler $handler; - private ReferenceProviderInterface|MockObject $referenceProvider; - private ReferenceHandlerInterface|MockObject $referenceHandler; - private SessionInterface|MockObject $session; + private ReferenceProviderInterface&MockObject $referenceProvider; + private ReferenceHandlerInterface&MockObject $referenceHandler; + private SessionInterface&MockObject $session; protected function setUp(): void { @@ -75,7 +75,7 @@ public function testHandleSuccessfulResourceRead(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn('test'); $resourceReference @@ -115,7 +115,7 @@ public function testHandleResourceReadWithBlobContent(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn('fake-image-data'); $resourceReference @@ -159,7 +159,7 @@ public function testHandleResourceReadWithMultipleContents(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn('binary-data'); $resourceReference @@ -250,7 +250,7 @@ public function testHandleResourceReadWithDifferentUriSchemes(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn('test'); $resourceReference @@ -295,7 +295,7 @@ public function testHandleResourceReadWithEmptyContent(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn(''); $resourceReference @@ -357,7 +357,7 @@ public function testHandleResourceReadWithDifferentMimeTypes(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($resourceReference, ['uri' => $uri]) + ->with($resourceReference, ['uri' => $uri, '_session' => $this->session]) ->willReturn($expectedContent); $resourceReference diff --git a/tests/Unit/Server/ProtocolTest.php b/tests/Unit/Server/ProtocolTest.php index 859c4271..fa949c38 100644 --- a/tests/Unit/Server/ProtocolTest.php +++ b/tests/Unit/Server/ProtocolTest.php @@ -98,16 +98,26 @@ public function testRequestHandledByFirstMatchingHandler(): void $this->sessionStore->method('exists')->willReturn(true); $session->method('getId')->willReturn(Uuid::v4()); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } + + return $default; + }); - return isset($decoded['result']); - }), - $this->anything() - ); + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [$handlerA, $handlerB, $handlerC], @@ -124,6 +134,13 @@ public function testRequestHandledByFirstMatchingHandler(): void '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', $sessionId ); + + // Check that the response was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('result', $message); } #[TestDox('Initialize request must not have a session ID')] @@ -297,17 +314,26 @@ public function testInvalidMessageStructureReturnsError(): void $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } - return isset($decoded['error']) - && Error::INVALID_REQUEST === $decoded['error']['code']; - }), - $this->anything() - ); + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [], @@ -324,6 +350,14 @@ public function testInvalidMessageStructureReturnsError(): void '{"jsonrpc": "2.0", "params": {}}', $sessionId ); + + // Check that the error was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('error', $message); + $this->assertEquals(Error::INVALID_REQUEST, $message['error']['code']); } #[TestDox('Request without handler returns method not found error')] @@ -334,18 +368,26 @@ public function testRequestWithoutHandlerReturnsMethodNotFoundError(): void $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } - return isset($decoded['error']) - && Error::METHOD_NOT_FOUND === $decoded['error']['code'] - && str_contains($decoded['error']['message'], 'No handler found'); - }), - $this->anything() - ); + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [], @@ -362,6 +404,15 @@ public function testRequestWithoutHandlerReturnsMethodNotFoundError(): void '{"jsonrpc": "2.0", "id": 1, "method": "ping"}', $sessionId ); + + // Check that the error was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('error', $message); + $this->assertEquals(Error::METHOD_NOT_FOUND, $message['error']['code']); + $this->assertStringContainsString('No handler found', $message['error']['message']); } #[TestDox('Handler throwing InvalidArgumentException returns invalid params error')] @@ -376,18 +427,26 @@ public function testHandlerInvalidArgumentReturnsInvalidParamsError(): void $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } - return isset($decoded['error']) - && Error::INVALID_PARAMS === $decoded['error']['code'] - && str_contains($decoded['error']['message'], 'Invalid parameter'); - }), - $this->anything() - ); + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [$handler], @@ -404,6 +463,15 @@ public function testHandlerInvalidArgumentReturnsInvalidParamsError(): void '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "test"}}', $sessionId ); + + // Check that the error was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('error', $message); + $this->assertEquals(Error::INVALID_PARAMS, $message['error']['code']); + $this->assertStringContainsString('Invalid parameter', $message['error']['message']); } #[TestDox('Handler throwing unexpected exception returns internal error')] @@ -418,18 +486,26 @@ public function testHandlerUnexpectedExceptionReturnsInternalError(): void $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } - return isset($decoded['error']) - && Error::INTERNAL_ERROR === $decoded['error']['code'] - && str_contains($decoded['error']['message'], 'Unexpected error'); - }), - $this->anything() - ); + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [$handler], @@ -446,6 +522,15 @@ public function testHandlerUnexpectedExceptionReturnsInternalError(): void '{"jsonrpc": "2.0", "id": 1, "method": "tools/call", "params": {"name": "test"}}', $sessionId ); + + // Check that the error was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('error', $message); + $this->assertEquals(Error::INTERNAL_ERROR, $message['error']['code']); + $this->assertStringContainsString('Unexpected error', $message['error']['message']); } #[TestDox('Notification handler exceptions are caught and logged')] @@ -493,18 +578,26 @@ public function testSuccessfulRequestReturnsResponseWithSessionId(): void $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - $this->transport->expects($this->once()) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } - return isset($decoded['result']); - }), - $this->callback(function ($context) use ($sessionId) { - return isset($context['session_id']) && $context['session_id']->equals($sessionId); - }) - ); + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + + // The protocol now queues responses instead of sending them directly + // save() is called once during processInput and once during consumeOutgoingMessages + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [$handler], @@ -520,6 +613,14 @@ public function testSuccessfulRequestReturnsResponseWithSessionId(): void '{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}', $sessionId ); + + // Check that the response was queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(1, $outgoing); + + $message = json_decode($outgoing[0]['message'], true); + $this->assertArrayHasKey('result', $message); + $this->assertEquals(['status' => 'ok'], $message['result']); } #[TestDox('Batch requests are processed and send multiple responses')] @@ -528,26 +629,38 @@ public function testBatchRequestsAreProcessed(): void $handlerA = $this->createMock(RequestHandlerInterface::class); $handlerA->method('supports')->willReturn(true); $handlerA->method('handle')->willReturnCallback(function ($request) { - return new Response($request->getId(), ['method' => $request::getMethod()]); + return Response::fromArray([ + 'jsonrpc' => '2.0', + 'id' => $request->getId(), + 'result' => ['method' => $request::getMethod()], + ]); }); $session = $this->createMock(SessionInterface::class); $session->method('getId')->willReturn(Uuid::v4()); + // Configure session mock for queue operations + $queue = []; + $session->method('get')->willReturnCallback(function ($key, $default = null) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + return $queue; + } + + return $default; + }); + + $session->method('set')->willReturnCallback(function ($key, $value) use (&$queue) { + if ('_mcp.outgoing_queue' === $key) { + $queue = $value; + } + }); + $this->sessionFactory->method('createWithId')->willReturn($session); $this->sessionStore->method('exists')->willReturn(true); - // Expect two calls to send() - $this->transport->expects($this->exactly(2)) - ->method('send') - ->with( - $this->callback(function ($data) { - $decoded = json_decode($data, true); - - return isset($decoded['result']); - }), - $this->anything() - ); + // The protocol now queues responses instead of sending them directly + $session->expects($this->exactly(2)) + ->method('save'); $protocol = new Protocol( requestHandlers: [$handlerA], @@ -564,6 +677,15 @@ public function testBatchRequestsAreProcessed(): void '[{"jsonrpc": "2.0", "method": "tools/list", "id": 1}, {"jsonrpc": "2.0", "method": "prompts/list", "id": 2}]', $sessionId ); + + // Check that both responses were queued in the session + $outgoing = $protocol->consumeOutgoingMessages($sessionId); + $this->assertCount(2, $outgoing); + + foreach ($outgoing as $outgoingMessage) { + $message = json_decode($outgoingMessage['message'], true); + $this->assertArrayHasKey('result', $message); + } } #[TestDox('Session is saved after processing')] From e9a6a9aaf55fc61aff96604376ad245cf4a77004 Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Wed, 15 Oct 2025 22:40:04 +0100 Subject: [PATCH 2/4] fix(server): use correct session ID string conversion method --- src/Server/ClientGateway.php | 4 ++-- src/Server/Protocol.php | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Server/ClientGateway.php b/src/Server/ClientGateway.php index 58f76b20..7e95fe9a 100644 --- a/src/Server/ClientGateway.php +++ b/src/Server/ClientGateway.php @@ -66,7 +66,7 @@ public function notify(Notification $notification): void \Fiber::suspend([ 'type' => 'notification', 'notification' => $notification, - 'session_id' => $this->session->getId()->toString(), + 'session_id' => $this->session->getId()->toRfc4122(), ]); } @@ -112,7 +112,7 @@ public function request(Request $request, int $timeout = 120): Response|Error $response = \Fiber::suspend([ 'type' => 'request', 'request' => $request, - 'session_id' => $this->session->getId()->toString(), + 'session_id' => $this->session->getId()->toRfc4122(), 'timeout' => $timeout, ]); diff --git a/src/Server/Protocol.php b/src/Server/Protocol.php index cbff5fbe..c3b42f58 100644 --- a/src/Server/Protocol.php +++ b/src/Server/Protocol.php @@ -543,7 +543,7 @@ private function resolveSession(?Uuid $sessionId, array $messages): ?SessionInte $session = $this->sessionFactory->create($this->sessionStore); $this->logger->debug('Created new session for initialize', [ - 'session_id' => $session->getId()->toString(), + 'session_id' => $session->getId()->toRfc4122(), ]); $this->transport->setSessionId($session->getId()); From 23e8e0697076a15e5576dff8fb994542f8a3da6e Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Tue, 21 Oct 2025 16:23:04 +0100 Subject: [PATCH 3/4] fix: clean up docblocks for request handler methods in Builder --- src/Server/Builder.php | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/Server/Builder.php b/src/Server/Builder.php index 25a97bde..f309a3d4 100644 --- a/src/Server/Builder.php +++ b/src/Server/Builder.php @@ -176,8 +176,7 @@ public function setCapabilities(ServerCapabilities $serverCapabilities): self /** * Register a single custom method handler. - */ - /** + * * @param RequestHandlerInterface $handler */ public function addRequestHandler(RequestHandlerInterface $handler): self @@ -190,9 +189,6 @@ public function addRequestHandler(RequestHandlerInterface $handler): self /** * Register multiple custom method handlers. * - * @param iterable $handlers - */ - /** * @param iterable> $handlers */ public function addRequestHandlers(iterable $handlers): self From 5db83d44181d45be3cf757858802e5913c8de7d0 Mon Sep 17 00:00:00 2001 From: Kyrian Obikwelu Date: Sun, 26 Oct 2025 22:32:25 +0100 Subject: [PATCH 4/4] fix: merge artifacts --- src/Server/Transport/StreamableHttpTransport.php | 2 +- tests/Unit/Server/Handler/Request/CallToolHandlerTest.php | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index c4e8e047..8ff918bc 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -45,7 +45,7 @@ public function __construct( private readonly ServerRequestInterface $request, ?ResponseFactoryInterface $responseFactory = null, ?StreamFactoryInterface $streamFactory = null, - private readonly LoggerInterface $logger = new NullLogger(), + LoggerInterface $logger = new NullLogger(), ) { parent::__construct($logger); $sessionIdString = $this->request->getHeaderLine('Mcp-Session-Id'); diff --git a/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php b/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php index 87cf15fa..359afa1b 100644 --- a/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php +++ b/tests/Unit/Server/Handler/Request/CallToolHandlerTest.php @@ -358,7 +358,7 @@ public function testHandleReturnsStructuredContentResult(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, ['query' => 'php']) + ->with($toolReference, ['query' => 'php', '_session' => $this->session]) ->willReturn($structuredResult); $toolReference @@ -387,7 +387,7 @@ public function testHandleReturnsCallToolResult(): void $this->referenceHandler ->expects($this->once()) ->method('handle') - ->with($toolReference, ['query' => 'php']) + ->with($toolReference, ['query' => 'php', '_session' => $this->session]) ->willReturn($callToolResult); $toolReference