From ac130ffe69275e6b188b239da28e5c4cb843c2ec Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Nov 2025 14:48:45 +0100 Subject: [PATCH 1/6] feat(mcp): add LocalAI endpoint to stream live results of the agent Signed-off-by: Ettore Di Giacinto --- core/config/model_config.go | 66 +++++++ core/http/app.go | 2 +- core/http/endpoints/localai/mcp.go | 268 +++++++++++++++++++++++++++++ core/http/endpoints/openai/mcp.go | 7 + core/http/routes/localai.go | 23 ++- core/http/static/chat.js | 181 ++++++++++++++----- core/http/views/chat.html | 45 ++++- 7 files changed, 540 insertions(+), 52 deletions(-) create mode 100644 core/http/endpoints/localai/mcp.go diff --git a/core/config/model_config.go b/core/config/model_config.go index 41664f2a3dd4..1fcc13fa9754 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -1,14 +1,17 @@ package config import ( + "context" "os" "regexp" "slices" "strings" + "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/cogito" "gopkg.in/yaml.v3" ) @@ -668,3 +671,66 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool { return true } + +// BuildCogitoOptions generates cogito options from the model configuration +// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results +func (c *ModelConfig) BuildCogitoOptions( + ctx context.Context, + sessions []*mcp.ClientSession, + statusCallback func(string), + reasoningCallback func(string), + toolCallCallback func(*cogito.ToolChoice) bool, + toolCallResultCallback func(cogito.ToolStatus), +) []cogito.Option { + cogitoOpts := []cogito.Option{ + cogito.WithContext(ctx), + cogito.WithMCPs(sessions...), + cogito.WithIterations(3), // default to 3 iterations + cogito.WithMaxAttempts(3), // default to 3 attempts + cogito.WithForceReasoning(), + } + + // Add optional callbacks if provided + if statusCallback != nil { + cogitoOpts = append(cogitoOpts, cogito.WithStatusCallback(statusCallback)) + } + + if reasoningCallback != nil { + cogitoOpts = append(cogitoOpts, cogito.WithReasoningCallback(reasoningCallback)) + } + + if toolCallCallback != nil { + cogitoOpts = append(cogitoOpts, cogito.WithToolCallBack(toolCallCallback)) + } + + if toolCallResultCallback != nil { + cogitoOpts = append(cogitoOpts, cogito.WithToolCallResultCallback(toolCallResultCallback)) + } + + // Apply agent configuration options + if c.Agent.EnableReasoning { + cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner) + } + + if c.Agent.EnablePlanning { + cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan) + } + + if c.Agent.EnableMCPPrompts { + cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts) + } + + if c.Agent.EnablePlanReEvaluator { + cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator) + } + + if c.Agent.MaxIterations != 0 { + cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations)) + } + + if c.Agent.MaxAttempts != 0 { + cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts)) + } + + return cogitoOpts +} diff --git a/core/http/app.go b/core/http/app.go index 731e69df565c..7497a5d611fa 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -205,7 +205,7 @@ func API(application *application.Application) (*echo.Echo, error) { opcache = services.NewOpCache(application.GalleryService()) } - routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) + routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator()) routes.RegisterOpenAIRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache) diff --git a/core/http/endpoints/localai/mcp.go b/core/http/endpoints/localai/mcp.go new file mode 100644 index 000000000000..b4d86af74918 --- /dev/null +++ b/core/http/endpoints/localai/mcp.go @@ -0,0 +1,268 @@ +package localai + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/templates" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/cogito" + "github.com/rs/zerolog/log" +) + +// MCP SSE Event Types +type MCPReasoningEvent struct { + Type string `json:"type"` + Content string `json:"content"` +} + +type MCPToolCallEvent struct { + Type string `json:"type"` + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` + Reasoning string `json:"reasoning"` +} + +type MCPToolResultEvent struct { + Type string `json:"type"` + Name string `json:"name"` + Result string `json:"result"` +} + +type MCPStatusEvent struct { + Type string `json:"type"` + Message string `json:"message"` +} + +type MCPAssistantEvent struct { + Type string `json:"type"` + Content string `json:"content"` +} + +type MCPErrorEvent struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions +// @Summary Stream MCP chat completions with reasoning, tool calls, and results +// @Param request body schema.OpenAIRequest true "query params" +// @Success 200 {object} schema.OpenAIResponse "Response" +// @Router /v1/mcp/chat/completions [post] +func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + + // Handle Correlation + id := c.Request().Header.Get("X-Correlation-ID") + if id == "" { + id = fmt.Sprintf("mcp-%d", time.Now().UnixNano()) + } + + input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return echo.ErrBadRequest + } + + config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) + if !ok || config == nil { + return echo.ErrBadRequest + } + + if config.MCP.Servers == "" && config.MCP.Stdio == "" { + return fmt.Errorf("no MCP servers configured") + } + + // Get MCP config from model config + remote, stdio, err := config.MCP.MCPConfigFromYAML() + if err != nil { + return fmt.Errorf("failed to get MCP config: %w", err) + } + + // Check if we have tools in cache, or we have to have an initial connection + sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio) + if err != nil { + return fmt.Errorf("failed to get MCP sessions: %w", err) + } + + if len(sessions) == 0 { + return fmt.Errorf("no working MCP servers found") + } + + // Set up SSE headers + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().Header().Set("X-Correlation-ID", id) + + // Create channel for streaming events + events := make(chan interface{}) + ended := make(chan error, 1) + + ctxWithCancellation, cancel := context.WithCancel(ctx) + defer cancel() + + // Build fragment from messages + fragment := cogito.NewEmptyFragment() + for _, message := range input.Messages { + fragment = fragment.AddMessage(message.Role, message.StringContent) + } + + port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:] + apiKey := "" + if appConfig.ApiKeys != nil && len(appConfig.ApiKeys) > 0 { + apiKey = appConfig.ApiKeys[0] + } + + // TODO: instead of connecting to the API, we should just wire this internally + // and act like completion.go. + // We can do this as cogito expects an interface and we can create one that + // we satisfy to just call internally ComputeChoices + defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port) + + // Set up callbacks for streaming + statusCallback := func(s string) { + events <- MCPStatusEvent{ + Type: "status", + Message: s, + } + } + + reasoningCallback := func(s string) { + events <- MCPReasoningEvent{ + Type: "reasoning", + Content: s, + } + } + + toolCallCallback := func(t *cogito.ToolChoice) bool { + events <- MCPToolCallEvent{ + Type: "tool_call", + Name: t.Name, + Arguments: t.Arguments, + Reasoning: t.Reasoning, + } + return true + } + + toolCallResultCallback := func(t cogito.ToolStatus) { + events <- MCPToolResultEvent{ + Type: "tool_result", + Name: t.Name, + Result: t.Result, + } + } + + // Build cogito options using the consolidated method + cogitoOpts := config.BuildCogitoOptions( + ctxWithCancellation, + sessions, + statusCallback, + reasoningCallback, + toolCallCallback, + toolCallResultCallback, + ) + + // Execute tools in a goroutine + go func() { + defer close(events) + + f, err := cogito.ExecuteTools( + defaultLLM, fragment, + cogitoOpts..., + ) + if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) { + events <- MCPErrorEvent{ + Type: "error", + Message: fmt.Sprintf("Failed to execute tools: %v", err), + } + ended <- err + return + } + + // Get final response + f, err = defaultLLM.Ask(ctx, f) + if err != nil { + events <- MCPErrorEvent{ + Type: "error", + Message: fmt.Sprintf("Failed to get response: %v", err), + } + ended <- err + return + } + + // Stream final assistant response + content := f.LastMessage().Content + events <- MCPAssistantEvent{ + Type: "assistant", + Content: content, + } + + ended <- nil + }() + + // Stream events to client + LOOP: + for { + select { + case <-ctx.Done(): + // Context was cancelled (client disconnected or request cancelled) + log.Debug().Msgf("Request context cancelled, stopping stream") + cancel() + break LOOP + case event := <-events: + if event == nil { + // Channel closed + break LOOP + } + eventData, err := json.Marshal(event) + if err != nil { + log.Debug().Msgf("Failed to marshal event: %v", err) + continue + } + log.Debug().Msgf("Sending event: %s", string(eventData)) + _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData)) + if err != nil { + log.Debug().Msgf("Sending event failed: %v", err) + cancel() + return err + } + c.Response().Flush() + case err := <-ended: + if err == nil { + // Send done signal + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + break LOOP + } + log.Error().Msgf("Stream ended with error: %v", err) + errorEvent := MCPErrorEvent{ + Type: "error", + Message: err.Error(), + } + errorData, marshalErr := json.Marshal(errorEvent) + if marshalErr != nil { + fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n") + } else { + fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData)) + } + fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") + c.Response().Flush() + return nil + } + } + + log.Debug().Msgf("Stream ended") + return nil + } +} + diff --git a/core/http/endpoints/openai/mcp.go b/core/http/endpoints/openai/mcp.go index a91706f51d10..6bff942c486a 100644 --- a/core/http/endpoints/openai/mcp.go +++ b/core/http/endpoints/openai/mcp.go @@ -95,6 +95,13 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s) }), cogito.WithContext(ctxWithCancellation), + cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool { + log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments) + return true + }), + cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) { + log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments) + }), cogito.WithMCPs(sessions...), cogito.WithIterations(3), // default to 3 iterations cogito.WithMaxAttempts(3), // default to 3 attempts diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index 7b1c003ca021..bf8a7bfb8f16 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -7,6 +7,7 @@ import ( "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" echoswagger "github.com/swaggo/echo-swagger" @@ -18,7 +19,8 @@ func RegisterLocalAIRoutes(router *echo.Echo, ml *model.ModelLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, - opcache *services.OpCache) { + opcache *services.OpCache, + evaluator *templates.Evaluator) { router.GET("/swagger/*", echoswagger.WrapHandler) // default @@ -133,4 +135,23 @@ func RegisterLocalAIRoutes(router *echo.Echo, requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)), requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) })) + // MCP Stream endpoint + if evaluator != nil { + mcpStreamHandler := localai.MCPStreamEndpoint(cl, ml, evaluator, appConfig) + mcpStreamMiddleware := []echo.MiddlewareFunc{ + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if err := requestExtractor.SetOpenAIRequest(c); err != nil { + return err + } + return next(c) + } + }, + } + router.POST("/v1/mcp/chat/completions", mcpStreamHandler, mcpStreamMiddleware...) + router.POST("/mcp/v1/chat/completions", mcpStreamHandler, mcpStreamMiddleware...) + } + } diff --git a/core/http/static/chat.js b/core/http/static/chat.js index 993c956ac91a..634255227380 100644 --- a/core/http/static/chat.js +++ b/core/http/static/chat.js @@ -379,16 +379,14 @@ async function promptGPT(systemPrompt, input) { document.getElementById("fileName").innerHTML = ""; // Choose endpoint based on MCP mode - const endpoint = mcpMode ? "mcp/v1/chat/completions" : "v1/chat/completions"; + const endpoint = mcpMode ? "v1/mcp/chat/completions" : "v1/chat/completions"; const requestBody = { model: model, messages: messages, }; - // Only add stream parameter for regular chat (MCP doesn't support streaming) - if (!mcpMode) { - requestBody.stream = true; - } + // Add stream parameter for both regular chat and MCP (MCP now supports SSE streaming) + requestBody.stream = true; let response; try { @@ -444,64 +442,153 @@ async function promptGPT(systemPrompt, input) { return; } + // Handle streaming response (both regular and MCP mode now use SSE) if (mcpMode) { - // Handle MCP non-streaming response + // Handle MCP SSE streaming with new event types + const reader = response.body + ?.pipeThrough(new TextDecoderStream()) + .getReader(); + + if (!reader) { + Alpine.store("chat").add( + "assistant", + `Error: Failed to decode MCP API response`, + ); + toggleLoader(false); + return; + } + + // Store reader globally so stop button can cancel it + currentReader = reader; + + let buffer = ""; + let assistantContent = ""; + let lastAssistantMessageIndex = -1; + try { - const data = await response.json(); - - // Update token usage if present - if (data.usage) { - Alpine.store("chat").updateTokenUsage(data.usage); - } - - // MCP endpoint returns content in choices[0].message.content (chat completion format) - // Fallback to choices[0].text for backward compatibility (completion format) - const content = data.choices[0]?.message?.content || data.choices[0]?.text || ""; - - if (!content && (!data.choices || data.choices.length === 0)) { - Alpine.store("chat").add( - "assistant", - `Error: Empty response from MCP endpoint`, - ); - toggleLoader(false); - return; + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + buffer += value; + + let lines = buffer.split("\n"); + buffer = lines.pop(); // Retain any incomplete line in the buffer + + lines.forEach((line) => { + if (line.length === 0 || line.startsWith(":")) return; + if (line === "data: [DONE]") { + return; + } + + if (line.startsWith("data: ")) { + try { + const eventData = JSON.parse(line.substring(6)); + + // Handle different event types + switch (eventData.type) { + case "reasoning": + if (eventData.content) { + Alpine.store("chat").add("reasoning", eventData.content); + } + break; + + case "tool_call": + if (eventData.name) { + const toolCallContent = `**Tool:** ${eventData.name}\n\n` + + (eventData.reasoning ? `**Reasoning:** ${eventData.reasoning}\n\n` : '') + + `**Arguments:**\n\`\`\`json\n${JSON.stringify(eventData.arguments, null, 2)}\n\`\`\``; + Alpine.store("chat").add("tool_call", toolCallContent); + } + break; + + case "tool_result": + if (eventData.name) { + const toolResultContent = `**Tool:** ${eventData.name}\n\n` + + `**Result:**\n\`\`\`\n${eventData.result}\n\`\`\``; + Alpine.store("chat").add("tool_result", toolResultContent); + } + break; + + case "status": + // Status messages can be logged but not necessarily displayed + console.log("[MCP Status]", eventData.message); + break; + + case "assistant": + if (eventData.content) { + assistantContent += eventData.content; + // Count tokens for rate calculation + tokensReceived += Math.ceil(eventData.content.length / 4); + updateTokensPerSecond(); + + // Process thinking tags in assistant content + const { regularContent, thinkingContent } = processThinkingTags(assistantContent); + + // Update or create assistant message + if (lastAssistantMessageIndex === -1) { + Alpine.store("chat").add("assistant", regularContent || assistantContent); + lastAssistantMessageIndex = Alpine.store("chat").history.length - 1; + } else { + const chatStore = Alpine.store("chat"); + const lastMessage = chatStore.history[lastAssistantMessageIndex]; + if (lastMessage && lastMessage.role === "assistant") { + lastMessage.content = regularContent || assistantContent; + lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content)); + } + } + + // Add thinking content if present + if (thinkingContent) { + Alpine.store("chat").add("thinking", thinkingContent); + } + } + break; + + case "error": + Alpine.store("chat").add( + "assistant", + `MCP Error: ${eventData.message}`, + ); + break; + } + } catch (error) { + console.error("Failed to parse MCP event:", line, error); + } + } + }); } - - if (content) { - // Count tokens for rate calculation (MCP mode - full content at once) - // Prefer actual token count from API if available - if (data.usage && data.usage.completion_tokens) { - tokensReceived = data.usage.completion_tokens; - } else { - tokensReceived += Math.ceil(content.length / 4); + + // Final assistant content flush if any data remains + if (assistantContent.trim() && lastAssistantMessageIndex !== -1) { + const { regularContent, thinkingContent } = processThinkingTags(assistantContent); + const chatStore = Alpine.store("chat"); + const lastMessage = chatStore.history[lastAssistantMessageIndex]; + if (lastMessage && lastMessage.role === "assistant") { + lastMessage.content = regularContent || assistantContent; + lastMessage.html = DOMPurify.sanitize(marked.parse(lastMessage.content)); } - updateTokensPerSecond(); - - // Process thinking tags using shared function - const { regularContent, thinkingContent } = processThinkingTags(content); - - // Add thinking content if present if (thinkingContent) { Alpine.store("chat").add("thinking", thinkingContent); } - - // Add regular content if present - if (regularContent) { - Alpine.store("chat").add("assistant", regularContent); - } } - - // Highlight all code blocks + + // Highlight all code blocks once at the end hljs.highlightAll(); } catch (error) { // Don't show error if request was aborted by user - if (error.name !== 'AbortError' || currentAbortController) { + if (error.name !== 'AbortError' || !currentAbortController) { Alpine.store("chat").add( "assistant", - `Error: Failed to parse MCP response`, + `Error: Failed to process MCP stream`, ); } } finally { + // Perform any cleanup if necessary + if (reader) { + reader.releaseLock(); + } + currentReader = null; currentAbortController = null; } } else { diff --git a/core/http/views/chat.html b/core/http/views/chat.html index 5aa45e3e30c5..acb98f1bca35 100644 --- a/core/http/views/chat.html +++ b/core/http/views/chat.html @@ -111,8 +111,8 @@ }, add(role, content, image, audio) { const N = this.history.length - 1; - // For thinking messages, always create a new message - if (role === "thinking") { + // For thinking, reasoning, tool_call, and tool_result messages, always create a new message + if (role === "thinking" || role === "reasoning" || role === "tool_call" || role === "tool_result") { let c = ""; const lines = content.split("\n"); lines.forEach((line) => { @@ -527,7 +527,46 @@

-