diff --git a/cmd/mcp.go b/cmd/mcp.go index f74efc2..c3686a8 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -145,18 +145,36 @@ func loadMCPConfig() (*MCPConfig, error) { func createMCPClients( config *MCPConfig, -) (map[string]*mcpclient.StdioMCPClient, error) { - clients := make(map[string]*mcpclient.StdioMCPClient) +) (map[string]mcpclient.MCPClient, error) { + clients := make(map[string]mcpclient.MCPClient) for name, server := range config.MCPServers { var env []string for k, v := range server.Env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } - client, err := mcpclient.NewStdioMCPClient( - server.Command, - env, - server.Args...) + var client mcpclient.MCPClient + var err error + + if server.Command == "sse_server" { + if len(server.Args) == 0 { + return nil, fmt.Errorf( + "no arguments provided for sse command", + ) + } + + client, err = mcpclient.NewSSEMCPClient( + server.Args[0], + ) + if err == nil { + err = client.(*mcpclient.SSEMCPClient).Start(context.Background()) + } + } else { + client, err = mcpclient.NewStdioMCPClient( + server.Command, + env, + server.Args...) + } if err != nil { for _, c := range clients { c.Close() @@ -202,7 +220,7 @@ func createMCPClients( func handleSlashCommand( prompt string, mcpConfig *MCPConfig, - mcpClients map[string]*mcpclient.StdioMCPClient, + mcpClients map[string]mcpclient.MCPClient, messages interface{}, ) (bool, error) { if !strings.HasPrefix(prompt, "/") { @@ -329,7 +347,7 @@ func handleServersCommand(config *MCPConfig) { fmt.Print("\n" + containerStyle.Render(rendered) + "\n") } -func handleToolsCommand(mcpClients map[string]*mcpclient.StdioMCPClient) { +func handleToolsCommand(mcpClients map[string]mcpclient.MCPClient) { // Get terminal width for proper wrapping width := getTerminalWidth() diff --git a/cmd/root.go b/cmd/root.go index e7ceee0..c1936bc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -220,7 +220,7 @@ func updateRenderer() error { // Method implementations for simpleMessage func runPrompt( provider llm.Provider, - mcpClients map[string]*mcpclient.StdioMCPClient, + mcpClients map[string]mcpclient.MCPClient, tools []llm.Tool, prompt string, messages *[]history.HistoryMessage,