diff --git a/README.md b/README.md index 0f7ab2f..0e81af9 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,9 @@ This architecture allows language models to: Currently supports: - Claude 3.5 Sonnet (claude-3-5-sonnet-20240620) - Any Ollama-compatible model with function calling support +- Google Gemini models - Any OpenAI-compatible local or online model with function calling support - ## Features ✨ - Interactive conversations with support models @@ -35,6 +35,7 @@ Currently supports: - Go 1.23 or later - For Claude: An Anthropic API key - For Ollama: Local Ollama installation with desired models +- For Google/Gemini: Google API key (see https://aistudio.google.com/app/apikey) - One or more MCP-compatible tool servers ## Environment Setup 🔧 @@ -55,9 +56,13 @@ ollama pull mistral ollama serve ``` -3. OpenAI compatible online Setup -- Get your api server base url, api key and model name +3. Google API Key (for Gemini): +```bash +export GOOGLE_API_KEY='your-api-key' +``` +4. OpenAI compatible online Setup +- Get your api server base url, api key and model name ## Installation 📦 @@ -107,6 +112,7 @@ Models can be specified using the `--model` (`-m`) flag: - Anthropic Claude (default): `anthropic:claude-3-5-sonnet-latest` - OpenAI or OpenAI-compatible: `openai:gpt-4` - Ollama models: `ollama:modelname` +- Google: `google:gemini-2.0-flash` ### Examples ```bash @@ -131,6 +137,7 @@ mcphost --model openai: \ - `-m, --model string`: Model to use (format: provider:model) (default "anthropic:claude-3-5-sonnet-latest") - `--openai-url string`: Base URL for OpenAI API (defaults to api.openai.com) - `--openai-api-key string`: OpenAI API key (can also be set via OPENAI_API_KEY environment variable) +- `--google-api-key string`: Google API key (can also be set via GOOGLE_API_KEY environment variable) ### Interactive Commands diff --git a/cmd/root.go b/cmd/root.go index e7ceee0..55db13b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,6 +3,7 @@ package cmd import ( "context" "encoding/json" + "errors" "fmt" "os" "strings" @@ -19,6 +20,7 @@ import ( "github.com/mark3labs/mcphost/pkg/history" "github.com/mark3labs/mcphost/pkg/llm" "github.com/mark3labs/mcphost/pkg/llm/anthropic" + "github.com/mark3labs/mcphost/pkg/llm/google" "github.com/mark3labs/mcphost/pkg/llm/ollama" "github.com/mark3labs/mcphost/pkg/llm/openai" "github.com/spf13/cobra" @@ -34,6 +36,7 @@ var ( anthropicBaseURL string // Base URL for Anthropic API openaiAPIKey string anthropicAPIKey string + googleAPIKey string ) const ( @@ -53,12 +56,14 @@ Available models can be specified using the --model flag: - Anthropic Claude (default): anthropic:claude-3-5-sonnet-latest - OpenAI: openai:gpt-4 - Ollama models: ollama:modelname +- Google: google:modelname Example: mcphost -m ollama:qwen2.5:3b - mcphost -m openai:gpt-4`, + mcphost -m openai:gpt-4 + mcphost -m google:gemini-2.0-flash`, RunE: func(cmd *cobra.Command, args []string) error { - return runMCPHost() + return runMCPHost(context.Background()) }, } @@ -88,10 +93,11 @@ func init() { flags.StringVar(&anthropicBaseURL, "anthropic-url", "", "base URL for Anthropic API (defaults to api.anthropic.com)") flags.StringVar(&openaiAPIKey, "openai-api-key", "", "OpenAI API key") flags.StringVar(&anthropicAPIKey, "anthropic-api-key", "", "Anthropic API key") + flags.StringVar(&googleAPIKey, "google-api-key", "", "Google (Gemini) API key") } // Add new function to create provider -func createProvider(modelString string) (llm.Provider, error) { +func createProvider(ctx context.Context, modelString string) (llm.Provider, error) { parts := strings.SplitN(modelString, ":", 2) if len(parts) < 2 { return nil, fmt.Errorf( @@ -133,6 +139,17 @@ func createProvider(modelString string) (llm.Provider, error) { } return openai.NewProvider(apiKey, openaiBaseURL, model), nil + case "google": + apiKey := googleAPIKey + if apiKey == "" { + apiKey = os.Getenv("GOOGLE_API_KEY") + } + if apiKey == "" { + // The project structure is provider specific, but Google calls this GEMINI_API_KEY in e.g. AI Studio. Support both. + apiKey = os.Getenv("GEMINI_API_KEY") + } + return google.NewProvider(ctx, apiKey, model) + default: return nil, fmt.Errorf("unsupported provider: %s", provider) } @@ -219,6 +236,7 @@ func updateRenderer() error { // Method implementations for simpleMessage func runPrompt( + ctx context.Context, provider llm.Provider, mcpClients map[string]*mcpclient.StdioMCPClient, tools []llm.Tool, @@ -255,7 +273,7 @@ func runPrompt( for { action := func() { message, err = provider.CreateMessage( - context.Background(), + ctx, prompt, llmMessages, tools, @@ -294,7 +312,7 @@ func runPrompt( var messageContent []history.ContentBlock // Handle the message response - if str, err := renderer.Render("\nAssistant: "); err == nil { + if str, err := renderer.Render("\nAssistant: "); message.GetContent() != "" && err == nil { fmt.Print(str) } @@ -440,14 +458,14 @@ func runPrompt( Content: toolResults, }) // Make another call to get Claude's response to the tool results - return runPrompt(provider, mcpClients, tools, "", messages) + return runPrompt(ctx, provider, mcpClients, tools, "", messages) } fmt.Println() // Add spacing return nil } -func runMCPHost() error { +func runMCPHost(ctx context.Context) error { // Set up logging based on debug flag if debugMode { log.SetLevel(log.DebugLevel) @@ -459,7 +477,7 @@ func runMCPHost() error { } // Create the provider based on the model flag - provider, err := createProvider(modelFlag) + provider, err := createProvider(ctx, modelFlag) if err != nil { return fmt.Errorf("error creating provider: %v", err) } @@ -497,7 +515,7 @@ func runMCPHost() error { var allTools []llm.Tool for serverName, mcpClient := range mcpClients { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) cancel() @@ -531,28 +549,23 @@ func runMCPHost() error { // Main interaction loop for { - width := getTerminalWidth() var prompt string - form := huh.NewForm( - huh.NewGroup( - huh.NewText(). - Key("prompt"). - Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)"). - Value(&prompt), - ), - ).WithWidth(width).WithTheme(huh.ThemeCharm()) - - err := form.Run() + err := huh.NewForm(huh.NewGroup(huh.NewText(). + Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)"). + Value(&prompt)), + ).WithWidth(getTerminalWidth()). + WithTheme(huh.ThemeCharm()). + Run() + if err != nil { // Check if it's a user abort (Ctrl+C) - if err.Error() == "user aborted" { + if errors.Is(err, huh.ErrUserAborted) { fmt.Println("\nGoodbye!") return nil // Exit cleanly } return err // Return other errors normally } - prompt = form.GetString("prompt") if prompt == "" { continue } @@ -574,7 +587,7 @@ func runMCPHost() error { if len(messages) > 0 { messages = pruneMessages(messages) } - err = runPrompt(provider, mcpClients, allTools, prompt, &messages) + err = runPrompt(ctx, provider, mcpClients, allTools, prompt, &messages) if err != nil { return err } diff --git a/go.mod b/go.mod index b8e16e0..d66bb7b 100644 --- a/go.mod +++ b/go.mod @@ -1,30 +1,58 @@ module github.com/mark3labs/mcphost -go 1.23 +go 1.23.0 require ( github.com/charmbracelet/huh v0.3.0 github.com/charmbracelet/huh/spinner v0.0.0-20241127125741-aad810dfbce6 github.com/charmbracelet/lipgloss v1.0.0 github.com/charmbracelet/log v0.4.0 + github.com/google/generative-ai-go v0.19.0 github.com/mark3labs/mcp-go v0.8.2 github.com/ollama/ollama v0.5.1 github.com/spf13/cobra v1.8.1 - golang.org/x/term v0.22.0 + golang.org/x/term v0.30.0 + google.golang.org/api v0.228.0 ) require ( + cloud.google.com/go v0.115.0 // indirect + cloud.google.com/go/ai v0.8.0 // indirect + cloud.google.com/go/auth v0.15.0 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.6.0 // indirect + cloud.google.com/go/longrunning v0.5.7 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/catppuccin/go v0.2.0 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect + github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/yuin/goldmark v1.7.4 // indirect github.com/yuin/goldmark-emoji v1.0.3 // indirect - golang.org/x/net v0.27.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect + go.opentelemetry.io/otel v1.34.0 // indirect + go.opentelemetry.io/otel/metric v1.34.0 // indirect + go.opentelemetry.io/otel/trace v1.34.0 // indirect + golang.org/x/crypto v0.36.0 // indirect + golang.org/x/net v0.37.0 // indirect + golang.org/x/oauth2 v0.28.0 // indirect + golang.org/x/time v0.11.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 // indirect + google.golang.org/grpc v1.71.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect ) require ( @@ -47,7 +75,7 @@ require ( github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect - golang.org/x/sync v0.9.0 // indirect - golang.org/x/sys v0.27.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/sync v0.12.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.23.0 // indirect ) diff --git a/go.sum b/go.sum index 83291ba..5263c05 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,15 @@ +cloud.google.com/go v0.115.0 h1:CnFSK6Xo3lDYRoBKEcAtia6VSC837/ZkJuRduSFnr14= +cloud.google.com/go v0.115.0/go.mod h1:8jIM5vVgoAEoiVxQ/O4BFTfHqulPZgs/ufEzMcFMdWU= +cloud.google.com/go/ai v0.8.0 h1:rXUEz8Wp2OlrM8r1bfmpF2+VKqc1VJpafE3HgzRnD/w= +cloud.google.com/go/ai v0.8.0/go.mod h1:t3Dfk4cM61sytiggo2UyGsDVW3RF1qGZaUKDrZFyqkE= +cloud.google.com/go/auth v0.15.0 h1:Ly0u4aA5vG/fsSsxu98qCQBemXtAtJf+95z9HK+cxps= +cloud.google.com/go/auth v0.15.0/go.mod h1:WJDGqZ1o9E9wKIL+IwStfyn/+s59zl4Bi+1KQNVXLZ8= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= +cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= +cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ= github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= @@ -43,12 +55,29 @@ github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxK github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/generative-ai-go v0.19.0 h1:R71szggh8wHMCUlEMsW2A/3T+5LdEIkiaHSYgSpUgdg= +github.com/google/generative-ai-go v0.19.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= +github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/gax-go/v2 v2.14.1 h1:hb0FFeiPaQskmvakKu5EbCbpntQn48jyHuvrkurSS/Q= +github.com/googleapis/gax-go/v2 v2.14.1/go.mod h1:Hb/NubMaVM88SrNkvl8X/o8XWwDJEPqouaLeN2IUxoA= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -89,27 +118,59 @@ github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark-emoji v1.0.3 h1:aLRkLHOuBR2czCY4R8olwMjID+tENfhyFDMCRhbIQY4= github.com/yuin/goldmark-emoji v1.0.3/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0/go.mod h1:ijPqXp5P6IRRByFVVg9DY8P5HkxkHE5ARIa+86aXPf4= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 h1:CV7UdSGJt/Ao6Gp4CXckLxVRRsRgDHoI8XjbL3PDl8s= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0/go.mod h1:FRmFuRJfag1IZ2dPkHnEoSFVgTVPUd2qf5Vi69hLb8I= +go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY= +go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= +go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ= +go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= +go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A= +go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= +go.opentelemetry.io/otel/sdk/metric v1.34.0 h1:5CeK9ujjbFVL5c1PhLuStg1wxA7vQv7ce1EK0Gyvahk= +go.opentelemetry.io/otel/sdk/metric v1.34.0/go.mod h1:jQ/r8Ze28zRKoNRdkjCZxfs6YvBTG1+YIqyFVFYec5w= +go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k= +go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= +golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +google.golang.org/api v0.228.0 h1:X2DJ/uoWGnY5obVjewbp8icSL5U4FzuCfy9OjbLSnLs= +google.golang.org/api v0.228.0/go.mod h1:wNvRS1Pbe8r4+IfBIniV8fwCpGwTrYa+kMUDiC5z5a4= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= +google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4 h1:iK2jbkWL86DXjEx0qiHcRE9dE4/Ahua5k6V8OWFb//c= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250313205543-e70fdf4c4cb4/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= +google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/llm/google/provider.go b/pkg/llm/google/provider.go new file mode 100644 index 0000000..df874dd --- /dev/null +++ b/pkg/llm/google/provider.go @@ -0,0 +1,178 @@ +package google + +import ( + "context" + "fmt" + "strings" + + "github.com/google/generative-ai-go/genai" + "github.com/mark3labs/mcphost/pkg/history" + "github.com/mark3labs/mcphost/pkg/llm" + "google.golang.org/api/option" +) + +type Provider struct { + client *genai.Client + model *genai.GenerativeModel + chat *genai.ChatSession + + toolCallID int +} + +func NewProvider(ctx context.Context, apiKey string, model string) (*Provider, error) { + client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + if err != nil { + return nil, err + } + m := client.GenerativeModel(model) + return &Provider{ + client: client, + model: m, + chat: m.StartChat(), + }, nil +} + +func (p *Provider) CreateMessage(ctx context.Context, prompt string, messages []llm.Message, tools []llm.Tool) (llm.Message, error) { + var hist []*genai.Content + for _, msg := range messages { + for _, call := range msg.GetToolCalls() { + hist = append(hist, &genai.Content{ + Role: msg.GetRole(), + Parts: []genai.Part{ + genai.FunctionCall{ + Name: call.GetName(), + Args: call.GetArguments(), + }, + }, + }) + } + + if msg.IsToolResponse() { + if historyMsg, ok := msg.(*history.HistoryMessage); ok { + for _, block := range historyMsg.Content { + if block.Type == "tool_result" { + hist = append(hist, &genai.Content{ + Role: msg.GetRole(), + Parts: []genai.Part{genai.Text(block.Text)}, + }) + } + } + } + } + + if text := strings.TrimSpace(msg.GetContent()); text != "" { + hist = append(hist, &genai.Content{ + Role: msg.GetRole(), + Parts: []genai.Part{genai.Text(text)}, + }) + } + } + + p.model.Tools = nil + for _, tool := range tools { + p.model.Tools = append(p.model.Tools, &genai.Tool{ + FunctionDeclarations: []*genai.FunctionDeclaration{ + { + Name: tool.Name, + Description: tool.Description, + Parameters: translateToGoogleSchema(tool.InputSchema), + }, + }, + }) + } + + p.chat.History = hist + // The provided messages slice (and thus history) already includes the new prompt, + // so we just call SendMessage with an empty string that will be trimmed by the server. + resp, err := p.chat.SendMessage(ctx, genai.Text("")) + if err != nil { + return nil, err + } + + if len(resp.Candidates) == 0 { + return nil, fmt.Errorf("no response from model") + } + + // The library enforces a generation config with 1 candidate. + m := &Message{ + Candidate: resp.Candidates[0], + toolCallID: p.toolCallID, + } + + p.toolCallID += len(m.Candidate.FunctionCalls()) + return m, nil +} + +func (p *Provider) CreateToolResponse(toolCallID string, content any) (llm.Message, error) { + // UNUSED: Nothing in root.go calls this. + return nil, nil +} + +func (p *Provider) SupportsTools() bool { + // UNUSED: Nothing in root.go calls this. + return true +} + +func (p *Provider) Name() string { + return "Google" +} + +func translateToGoogleSchema(schema llm.Schema) *genai.Schema { + s := &genai.Schema{ + Type: toType(schema.Type), + Required: schema.Required, + Properties: make(map[string]*genai.Schema), + } + + for name, prop := range schema.Properties { + s.Properties[name] = propertyToGoogleSchema(prop.(map[string]any)) + } + + if len(s.Properties) == 0 { + // Functions that don't take any arguments have an object-type schema with 0 properties. + // Google/Gemini does not like that: Error 400: * GenerateContentRequest properties: should be non-empty for OBJECT type. + // To work around this issue, we'll just inject some unused, nullable property with a primitive type. + s.Nullable = true + s.Properties["unused"] = &genai.Schema{ + Type: genai.TypeInteger, + Nullable: true, + } + } + return s +} + +func propertyToGoogleSchema(properties map[string]any) *genai.Schema { + s := &genai.Schema{Type: toType(properties["type"].(string))} + if desc, ok := properties["description"].(string); ok { + s.Description = desc + } + + // Objects and arrays need to have their properties recursively mapped. + if s.Type == genai.TypeObject { + objectProperties := properties["properties"].(map[string]any) + s.Properties = make(map[string]*genai.Schema) + for name, prop := range objectProperties { + s.Properties[name] = propertyToGoogleSchema(prop.(map[string]any)) + } + } else if s.Type == genai.TypeArray { + itemProperties := properties["items"].(map[string]any) + s.Items = propertyToGoogleSchema(itemProperties) + } + + return s +} + +func toType(typ string) genai.Type { + switch typ { + case "string": + return genai.TypeString + case "boolean": + return genai.TypeBoolean + case "object": + return genai.TypeObject + case "array": + return genai.TypeArray + default: + return genai.TypeUnspecified + } +} diff --git a/pkg/llm/google/types.go b/pkg/llm/google/types.go new file mode 100644 index 0000000..8bc9156 --- /dev/null +++ b/pkg/llm/google/types.go @@ -0,0 +1,72 @@ +package google + +import ( + "fmt" + "strings" + + "github.com/google/generative-ai-go/genai" + "github.com/mark3labs/mcphost/pkg/llm" +) + +type ToolCall struct { + genai.FunctionCall + + toolCallID int +} + +func (t *ToolCall) GetName() string { + return t.Name +} + +func (t *ToolCall) GetArguments() map[string]any { + return t.Args +} + +func (t *ToolCall) GetID() string { + return fmt.Sprintf("Tool<%d>", t.toolCallID) +} + +type Message struct { + *genai.Candidate + + toolCallID int +} + +func (m *Message) GetRole() string { + return m.Candidate.Content.Role +} + +func (m *Message) GetContent() string { + var sb strings.Builder + for _, part := range m.Candidate.Content.Parts { + if text, ok := part.(genai.Text); ok { + sb.WriteString(string(text)) + } + } + return sb.String() +} + +func (m *Message) GetToolCalls() []llm.ToolCall { + var calls []llm.ToolCall + for i, call := range m.Candidate.FunctionCalls() { + calls = append(calls, &ToolCall{call, m.toolCallID + i}) + } + return calls +} + +func (m *Message) IsToolResponse() bool { + for _, part := range m.Candidate.Content.Parts { + if _, ok := part.(*genai.FunctionResponse); ok { + return true + } + } + return false +} + +func (m *Message) GetToolResponseID() string { + return fmt.Sprintf("Tool<%d>", m.toolCallID) +} + +func (m *Message) GetUsage() (input int, output int) { + return 0, 0 +}