diff --git a/cmd/root.go b/cmd/root.go index 8413a94..9566602 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -182,8 +182,9 @@ func init() { flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)") // Ollama-specific parameters - flags.Int32Var(&numGPU, "num-gpu", 1, "number of GPUs to use for Ollama models") - flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU to use for Ollama models") + flags.Int32Var(&numGPU, "num-gpu-layers", -1, "number of model layers to offload to GPU for Ollama models (-1 for auto-detect)") + flags.MarkHidden("num-gpu-layers") // Advanced option, hidden from help + flags.Int32Var(&mainGPU, "main-gpu", 0, "main GPU device to use for Ollama models") // Bind flags to viper for config file support viper.BindPFlag("system-prompt", rootCmd.PersistentFlags().Lookup("system-prompt")) @@ -198,7 +199,7 @@ func init() { viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p")) viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k")) viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences")) - viper.BindPFlag("num-gpu", rootCmd.PersistentFlags().Lookup("num-gpu")) + viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers")) viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu")) // Defaults are already set in flag definitions, no need to duplicate in viper @@ -265,7 +266,7 @@ func runNormalMode(ctx context.Context) error { temperature := float32(viper.GetFloat64("temperature")) topP := float32(viper.GetFloat64("top-p")) topK := int32(viper.GetInt("top-k")) - numGPU := int32(viper.GetInt("num-gpu")) + numGPU := int32(viper.GetInt("num-gpu-layers")) mainGPU := int32(viper.GetInt("main-gpu")) modelConfig := &models.ProviderConfig{ @@ -290,8 +291,27 @@ func runNormalMode(ctx context.Context) error { MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it } - // Create the agent - mcpAgent, err := agent.NewAgent(ctx, agentConfig) + // Create the agent with spinner for Ollama models + var mcpAgent *agent.Agent + + if strings.HasPrefix(viper.GetString("model"), "ollama:") && !quietFlag { + // Create a temporary CLI for the spinner + tempCli, tempErr := ui.NewCLI(viper.GetBool("debug")) + if tempErr == nil { + err = tempCli.ShowSpinner("Loading Ollama model...", func() error { + var agentErr error + mcpAgent, agentErr = agent.NewAgent(ctx, agentConfig) + return agentErr + }) + } else { + // Fallback without spinner + mcpAgent, err = agent.NewAgent(ctx, agentConfig) + } + } else { + // No spinner for other providers + mcpAgent, err = agent.NewAgent(ctx, agentConfig) + } + if err != nil { return fmt.Errorf("failed to create agent: %v", err) } @@ -344,8 +364,13 @@ func runNormalMode(ctx context.Context) error { if len(parts) == 2 { cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", parts[0], parts[1])) } - cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools))) + // Display loading message if available (e.g., GPU fallback info) + if loadingMessage := mcpAgent.GetLoadingMessage(); loadingMessage != "" { + cli.DisplayInfo(loadingMessage) + } + + cli.DisplayInfo(fmt.Sprintf("Loaded %d tools from MCP servers", len(tools))) // Display debug configuration if debug mode is enabled if viper.GetBool("debug") { debugConfig := map[string]any{ @@ -361,7 +386,7 @@ func runNormalMode(ctx context.Context) error { // Add Ollama-specific parameters if using Ollama if strings.HasPrefix(viper.GetString("model"), "ollama:") { - debugConfig["num-gpu"] = viper.GetInt("num-gpu") + debugConfig["num-gpu-layers"] = viper.GetInt("num-gpu-layers") debugConfig["main-gpu"] = viper.GetInt("main-gpu") } diff --git a/go.mod b/go.mod index 7bf70f0..779af6b 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/cloudwego/eino-ext/components/model/claude v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/ollama v0.0.0-20250609074000-b7f307dffa18 github.com/cloudwego/eino-ext/components/model/openai v0.0.0-20250609074000-b7f307dffa18 - github.com/getkin/kin-openapi v0.131.0 github.com/mark3labs/mcp-filesystem-server v0.11.1 github.com/mark3labs/mcp-go v0.32.0 github.com/ollama/ollama v0.5.12 @@ -25,6 +24,8 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require github.com/getkin/kin-openapi v0.118.0 + require ( cloud.google.com/go v0.116.0 // indirect cloud.google.com/go/auth v0.15.0 // indirect @@ -51,6 +52,7 @@ require ( github.com/bytedance/sonic/loader v0.2.4 // indirect github.com/catppuccin/go v0.2.0 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/harmonica v0.2.0 // indirect github.com/charmbracelet/x/cellbuf v0.0.13 // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect github.com/cloudwego/base64x v0.1.5 // indirect @@ -76,6 +78,7 @@ require ( github.com/goph/emperror v0.17.2 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/invopop/yaml v0.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect @@ -87,8 +90,6 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/nikolalohinski/gonja v1.5.3 // indirect - github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect - github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/pkg/errors v0.9.1 // indirect @@ -127,8 +128,8 @@ require ( require ( github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/charmbracelet/bubbles v0.20.0 - github.com/charmbracelet/bubbletea v1.2.4 + github.com/charmbracelet/bubbles v0.21.0 + github.com/charmbracelet/bubbletea v1.3.5 github.com/charmbracelet/glamour v0.10.0 github.com/charmbracelet/x/ansi v0.8.0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect diff --git a/go.sum b/go.sum index 6e8411c..d411e82 100644 --- a/go.sum +++ b/go.sum @@ -73,14 +73,16 @@ github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos github.com/catppuccin/go v0.2.0 h1:ktBeIrIP42b/8FGiScP9sgrWOss3lw0Z5SktRoithGA= github.com/catppuccin/go v0.2.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= -github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE= -github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU= -github.com/charmbracelet/bubbletea v1.2.4 h1:KN8aCViA0eps9SCOThb2/XPIlea3ANJLUkv3KnQRNCE= -github.com/charmbracelet/bubbletea v1.2.4/go.mod h1:Qr6fVQw+wX7JkWWkVyXYk/ZUQ92a6XNekLXa3rR18MM= +github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= +github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/charmbracelet/bubbletea v1.3.5 h1:JAMNLTbqMOhSwoELIr0qyP4VidFq72/6E9j7HHmRKQc= +github.com/charmbracelet/bubbletea v1.3.5/go.mod h1:TkCnmH+aBd4LrXhXcqrKiYwRs7qyQx5rBgH5fVY3v54= github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= github.com/charmbracelet/glamour v0.10.0 h1:MtZvfwsYCx8jEPFJm3rIBFIMZUfUJ765oX8V6kXldcY= github.com/charmbracelet/glamour v0.10.0/go.mod h1:f+uf+I/ChNmqo087elLnVdCiVgjSKWuXa/l6NU2ndYk= +github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ= +github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao= github.com/charmbracelet/huh v0.3.0 h1:CxPplWkgW2yUTDDG0Z4S5HH8SJOosWHd4LxCvi0XsKE= github.com/charmbracelet/huh v0.3.0/go.mod h1:fujUdKX8tC45CCSaRQdw789O6uaCRwx8l2NDyKfC4jA= github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE= @@ -89,8 +91,8 @@ github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2ll github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q= github.com/charmbracelet/x/cellbuf v0.0.13 h1:/KBBKHuVRbq1lYx5BzEHBAFBP8VcQzJejZ/IA3iR28k= github.com/charmbracelet/x/cellbuf v0.0.13/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= -github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b h1:MnAMdlwSltxJyULnrYbkZpp4k58Co7Tah3ciKhSNo0Q= -github.com/charmbracelet/x/exp/golden v0.0.0-20240815200342-61de596daa2b/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= +github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI= github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= @@ -131,8 +133,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/ github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= -github.com/getkin/kin-openapi v0.131.0 h1:NO2UeHnFKRYhZ8wg6Nyh5Cq7dHk4suQQr72a4pMrDxE= -github.com/getkin/kin-openapi v0.131.0/go.mod h1:3OlG51PCYNsPByuiMB0t4fjnNlIDnaEDsjiKUV8nL58= +github.com/getkin/kin-openapi v0.118.0 h1:z43njxPmJ7TaPpMSCQb7PN0dEYno4tyBPQcrFdHoLuM= +github.com/getkin/kin-openapi v0.118.0/go.mod h1:l5e9PaFUo9fyLJCPGQeXI2ML8c3P8BHOEV2VaAVf/pc= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= @@ -141,8 +143,10 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-test/deep v1.0.8 h1:TDsG77qcSprGbC6vTN8OuXp5g+J+b5Pcguhf7Zt61VM= @@ -173,6 +177,7 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -180,6 +185,8 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/invopop/yaml v0.1.0 h1:YW3WGUoJEXYfzWBjn00zIlrw7brGVD0fUKRYDPAPhrc= +github.com/invopop/yaml v0.1.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -202,6 +209,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mark3labs/mcp-filesystem-server v0.11.1 h1:7uKIZRMaKWfgvtDj/uLAvo0+7Mwb8gxo5DJywhqFW88= @@ -240,10 +249,6 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/nikolalohinski/gonja v1.5.3 h1:GsA+EEaZDZPGJ8JtpeGN78jidhOlxeJROpqMT9fTj9c= github.com/nikolalohinski/gonja v1.5.3/go.mod h1:RmjwxNiXAEqcq1HeK5SSMmqFJvKOfTfXhkJv6YBtPa4= -github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//JalHPu/3yz+De2J+4aLtSRlHiY= -github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw= -github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c= -github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= github.com/ollama/ollama v0.5.12 h1:qM+k/ozyHLJzEQoAEPrUQ0qXqsgDEEdpIVwuwScrd2U= github.com/ollama/ollama v0.5.12/go.mod h1:ibdmDvb/TjKY1OArBWIazL3pd1DHTk8eG2MMjEkWhiI= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -251,6 +256,7 @@ github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/perimeterx/marshmallow v1.1.4/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -323,6 +329,9 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go v1.2.7 h1:qYhyWUUd6WbiM+C6JZAUkIJt/1WrjzNHY9+KCIjVqTo= +github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= +github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= @@ -463,6 +472,7 @@ google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -473,6 +483,7 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 567d36a..cfcc076 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -41,16 +41,17 @@ type ToolCallContentHandler func(content string) // Agent is the agent with real-time tool call display. type Agent struct { - toolManager *tools.MCPToolManager - model model.ToolCallingChatModel - maxSteps int - systemPrompt string + toolManager *tools.MCPToolManager + model model.ToolCallingChatModel + maxSteps int + systemPrompt string + loadingMessage string // Message from provider loading (e.g., GPU fallback info) } // NewAgent creates an agent with MCP tool integration and real-time tool call display func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { // Create the LLM provider - model, err := models.CreateProvider(ctx, config.ModelConfig) + providerResult, err := models.CreateProvider(ctx, config.ModelConfig) if err != nil { return nil, fmt.Errorf("failed to create model provider: %v", err) } @@ -62,10 +63,11 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { } return &Agent{ - toolManager: toolManager, - model: model, - maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop - systemPrompt: config.SystemPrompt, + toolManager: toolManager, + model: providerResult.Model, + maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop + systemPrompt: config.SystemPrompt, + loadingMessage: providerResult.Message, }, nil } @@ -220,6 +222,11 @@ func (a *Agent) GetTools() []tool.BaseTool { return a.toolManager.GetTools() } +// GetLoadingMessage returns the loading message from provider creation (e.g., GPU fallback info) +func (a *Agent) GetLoadingMessage() string { + return a.loadingMessage +} + // generateWithCancellation calls the LLM with ESC key cancellation support func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) { // Create a cancellable context for just this LLM call diff --git a/internal/models/providers.go b/internal/models/providers.go index 859e400..dccfb4e 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -9,11 +9,13 @@ import ( "net/http" "os" "strings" + "time" "github.com/cloudwego/eino-ext/components/model/claude" "github.com/cloudwego/eino-ext/components/model/ollama" "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/components/model" + "github.com/mark3labs/mcphost/internal/ui/progress" "github.com/ollama/ollama/api" "google.golang.org/genai" @@ -76,8 +78,14 @@ type ProviderConfig struct { MainGPU *int32 } +// ProviderResult contains the result of provider creation +type ProviderResult struct { + Model model.ToolCallingChatModel + Message string // Optional message for user feedback (e.g., GPU fallback info) +} + // CreateProvider creates an eino ToolCallingChatModel based on the provider configuration -func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCallingChatModel, error) { +func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) { parts := strings.SplitN(config.ModelString, ":", 2) if len(parts) < 2 { return nil, fmt.Errorf("invalid model format. Expected provider:model, got %s", config.ModelString) @@ -119,15 +127,31 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall switch provider { case "anthropic": - return createAnthropicProvider(ctx, config, modelName) + model, err := createAnthropicProvider(ctx, config, modelName) + if err != nil { + return nil, err + } + return &ProviderResult{Model: model, Message: ""}, nil case "openai": - return createOpenAIProvider(ctx, config, modelName) + model, err := createOpenAIProvider(ctx, config, modelName) + if err != nil { + return nil, err + } + return &ProviderResult{Model: model, Message: ""}, nil case "google": - return createGoogleProvider(ctx, config, modelName) + model, err := createGoogleProvider(ctx, config, modelName) + if err != nil { + return nil, err + } + return &ProviderResult{Model: model, Message: ""}, nil case "ollama": - return createOllamaProvider(ctx, config, modelName) + return createOllamaProviderWithResult(ctx, config, modelName) case "azure": - return createAzureOpenAIProvider(ctx, config, modelName) + model, err := createAzureOpenAIProvider(ctx, config, modelName) + if err != nil { + return nil, err + } + return &ProviderResult{Model: model, Message: ""}, nil default: return nil, fmt.Errorf("unsupported provider: %s", provider) } @@ -353,7 +377,175 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName return gemini.NewChatModel(ctx, geminiConfig) } -func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName string) (model.ToolCallingChatModel, error) { +// OllamaLoadingResult contains the result of model loading with actual settings used +type OllamaLoadingResult struct { + Options *api.Options + Message string +} + +// loadOllamaModelWithFallback loads an Ollama model with GPU settings and automatic CPU fallback +func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string, options *api.Options) (*OllamaLoadingResult, error) { + client := &http.Client{} + + // Phase 1: Check if model exists locally + if err := checkOllamaModelExists(client, baseURL, modelName); err != nil { + // Phase 2: Pull model if not found + if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil { + return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err) + } + } + + // Phase 3: Load model with GPU settings + _, err := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, options) + if err != nil { + // Phase 4: Fallback to CPU if GPU memory insufficient + if isGPUMemoryError(err) { + cpuOptions := *options + cpuOptions.NumGPU = 0 + + _, cpuErr := loadOllamaModelWithOptions(ctx, client, baseURL, modelName, &cpuOptions) + if cpuErr != nil { + return nil, fmt.Errorf("failed to load model on GPU (%v) and CPU fallback failed (%v)", err, cpuErr) + } + + return &OllamaLoadingResult{ + Options: &cpuOptions, + Message: "Insufficient GPU memory, falling back to CPU inference", + }, nil + } + return nil, err + } + + return &OllamaLoadingResult{ + Options: options, + Message: "Model loaded successfully on GPU", + }, nil +} + +// checkOllamaModelExists checks if a model exists locally +func checkOllamaModelExists(client *http.Client, baseURL, modelName string) error { + reqBody := map[string]string{"model": modelName} + jsonBody, _ := json.Marshal(reqBody) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/show", bytes.NewBuffer(jsonBody)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("model not found locally") + } + + return nil +} + +// pullOllamaModel pulls a model from the registry +func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error { + return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true) +} + +// pullOllamaModelWithProgress pulls a model from the registry with optional progress display +func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error { + reqBody := map[string]string{"name": modelName} + jsonBody, _ := json.Marshal(reqBody) + + // Use a longer timeout for pulling models (5 minutes) + pullCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + req, err := http.NewRequestWithContext(pullCtx, "POST", baseURL+"/api/pull", bytes.NewBuffer(jsonBody)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body)) + } + + // Read the streaming response with optional progress display + if showProgress { + progressReader := progress.NewProgressReader(resp.Body) + defer progressReader.Close() + _, err = io.ReadAll(progressReader) + } else { + _, err = io.ReadAll(resp.Body) + } + return err +} + +// loadOllamaModelWithOptions loads a model with specific options using a warmup request +func loadOllamaModelWithOptions(ctx context.Context, client *http.Client, baseURL, modelName string, options *api.Options) (*api.Options, error) { + // Create a copy of options for warmup to avoid modifying the original + warmupOptions := *options + warmupOptions.NumPredict = 1 // Limit response length for warmup + + reqBody := map[string]interface{}{ + "model": modelName, + "prompt": "Hello", + "stream": false, + "options": &warmupOptions, + } + + jsonBody, _ := json.Marshal(reqBody) + + // Use medium timeout for warmup (30 seconds) + warmupCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(warmupCtx, "POST", baseURL+"/api/generate", bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("warmup request failed (status %d): %s", resp.StatusCode, string(body)) + } + + // Read response to completion + _, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return options, nil +} + +// isGPUMemoryError checks if an error indicates insufficient GPU memory +func isGPUMemoryError(err error) bool { + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "out of memory") || + strings.Contains(errStr, "insufficient memory") || + strings.Contains(errStr, "cuda out of memory") || + strings.Contains(errStr, "gpu memory") +} + +func createOllamaProviderWithResult(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) { baseURL := "http://localhost:11434" // Default Ollama URL // Check for custom Ollama host from environment @@ -366,11 +558,6 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName baseURL = config.ProviderURL } - ollamaConfig := &ollama.ChatModelConfig{ - BaseURL: baseURL, - Model: modelName, - } - // Set up options for Ollama using the api.Options struct options := &api.Options{} @@ -403,9 +590,40 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName options.MainGPU = int(*config.MainGPU) } - ollamaConfig.Options = options + // Create a clean copy of options for the final model + finalOptions := &api.Options{} + *finalOptions = *options // Copy all fields + + // Try to pre-load the model with GPU settings and automatic CPU fallback + // If this fails, fall back to the original behavior + loadingResult, err := loadOllamaModelWithFallback(ctx, baseURL, modelName, options) + var loadingMessage string + + if err != nil { + // Pre-loading failed, use original options and no message + loadingMessage = "" + } else { + // Pre-loading succeeded, update GPU settings that worked + finalOptions.NumGPU = loadingResult.Options.NumGPU + finalOptions.MainGPU = loadingResult.Options.MainGPU + loadingMessage = loadingResult.Message + } + + ollamaConfig := &ollama.ChatModelConfig{ + BaseURL: baseURL, + Model: modelName, + Options: finalOptions, + } + + chatModel, err := ollama.NewChatModel(ctx, ollamaConfig) + if err != nil { + return nil, err + } - return ollama.NewChatModel(ctx, ollamaConfig) + return &ProviderResult{ + Model: chatModel, + Message: loadingMessage, + }, nil } // createOAuthHTTPClient creates an HTTP client that adds OAuth headers for Anthropic API diff --git a/internal/ui/progress/ollama.go b/internal/ui/progress/ollama.go new file mode 100644 index 0000000..7df4ca6 --- /dev/null +++ b/internal/ui/progress/ollama.go @@ -0,0 +1,267 @@ +package progress + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/charmbracelet/bubbles/progress" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +var helpStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#626262")).Render + +const ( + padding = 2 + maxWidth = 80 +) + +// OllamaPullProgress represents the progress information from Ollama pull API +type OllamaPullProgress struct { + Status string `json:"status"` + Digest string `json:"digest,omitempty"` + Total int64 `json:"total,omitempty"` + Completed int64 `json:"completed,omitempty"` +} + +// progressMsg represents progress updates +type progressMsg struct { + percent float64 + status string +} + +// progressErrMsg represents errors during progress +type progressErrMsg struct{ err error } + +// progressCompleteMsg indicates completion +type progressCompleteMsg struct{} + +// ProgressModel represents the progress bar model +type ProgressModel struct { + progress progress.Model + status string + err error + complete bool +} + +// NewProgressModel creates a new progress model +func NewProgressModel() ProgressModel { + return ProgressModel{ + progress: progress.New(progress.WithDefaultGradient()), + status: "Initializing...", + } +} + +// Init initializes the progress model +func (m ProgressModel) Init() tea.Cmd { + return nil +} + +// Update handles progress updates +func (m ProgressModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.String() == "q" || msg.String() == "ctrl+c" { + return m, tea.Quit + } + return m, nil + + case tea.WindowSizeMsg: + m.progress.Width = msg.Width - padding*2 - 4 + if m.progress.Width > maxWidth { + m.progress.Width = maxWidth + } + return m, nil + + case progressErrMsg: + m.err = msg.err + return m, tea.Quit + + case progressCompleteMsg: + m.complete = true + return m, tea.Quit + + case progressMsg: + var cmds []tea.Cmd + m.status = msg.status + + if msg.percent >= 1.0 { + m.complete = true + cmds = append(cmds, tea.Quit) + } + + cmds = append(cmds, m.progress.SetPercent(msg.percent)) + return m, tea.Batch(cmds...) + + case progress.FrameMsg: + progressModel, cmd := m.progress.Update(msg) + m.progress = progressModel.(progress.Model) + return m, cmd + + default: + return m, nil + } +} + +// View renders the progress bar +func (m ProgressModel) View() string { + if m.err != nil { + return fmt.Sprintf("Error: %s\n", m.err.Error()) + } + + if m.complete { + return fmt.Sprintf("\n%s%s\n\n%sComplete!\n", + strings.Repeat(" ", padding), + m.progress.View(), + strings.Repeat(" ", padding)) + } + + pad := strings.Repeat(" ", padding) + return fmt.Sprintf("\n%s%s\n%s%s\n\n%s", + pad, m.progress.View(), + pad, m.status, + pad+helpStyle("Press 'q' or Ctrl+C to cancel")) +} + +// ProgressReader wraps an io.Reader to provide progress updates for Ollama pull operations +type ProgressReader struct { + reader io.Reader + program *tea.Program + model ProgressModel + lastLine string + done chan struct{} + wg sync.WaitGroup +} + +// NewProgressReader creates a new progress reader for Ollama pull operations +func NewProgressReader(reader io.Reader) *ProgressReader { + model := NewProgressModel() + // Create program with standard settings + program := tea.NewProgram(model) + + pr := &ProgressReader{ + reader: reader, + program: program, + model: model, + done: make(chan struct{}), + } + + // Start the TUI in a goroutine + pr.wg.Add(1) + go func() { + defer pr.wg.Done() + if _, err := program.Run(); err != nil { + // Handle error silently for now + } + close(pr.done) + }() + + return pr +} + +// Read implements io.Reader and parses Ollama streaming responses +func (pr *ProgressReader) Read(p []byte) (n int, err error) { + n, err = pr.reader.Read(p) + if n > 0 { + // Parse the JSON lines for progress information + data := string(p[:n]) + pr.lastLine += data + + // Process complete lines + for { + lineEnd := strings.Index(pr.lastLine, "\n") + if lineEnd == -1 { + break + } + + line := strings.TrimSpace(pr.lastLine[:lineEnd]) + pr.lastLine = pr.lastLine[lineEnd+1:] + + if line != "" { + pr.parseProgressLine(line) + } + } + } + + if err == io.EOF { + // Send completion message and ensure program quits + pr.program.Send(progressCompleteMsg{}) + } + + return n, err +} + +// parseProgressLine parses a single JSON line from Ollama pull response +func (pr *ProgressReader) parseProgressLine(line string) { + var progress OllamaPullProgress + if err := json.Unmarshal([]byte(line), &progress); err != nil { + return // Ignore malformed JSON + } + + var percent float64 + status := progress.Status + + // Calculate progress percentage if we have total and completed + if progress.Total > 0 && progress.Completed >= 0 { + percent = float64(progress.Completed) / float64(progress.Total) + + // Format status with progress info + if progress.Digest != "" { + status = fmt.Sprintf("%s (%s)", progress.Status, progress.Digest[:12]) + } + + // Add size information + if progress.Total > 0 { + totalMB := float64(progress.Total) / (1024 * 1024) + completedMB := float64(progress.Completed) / (1024 * 1024) + status = fmt.Sprintf("%s - %.1f/%.1f MB", status, completedMB, totalMB) + } + } else { + // For status-only updates (like "pulling manifest"), show indeterminate progress + if strings.Contains(strings.ToLower(progress.Status), "pulling") || + strings.Contains(strings.ToLower(progress.Status), "downloading") { + // Keep current progress or show small progress for activity + percent = 0.1 + } else if strings.Contains(strings.ToLower(progress.Status), "success") || + strings.Contains(strings.ToLower(progress.Status), "complete") { + percent = 1.0 + } + } + + pr.program.Send(progressMsg{ + percent: percent, + status: status, + }) +} + +// Close stops the progress display and waits for cleanup +func (pr *ProgressReader) Close() error { + // Send completion message to trigger quit + pr.program.Send(progressCompleteMsg{}) + + // Wait for the program to finish with timeout + done := make(chan struct{}) + go func() { + pr.wg.Wait() + close(done) + }() + + // Wait for completion or timeout after 2 seconds + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + select { + case <-done: + // Program finished normally + case <-ctx.Done(): + // Timeout - force kill the program + pr.program.Kill() + } + + return nil +} diff --git a/internal/ui/spinner.go b/internal/ui/spinner.go index 71cc511..3b34ea5 100644 --- a/internal/ui/spinner.go +++ b/internal/ui/spinner.go @@ -64,7 +64,7 @@ func (m spinnerModel) View() string { Foreground(theme.Text). Italic(true) - return fmt.Sprintf("%s %s", + return fmt.Sprintf(" %s %s", spinnerStyle.Render(m.spinner.View()), messageStyle.Render(m.message)) }