diff --git a/Taskfile.yml b/Taskfile.yml index 4a6d1b41..129a8af8 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -234,7 +234,6 @@ tasks: - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain WebFetchService - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain AgentService - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain ModelService - - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain ConfigService - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain ThemeService - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain ConversationRepository - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain SDKClient diff --git a/cmd/agent.go b/cmd/agent.go index 86217753..5eeb29a6 100644 --- a/cmd/agent.go +++ b/cmd/agent.go @@ -52,11 +52,7 @@ Examples: noSave, _ := cmd.Flags().GetBool("no-save") sessionID, _ := cmd.Flags().GetString("session-id") requireApproval, _ := cmd.Flags().GetBool("require-approval") - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - return RunAgentCommand(cfg, model, args[0], files, noSave, sessionID, requireApproval) + return RunAgentCommand(Cfg, model, args[0], files, noSave, sessionID, requireApproval) }, } @@ -96,7 +92,7 @@ type AgentSession struct { } func RunAgentCommand(cfg *config.Config, modelFlag, taskDescription string, files []string, noSave bool, sessionID string, requireApproval bool) error { - svc := container.NewServiceContainer(cfg, V) + svc := container.NewServiceContainer(cfg) defer func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/cmd/agents.go b/cmd/agents.go index ccd1fe8a..9c97f15f 100644 --- a/cmd/agents.go +++ b/cmd/agents.go @@ -8,7 +8,6 @@ import ( "strings" config "github.com/inference-gateway/cli/config" - services "github.com/inference-gateway/cli/internal/services" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" cobra "github.com/spf13/cobra" ) @@ -218,21 +217,17 @@ var agentsDisableCmd = &cobra.Command{ }, } -func getAgentsConfigService(cmd *cobra.Command) (*services.AgentsConfigService, error) { - userspace := GetUserspaceFlag(cmd) - - var agentsPath string - if userspace { +// agentsConfigPath returns the agents.yaml path for the current command, +// honouring --userspace. +func agentsConfigPath(cmd *cobra.Command) (string, error) { + if GetUserspaceFlag(cmd) { homeDir, err := os.UserHomeDir() if err != nil { - return nil, fmt.Errorf("failed to get user home directory: %w", err) + return "", fmt.Errorf("failed to get user home directory: %w", err) } - agentsPath = filepath.Join(homeDir, config.ConfigDirName, config.AgentsFileName) - } else { - agentsPath = config.DefaultAgentsPath + return filepath.Join(homeDir, config.ConfigDirName, config.AgentsFileName), nil } - - return services.NewAgentsConfigService(agentsPath), nil + return config.DefaultAgentsPath, nil } // ExternalAgent represents an agent configured via INFER_A2A_AGENTS @@ -241,15 +236,6 @@ type ExternalAgent struct { URL string } -// getConfig loads the configuration from viper -func getConfig(_ *cobra.Command) (*config.Config, error) { - cfg, err := getConfigFromViper() - if err != nil { - return nil, fmt.Errorf("failed to load config: %w", err) - } - return cfg, nil -} - // extractExternalAgents extracts agent names and URLs from INFER_A2A_AGENTS func extractExternalAgents(cfg *config.Config) []ExternalAgent { if len(cfg.A2A.Agents) == 0 { @@ -288,7 +274,7 @@ func addAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bool, return fmt.Errorf("--model is required when --run is enabled. Specify a model in the format provider/model (e.g., openai/gpt-5, anthropic/claude-4-5-sonnet)") } - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } @@ -303,7 +289,7 @@ func addAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bool, Environment: environment, } - if err := svc.AddAgent(agent); err != nil { + if err := config.AddAgent(path, agent); err != nil { return err } @@ -329,12 +315,12 @@ func addAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bool, } func updateAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bool, model string, environment map[string]string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - existing, err := svc.GetAgent(name) + existing, err := config.GetAgent(path, name) if err != nil { return err } @@ -363,7 +349,7 @@ func updateAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bo return fmt.Errorf("--model is required when --run is enabled. Specify a model in the format provider/model (e.g., openai/gpt-5, anthropic/claude-4-5-sonnet)") } - if err := svc.UpdateAgent(agent); err != nil { + if err := config.UpdateAgent(path, agent); err != nil { return err } @@ -389,12 +375,12 @@ func updateAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bo } func removeAgent(cmd *cobra.Command, name string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - if err := svc.RemoveAgent(name); err != nil { + if err := config.RemoveAgent(path, name); err != nil { return err } @@ -403,22 +389,17 @@ func removeAgent(cmd *cobra.Command, name string) error { } func listAgents(cmd *cobra.Command, args []string) error { - svc, err := getAgentsConfigService(cmd) - if err != nil { - return err - } - - localAgents, err := svc.ListAgents() + path, err := agentsConfigPath(cmd) if err != nil { return err } - cfg, err := getConfig(cmd) + localAgents, err := config.ListAgents(path) if err != nil { return err } - externalAgents := extractExternalAgents(cfg) + externalAgents := extractExternalAgents(Cfg) totalAgents := len(localAgents) + len(externalAgents) @@ -504,12 +485,12 @@ func listAgents(cmd *cobra.Command, args []string) error { } func showAgent(cmd *cobra.Command, name string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - agent, err := svc.GetAgent(name) + agent, err := config.GetAgent(path, name) if err != nil { return err } @@ -573,13 +554,12 @@ func showAgent(cmd *cobra.Command, name string) error { } func initAgents(cmd *cobra.Command, args []string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - cfg := config.DefaultAgentsConfig() - if err := svc.Save(cfg); err != nil { + if err := config.SaveAgents(path, config.DefaultAgentsConfig()); err != nil { return err } @@ -594,18 +574,18 @@ func initAgents(cmd *cobra.Command, args []string) error { } func enableAgent(cmd *cobra.Command, name string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - agent, err := svc.GetAgent(name) + agent, err := config.GetAgent(path, name) if err != nil { return fmt.Errorf("failed to find agent: %w", err) } agent.Enabled = true - if err := svc.UpdateAgent(*agent); err != nil { + if err := config.UpdateAgent(path, *agent); err != nil { return fmt.Errorf("failed to enable agent: %w", err) } @@ -617,18 +597,18 @@ func enableAgent(cmd *cobra.Command, name string) error { } func disableAgent(cmd *cobra.Command, name string) error { - svc, err := getAgentsConfigService(cmd) + path, err := agentsConfigPath(cmd) if err != nil { return err } - agent, err := svc.GetAgent(name) + agent, err := config.GetAgent(path, name) if err != nil { return fmt.Errorf("failed to find agent: %w", err) } agent.Enabled = false - if err := svc.UpdateAgent(*agent); err != nil { + if err := config.UpdateAgent(path, *agent); err != nil { return fmt.Errorf("failed to disable agent: %w", err) } diff --git a/cmd/channels.go b/cmd/channels.go index 11a630e0..794025ad 100644 --- a/cmd/channels.go +++ b/cmd/channels.go @@ -40,11 +40,7 @@ Examples: INFER_CHANNELS_TELEGRAM_ALLOWED_USERS="123456789" \ infer channels-manager`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - return RunChannelsCommand(cfg) + return RunChannelsCommand(Cfg) }, } diff --git a/cmd/chat.go b/cmd/chat.go index ee579ef2..f7f8fff4 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "os/signal" - "path/filepath" "strings" "sync" "syscall" @@ -15,7 +14,8 @@ import ( tea "github.com/charmbracelet/bubbletea" uuid "github.com/google/uuid" cobra "github.com/spf13/cobra" - viper "github.com/spf13/viper" + + sdk "github.com/inference-gateway/sdk" config "github.com/inference-gateway/cli/config" tools "github.com/inference-gateway/cli/internal/agent/tools" @@ -26,7 +26,6 @@ import ( logger "github.com/inference-gateway/cli/internal/logger" screenshotsvc "github.com/inference-gateway/cli/internal/services" web "github.com/inference-gateway/cli/internal/web" - sdk "github.com/inference-gateway/sdk" ) var chatCmd = &cobra.Command{ @@ -35,10 +34,7 @@ var chatCmd = &cobra.Command{ Long: `Start an interactive chat session where you can select a model from a dropdown and have a conversational interface with the inference gateway.`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } + cfg := Cfg if os.Getenv("INFER_WEB_MODE") == "true" { cfg.Web.Enabled = true @@ -79,24 +75,24 @@ and have a conversational interface with the inference gateway.`, } } - return StartWebChatSession(cfg, V) + return StartWebChatSession(cfg) } if !isInteractiveTerminal() { - return runNonInteractiveChat(cfg, V) + return runNonInteractiveChat(cfg) } - return StartChatSession(cfg, V) + return StartChatSession(cfg) }, } // StartChatSession starts a chat session // //nolint:funlen // Chat session initialization requires multiple setup steps -func StartChatSession(cfg *config.Config, v *viper.Viper) error { +func StartChatSession(cfg *config.Config) error { _ = clipboard.Init() - services := container.NewServiceContainer(cfg, v) + services := container.NewServiceContainer(cfg) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP) @@ -150,8 +146,6 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { agentService := services.GetAgentService() conversationRepo := services.GetConversationRepository() modelService := services.GetModelService() - config := services.GetConfig() - configService := services.GetConfigService() toolService := services.GetToolService() fileService := services.GetFileService() imageService := services.GetImageService() @@ -170,8 +164,8 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { var screenshotServer *screenshotsvc.ScreenshotServer - if config.ComputerUse.Enabled && config.ComputerUse.Screenshot.StreamingEnabled { - screenshotServer = startScreenshotServer(config, imageService, toolRegistry) + if cfg.ComputerUse.Enabled && cfg.ComputerUse.Screenshot.StreamingEnabled { + screenshotServer = startScreenshotServer(cfg, imageService, toolRegistry) if screenshotServer != nil { defer func() { if err := screenshotServer.Stop(); err != nil { @@ -181,7 +175,7 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { } } - floatingWindowMgr, err := initFloatingWindow(config, stateManager, agentService) + floatingWindowMgr, err := initFloatingWindow(cfg, stateManager, agentService) if err != nil { return fmt.Errorf("failed to initialize floating window: %w", err) } @@ -193,31 +187,29 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { }() } - versionInfo := GetVersionInfo() application := app.NewChatApplication( + cfg, models, defaultModel, + GetVersionInfo(), + agentManager, agentService, - conversationRepo, + backgroundTaskService, conversationOptimizer, - sessionRolloverManager, - modelService, - configService, - toolService, + conversationRepo, fileService, imageService, + mcpManager, + messageQueue, + modelService, pricingService, - shortcutRegistry, + sessionRolloverManager, stateManager, - messageQueue, + taskRetentionService, themeService, + toolService, + shortcutRegistry, toolRegistry, - mcpManager, - taskRetentionService, - backgroundTaskService, - agentManager, - getEffectiveConfigPath(), - versionInfo, ) program := tea.NewProgram( @@ -244,8 +236,8 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { } // StartWebChatSession starts a web-based chat session with PTY and WebSocket -func StartWebChatSession(cfg *config.Config, v *viper.Viper) error { - server := web.NewWebTerminalServer(cfg, v) +func StartWebChatSession(cfg *config.Config) error { + server := web.NewWebTerminalServer(cfg) return server.Start() } @@ -272,31 +264,6 @@ func validateAndSetDefaultModel(modelService domain.ModelService, models []strin return defaultModel } -// getEffectiveConfigPath returns the actual config file path that should be displayed -// It follows Viper's search order and returns the first existing config file -func getEffectiveConfigPath() string { - searchPaths := []string{ - ".infer/config.yaml", - } - - if homeDir, err := os.UserHomeDir(); err == nil { - homePath := filepath.Join(homeDir, ".infer", "config.yaml") - searchPaths = append(searchPaths, homePath) - } - - for _, path := range searchPaths { - if _, err := os.Stat(path); err == nil { - return path - } - } - - if configFile := V.ConfigFileUsed(); configFile != "" { - return configFile - } - - return ".infer/config.yaml" -} - // isInteractiveTerminal checks if we're running in an interactive terminal func isInteractiveTerminal() bool { if fileInfo, _ := os.Stdin.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 { @@ -309,8 +276,8 @@ func isInteractiveTerminal() bool { } // runNonInteractiveChat handles non-interactive chat mode (stdin/stdout) -func runNonInteractiveChat(cfg *config.Config, v *viper.Viper) error { - services := container.NewServiceContainer(cfg, v) +func runNonInteractiveChat(cfg *config.Config) error { + services := container.NewServiceContainer(cfg) defer func() { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() diff --git a/cmd/claude_code.go b/cmd/claude_code.go index 12fd8e9b..2a81c3eb 100644 --- a/cmd/claude_code.go +++ b/cmd/claude_code.go @@ -38,11 +38,7 @@ authentication process. Example: infer claude-code setup`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } - claudePath := cfg.ClaudeCode.CLIPath + claudePath := Cfg.ClaudeCode.CLIPath if _, err := exec.LookPath(claudePath); err != nil { return fmt.Errorf("claude Code CLI not found at '%s'.\n\nMake sure Claude Code CLI is installed and in your PATH,\nor set custom path in .infer/config.yaml:\n claude_code:\n cli_path: /path/to/claude", claudePath) @@ -80,11 +76,7 @@ var claudeCodeTestCmd = &cobra.Command{ Short: "Test Claude Code CLI integration", Long: `Test if Claude Code CLI is properly configured and authenticated.`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } - claudePath := cfg.ClaudeCode.CLIPath + claudePath := Cfg.ClaudeCode.CLIPath if _, err := exec.LookPath(claudePath); err != nil { fmt.Printf("✗ Claude Code CLI not found at '%s'\n", claudePath) diff --git a/cmd/config.go b/cmd/config.go index 343b9eb9..639ef611 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -168,11 +168,7 @@ var configToolsValidateCmd = &cobra.Command{ Long: `Check if a specific command would be allowed to execute without actually running it.`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - return ValidateTool(cfg, args[0]) + return ValidateTool(Cfg, args[0]) }, } @@ -203,12 +199,8 @@ Examples: This uses the exact same argument parsing as LLMs, ensuring consistency.`, Args: cobra.MinimumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } format, _ := cmd.Flags().GetString("format") - return ExecTool(cfg, args, format) + return ExecTool(Cfg, args, format) }, } @@ -630,8 +622,11 @@ func getKeybindingsConfigWritePath(userspace bool) (string, error) { return config.DefaultKeybindingsPath, nil } -// getConfigFromViper creates a config object from current Viper settings -func getConfigFromViper() (*config.Config, error) { +// loadConfigFromViper assembles the in-memory Config by unmarshalling +// viper, then layering on the per-file YAML overlays (mcp, keybindings, +// prompts) and finally honouring INFER_* env overrides. It runs once at +// startup (initConfig); commands afterwards read the cached cmd.Cfg. +func loadConfigFromViper() (*config.Config, error) { cfg := &config.Config{} if err := V.Unmarshal(cfg); err != nil { return nil, fmt.Errorf("failed to unmarshal config from Viper: %w", err) @@ -640,29 +635,25 @@ func getConfigFromViper() (*config.Config, error) { resolveViperEnvironmentVariables(cfg, "") mcpConfigPath := getEffectiveMCPConfigPath() - mcpConfigService := services.NewMCPConfigService(mcpConfigPath) - mcpConfig, err := mcpConfigService.Load() + mcpConfig, err := config.LoadMCP(mcpConfigPath) if err != nil { logger.Warn("Failed to load MCP config, using defaults", "error", err, "path", mcpConfigPath) mcpConfig = config.DefaultMCPConfig() } - cfg.MCP = *mcpConfig kbPath := getEffectiveKeybindingsConfigPath() - kbService := services.NewKeybindingsConfigService(kbPath) - kbConfig, err := kbService.Load() + kbConfig, err := config.LoadKeybindings(kbPath) if err != nil { logger.Warn("Failed to load keybindings config, using defaults", "error", err, "path", kbPath) - kbConfig = services.DefaultKeybindingsConfig() + kbConfig = config.DefaultKeybindingsConfig() } cfg.Chat.Keybindings = *kbConfig applyKeybindingEnvOverrides(cfg) promptsPath := getEffectivePromptsConfigPath() - promptsService := services.NewPromptsConfigService(promptsPath) - prompts, err := promptsService.Load() + prompts, err := config.LoadPrompts(promptsPath) if err != nil { logger.Warn("Failed to load prompts config, using defaults", "error", err, "path", promptsPath) prompts = config.DefaultPromptsConfig() @@ -1117,7 +1108,7 @@ func ValidateTool(cfg *config.Config, command string) error { return nil } - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(cfg) toolService := services.GetToolService() toolArgs := map[string]any{ "command": command, @@ -1140,7 +1131,7 @@ func ExecTool(cfg *config.Config, args []string, format string) error { return fmt.Errorf("tools are not enabled") } - serviceContainer := container.NewServiceContainer(cfg, V) + serviceContainer := container.NewServiceContainer(cfg) toolService := serviceContainer.GetToolService() toolRegistry := serviceContainer.GetToolRegistry() @@ -1352,10 +1343,7 @@ func getConfigPath() string { } func listSandboxDirectories(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } + cfg := Cfg if len(cfg.Tools.Sandbox.Directories) == 0 { fmt.Println("No sandbox directories are currently configured.") @@ -1445,10 +1433,7 @@ func disableFetch(cmd *cobra.Command, args []string) error { } func listFetchDomains(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } + cfg := Cfg format, _ := cmd.Flags().GetString("format") @@ -1566,10 +1551,7 @@ func removeFetchDomain(cmd *cobra.Command, args []string) error { } func fetchCacheStatus(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } + cfg := Cfg fmt.Printf("Cache Status: ") if cfg.Tools.WebFetch.Cache.Enabled { @@ -1701,10 +1683,7 @@ func setGrepBackend(cmd *cobra.Command, args []string) error { } func grepStatus(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return err - } + cfg := Cfg fmt.Printf("Grep Tool Status: ") if cfg.Tools.Grep.Enabled { diff --git a/cmd/conversation_title.go b/cmd/conversation_title.go index e5e85344..e5667319 100644 --- a/cmd/conversation_title.go +++ b/cmd/conversation_title.go @@ -26,12 +26,7 @@ var generateTitlesCmd = &cobra.Command{ Long: `Generate AI-powered titles for conversations that either don't have generated titles or have invalidated titles due to being resumed or modified.`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(Cfg) backgroundJobManager := services.GetBackgroundJobManager() if backgroundJobManager == nil { @@ -60,12 +55,7 @@ var statusTitlesCmd = &cobra.Command{ Short: "Show conversation title generation status", Long: `Show the status of conversation title generation including configuration and pending conversations.`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(Cfg) storage := services.GetStorage() backgroundJobManager := services.GetBackgroundJobManager() @@ -75,9 +65,9 @@ var statusTitlesCmd = &cobra.Command{ fmt.Printf("📝 Conversation Title Generation Status\n\n") fmt.Printf("Configuration:\n") - fmt.Printf(" Enabled: %v\n", cfg.Conversation.TitleGeneration.Enabled) - fmt.Printf(" Model: %s\n", cfg.Conversation.TitleGeneration.Model) - fmt.Printf(" Batch Size: %d\n", cfg.Conversation.TitleGeneration.BatchSize) + fmt.Printf(" Enabled: %v\n", Cfg.Conversation.TitleGeneration.Enabled) + fmt.Printf(" Model: %s\n", Cfg.Conversation.TitleGeneration.Model) + fmt.Printf(" Batch Size: %d\n", Cfg.Conversation.TitleGeneration.BatchSize) fmt.Printf(" Background Jobs Running: %v\n", backgroundJobManager != nil && backgroundJobManager.IsRunning()) if storage != nil { @@ -114,12 +104,7 @@ var daemonCmd = &cobra.Command{ Short: "Run conversation title generation daemon", Long: `Run the background job manager as a daemon to continuously generate titles for conversations.`, RunE: func(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(Cfg) backgroundJobManager := services.GetBackgroundJobManager() if backgroundJobManager == nil { diff --git a/cmd/conversations.go b/cmd/conversations.go index a92b46e9..448e303b 100644 --- a/cmd/conversations.go +++ b/cmd/conversations.go @@ -6,10 +6,11 @@ import ( "fmt" "strings" + cobra "github.com/spf13/cobra" + container "github.com/inference-gateway/cli/internal/container" formatting "github.com/inference-gateway/cli/internal/formatting" storage "github.com/inference-gateway/cli/internal/infra/storage" - cobra "github.com/spf13/cobra" ) var conversationsCmd = &cobra.Command{ @@ -45,47 +46,33 @@ Examples: } func init() { - // Register subcommands conversationsCmd.AddCommand(conversationsListCmd) - // Add flags to list command conversationsListCmd.Flags().IntP("limit", "l", 50, "Maximum number of conversations to display") conversationsListCmd.Flags().Int("offset", 0, "Number of conversations to skip (for pagination)") conversationsListCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") - // Register parent command with root rootCmd.AddCommand(conversationsCmd) } func listConversations(cmd *cobra.Command, args []string) error { - // Get config - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - // Create service container - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(Cfg) - // Get storage store := services.GetStorage() if store == nil { return fmt.Errorf("storage is not configured") } - // Parse flags limit, _ := cmd.Flags().GetInt("limit") offset, _ := cmd.Flags().GetInt("offset") format, _ := cmd.Flags().GetString("format") - // Fetch conversations ctx := context.Background() conversations, err := store.ListConversations(ctx, limit, offset) if err != nil { return fmt.Errorf("failed to list conversations: %w", err) } - // Render output if format == "json" { return renderConversationsJSON(conversations) } @@ -122,11 +109,9 @@ func renderConversationsTable(conversations []storage.ConversationSummary, limit var md strings.Builder fmt.Fprintf(&md, "**SAVED CONVERSATIONS:** %d total\n\n", len(conversations)) - // Table header md.WriteString("| ID | Summary | Messages | Requests | Input Tokens | Output Tokens | Cost |\n") md.WriteString("|--------------------------------------|--------------------------|----------|----------|--------------|---------------|---------|" + "\n") - // Table rows for _, conv := range conversations { id := conv.ID summary := formatting.TruncateText(conv.Title, 25) @@ -140,7 +125,6 @@ func renderConversationsTable(conversations []storage.ConversationSummary, limit id, summary, messages, requests, inputTokens, outputTokens, cost) } - // Footer with pagination info if len(conversations) >= limit { fmt.Fprintf(&md, "\nShowing %d-%d conversations (use --limit and --offset for pagination)\n", offset+1, offset+len(conversations)) @@ -149,10 +133,8 @@ func renderConversationsTable(conversations []storage.ConversationSummary, limit offset+1, offset+len(conversations)) } - // Render markdown with glamour rendered, err := renderMarkdown(md.String()) if err != nil { - // Fallback to plain text if glamour fails fmt.Print(md.String()) return nil } diff --git a/cmd/export.go b/cmd/export.go index 9454ed6a..b80b204b 100644 --- a/cmd/export.go +++ b/cmd/export.go @@ -36,10 +36,7 @@ func init() { } func runExport(sessionID string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } + cfg := Cfg storageConfig := storage.NewStorageFromConfig(cfg) storageBackend, err := storage.NewStorage(storageConfig) @@ -47,8 +44,7 @@ func runExport(sessionID string) error { return fmt.Errorf("failed to initialize storage: %w", err) } - configService := services.NewConfigService(V, cfg) - toolRegistry := tools.NewRegistry(configService, nil, nil, nil, nil, nil, nil) + toolRegistry := tools.NewRegistry(cfg, nil, nil, nil, nil, nil, nil) themeService := domain.NewThemeProvider() styleProvider := styles.NewProvider(themeService) toolFormatterService := services.NewToolFormatterService(toolRegistry, styleProvider) diff --git a/cmd/init.go b/cmd/init.go index 36b85875..539ac6a4 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -10,7 +10,6 @@ import ( yaml "gopkg.in/yaml.v3" config "github.com/inference-gateway/cli/config" - services "github.com/inference-gateway/cli/internal/services" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) @@ -415,8 +414,7 @@ func createKeybindingsConfigFile(path string) error { return fmt.Errorf("failed to create config directory: %w", err) } - service := services.NewKeybindingsConfigService(path) - return service.Save(services.DefaultKeybindingsConfig()) + return config.SaveKeybindings(path, config.DefaultKeybindingsConfig()) } // createPromptsConfigFile writes a fresh prompts.yaml seeded from the @@ -427,8 +425,7 @@ func createPromptsConfigFile(path string) error { return fmt.Errorf("failed to create config directory: %w", err) } - service := services.NewPromptsConfigService(path) - return service.Save(config.DefaultPromptsConfig()) + return config.SavePrompts(path, config.DefaultPromptsConfig()) } // createMCPConfigFile creates the MCP configuration YAML file diff --git a/cmd/keybindings.go b/cmd/keybindings.go index b049a8bc..3e72a425 100644 --- a/cmd/keybindings.go +++ b/cmd/keybindings.go @@ -8,7 +8,6 @@ import ( cobra "github.com/spf13/cobra" config "github.com/inference-gateway/cli/config" - services "github.com/inference-gateway/cli/internal/services" keybinding "github.com/inference-gateway/cli/internal/ui/keybinding" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) @@ -82,12 +81,7 @@ func init() { } func listKeybindings(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } - - registry := keybinding.NewRegistry(cfg) + registry := keybinding.NewRegistry(Cfg) actions := registry.ListAllActions() if len(actions) == 0 { @@ -132,8 +126,7 @@ func resetKeybindings(cmd *cobra.Command, args []string) error { return err } - service := services.NewKeybindingsConfigService(path) - if err := service.Save(services.DefaultKeybindingsConfig()); err != nil { + if err := config.SaveKeybindings(path, config.DefaultKeybindingsConfig()); err != nil { return fmt.Errorf("failed to save keybindings: %w", err) } @@ -144,10 +137,7 @@ func resetKeybindings(cmd *cobra.Command, args []string) error { } func validateKeybindings(cmd *cobra.Command, args []string) error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to load config: %w", err) - } + cfg := Cfg if !cfg.Chat.Keybindings.Enabled { fmt.Println("Note: Keybindings are currently disabled in config") @@ -284,7 +274,7 @@ func setKeybinding(cmd *cobra.Command, args []string) error { return fmt.Errorf("unknown action '%s'. Run 'infer keybindings list' to see available actions", actionID) } - path, kbConfig, service, err := loadKeybindingsForWrite(cmd) + path, kbConfig, err := loadKeybindingsForWrite(cmd) if err != nil { return err } @@ -293,7 +283,7 @@ func setKeybinding(cmd *cobra.Command, args []string) error { entry.Keys = keys kbConfig.Bindings[actionID] = entry - if err := service.Save(kbConfig); err != nil { + if err := config.SaveKeybindings(path, kbConfig); err != nil { return fmt.Errorf("failed to save keybindings: %w", err) } @@ -314,7 +304,7 @@ func enableKeybinding(cmd *cobra.Command, args []string) error { return fmt.Errorf("unknown action '%s'. Run 'infer keybindings list' to see available actions", actionID) } - path, kbConfig, service, err := loadKeybindingsForWrite(cmd) + path, kbConfig, err := loadKeybindingsForWrite(cmd) if err != nil { return err } @@ -324,7 +314,7 @@ func enableKeybinding(cmd *cobra.Command, args []string) error { entry.Enabled = &enabled kbConfig.Bindings[actionID] = entry - if err := service.Save(kbConfig); err != nil { + if err := config.SaveKeybindings(path, kbConfig); err != nil { return fmt.Errorf("failed to save keybindings: %w", err) } @@ -344,7 +334,7 @@ func disableKeybinding(cmd *cobra.Command, args []string) error { return fmt.Errorf("unknown action '%s'. Run 'infer keybindings list' to see available actions", actionID) } - path, kbConfig, service, err := loadKeybindingsForWrite(cmd) + path, kbConfig, err := loadKeybindingsForWrite(cmd) if err != nil { return err } @@ -354,7 +344,7 @@ func disableKeybinding(cmd *cobra.Command, args []string) error { entry.Enabled = &disabled kbConfig.Bindings[actionID] = entry - if err := service.Save(kbConfig); err != nil { + if err := config.SaveKeybindings(path, kbConfig); err != nil { return fmt.Errorf("failed to save keybindings: %w", err) } @@ -369,23 +359,22 @@ func disableKeybinding(cmd *cobra.Command, args []string) error { // loadKeybindingsForWrite resolves the destination keybindings.yaml path // (honouring --userspace), loads the existing config (or defaults if the // file is absent), and returns everything callers need to mutate-and-save. -func loadKeybindingsForWrite(cmd *cobra.Command) (string, *config.KeybindingsConfig, *services.KeybindingsConfigService, error) { +func loadKeybindingsForWrite(cmd *cobra.Command) (string, *config.KeybindingsConfig, error) { path, err := getKeybindingsConfigWritePath(GetUserspaceFlag(cmd)) if err != nil { - return "", nil, nil, err + return "", nil, err } - service := services.NewKeybindingsConfigService(path) - kbConfig, err := service.Load() + kbConfig, err := config.LoadKeybindings(path) if err != nil { - return "", nil, nil, fmt.Errorf("failed to load keybindings: %w", err) + return "", nil, fmt.Errorf("failed to load keybindings: %w", err) } if kbConfig.Bindings == nil { kbConfig.Bindings = make(map[string]config.KeyBindingEntry) } - return path, kbConfig, service, nil + return path, kbConfig, nil } // getValidActionIDs returns all valid action IDs by creating a temporary registry diff --git a/cmd/mcp.go b/cmd/mcp.go index 78189259..5ee0ace7 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -7,7 +7,6 @@ import ( glamour "github.com/charmbracelet/glamour" config "github.com/inference-gateway/cli/config" - services "github.com/inference-gateway/cli/internal/services" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" cobra "github.com/spf13/cobra" ) @@ -136,9 +135,8 @@ func getMCPConfigPath(cmd *cobra.Command) string { func listMCPServers(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - cfg, err := mcpConfigService.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { return fmt.Errorf("failed to load MCP config: %w", err) } @@ -299,8 +297,7 @@ func addMCPServer(cmd *cobra.Command, args []string) error { basePort := 3000 configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - existingConfig, _ := mcpConfigService.Load() + existingConfig, _ := config.LoadMCP(configPath) for _, existing := range existingConfig.Servers { if existing.Port > basePort { @@ -316,9 +313,8 @@ func addMCPServer(cmd *cobra.Command, args []string) error { } configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - if err := mcpConfigService.AddServer(server); err != nil { + if err := config.AddMCPServer(configPath, server); err != nil { return fmt.Errorf("failed to add MCP server: %w", err) } @@ -344,9 +340,8 @@ func removeMCPServer(cmd *cobra.Command, args []string) error { name := args[0] configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - if err := mcpConfigService.RemoveServer(name); err != nil { + if err := config.RemoveMCPServer(configPath, name); err != nil { return fmt.Errorf("failed to remove MCP server: %w", err) } @@ -362,9 +357,8 @@ func updateMCPServer(cmd *cobra.Command, args []string) error { name := args[0] configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - existing, err := mcpConfigService.GetServer(name) + existing, err := config.GetMCPServer(configPath, name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } @@ -405,7 +399,7 @@ func updateMCPServer(cmd *cobra.Command, args []string) error { existing.Enabled = enabled } - if err := mcpConfigService.UpdateServer(*existing); err != nil { + if err := config.UpdateMCPServer(configPath, *existing); err != nil { return fmt.Errorf("failed to update MCP server: %w", err) } @@ -421,16 +415,15 @@ func enableMCPServer(cmd *cobra.Command, args []string) error { name := args[0] configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - server, err := mcpConfigService.GetServer(name) + server, err := config.GetMCPServer(configPath, name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } server.Enabled = true - if err := mcpConfigService.UpdateServer(*server); err != nil { + if err := config.UpdateMCPServer(configPath, *server); err != nil { return fmt.Errorf("failed to enable MCP server: %w", err) } @@ -447,16 +440,15 @@ func disableMCPServer(cmd *cobra.Command, args []string) error { name := args[0] configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - server, err := mcpConfigService.GetServer(name) + server, err := config.GetMCPServer(configPath, name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } server.Enabled = false - if err := mcpConfigService.UpdateServer(*server); err != nil { + if err := config.UpdateMCPServer(configPath, *server); err != nil { return fmt.Errorf("failed to disable MCP server: %w", err) } @@ -471,16 +463,15 @@ func disableMCPServer(cmd *cobra.Command, args []string) error { func enableMCPGlobal(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - cfg, err := mcpConfigService.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { return fmt.Errorf("failed to load MCP config: %w", err) } cfg.Enabled = true - if err := mcpConfigService.Save(cfg); err != nil { + if err := config.SaveMCP(configPath, cfg); err != nil { return fmt.Errorf("failed to enable MCP globally: %w", err) } @@ -494,16 +485,15 @@ func enableMCPGlobal(cmd *cobra.Command, args []string) error { func disableMCPGlobal(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - mcpConfigService := services.NewMCPConfigService(configPath) - cfg, err := mcpConfigService.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { return fmt.Errorf("failed to load MCP config: %w", err) } cfg.Enabled = false - if err := mcpConfigService.Save(cfg); err != nil { + if err := config.SaveMCP(configPath, cfg); err != nil { return fmt.Errorf("failed to disable MCP globally: %w", err) } diff --git a/cmd/migrate.go b/cmd/migrate.go index c4d00f0e..af50dd82 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -39,12 +39,7 @@ func init() { // runMigrations executes pending database migrations func runMigrations() error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to get config: %w", err) - } - - serviceContainer := container.NewServiceContainer(cfg, V) + serviceContainer := container.NewServiceContainer(Cfg) conversationStorage := serviceContainer.GetStorage() @@ -79,12 +74,7 @@ func runMigrations() error { // showMigrationStatus displays the current migration status func showMigrationStatus() error { - cfg, err := getConfigFromViper() - if err != nil { - return fmt.Errorf("failed to get config: %w", err) - } - - serviceContainer := container.NewServiceContainer(cfg, V) + serviceContainer := container.NewServiceContainer(Cfg) conversationStorage := serviceContainer.GetStorage() diff --git a/cmd/prompts_overlay_test.go b/cmd/prompts_overlay_test.go index c5dc2d6d..20a3fcfb 100644 --- a/cmd/prompts_overlay_test.go +++ b/cmd/prompts_overlay_test.go @@ -8,18 +8,16 @@ import ( require "github.com/stretchr/testify/require" config "github.com/inference-gateway/cli/config" - services "github.com/inference-gateway/cli/internal/services" ) -// TestGetConfigFromViper_PromptsDefaultsWhenFileAbsent confirms the +// TestLoadConfigFromViper_PromptsDefaultsWhenFileAbsent confirms the // overlay falls back to in-code defaults when no prompts.yaml exists, // so freshly-cloned repos still get a working agent prompt. -func TestGetConfigFromViper_PromptsDefaultsWhenFileAbsent(t *testing.T) { +func TestLoadConfigFromViper_PromptsDefaultsWhenFileAbsent(t *testing.T) { withHermeticEnv(t) initConfig() - cfg, err := getConfigFromViper() - require.NoError(t, err) + cfg := Cfg defaults := config.DefaultPromptsConfig() require.Equal(t, defaults.Agent.SystemPrompt, cfg.Agent.SystemPrompt) @@ -31,12 +29,12 @@ func TestGetConfigFromViper_PromptsDefaultsWhenFileAbsent(t *testing.T) { require.Equal(t, defaults.Init.Prompt, cfg.Init.Prompt) } -// TestGetConfigFromViper_PromptsPartialFileFallsBackForUnsetFields +// TestLoadConfigFromViper_PromptsPartialFileFallsBackForUnsetFields // guards the partial-overlay rule: if a user blanks out (or never sets) // a single prompt key, the others must still resolve to defaults instead // of becoming empty strings. Empty prompts at runtime would cause the // LLM to receive no system instructions. -func TestGetConfigFromViper_PromptsPartialFileFallsBackForUnsetFields(t *testing.T) { +func TestLoadConfigFromViper_PromptsPartialFileFallsBackForUnsetFields(t *testing.T) { withHermeticEnv(t) homeDir := os.Getenv("HOME") @@ -46,11 +44,10 @@ func TestGetConfigFromViper_PromptsPartialFileFallsBackForUnsetFields(t *testing SystemPrompt: "USER OVERRIDE: only this is set", }, } - require.NoError(t, services.NewPromptsConfigService(promptsPath).Save(custom)) + require.NoError(t, config.SavePrompts(promptsPath, custom)) initConfig() - cfg, err := getConfigFromViper() - require.NoError(t, err) + cfg := Cfg defaults := config.DefaultPromptsConfig() require.Equal(t, "USER OVERRIDE: only this is set", cfg.Agent.SystemPrompt) @@ -59,23 +56,22 @@ func TestGetConfigFromViper_PromptsPartialFileFallsBackForUnsetFields(t *testing require.Equal(t, defaults.Init.Prompt, cfg.Init.Prompt, "unset init prompt should fall back to default") } -// TestGetConfigFromViper_PromptsEnvOverridesFile pins the precedence +// TestLoadConfigFromViper_PromptsEnvOverridesFile pins the precedence // order: env > file > in-code defaults. Without this guarantee, ops // teams cannot inject a prompt at deploy time without editing the file // inside the container image. -func TestGetConfigFromViper_PromptsEnvOverridesFile(t *testing.T) { +func TestLoadConfigFromViper_PromptsEnvOverridesFile(t *testing.T) { withHermeticEnv(t) homeDir := os.Getenv("HOME") promptsPath := filepath.Join(homeDir, config.ConfigDirName, config.PromptsFileName) - require.NoError(t, services.NewPromptsConfigService(promptsPath).Save(&config.PromptsConfig{ + require.NoError(t, config.SavePrompts(promptsPath, &config.PromptsConfig{ Agent: config.PromptsAgentConfig{SystemPrompt: "from-file"}, })) t.Setenv("INFER_PROMPTS_AGENT_SYSTEM_PROMPT", "from-env") initConfig() - cfg, err := getConfigFromViper() - require.NoError(t, err) + cfg := Cfg require.Equal(t, "from-env", cfg.Agent.SystemPrompt) } diff --git a/cmd/root.go b/cmd/root.go index ba6f633d..5ed19815 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,8 +13,13 @@ import ( logger "github.com/inference-gateway/cli/internal/logger" ) -// Global Viper instance for commands to use -var V *viper.Viper +// Global Viper instance and resolved Config used by every command. Both +// are populated exactly once by initConfig() at startup. Commands read +// Cfg directly instead of re-unmarshalling viper. +var ( + V *viper.Viper + Cfg *config.Config +) var rootCmd = &cobra.Command{ Use: "infer", @@ -118,6 +123,13 @@ func initConfig() { } } + cfg, err := loadConfigFromViper() + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading config: %v\n", err) + os.Exit(1) + } + Cfg = cfg + verbose := v.GetBool("verbose") debug := v.GetBool("logging.debug") logDir := v.GetString("logging.dir") diff --git a/cmd/status.go b/cmd/status.go index 2b700159..9362380f 100644 --- a/cmd/status.go +++ b/cmd/status.go @@ -59,7 +59,7 @@ var statusCmd = &cobra.Command{ func fetchModels(cfg *config.Config) (*struct { Data []string `json:"data"` }, error) { - services := container.NewServiceContainer(cfg, V) + services := container.NewServiceContainer(cfg) timeout := time.Duration(cfg.Gateway.Timeout) * time.Second ctx, cancel := context.WithTimeout(context.Background(), timeout) diff --git a/config/agents.go b/config/agents.go index f913b02d..3fa7753c 100644 --- a/config/agents.go +++ b/config/agents.go @@ -1,6 +1,14 @@ package config -import "strings" +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "strings" + + yaml "gopkg.in/yaml.v3" +) // AgentsConfig represents the agents.yaml configuration file type AgentsConfig struct { @@ -60,3 +68,150 @@ func (a *AgentEntry) GetEnvironmentWithModel() map[string]string { return env } + +// LoadAgents reads agents.yaml from disk. When the file is missing it +// returns the in-code defaults so callers can treat absence as "use +// defaults" without special-casing. +func LoadAgents(path string) (*AgentsConfig, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return DefaultAgentsConfig(), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read agents config: %w", err) + } + + expandedData := os.ExpandEnv(string(data)) + + var cfg AgentsConfig + if err := yaml.Unmarshal([]byte(expandedData), &cfg); err != nil { + return nil, fmt.Errorf("failed to parse agents config: %w", err) + } + + return &cfg, nil +} + +// SaveAgents writes the agents configuration to disk, creating any +// missing parent directories. +func SaveAgents(path string, cfg *AgentsConfig) error { + var buf bytes.Buffer + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + + if err := encoder.Encode(cfg); err != nil { + return fmt.Errorf("failed to marshal agents config: %w", err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("failed to close encoder: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write agents config: %w", err) + } + + return nil +} + +// AddAgent adds a new agent to the configuration file. +func AddAgent(path string, agent AgentEntry) error { + cfg, err := LoadAgents(path) + if err != nil { + return err + } + + for _, existing := range cfg.Agents { + if existing.Name == agent.Name { + return fmt.Errorf("agent with name '%s' already exists", agent.Name) + } + } + + cfg.Agents = append(cfg.Agents, agent) + return SaveAgents(path, cfg) +} + +// UpdateAgent updates an existing agent by name. +func UpdateAgent(path string, agent AgentEntry) error { + cfg, err := LoadAgents(path) + if err != nil { + return err + } + + for i, existing := range cfg.Agents { + if existing.Name == agent.Name { + cfg.Agents[i] = agent + return SaveAgents(path, cfg) + } + } + + return fmt.Errorf("agent with name '%s' not found", agent.Name) +} + +// RemoveAgent removes an agent by name. +func RemoveAgent(path, name string) error { + cfg, err := LoadAgents(path) + if err != nil { + return err + } + + found := false + newAgents := make([]AgentEntry, 0, len(cfg.Agents)) + for _, agent := range cfg.Agents { + if agent.Name != name { + newAgents = append(newAgents, agent) + } else { + found = true + } + } + + if !found { + return fmt.Errorf("agent with name '%s' not found", name) + } + + cfg.Agents = newAgents + return SaveAgents(path, cfg) +} + +// ListAgents returns all configured agents. +func ListAgents(path string) ([]AgentEntry, error) { + cfg, err := LoadAgents(path) + if err != nil { + return nil, err + } + return cfg.Agents, nil +} + +// GetAgent returns a single agent entry by name. +func GetAgent(path, name string) (*AgentEntry, error) { + cfg, err := LoadAgents(path) + if err != nil { + return nil, err + } + + for _, agent := range cfg.Agents { + if agent.Name == name { + return &agent, nil + } + } + + return nil, fmt.Errorf("agent with name '%s' not found", name) +} + +// GetAgentURLs returns URLs of all configured agents. +func GetAgentURLs(path string) ([]string, error) { + agents, err := ListAgents(path) + if err != nil { + return nil, err + } + + urls := make([]string, 0, len(agents)) + for _, agent := range agents { + urls = append(urls, agent.URL) + } + return urls, nil +} diff --git a/internal/services/agents_config_test.go b/config/agents_persistence_test.go similarity index 70% rename from internal/services/agents_config_test.go rename to config/agents_persistence_test.go index e8fdef22..b91c09db 100644 --- a/internal/services/agents_config_test.go +++ b/config/agents_persistence_test.go @@ -1,4 +1,4 @@ -package services +package config_test import ( "os" @@ -9,10 +9,9 @@ import ( require "github.com/stretchr/testify/require" ) -func TestAgentsConfigService_AddAgent(t *testing.T) { +func TestAddAgent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) agent := config.AgentEntry{ Name: "test-agent", @@ -25,195 +24,152 @@ func TestAgentsConfigService_AddAgent(t *testing.T) { }, } - err := svc.AddAgent(agent) - if err != nil { + if err := config.AddAgent(agentsPath, agent); err != nil { t.Fatalf("Failed to add agent: %v", err) } - if _, err := os.Stat(agentsPath); os.IsNotExist(err) { t.Fatal("Agents config file was not created") } - cfg, err := svc.Load() + cfg, err := config.LoadAgents(agentsPath) if err != nil { t.Fatalf("Failed to load config: %v", err) } - if len(cfg.Agents) != 1 { t.Fatalf("Expected 1 agent, got %d", len(cfg.Agents)) } - if cfg.Agents[0].Name != agent.Name { t.Errorf("Expected name %s, got %s", agent.Name, cfg.Agents[0].Name) } - if cfg.Agents[0].URL != agent.URL { t.Errorf("Expected URL %s, got %s", agent.URL, cfg.Agents[0].URL) } - if cfg.Agents[0].OCI != agent.OCI { t.Errorf("Expected OCI %s, got %s", agent.OCI, cfg.Agents[0].OCI) } - if cfg.Agents[0].Run != agent.Run { t.Errorf("Expected Run %v, got %v", agent.Run, cfg.Agents[0].Run) } - if len(cfg.Agents[0].Environment) != 2 { t.Errorf("Expected 2 environment variables, got %d", len(cfg.Agents[0].Environment)) } } -func TestAgentsConfigService_AddDuplicateAgent(t *testing.T) { +func TestAddAgent_Duplicate(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - agent := config.AgentEntry{ - Name: "test-agent", - URL: "https://agent.example.com", - } + agent := config.AgentEntry{Name: "test-agent", URL: "https://agent.example.com"} - err := svc.AddAgent(agent) - if err != nil { + if err := config.AddAgent(agentsPath, agent); err != nil { t.Fatalf("Failed to add agent: %v", err) } - - err = svc.AddAgent(agent) - if err == nil { + if err := config.AddAgent(agentsPath, agent); err == nil { t.Fatal("Expected error when adding duplicate agent, got nil") } } -func TestAgentsConfigService_RemoveAgent(t *testing.T) { +func TestRemoveAgent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - - agent1 := config.AgentEntry{ - Name: "agent1", - URL: "https://agent1.example.com", - } - agent2 := config.AgentEntry{ - Name: "agent2", - URL: "https://agent2.example.com", - } + agent1 := config.AgentEntry{Name: "agent1", URL: "https://agent1.example.com"} + agent2 := config.AgentEntry{Name: "agent2", URL: "https://agent2.example.com"} - require.NoError(t, svc.AddAgent(agent1)) - require.NoError(t, svc.AddAgent(agent2)) + require.NoError(t, config.AddAgent(agentsPath, agent1)) + require.NoError(t, config.AddAgent(agentsPath, agent2)) - err := svc.RemoveAgent("agent1") - if err != nil { + if err := config.RemoveAgent(agentsPath, "agent1"); err != nil { t.Fatalf("Failed to remove agent: %v", err) } - agents, err := svc.ListAgents() + agents, err := config.ListAgents(agentsPath) if err != nil { t.Fatalf("Failed to list agents: %v", err) } - if len(agents) != 1 { t.Fatalf("Expected 1 agent, got %d", len(agents)) } - if agents[0].Name != "agent2" { t.Errorf("Expected agent2 to remain, got %s", agents[0].Name) } } -func TestAgentsConfigService_RemoveNonexistentAgent(t *testing.T) { +func TestRemoveAgent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - err := svc.RemoveAgent("nonexistent") - if err == nil { + if err := config.RemoveAgent(agentsPath, "nonexistent"); err == nil { t.Fatal("Expected error when removing nonexistent agent, got nil") } } -func TestAgentsConfigService_ListAgents(t *testing.T) { +func TestListAgents(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) agents := []config.AgentEntry{ {Name: "agent1", URL: "https://agent1.example.com"}, {Name: "agent2", URL: "https://agent2.example.com"}, {Name: "agent3", URL: "https://agent3.example.com"}, } - for _, agent := range agents { - require.NoError(t, svc.AddAgent(agent)) + require.NoError(t, config.AddAgent(agentsPath, agent)) } - listed, err := svc.ListAgents() + listed, err := config.ListAgents(agentsPath) if err != nil { t.Fatalf("Failed to list agents: %v", err) } - if len(listed) != 3 { t.Fatalf("Expected 3 agents, got %d", len(listed)) } } -func TestAgentsConfigService_GetAgent(t *testing.T) { +func TestGetAgent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - agent := config.AgentEntry{ - Name: "test-agent", - URL: "https://agent.example.com", - } - - require.NoError(t, svc.AddAgent(agent)) + agent := config.AgentEntry{Name: "test-agent", URL: "https://agent.example.com"} + require.NoError(t, config.AddAgent(agentsPath, agent)) - retrieved, err := svc.GetAgent("test-agent") + retrieved, err := config.GetAgent(agentsPath, "test-agent") if err != nil { t.Fatalf("Failed to get agent: %v", err) } - if retrieved.Name != agent.Name { t.Errorf("Expected name %s, got %s", agent.Name, retrieved.Name) } - if retrieved.URL != agent.URL { t.Errorf("Expected URL %s, got %s", agent.URL, retrieved.URL) } } -func TestAgentsConfigService_GetNonexistentAgent(t *testing.T) { +func TestGetAgent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - _, err := svc.GetAgent("nonexistent") - if err == nil { + if _, err := config.GetAgent(agentsPath, "nonexistent"); err == nil { t.Fatal("Expected error when getting nonexistent agent, got nil") } } -func TestAgentsConfigService_GetAgentURLs(t *testing.T) { +func TestGetAgentURLs(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) agents := []config.AgentEntry{ {Name: "agent1", URL: "https://agent1.example.com"}, {Name: "agent2", URL: "https://agent2.example.com"}, } - for _, agent := range agents { - require.NoError(t, svc.AddAgent(agent)) + require.NoError(t, config.AddAgent(agentsPath, agent)) } - urls, err := svc.GetAgentURLs() + urls, err := config.GetAgentURLs(agentsPath) if err != nil { t.Fatalf("Failed to get agent URLs: %v", err) } - if len(urls) != 2 { t.Fatalf("Expected 2 URLs, got %d", len(urls)) } @@ -222,14 +178,12 @@ func TestAgentsConfigService_GetAgentURLs(t *testing.T) { "https://agent1.example.com": false, "https://agent2.example.com": false, } - for _, url := range urls { if _, exists := expectedURLs[url]; !exists { t.Errorf("Unexpected URL: %s", url) } expectedURLs[url] = true } - for url, found := range expectedURLs { if !found { t.Errorf("Expected URL not found: %s", url) @@ -237,25 +191,22 @@ func TestAgentsConfigService_GetAgentURLs(t *testing.T) { } } -func TestAgentsConfigService_LoadNonexistentFile(t *testing.T) { +func TestLoadAgents_NonexistentFile(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "nonexistent.yaml") - svc := NewAgentsConfigService(agentsPath) - cfg, err := svc.Load() + cfg, err := config.LoadAgents(agentsPath) if err != nil { t.Fatalf("Expected no error for nonexistent file, got: %v", err) } - if len(cfg.Agents) != 0 { t.Errorf("Expected empty agents list, got %d agents", len(cfg.Agents)) } } -func TestAgentsConfigService_UpdateAgent(t *testing.T) { +func TestUpdateAgent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) agent := config.AgentEntry{ Name: "test-agent", @@ -267,8 +218,7 @@ func TestAgentsConfigService_UpdateAgent(t *testing.T) { "API_KEY": "secret", }, } - - require.NoError(t, svc.AddAgent(agent)) + require.NoError(t, config.AddAgent(agentsPath, agent)) updatedAgent := config.AgentEntry{ Name: "test-agent", @@ -280,67 +230,52 @@ func TestAgentsConfigService_UpdateAgent(t *testing.T) { "DEBUG": "true", }, } + require.NoError(t, config.UpdateAgent(agentsPath, updatedAgent)) - err := svc.UpdateAgent(updatedAgent) - require.NoError(t, err) - - retrieved, err := svc.GetAgent("test-agent") + retrieved, err := config.GetAgent(agentsPath, "test-agent") require.NoError(t, err) if retrieved.URL != updatedAgent.URL { t.Errorf("Expected URL %s, got %s", updatedAgent.URL, retrieved.URL) } - if retrieved.OCI != updatedAgent.OCI { t.Errorf("Expected OCI %s, got %s", updatedAgent.OCI, retrieved.OCI) } - if retrieved.Run != updatedAgent.Run { t.Errorf("Expected Run %v, got %v", updatedAgent.Run, retrieved.Run) } - if retrieved.Model != updatedAgent.Model { t.Errorf("Expected Model %s, got %s", updatedAgent.Model, retrieved.Model) } - if len(retrieved.Environment) != 1 || retrieved.Environment["DEBUG"] != "true" { t.Errorf("Expected Environment to be updated, got %v", retrieved.Environment) } - agents, err := svc.ListAgents() + agents, err := config.ListAgents(agentsPath) require.NoError(t, err) - if len(agents) != 1 { t.Errorf("Expected 1 agent after update, got %d", len(agents)) } } -func TestAgentsConfigService_UpdateNonexistentAgent(t *testing.T) { +func TestUpdateAgent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - svc := NewAgentsConfigService(agentsPath) - agent := config.AgentEntry{ - Name: "nonexistent", - URL: "https://agent.example.com", - } - - err := svc.UpdateAgent(agent) - if err == nil { + agent := config.AgentEntry{Name: "nonexistent", URL: "https://agent.example.com"} + if err := config.UpdateAgent(agentsPath, agent); err == nil { t.Fatal("Expected error when updating nonexistent agent, got nil") } } -func TestAgentsConfigService_EnvironmentVariableExpansion(t *testing.T) { +func TestLoadAgents_EnvironmentVariableExpansion(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - // Set test environment variables t.Setenv("TEST_API_KEY", "secret-key-123") t.Setenv("TEST_MODEL", "gpt-4") t.Setenv("TEST_DEBUG", "true") - // Write config file with env var placeholders configContent := `agents: - name: test-agent url: http://localhost:8080 @@ -354,12 +289,10 @@ func TestAgentsConfigService_EnvironmentVariableExpansion(t *testing.T) { ` require.NoError(t, os.WriteFile(agentsPath, []byte(configContent), 0644)) - svc := NewAgentsConfigService(agentsPath) - cfg, err := svc.Load() + cfg, err := config.LoadAgents(agentsPath) require.NoError(t, err) require.Len(t, cfg.Agents, 1, "Expected 1 agent") - agent := cfg.Agents[0] require.Equal(t, "test-agent", agent.Name) require.Equal(t, "openai/gpt-4", agent.Model) @@ -370,11 +303,10 @@ func TestAgentsConfigService_EnvironmentVariableExpansion(t *testing.T) { require.Equal(t, "static-value", agent.Environment["STATIC_VAR"], "STATIC_VAR should remain unchanged") } -func TestAgentsConfigService_EnvironmentVariableExpansion_UndefinedVar(t *testing.T) { +func TestLoadAgents_EnvironmentVariableExpansion_UndefinedVar(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - // Write config file with undefined env var configContent := `agents: - name: test-agent url: http://localhost:8080 @@ -385,23 +317,20 @@ func TestAgentsConfigService_EnvironmentVariableExpansion_UndefinedVar(t *testin ` require.NoError(t, os.WriteFile(agentsPath, []byte(configContent), 0644)) - svc := NewAgentsConfigService(agentsPath) - cfg, err := svc.Load() + cfg, err := config.LoadAgents(agentsPath) require.NoError(t, err) require.Len(t, cfg.Agents, 1) - // Undefined env vars expand to empty string require.Equal(t, "", cfg.Agents[0].Environment["UNDEFINED"]) } -func TestAgentsConfigService_EnvironmentVariableExpansion_MixedSyntax(t *testing.T) { +func TestLoadAgents_EnvironmentVariableExpansion_MixedSyntax(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") t.Setenv("VAR1", "value1") t.Setenv("VAR2", "value2") - // Test both $VAR and ${VAR} syntax configContent := `agents: - name: test-agent url: http://localhost:8080 @@ -414,8 +343,7 @@ func TestAgentsConfigService_EnvironmentVariableExpansion_MixedSyntax(t *testing ` require.NoError(t, os.WriteFile(agentsPath, []byte(configContent), 0644)) - svc := NewAgentsConfigService(agentsPath) - cfg, err := svc.Load() + cfg, err := config.LoadAgents(agentsPath) require.NoError(t, err) agent := cfg.Agents[0] diff --git a/config/config.go b/config/config.go index 468a8fe4..a9cc7614 100644 --- a/config/config.go +++ b/config/config.go @@ -1049,7 +1049,6 @@ func DefaultConfig() *Config { //nolint:funlen // IsApprovalRequired checks if approval is required for a specific tool // It returns true if tool-specific approval is set to true, or if global approval is true and tool-specific is not set to false -// ConfigService interface implementation func (c *Config) IsApprovalRequired(toolName string) bool { // nolint:gocyclo,cyclop globalApproval := c.Tools.Safety.RequireApproval @@ -1200,6 +1199,22 @@ func (c *Config) GetConfigDir() string { return c.configDir } +// ResolveConfigDir searches the standard project then userspace locations +// for an existing config.yaml and returns its directory. Falls back to the +// default project directory name when nothing is found on disk. +func ResolveConfigDir() string { + candidates := []string{DefaultConfigPath} + if homeDir, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, filepath.Join(homeDir, ConfigDirName, ConfigFileName)) + } + for _, path := range candidates { + if _, err := os.Stat(path); err == nil { + return filepath.Dir(path) + } + } + return ConfigDirName +} + // IsBashCommandWhitelisted checks if a specific bash command is whitelisted func (c *Config) IsBashCommandWhitelisted(command string) bool { command = strings.TrimSpace(command) diff --git a/config/keybindings.go b/config/keybindings.go index 1811c19b..509e7287 100644 --- a/config/keybindings.go +++ b/config/keybindings.go @@ -1,10 +1,77 @@ package config +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + yaml "gopkg.in/yaml.v3" +) + const ( KeybindingsFileName = "keybindings.yaml" DefaultKeybindingsPath = ConfigDirName + "/" + KeybindingsFileName ) +// DefaultKeybindingsConfig returns the default keybindings config used when +// no file exists. Callers (init, reset) use it to seed a fresh file. +func DefaultKeybindingsConfig() *KeybindingsConfig { + return &KeybindingsConfig{ + Enabled: true, + Bindings: GetDefaultKeybindings(), + } +} + +// LoadKeybindings reads keybindings.yaml from disk. When the file is +// missing it returns the in-code defaults so callers can treat absence +// as "use defaults" without special-casing. +func LoadKeybindings(path string) (*KeybindingsConfig, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return DefaultKeybindingsConfig(), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read keybindings config: %w", err) + } + + var cfg KeybindingsConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse keybindings config: %w", err) + } + + return &cfg, nil +} + +// SaveKeybindings writes the keybindings configuration to disk, creating +// any missing parent directories. +func SaveKeybindings(path string, cfg *KeybindingsConfig) error { + var buf bytes.Buffer + buf.WriteString("---\n") + + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + + if err := encoder.Encode(cfg); err != nil { + return fmt.Errorf("failed to marshal keybindings config: %w", err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("failed to close encoder: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write keybindings config: %w", err) + } + + return nil +} + // GetDefaultKeybindings returns the default keybinding configuration // Users can override these in their config file, and any missing entries // will fall back to these defaults diff --git a/internal/services/keybindings_config_test.go b/config/keybindings_persistence_test.go similarity index 70% rename from internal/services/keybindings_config_test.go rename to config/keybindings_persistence_test.go index 0283be65..5c5aea0f 100644 --- a/internal/services/keybindings_config_test.go +++ b/config/keybindings_persistence_test.go @@ -1,4 +1,4 @@ -package services +package config_test import ( "os" @@ -8,18 +8,16 @@ import ( config "github.com/inference-gateway/cli/config" ) -func TestKeybindingsConfigService_Load_NonExistentFile(t *testing.T) { +func TestLoadKeybindings_NonExistentFile(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "non-existent.yaml") - service := NewKeybindingsConfigService(configPath) - cfg, err := service.Load() - + cfg, err := config.LoadKeybindings(configPath) if err != nil { - t.Fatalf("Load() should not error for non-existent file, got: %v", err) + t.Fatalf("LoadKeybindings() should not error for non-existent file, got: %v", err) } if cfg == nil { - t.Fatal("Load() returned nil config") + t.Fatal("LoadKeybindings() returned nil config") } if !cfg.Enabled { t.Error("Default keybindings config should be enabled") @@ -33,7 +31,7 @@ func TestKeybindingsConfigService_Load_NonExistentFile(t *testing.T) { } } -func TestKeybindingsConfigService_Load_ValidYAML(t *testing.T) { +func TestLoadKeybindings_ValidYAML(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "keybindings.yaml") @@ -52,10 +50,9 @@ bindings: t.Fatalf("Failed to write test config file: %v", err) } - service := NewKeybindingsConfigService(configPath) - cfg, err := service.Load() + cfg, err := config.LoadKeybindings(configPath) if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("LoadKeybindings() failed: %v", err) } if !cfg.Enabled { t.Error("Expected Enabled to be true") @@ -69,10 +66,9 @@ bindings: } } -func TestKeybindingsConfigService_Save_RoundTrip(t *testing.T) { +func TestSaveKeybindings_RoundTrip(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "keybindings.yaml") - service := NewKeybindingsConfigService(configPath) enabled := true original := &config.KeybindingsConfig{ @@ -87,16 +83,16 @@ func TestKeybindingsConfigService_Save_RoundTrip(t *testing.T) { }, } - if err := service.Save(original); err != nil { - t.Fatalf("Save() failed: %v", err) + if err := config.SaveKeybindings(configPath, original); err != nil { + t.Fatalf("SaveKeybindings() failed: %v", err) } if _, err := os.Stat(configPath); err != nil { - t.Fatalf("Save() did not create file: %v", err) + t.Fatalf("SaveKeybindings() did not create file: %v", err) } - loaded, err := service.Load() + loaded, err := config.LoadKeybindings(configPath) if err != nil { - t.Fatalf("Load() after save failed: %v", err) + t.Fatalf("LoadKeybindings() after save failed: %v", err) } binding, ok := loaded.Bindings["global_quit"] if !ok { @@ -110,13 +106,12 @@ func TestKeybindingsConfigService_Save_RoundTrip(t *testing.T) { } } -func TestKeybindingsConfigService_Save_CreatesParentDirectory(t *testing.T) { +func TestSaveKeybindings_CreatesParentDirectory(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "nested", "deep", "keybindings.yaml") - service := NewKeybindingsConfigService(configPath) - if err := service.Save(DefaultKeybindingsConfig()); err != nil { - t.Fatalf("Save() failed to create nested dirs: %v", err) + if err := config.SaveKeybindings(configPath, config.DefaultKeybindingsConfig()); err != nil { + t.Fatalf("SaveKeybindings() failed to create nested dirs: %v", err) } if _, err := os.Stat(configPath); err != nil { t.Fatalf("File not created at nested path: %v", err) @@ -124,7 +119,7 @@ func TestKeybindingsConfigService_Save_CreatesParentDirectory(t *testing.T) { } func TestDefaultKeybindingsConfig(t *testing.T) { - cfg := DefaultKeybindingsConfig() + cfg := config.DefaultKeybindingsConfig() if cfg == nil { t.Fatal("DefaultKeybindingsConfig returned nil") } diff --git a/config/mcp.go b/config/mcp.go index 3ad45988..c7699cee 100644 --- a/config/mcp.go +++ b/config/mcp.go @@ -1,8 +1,13 @@ package config import ( + "bytes" "fmt" + "os" + "path/filepath" "strings" + + yaml "gopkg.in/yaml.v3" ) // MCPConfig represents the mcp.yaml configuration file @@ -152,3 +157,187 @@ const ( MCPFileName = "mcp.yaml" DefaultMCPPath = ConfigDirName + "/" + MCPFileName ) + +// LoadMCP reads mcp.yaml from disk. When the file is missing it returns +// the in-code defaults so callers can treat absence as "use defaults" +// without special-casing. +func LoadMCP(path string) (*MCPConfig, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return DefaultMCPConfig(), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read MCP config: %w", err) + } + + expandedData := os.ExpandEnv(string(data)) + + var cfg MCPConfig + if err := yaml.Unmarshal([]byte(expandedData), &cfg); err != nil { + return nil, fmt.Errorf("failed to parse MCP config: %w", err) + } + + return &cfg, nil +} + +// SaveMCP writes the MCP configuration to disk, creating any missing +// parent directories. +func SaveMCP(path string, cfg *MCPConfig) error { + var buf bytes.Buffer + buf.WriteString("---\n") + + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + + if err := encoder.Encode(cfg); err != nil { + return fmt.Errorf("failed to marshal MCP config: %w", err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("failed to close encoder: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write MCP config: %w", err) + } + + return nil +} + +// AddMCPServer adds a new MCP server entry to the configuration file. +func AddMCPServer(path string, server MCPServerEntry) error { + cfg, err := LoadMCP(path) + if err != nil { + return err + } + + for _, existing := range cfg.Servers { + if existing.Name == server.Name { + return fmt.Errorf("MCP server with name '%s' already exists", server.Name) + } + } + + cfg.Servers = append(cfg.Servers, server) + return SaveMCP(path, cfg) +} + +// UpdateMCPServer updates an existing MCP server entry by name. +func UpdateMCPServer(path string, server MCPServerEntry) error { + cfg, err := LoadMCP(path) + if err != nil { + return err + } + + for i, existing := range cfg.Servers { + if existing.Name == server.Name { + cfg.Servers[i] = server + return SaveMCP(path, cfg) + } + } + + return fmt.Errorf("MCP server with name '%s' not found", server.Name) +} + +// RemoveMCPServer removes an MCP server entry by name. +func RemoveMCPServer(path, name string) error { + cfg, err := LoadMCP(path) + if err != nil { + return err + } + + found := false + newServers := make([]MCPServerEntry, 0, len(cfg.Servers)) + for _, server := range cfg.Servers { + if server.Name != name { + newServers = append(newServers, server) + } else { + found = true + } + } + + if !found { + return fmt.Errorf("MCP server with name '%s' not found", name) + } + + cfg.Servers = newServers + return SaveMCP(path, cfg) +} + +// ListMCPServers returns all configured MCP servers. +func ListMCPServers(path string) ([]MCPServerEntry, error) { + cfg, err := LoadMCP(path) + if err != nil { + return nil, err + } + return cfg.Servers, nil +} + +// GetMCPServer returns a single MCP server entry by name. +func GetMCPServer(path, name string) (*MCPServerEntry, error) { + cfg, err := LoadMCP(path) + if err != nil { + return nil, err + } + + for _, server := range cfg.Servers { + if server.Name == name { + return &server, nil + } + } + + return nil, fmt.Errorf("MCP server with name '%s' not found", name) +} + +// MergeMCP merges an optional mcp.yaml config on top of a base config. +// Optional values take precedence; servers from both are combined and +// optional entries override base entries with the same name. +func MergeMCP(base *MCPConfig, optional *MCPConfig) *MCPConfig { + if optional == nil { + return base + } + + merged := &MCPConfig{ + Enabled: optional.Enabled || base.Enabled, + ConnectionTimeout: base.ConnectionTimeout, + DiscoveryTimeout: base.DiscoveryTimeout, + LivenessProbeEnabled: base.LivenessProbeEnabled, + LivenessProbeInterval: base.LivenessProbeInterval, + MaxRetries: base.MaxRetries, + Servers: make([]MCPServerEntry, 0), + } + + if optional.ConnectionTimeout > 0 { + merged.ConnectionTimeout = optional.ConnectionTimeout + } + if optional.DiscoveryTimeout > 0 { + merged.DiscoveryTimeout = optional.DiscoveryTimeout + } + if optional.LivenessProbeInterval > 0 { + merged.LivenessProbeInterval = optional.LivenessProbeInterval + } + if optional.MaxRetries > 0 { + merged.MaxRetries = optional.MaxRetries + } + if optional.LivenessProbeEnabled { + merged.LivenessProbeEnabled = true + } + + serverMap := make(map[string]MCPServerEntry) + for _, server := range base.Servers { + serverMap[server.Name] = server + } + for _, server := range optional.Servers { + serverMap[server.Name] = server + } + + for _, server := range serverMap { + merged.Servers = append(merged.Servers, server) + } + + return merged +} diff --git a/internal/services/mcp_config_test.go b/config/mcp_persistence_test.go similarity index 62% rename from internal/services/mcp_config_test.go rename to config/mcp_persistence_test.go index 29408c1c..2e8634ab 100644 --- a/internal/services/mcp_config_test.go +++ b/config/mcp_persistence_test.go @@ -1,4 +1,4 @@ -package services +package config_test import ( "os" @@ -8,19 +8,16 @@ import ( config "github.com/inference-gateway/cli/config" ) -func TestMCPConfigService_Load_NonExistentFile(t *testing.T) { +func TestLoadMCP_NonExistentFile(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "non-existent.yaml") - service := NewMCPConfigService(configPath) - cfg, err := service.Load() - + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() should not error for non-existent file, got: %v", err) + t.Fatalf("LoadMCP() should not error for non-existent file, got: %v", err) } - if cfg == nil { - t.Fatal("Load() returned nil config") + t.Fatal("LoadMCP() returned nil config") } defaultCfg := config.DefaultMCPConfig() @@ -29,7 +26,7 @@ func TestMCPConfigService_Load_NonExistentFile(t *testing.T) { } } -func TestMCPConfigService_Load_ValidYAML(t *testing.T) { +func TestLoadMCP_ValidYAML(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") @@ -53,30 +50,24 @@ servers: - tool3 ` - err := os.WriteFile(configPath, []byte(yamlContent), 0644) - if err != nil { + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { t.Fatalf("Failed to write test config file: %v", err) } - service := NewMCPConfigService(configPath) - cfg, err := service.Load() - + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("LoadMCP() failed: %v", err) } if !cfg.Enabled { t.Error("Expected Enabled to be true") } - if cfg.ConnectionTimeout != 60 { t.Errorf("Expected ConnectionTimeout=60, got %d", cfg.ConnectionTimeout) } - if cfg.DiscoveryTimeout != 45 { t.Errorf("Expected DiscoveryTimeout=45, got %d", cfg.DiscoveryTimeout) } - if len(cfg.Servers) != 1 { t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) } @@ -90,36 +81,25 @@ servers: if server.GetURL() != expectedURL { t.Errorf("Expected server URL %q, got %q", expectedURL, server.GetURL()) } - if !server.Enabled { t.Error("Expected server to be enabled") } - if server.Timeout != 30 { t.Errorf("Expected server timeout=30, got %d", server.Timeout) } - if len(server.IncludeTools) != 2 { t.Errorf("Expected 2 include tools, got %d", len(server.IncludeTools)) } - if len(server.ExcludeTools) != 1 { t.Errorf("Expected 1 exclude tool, got %d", len(server.ExcludeTools)) } } -func TestMCPConfigService_Load_EnvironmentVariableExpansion(t *testing.T) { +func TestLoadMCP_EnvironmentVariableExpansion(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - if err := os.Setenv("TEST_MCP_URL", "http://env-server:8080/sse"); err != nil { - t.Fatalf("Failed to set environment variable: %v", err) - } - defer func() { - if err := os.Unsetenv("TEST_MCP_URL"); err != nil { - t.Logf("Failed to unset environment variable: %v", err) - } - }() + t.Setenv("TEST_MCP_URL", "http://env-server:8080/sse") yamlContent := `enabled: true servers: @@ -132,18 +112,14 @@ servers: enabled: true ` - err := os.WriteFile(configPath, []byte(yamlContent), 0644) - if err != nil { + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { t.Fatalf("Failed to write test config file: %v", err) } - service := NewMCPConfigService(configPath) - cfg, err := service.Load() - + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("LoadMCP() failed: %v", err) } - if len(cfg.Servers) != 1 { t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) } @@ -155,12 +131,10 @@ servers: } } -func TestMCPConfigService_Save(t *testing.T) { +func TestSaveMCP(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "subdir", "mcp.yaml") - service := NewMCPConfigService(configPath) - cfg := &config.MCPConfig{ Enabled: true, ConnectionTimeout: 60, @@ -179,35 +153,29 @@ func TestMCPConfigService_Save(t *testing.T) { }, } - err := service.Save(cfg) - if err != nil { - t.Fatalf("Save() failed: %v", err) + if err := config.SaveMCP(configPath, cfg); err != nil { + t.Fatalf("SaveMCP() failed: %v", err) } - if _, err := os.Stat(configPath); os.IsNotExist(err) { t.Fatal("Config file was not created") } - loadedCfg, err := service.Load() + loadedCfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() after Save() failed: %v", err) + t.Fatalf("LoadMCP() after SaveMCP() failed: %v", err) } - if loadedCfg.Enabled != cfg.Enabled { t.Errorf("Expected Enabled=%v, got %v", cfg.Enabled, loadedCfg.Enabled) } - if len(loadedCfg.Servers) != len(cfg.Servers) { t.Errorf("Expected %d servers, got %d", len(cfg.Servers), len(loadedCfg.Servers)) } } -func TestMCPConfigService_AddServer(t *testing.T) { +func TestAddMCPServer(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - newServer := config.MCPServerEntry{ Name: "new-server", Scheme: "http", @@ -218,31 +186,26 @@ func TestMCPConfigService_AddServer(t *testing.T) { Description: "New server", } - err := service.AddServer(newServer) - if err != nil { - t.Fatalf("AddServer() failed: %v", err) + if err := config.AddMCPServer(configPath, newServer); err != nil { + t.Fatalf("AddMCPServer() failed: %v", err) } - cfg, err := service.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() after AddServer() failed: %v", err) + t.Fatalf("LoadMCP() after AddMCPServer() failed: %v", err) } - if len(cfg.Servers) != 1 { t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) } - if cfg.Servers[0].Name != newServer.Name { t.Errorf("Expected server name %q, got %q", newServer.Name, cfg.Servers[0].Name) } } -func TestMCPConfigService_AddServer_DuplicateName(t *testing.T) { +func TestAddMCPServer_DuplicateName(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - server := config.MCPServerEntry{ Name: "duplicate-server", Scheme: "http", @@ -252,23 +215,18 @@ func TestMCPConfigService_AddServer_DuplicateName(t *testing.T) { Enabled: true, } - err := service.AddServer(server) - if err != nil { - t.Fatalf("First AddServer() failed: %v", err) + if err := config.AddMCPServer(configPath, server); err != nil { + t.Fatalf("First AddMCPServer() failed: %v", err) } - - err = service.AddServer(server) - if err == nil { + if err := config.AddMCPServer(configPath, server); err == nil { t.Fatal("Expected error when adding duplicate server, got nil") } } -func TestMCPConfigService_UpdateServer(t *testing.T) { +func TestUpdateMCPServer(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - initialServer := config.MCPServerEntry{ Name: "update-server", Scheme: "http", @@ -277,10 +235,8 @@ func TestMCPConfigService_UpdateServer(t *testing.T) { Path: "/sse", Enabled: true, } - - err := service.AddServer(initialServer) - if err != nil { - t.Fatalf("AddServer() failed: %v", err) + if err := config.AddMCPServer(configPath, initialServer); err != nil { + t.Fatalf("AddMCPServer() failed: %v", err) } updatedServer := config.MCPServerEntry{ @@ -292,17 +248,14 @@ func TestMCPConfigService_UpdateServer(t *testing.T) { Enabled: false, Description: "Updated description", } - - err = service.UpdateServer(updatedServer) - if err != nil { - t.Fatalf("UpdateServer() failed: %v", err) + if err := config.UpdateMCPServer(configPath, updatedServer); err != nil { + t.Fatalf("UpdateMCPServer() failed: %v", err) } - cfg, err := service.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() after UpdateServer() failed: %v", err) + t.Fatalf("LoadMCP() after UpdateMCPServer() failed: %v", err) } - if len(cfg.Servers) != 1 { t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) } @@ -311,22 +264,18 @@ func TestMCPConfigService_UpdateServer(t *testing.T) { if server.GetURL() != updatedServer.GetURL() { t.Errorf("Expected URL %q, got %q", updatedServer.GetURL(), server.GetURL()) } - if server.Enabled != updatedServer.Enabled { t.Errorf("Expected Enabled=%v, got %v", updatedServer.Enabled, server.Enabled) } - if server.Description != updatedServer.Description { t.Errorf("Expected Description %q, got %q", updatedServer.Description, server.Description) } } -func TestMCPConfigService_UpdateServer_NotFound(t *testing.T) { +func TestUpdateMCPServer_NotFound(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - server := config.MCPServerEntry{ Name: "nonexistent-server", Scheme: "http", @@ -335,129 +284,85 @@ func TestMCPConfigService_UpdateServer_NotFound(t *testing.T) { Path: "/sse", Enabled: true, } - - err := service.UpdateServer(server) - if err == nil { + if err := config.UpdateMCPServer(configPath, server); err == nil { t.Fatal("Expected error when updating non-existent server, got nil") } } -func TestMCPConfigService_RemoveServer(t *testing.T) { +func TestRemoveMCPServer(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - - server1 := config.MCPServerEntry{ - Name: "server1", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - } - server2 := config.MCPServerEntry{ - Name: "server2", - Scheme: "http", - Host: "localhost", - Ports: []string{"4000:8080"}, - Path: "/sse", - Enabled: true, - } + server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} + server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: true} - if err := service.AddServer(server1); err != nil { + if err := config.AddMCPServer(configPath, server1); err != nil { t.Fatalf("Failed to add server1: %v", err) } - if err := service.AddServer(server2); err != nil { + if err := config.AddMCPServer(configPath, server2); err != nil { t.Fatalf("Failed to add server2: %v", err) } - err := service.RemoveServer("server1") - if err != nil { - t.Fatalf("RemoveServer() failed: %v", err) + if err := config.RemoveMCPServer(configPath, "server1"); err != nil { + t.Fatalf("RemoveMCPServer() failed: %v", err) } - cfg, err := service.Load() + cfg, err := config.LoadMCP(configPath) if err != nil { - t.Fatalf("Load() after RemoveServer() failed: %v", err) + t.Fatalf("LoadMCP() after RemoveMCPServer() failed: %v", err) } - if len(cfg.Servers) != 1 { t.Fatalf("Expected 1 server after removal, got %d", len(cfg.Servers)) } - if cfg.Servers[0].Name != "server2" { t.Errorf("Expected remaining server to be 'server2', got %q", cfg.Servers[0].Name) } } -func TestMCPConfigService_RemoveServer_NotFound(t *testing.T) { +func TestRemoveMCPServer_NotFound(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - - err := service.RemoveServer("nonexistent-server") - if err == nil { + if err := config.RemoveMCPServer(configPath, "nonexistent-server"); err == nil { t.Fatal("Expected error when removing non-existent server, got nil") } } -func TestMCPConfigService_ListServers(t *testing.T) { +func TestListMCPServers(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - - servers, err := service.ListServers() + servers, err := config.ListMCPServers(configPath) if err != nil { - t.Fatalf("ListServers() failed: %v", err) + t.Fatalf("ListMCPServers() failed: %v", err) } - if len(servers) != 0 { t.Errorf("Expected 0 servers initially, got %d", len(servers)) } - server1 := config.MCPServerEntry{ - Name: "server1", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - } - server2 := config.MCPServerEntry{ - Name: "server2", - Scheme: "http", - Host: "localhost", - Ports: []string{"4000:8080"}, - Path: "/sse", - Enabled: false, - } + server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} + server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: false} - if err := service.AddServer(server1); err != nil { + if err := config.AddMCPServer(configPath, server1); err != nil { t.Fatalf("Failed to add server1: %v", err) } - if err := service.AddServer(server2); err != nil { + if err := config.AddMCPServer(configPath, server2); err != nil { t.Fatalf("Failed to add server2: %v", err) } - servers, err = service.ListServers() + servers, err = config.ListMCPServers(configPath) if err != nil { - t.Fatalf("ListServers() failed: %v", err) + t.Fatalf("ListMCPServers() failed: %v", err) } - if len(servers) != 2 { t.Fatalf("Expected 2 servers, got %d", len(servers)) } } -func TestMCPConfigService_GetServer(t *testing.T) { +func TestGetMCPServer(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - expectedServer := config.MCPServerEntry{ Name: "get-server", Scheme: "http", @@ -467,33 +372,27 @@ func TestMCPConfigService_GetServer(t *testing.T) { Enabled: true, Description: "Test server", } - - if err := service.AddServer(expectedServer); err != nil { + if err := config.AddMCPServer(configPath, expectedServer); err != nil { t.Fatalf("Failed to add server: %v", err) } - server, err := service.GetServer("get-server") + server, err := config.GetMCPServer(configPath, "get-server") if err != nil { - t.Fatalf("GetServer() failed: %v", err) + t.Fatalf("GetMCPServer() failed: %v", err) } - if server.Name != expectedServer.Name { t.Errorf("Expected name %q, got %q", expectedServer.Name, server.Name) } - if server.GetURL() != expectedServer.GetURL() { t.Errorf("Expected URL %q, got %q", expectedServer.GetURL(), server.GetURL()) } } -func TestMCPConfigService_GetServer_NotFound(t *testing.T) { +func TestGetMCPServer_NotFound(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "mcp.yaml") - service := NewMCPConfigService(configPath) - - _, err := service.GetServer("nonexistent-server") - if err == nil { + if _, err := config.GetMCPServer(configPath, "nonexistent-server"); err == nil { t.Fatal("Expected error when getting non-existent server, got nil") } } diff --git a/config/prompts.go b/config/prompts.go index 2a26a21e..fa5e7a36 100644 --- a/config/prompts.go +++ b/config/prompts.go @@ -1,10 +1,68 @@ package config +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + yaml "gopkg.in/yaml.v3" +) + const ( PromptsFileName = "prompts.yaml" DefaultPromptsPath = ConfigDirName + "/" + PromptsFileName ) +// LoadPrompts reads prompts.yaml from disk. When the file is missing it +// returns the in-code defaults so callers can treat absence as "use +// defaults" without special-casing. +func LoadPrompts(path string) (*PromptsConfig, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return DefaultPromptsConfig(), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read prompts config: %w", err) + } + + var cfg PromptsConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("failed to parse prompts config: %w", err) + } + + return &cfg, nil +} + +// SavePrompts writes the prompts configuration to disk, creating any +// missing parent directories. +func SavePrompts(path string, cfg *PromptsConfig) error { + var buf bytes.Buffer + buf.WriteString("---\n") + + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + + if err := encoder.Encode(cfg); err != nil { + return fmt.Errorf("failed to marshal prompts config: %w", err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("failed to close encoder: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write prompts config: %w", err) + } + + return nil +} + // PromptsConfig holds every customisable LLM prompt the CLI ships with. // It mirrors the nested key structure those prompts had when they lived // under .infer/config.yaml so users can move existing values verbatim. diff --git a/internal/services/prompts_config_test.go b/config/prompts_persistence_test.go similarity index 61% rename from internal/services/prompts_config_test.go rename to config/prompts_persistence_test.go index dc243fe0..860c4a91 100644 --- a/internal/services/prompts_config_test.go +++ b/config/prompts_persistence_test.go @@ -1,4 +1,4 @@ -package services +package config_test import ( "os" @@ -9,18 +9,16 @@ import ( config "github.com/inference-gateway/cli/config" ) -func TestPromptsConfigService_Load_NonExistentFile(t *testing.T) { +func TestLoadPrompts_NonExistentFile(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "non-existent.yaml") - service := NewPromptsConfigService(configPath) - cfg, err := service.Load() - + cfg, err := config.LoadPrompts(configPath) if err != nil { - t.Fatalf("Load() should not error for non-existent file, got: %v", err) + t.Fatalf("LoadPrompts() should not error for non-existent file, got: %v", err) } if cfg == nil { - t.Fatal("Load() returned nil config") + t.Fatal("LoadPrompts() returned nil config") } if cfg.Agent.SystemPrompt == "" { t.Error("Default prompts config should populate agent.system_prompt") @@ -33,7 +31,7 @@ func TestPromptsConfigService_Load_NonExistentFile(t *testing.T) { } } -func TestPromptsConfigService_Load_ValidYAML(t *testing.T) { +func TestLoadPrompts_ValidYAML(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "prompts.yaml") @@ -52,10 +50,9 @@ init: t.Fatalf("Failed to write test config file: %v", err) } - service := NewPromptsConfigService(configPath) - cfg, err := service.Load() + cfg, err := config.LoadPrompts(configPath) if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("LoadPrompts() failed: %v", err) } if cfg.Agent.SystemPrompt != "custom agent prompt" { t.Errorf("Expected custom system_prompt, got %q", cfg.Agent.SystemPrompt) @@ -71,7 +68,7 @@ init: } } -func TestPromptsConfigService_Load_PartialYAML(t *testing.T) { +func TestLoadPrompts_PartialYAML(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "prompts.yaml") @@ -84,16 +81,15 @@ agent: t.Fatalf("Failed to write test config file: %v", err) } - service := NewPromptsConfigService(configPath) - cfg, err := service.Load() + cfg, err := config.LoadPrompts(configPath) if err != nil { - t.Fatalf("Load() failed: %v", err) + t.Fatalf("LoadPrompts() failed: %v", err) } if cfg.Agent.SystemPrompt != "only this field is set" { t.Errorf("Expected partial value, got %q", cfg.Agent.SystemPrompt) } - // Other prompt fields stay zero — getConfigFromViper's overlay is what - // fills them back in from defaults at runtime. + // Other prompt fields stay zero — the runtime overlay (cmd applies it + // on top of LoadPrompts) is what fills them back in from defaults. if cfg.Agent.SystemPromptPlan != "" { t.Errorf("Expected unset plan prompt to be empty, got %q", cfg.Agent.SystemPromptPlan) } @@ -102,10 +98,9 @@ agent: } } -func TestPromptsConfigService_Save_RoundTrip(t *testing.T) { +func TestSavePrompts_RoundTrip(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "prompts.yaml") - service := NewPromptsConfigService(configPath) original := &config.PromptsConfig{ Agent: config.PromptsAgentConfig{ @@ -121,13 +116,13 @@ func TestPromptsConfigService_Save_RoundTrip(t *testing.T) { }, } - if err := service.Save(original); err != nil { - t.Fatalf("Save() failed: %v", err) + if err := config.SavePrompts(configPath, original); err != nil { + t.Fatalf("SavePrompts() failed: %v", err) } - loaded, err := service.Load() + loaded, err := config.LoadPrompts(configPath) if err != nil { - t.Fatalf("Load() after save failed: %v", err) + t.Fatalf("LoadPrompts() after save failed: %v", err) } if loaded.Agent.SystemPrompt != original.Agent.SystemPrompt { t.Errorf("agent.system_prompt not preserved, got %q", loaded.Agent.SystemPrompt) @@ -140,26 +135,24 @@ func TestPromptsConfigService_Save_RoundTrip(t *testing.T) { } } -func TestPromptsConfigService_Save_CreatesParentDirectory(t *testing.T) { +func TestSavePrompts_CreatesParentDirectory(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "nested", "deep", "prompts.yaml") - service := NewPromptsConfigService(configPath) - if err := service.Save(DefaultPromptsConfig()); err != nil { - t.Fatalf("Save() failed to create nested dirs: %v", err) + if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { + t.Fatalf("SavePrompts() failed to create nested dirs: %v", err) } if _, err := os.Stat(configPath); err != nil { t.Fatalf("File not created at nested path: %v", err) } } -func TestPromptsConfigService_Save_StartsWithYAMLDocumentMarker(t *testing.T) { +func TestSavePrompts_StartsWithYAMLDocumentMarker(t *testing.T) { tempDir := t.TempDir() configPath := filepath.Join(tempDir, "prompts.yaml") - service := NewPromptsConfigService(configPath) - if err := service.Save(DefaultPromptsConfig()); err != nil { - t.Fatalf("Save() failed: %v", err) + if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { + t.Fatalf("SavePrompts() failed: %v", err) } data, err := os.ReadFile(configPath) @@ -170,31 +163,3 @@ func TestPromptsConfigService_Save_StartsWithYAMLDocumentMarker(t *testing.T) { t.Errorf("Saved file should start with YAML document marker, got: %q", string(data[:min(20, len(data))])) } } - -func TestDefaultPromptsConfig(t *testing.T) { - cfg := DefaultPromptsConfig() - if cfg == nil { - t.Fatal("DefaultPromptsConfig returned nil") - } - if cfg.Agent.SystemPrompt == "" { - t.Error("agent.system_prompt should be non-empty") - } - if cfg.Agent.SystemPromptPlan == "" { - t.Error("agent.system_prompt_plan should be non-empty") - } - if cfg.Agent.SystemPromptRemote == "" { - t.Error("agent.system_prompt_remote should be non-empty") - } - if cfg.Agent.SystemReminders.ReminderText == "" { - t.Error("agent.system_reminders.reminder_text should be non-empty") - } - if cfg.Git.CommitMessage.SystemPrompt == "" { - t.Error("git.commit_message.system_prompt should be non-empty") - } - if cfg.Conversation.TitleGeneration.SystemPrompt == "" { - t.Error("conversation.title_generation.system_prompt should be non-empty") - } - if cfg.Init.Prompt == "" { - t.Error("init.prompt should be non-empty") - } -} diff --git a/internal/agent/agent.go b/internal/agent/agent.go index aab8d2aa..14c7beeb 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -9,6 +9,7 @@ import ( sdk "github.com/inference-gateway/sdk" + config "github.com/inference-gateway/cli/config" constants "github.com/inference-gateway/cli/internal/constants" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -19,7 +20,7 @@ import ( type AgentServiceImpl struct { client domain.SDKClient toolService domain.ToolService - config domain.ConfigService + config *config.Config conversationRepo domain.ConversationRepository a2aAgentService domain.A2AAgentService messageQueue domain.MessageQueue @@ -235,7 +236,7 @@ func (p *eventPublisher) publishToolExecutionCompleted(results []domain.Conversa func NewAgent( client domain.SDKClient, toolService domain.ToolService, - config domain.ConfigService, + cfg *config.Config, conversationRepo domain.ConversationRepository, a2aAgentService domain.A2AAgentService, messageQueue domain.MessageQueue, @@ -246,18 +247,18 @@ func NewAgent( ) *AgentServiceImpl { tokenizer := services.NewTokenizerService(services.DefaultTokenizerConfig()) - approvalPolicy := services.NewStandardApprovalPolicy(config.GetConfig(), stateManager) + approvalPolicy := services.NewStandardApprovalPolicy(cfg, stateManager) return &AgentServiceImpl{ client: client, toolService: toolService, - config: config, + config: cfg, conversationRepo: conversationRepo, a2aAgentService: a2aAgentService, messageQueue: messageQueue, stateManager: stateManager, timeoutSeconds: timeoutSeconds, - maxTokens: config.GetAgentConfig().MaxTokens, + maxTokens: cfg.GetAgentConfig().MaxTokens, optimizer: optimizer, tokenizer: tokenizer, approvalPolicy: approvalPolicy, diff --git a/internal/agent/agent_helpers_test.go b/internal/agent/agent_helpers_test.go index ee3dbf74..4dd4f403 100644 --- a/internal/agent/agent_helpers_test.go +++ b/internal/agent/agent_helpers_test.go @@ -89,15 +89,16 @@ func TestShouldInjectSystemReminder(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemReminders: config.SystemRemindersConfig{ - Enabled: tt.enabled, - Interval: tt.interval, + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemReminders: config.SystemRemindersConfig{ + Enabled: tt.enabled, + Interval: tt.interval, + }, }, - }) + } - agentService := &AgentServiceImpl{config: fakeConfig} + agentService := &AgentServiceImpl{config: cfg} result := agentService.shouldInjectSystemReminder(tt.turns) assert.Equal(t, tt.expected, result) @@ -121,17 +122,18 @@ func TestGetSystemPromptForMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemPrompt: tt.systemPrompt, - SystemPromptPlan: tt.planPrompt, - }) + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemPrompt: tt.systemPrompt, + SystemPromptPlan: tt.planPrompt, + }, + } fakeStateManager := &domainmocks.FakeStateManager{} fakeStateManager.GetAgentModeReturns(tt.mode) agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, stateManager: fakeStateManager, } diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index c5e8450c..14788ce3 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -245,19 +245,20 @@ func TestAgentServiceImpl_StreamingDeltaAccumulation(t *testing.T) { func TestNewAgentService(t *testing.T) { fakeToolService := &domainmocks.FakeToolService{} - fakeConfig := &domainmocks.FakeConfigService{} fakeConversationRepo := &domainmocks.FakeConversationRepository{} fakeStateManager := &domainmocks.FakeStateManager{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - MaxTokens: 4096, - MaxTurns: 10, - }) + cfg := &config.Config{ + Agent: config.AgentConfig{ + MaxTokens: 4096, + MaxTurns: 10, + }, + } agentService := NewAgent( nil, fakeToolService, - fakeConfig, + cfg, fakeConversationRepo, nil, nil, @@ -269,7 +270,7 @@ func TestNewAgentService(t *testing.T) { assert.NotNil(t, agentService) assert.Equal(t, fakeToolService, agentService.toolService) - assert.Equal(t, fakeConfig, agentService.config) + assert.Equal(t, cfg, agentService.config) assert.Equal(t, fakeConversationRepo, agentService.conversationRepo) assert.Equal(t, fakeStateManager, agentService.stateManager) assert.Equal(t, 120, agentService.timeoutSeconds) @@ -419,13 +420,14 @@ func TestAgentServiceImpl_ShouldInjectSystemReminder(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemReminders: tt.remindersConfig, - }) + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemReminders: tt.remindersConfig, + }, + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } result := agentService.shouldInjectSystemReminder(tt.turns) @@ -458,15 +460,16 @@ func TestAgentServiceImpl_GetSystemReminderMessage(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemReminders: config.SystemRemindersConfig{ - ReminderText: tt.reminderText, + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemReminders: config.SystemRemindersConfig{ + ReminderText: tt.reminderText, + }, }, - }) + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } message := agentService.getSystemReminderMessage() @@ -532,12 +535,17 @@ func TestAgentServiceImpl_BuildSandboxInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetSandboxDirectoriesReturns(tt.sandboxDirs) - fakeConfig.GetProtectedPathsReturns(tt.protectedPaths) + cfg := &config.Config{ + Tools: config.ToolsConfig{ + Sandbox: config.SandboxConfig{ + Directories: tt.sandboxDirs, + ProtectedPaths: tt.protectedPaths, + }, + }, + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } result := agentService.buildSandboxInfo() @@ -665,10 +673,7 @@ func TestAgentServiceImpl_ShouldRequireApproval(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.IsApprovalRequiredReturns(tt.isApprovalRequired) - fakeConfig.IsBashCommandWhitelistedReturns(tt.isBashCommandWhitelisted) - fakeConfig.GetConfigReturns(&config.Config{ + cfg := &config.Config{ Tools: config.ToolsConfig{ Safety: config.SafetyConfig{ RequireApproval: tt.isApprovalRequired, @@ -684,12 +689,12 @@ func TestAgentServiceImpl_ShouldRequireApproval(t *testing.T) { }, }, }, - }) + } fakeStateManager := &domainmocks.FakeStateManager{} fakeStateManager.GetAgentModeReturns(tt.agentMode) - approvalPolicy := services.NewStandardApprovalPolicy(fakeConfig.GetConfig(), fakeStateManager) + approvalPolicy := services.NewStandardApprovalPolicy(cfg, fakeStateManager) result := approvalPolicy.ShouldRequireApproval(context.Background(), tt.toolCall, tt.isChatMode) @@ -1172,14 +1177,15 @@ func TestAgentServiceImpl_GetSystemPromptForMode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemPrompt: tt.systemPrompt, - SystemPromptPlan: tt.planPrompt, - }) + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemPrompt: tt.systemPrompt, + SystemPromptPlan: tt.planPrompt, + }, + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } if !tt.nilStateManager { @@ -1196,16 +1202,21 @@ func TestAgentServiceImpl_GetSystemPromptForMode(t *testing.T) { } func TestAgentServiceImpl_AddSystemPrompt(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemPrompt: "You are a helpful assistant.", - SystemPromptWithDefaults: true, - }) - fakeConfig.GetSandboxDirectoriesReturns([]string{"/home/user"}) - fakeConfig.GetProtectedPathsReturns([]string{"/etc"}) + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemPrompt: "You are a helpful assistant.", + SystemPromptWithDefaults: true, + }, + Tools: config.ToolsConfig{ + Sandbox: config.SandboxConfig{ + Directories: []string{"/home/user"}, + ProtectedPaths: []string{"/etc"}, + }, + }, + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } inputMessages := []sdk.Message{ @@ -1228,15 +1239,14 @@ func TestAgentServiceImpl_AddSystemPrompt(t *testing.T) { } func TestAgentServiceImpl_AddSystemPrompt_EmptyPrompt(t *testing.T) { - fakeConfig := &domainmocks.FakeConfigService{} - fakeConfig.GetAgentConfigReturns(&config.AgentConfig{ - SystemPrompt: "", - }) - fakeConfig.GetSandboxDirectoriesReturns([]string{}) - fakeConfig.GetProtectedPathsReturns([]string{}) + cfg := &config.Config{ + Agent: config.AgentConfig{ + SystemPrompt: "", + }, + } agentService := &AgentServiceImpl{ - config: fakeConfig, + config: cfg, } inputMessages := []sdk.Message{ diff --git a/internal/agent/tools/activate_app.go b/internal/agent/tools/activate_app.go index 350dbd95..7731416a 100644 --- a/internal/agent/tools/activate_app.go +++ b/internal/agent/tools/activate_app.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -13,13 +14,13 @@ import ( // ActivateAppTool switches focus to a specific application type ActivateAppTool struct { - config domain.ConfigService + config *config.Config } // NewActivateAppTool creates a new ActivateApp tool -func NewActivateAppTool(config domain.ConfigService) *ActivateAppTool { +func NewActivateAppTool(cfg *config.Config) *ActivateAppTool { return &ActivateAppTool{ - config: config, + config: cfg, } } @@ -115,7 +116,7 @@ func (t *ActivateAppTool) Execute(ctx context.Context, args map[string]any) (*do // IsEnabled returns whether the tool is enabled func (t *ActivateAppTool) IsEnabled() bool { - return t.config.GetConfig().ComputerUse.Enabled + return t.config.ComputerUse.Enabled } // FormatPreview formats the result for display preview diff --git a/internal/agent/tools/get_focused_app.go b/internal/agent/tools/get_focused_app.go index c38493a9..61976f91 100644 --- a/internal/agent/tools/get_focused_app.go +++ b/internal/agent/tools/get_focused_app.go @@ -4,6 +4,7 @@ import ( "context" "fmt" + config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -12,13 +13,13 @@ import ( // GetFocusedAppTool gets the currently focused application type GetFocusedAppTool struct { - config domain.ConfigService + config *config.Config } // NewGetFocusedAppTool creates a new GetFocusedApp tool -func NewGetFocusedAppTool(config domain.ConfigService) *GetFocusedAppTool { +func NewGetFocusedAppTool(cfg *config.Config) *GetFocusedAppTool { return &GetFocusedAppTool{ - config: config, + config: cfg, } } @@ -91,7 +92,7 @@ func (t *GetFocusedAppTool) Execute(ctx context.Context, args map[string]any) (* // IsEnabled returns whether the tool is enabled func (t *GetFocusedAppTool) IsEnabled() bool { - return t.config.GetConfig().ComputerUse.Enabled + return t.config.ComputerUse.Enabled } // FormatPreview formats the result for display preview diff --git a/internal/agent/tools/registry.go b/internal/agent/tools/registry.go index fb36e16b..8cf8910e 100644 --- a/internal/agent/tools/registry.go +++ b/internal/agent/tools/registry.go @@ -6,6 +6,7 @@ import ( "strings" "time" + config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -19,7 +20,7 @@ import ( // Registry manages all available tools type Registry struct { - config domain.ConfigService + config *config.Config tools map[string]domain.Tool readToolUsed bool taskTracker domain.A2ATaskTracker @@ -34,7 +35,7 @@ type Registry struct { // taskTracker must be provided by the caller (typically the container, which // constructs the unified BackgroundTaskRegistry and passes its A2A view in // here so all tools observe the same tracker the agent's wait loop does). -func NewRegistry(cfg domain.ConfigService, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService, stateManager domain.StateManager, screenshotProvider domain.ScreenshotProvider, taskTracker domain.A2ATaskTracker) *Registry { +func NewRegistry(cfg *config.Config, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService, stateManager domain.StateManager, screenshotProvider domain.ScreenshotProvider, taskTracker domain.A2ATaskTracker) *Registry { if taskTracker == nil { taskTracker = utils.NewA2ATaskTracker() } @@ -58,7 +59,7 @@ func NewRegistry(cfg domain.ConfigService, imageService domain.ImageService, mcp func (r *Registry) SetScreenshotProvider(provider domain.ScreenshotProvider) { r.screenshotProvider = provider - cfg := r.config.GetConfig() + cfg := r.config if cfg.ComputerUse.Enabled { displayProvider, err := display.DetectDisplay() if err == nil { @@ -70,7 +71,7 @@ func (r *Registry) SetScreenshotProvider(provider domain.ScreenshotProvider) { // registerTools initializes and registers all available tools func (r *Registry) registerTools() { - cfg := r.config.GetConfig() + cfg := r.config r.tools["Bash"] = NewBashTool(cfg, r.shellService) @@ -134,7 +135,7 @@ func (r *Registry) registerTools() { // registerMCPTools discovers and registers tools from enabled MCP servers func (r *Registry) registerMCPTools() { - cfg := r.config.GetConfig() + cfg := r.config ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.MCP.DiscoveryTimeout)*time.Second) defer cancel() @@ -230,7 +231,7 @@ func (r *Registry) RegisterMCPServerTools(serverName string, tools []domain.MCPD } toolCount := 0 - cfg := r.config.GetConfig() + cfg := r.config for _, tool := range tools { fullToolName := fmt.Sprintf("MCP_%s_%s", serverName, tool.Name) @@ -281,7 +282,7 @@ func (r *Registry) UnregisterMCPServerTools(serverName string) int { // SetScreenshotServer dynamically registers the GetLatestScreenshot tool // This should be called after the screenshot server is started func (r *Registry) SetScreenshotServer(provider domain.ScreenshotProvider) { - cfg := r.config.GetConfig() + cfg := r.config if !cfg.ComputerUse.Enabled || !cfg.ComputerUse.Screenshot.StreamingEnabled { logger.Debug("Screenshot streaming not enabled, skipping GetLatestScreenshot tool registration") return diff --git a/internal/agent/tools/registry_test.go b/internal/agent/tools/registry_test.go index 609af3e0..ebad8c94 100644 --- a/internal/agent/tools/registry_test.go +++ b/internal/agent/tools/registry_test.go @@ -10,63 +10,6 @@ import ( sdk "github.com/inference-gateway/sdk" ) -// testConfigService is a minimal mock implementation of domain.ConfigService for testing -type testConfigService struct { - config *config.Config -} - -func newTestConfigService(cfg *config.Config) domain.ConfigService { - return &testConfigService{config: cfg} -} - -func (t *testConfigService) GetConfig() *config.Config { - return t.config -} - -func (t *testConfigService) Reload() (*config.Config, error) { - return t.config, nil -} - -func (t *testConfigService) SetValue(key, value string) error { - return nil -} - -func (t *testConfigService) IsApprovalRequired(toolName string) bool { - return t.config.IsApprovalRequired(toolName) -} - -func (t *testConfigService) IsBashCommandWhitelisted(command string) bool { - return t.config.IsBashCommandWhitelisted(command) -} - -func (t *testConfigService) GetOutputDirectory() string { - return t.config.GetOutputDirectory() -} - -func (t *testConfigService) GetGatewayURL() string { - return t.config.Gateway.URL -} - -func (t *testConfigService) GetAPIKey() string { - return t.config.Gateway.APIKey -} - -func (t *testConfigService) GetTimeout() int { - return t.config.Gateway.Timeout -} - -func (t *testConfigService) GetAgentConfig() *config.AgentConfig { - return t.config.GetAgentConfig() -} - -func (t *testConfigService) GetSandboxDirectories() []string { - return t.config.GetSandboxDirectories() -} - -func (t *testConfigService) GetProtectedPaths() []string { - return t.config.GetProtectedPaths() -} - func createTestRegistry() *Registry { cfg := &config.Config{ Tools: config.ToolsConfig{ @@ -90,7 +33,7 @@ func createTestRegistry() *Registry { }, } - return NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + return NewRegistry(cfg, nil, nil, nil, nil, nil, nil) } func TestRegistry_GetTool_Unknown(t *testing.T) { @@ -122,7 +65,7 @@ func TestRegistry_DisabledTools(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) tools := registry.ListAvailableTools() @@ -174,14 +117,13 @@ func TestRegistry_NewRegistry(t *testing.T) { }, } - configService := newTestConfigService(cfg) - registry := NewRegistry(configService, nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) if registry == nil { t.Fatal("Expected non-nil registry") } - if registry.config.GetConfig() != cfg { + if registry.config != cfg { t.Error("Expected config to be set correctly") } @@ -210,7 +152,7 @@ func TestRegistry_GetTool(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) tests := []struct { name string @@ -361,7 +303,7 @@ func TestRegistry_ListAvailableTools(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - registry := NewRegistry(newTestConfigService(tt.config), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(tt.config, nil, nil, nil, nil, nil, nil) tools := registry.ListAvailableTools() if len(tools) < tt.expectedMin || len(tools) > tt.expectedMax { @@ -414,7 +356,7 @@ func TestRegistry_GetToolDefinitions(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) definitions := registry.GetToolDefinitions() if len(definitions) < 5 || len(definitions) > 15 { @@ -464,7 +406,7 @@ func TestRegistry_IsToolEnabled(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) tests := []struct { name string @@ -517,7 +459,7 @@ func TestRegistry_WithMockedTool(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil, nil) + registry := NewRegistry(cfg, nil, nil, nil, nil, nil, nil) fakeTool := &mocks.FakeTool{} fakeTool.IsEnabledReturns(true) diff --git a/internal/app/chat.go b/internal/app/chat.go index b4c5d90c..0fc81c82 100644 --- a/internal/app/chat.go +++ b/internal/app/chat.go @@ -31,7 +31,7 @@ import ( // ChatApplication represents the main application model using state management type ChatApplication struct { // Dependencies - configService domain.ConfigService + config *config.Config agentService domain.AgentService conversationRepo domain.ConversationRepository conversationOptimizer domain.ConversationOptimizer @@ -98,29 +98,28 @@ type ChatApplication struct { // nolint: funlen // NewChatApplication creates a new chat application func NewChatApplication( + cfg *config.Config, models []string, defaultModel string, + versionInfo domain.VersionInfo, + agentManager domain.AgentManager, agentService domain.AgentService, - conversationRepo domain.ConversationRepository, + backgroundTaskService domain.BackgroundTaskService, conversationOptimizer domain.ConversationOptimizer, - sessionRolloverManager *services.SessionRolloverManager, - modelService domain.ModelService, - configService domain.ConfigService, - toolService domain.ToolService, + conversationRepo domain.ConversationRepository, fileService domain.FileService, imageService domain.ImageService, + mcpManager domain.MCPManager, + messageQueue domain.MessageQueue, + modelService domain.ModelService, pricingService domain.PricingService, - shortcutRegistry *shortcuts.Registry, + sessionRolloverManager *services.SessionRolloverManager, stateManager domain.StateManager, - messageQueue domain.MessageQueue, + taskRetentionService domain.TaskRetentionService, themeService domain.ThemeService, + toolService domain.ToolService, + shortcutRegistry *shortcuts.Registry, toolRegistry *tools.Registry, - mcpManager domain.MCPManager, - taskRetentionService domain.TaskRetentionService, - backgroundTaskService domain.BackgroundTaskService, - agentManager domain.AgentManager, - configPath string, - versionInfo domain.VersionInfo, ) *ChatApplication { initialView := domain.ViewStateModelSelection if defaultModel != "" { @@ -133,7 +132,7 @@ func NewChatApplication( conversationOptimizer: conversationOptimizer, sessionRolloverManager: sessionRolloverManager, modelService: modelService, - configService: configService, + config: cfg, toolService: toolService, fileService: fileService, imageService: imageService, @@ -160,26 +159,23 @@ func NewChatApplication( app.conversationView = factory.CreateConversationView(app.themeService) toolFormatterService := services.NewToolFormatterService(app.toolRegistry, styleProvider) + configDir := cfg.GetConfigDir() + app.configDir = configDir + if cv, ok := app.conversationView.(*components.ConversationView); ok { cv.SetToolFormatter(toolFormatterService) - cv.SetConfigPath(configPath) + cv.SetConfigPath(filepath.Join(configDir, config.ConfigFileName)) cv.SetVersionInfo(versionInfo) cv.SetToolCallRenderer(app.toolCallRenderer) cv.SetStateManager(app.stateManager) } - configDir := ".infer" - if configPath != "" { - configDir = filepath.Dir(configPath) - } - app.configDir = configDir - app.inputView = factory.CreateInputViewWithConfigDir(app.modelService, configDir) if iv, ok := app.inputView.(*components.InputView); ok { iv.SetThemeService(app.themeService) iv.SetStateManager(app.stateManager) iv.SetImageService(app.imageService) - iv.SetConfigService(app.configService.GetConfig()) + iv.SetConfig(app.config) iv.SetConversationRepo(app.conversationRepo) } @@ -193,7 +189,7 @@ func NewChatApplication( isb.SetModelService(app.modelService) isb.SetThemeService(app.themeService) isb.SetStateManager(app.stateManager) - isb.SetConfigService(app.configService.GetConfig()) + isb.SetConfig(app.config) isb.SetConversationRepo(app.conversationRepo) isb.SetToolService(app.toolService) isb.SetTokenEstimator(services.NewTokenizerService(services.DefaultTokenizerConfig())) @@ -214,7 +210,7 @@ func NewChatApplication( app.applicationViewRenderer = components.NewApplicationViewRenderer(styleProvider) app.fileSelectionHandler = components.NewFileSelectionHandler(styleProvider) - app.keyBindingManager = keybinding.NewKeyBindingManager(app, app.configService.GetConfig()) + app.keyBindingManager = keybinding.NewKeyBindingManager(app, app.config) app.updateHelpBarShortcuts() keyHintFormatter := app.keyBindingManager.GetHintFormatter() @@ -226,7 +222,7 @@ func NewChatApplication( } app.toolCallRenderer.SetKeyHintFormatter(keyHintFormatter) - app.modelSelector = components.NewModelSelector(models, app.modelService, app.pricingService, app.configService, styleProvider) + app.modelSelector = components.NewModelSelector(models, app.modelService, app.pricingService, app.config, styleProvider) app.themeSelector = components.NewThemeSelector(app.themeService, styleProvider) app.initGithubActionView = components.NewInitGithubActionView(styleProvider) @@ -267,7 +263,6 @@ func NewChatApplication( app.conversationOptimizer, app.sessionRolloverManager, app.modelService, - app.configService, app.toolService, app.fileService, app.imageService, @@ -278,7 +273,7 @@ func NewChatApplication( app.backgroundTaskService, app.toolRegistry.GetBackgroundShellService(), agentManager, - app.configService.GetConfig(), + app.config, ) app.messageHistoryHandler = handlers.NewMessageHistoryHandler( @@ -1068,7 +1063,7 @@ func (app *ChatApplication) handleA2ATaskManagementView(msg tea.Msg) []tea.Cmd { var cmds []tea.Cmd if app.taskManager == nil { - if !app.configService.GetConfig().A2A.Enabled { + if !app.config.A2A.Enabled { cmds = append(cmds, func() tea.Msg { return domain.ShowErrorEvent{ Error: "Task management requires A2A to be enabled in configuration.", @@ -1197,7 +1192,7 @@ func (app *ChatApplication) updateAllComponentsWithNewTheme() { } styleProvider := styles.NewProvider(app.themeService) - app.modelSelector = components.NewModelSelector(app.availableModels, app.modelService, app.pricingService, app.configService, styleProvider) + app.modelSelector = components.NewModelSelector(app.availableModels, app.modelService, app.pricingService, app.config, styleProvider) } func (app *ChatApplication) renderThemeSelection() string { @@ -1618,7 +1613,7 @@ func (app *ChatApplication) GetImageService() domain.ImageService { // GetConfig returns the configuration for keybinding context func (app *ChatApplication) GetConfig() *config.Config { - return app.configService.GetConfig() + return app.config } // GetConfigDir returns the configuration directory path diff --git a/internal/container/container.go b/internal/container/container.go index f3f68bdd..a196e81d 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -7,7 +7,6 @@ import ( "strings" "time" - viper "github.com/spf13/viper" zap "go.uber.org/zap" sdk "github.com/inference-gateway/sdk" @@ -16,12 +15,10 @@ import ( agent "github.com/inference-gateway/cli/internal/agent" tools "github.com/inference-gateway/cli/internal/agent/tools" domain "github.com/inference-gateway/cli/internal/domain" - filewriterdomain "github.com/inference-gateway/cli/internal/domain/filewriter" adapters "github.com/inference-gateway/cli/internal/infra/adapters" storage "github.com/inference-gateway/cli/internal/infra/storage" logger "github.com/inference-gateway/cli/internal/logger" services "github.com/inference-gateway/cli/internal/services" - filewriterservice "github.com/inference-gateway/cli/internal/services/filewriter" shortcuts "github.com/inference-gateway/cli/internal/shortcuts" styles "github.com/inference-gateway/cli/internal/ui/styles" ) @@ -38,16 +35,13 @@ type ServiceContainer struct { log *zap.Logger // Configuration - viper *viper.Viper - config *config.Config - configService *services.ConfigService + config *config.Config // Domain services conversationRepo domain.ConversationRepository conversationOptimizer domain.ConversationOptimizer sessionRolloverManager *services.SessionRolloverManager modelService domain.ModelService - chatService domain.ChatService agent domain.AgentService toolService domain.ToolService fileService domain.FileService @@ -72,7 +66,6 @@ type ServiceContainer struct { backgroundJobManager *services.BackgroundJobManager backgroundShellService *services.BackgroundShellService storage storage.ConversationStorage - agentsConfigService *services.AgentsConfigService // UI components themeService domain.ThemeService @@ -83,17 +76,10 @@ type ServiceContainer struct { // Tool registry toolRegistry *tools.Registry mcpManager domain.MCPManager - - // File writing services - pathValidator filewriterdomain.PathValidator - backupManager filewriterdomain.BackupManager - fileWriter filewriterdomain.FileWriter - chunkManager filewriterdomain.ChunkManager - paramExtractor *tools.ParameterExtractor } // NewServiceContainer creates a new service container with all dependencies -func NewServiceContainer(cfg *config.Config, v ...*viper.Viper) *ServiceContainer { +func NewServiceContainer(cfg *config.Config) *ServiceContainer { sessionID := domain.GenerateSessionID() log := logger.GetGlobalLogger() @@ -113,15 +99,9 @@ func NewServiceContainer(cfg *config.Config, v ...*viper.Viper) *ServiceContaine log: log, } - if len(v) > 0 && v[0] != nil { - container.viper = v[0] - container.configService = services.NewConfigService(v[0], cfg) - } - - cfg.SetConfigDir(container.determineConfigDirectory()) + cfg.SetConfigDir(config.ResolveConfigDir()) container.initializeGatewayManager() - container.initializeFileWriterServices() container.initializeStateManager() container.initializeDomainServices() container.initializeAgentManager() @@ -140,14 +120,12 @@ func (c *ServiceContainer) initializeGatewayManager() { // initializeAgentManager creates and starts the agent manager if A2A is enabled func (c *ServiceContainer) initializeAgentManager() { - agentsPath := filepath.Join(config.ConfigDirName, config.AgentsFileName) - c.agentsConfigService = services.NewAgentsConfigService(agentsPath) - if !c.config.IsA2AToolsEnabled() { return } - agentsConfig, err := c.agentsConfigService.Load() + agentsPath := filepath.Join(config.ConfigDirName, config.AgentsFileName) + agentsConfig, err := config.LoadAgents(agentsPath) if err != nil { logger.Warn("Failed to load agents configuration", "error", err) return @@ -180,15 +158,6 @@ func (c *ServiceContainer) initializeAgentManager() { } } -// initializeFileWriterServices creates the new file writer architecture services -func (c *ServiceContainer) initializeFileWriterServices() { - c.pathValidator = filewriterservice.NewPathValidator(c.config) - c.backupManager = filewriterservice.NewBackupManager(".") - c.fileWriter = filewriterservice.NewSafeFileWriter(c.pathValidator, c.backupManager) - c.chunkManager = filewriterservice.NewStreamingChunkManager("./.infer/tmp", c.fileWriter) - c.paramExtractor = tools.NewParameterExtractor() -} - // initializeMCPManager creates and starts MCP manager if enabled func (c *ServiceContainer) initializeMCPManager() { if !c.config.MCP.Enabled { @@ -232,7 +201,7 @@ func (c *ServiceContainer) initializeDomainServices() { c.initializeMCPManager() c.ensureBackgroundTaskRegistry() - c.toolRegistry = tools.NewRegistry(c.configService, c.imageService, c.mcpManager, c.BackgroundShellService(), c.stateManager, nil, c.backgroundTaskRegistry) + c.toolRegistry = tools.NewRegistry(c.config, c.imageService, c.mcpManager, c.BackgroundShellService(), c.stateManager, nil, c.backgroundTaskRegistry) styleProvider := styles.NewProvider(c.themeService) toolFormatterService := services.NewToolFormatterService(c.toolRegistry, styleProvider) @@ -316,7 +285,7 @@ func (c *ServiceContainer) initializeDomainServices() { c.agent = agent.NewAgent( agentClient, c.toolService, - c.configService, + c.config, c.conversationRepo, c.a2aAgentService, c.messageQueue, @@ -325,8 +294,6 @@ func (c *ServiceContainer) initializeDomainServices() { c.conversationOptimizer, c.backgroundTaskRegistry, ) - - c.chatService = services.NewStreamingChatService(c.agent) } // initializeStateManager creates the state manager before domain services need it @@ -389,28 +356,13 @@ func (c *ServiceContainer) registerDefaultCommands() { c.shortcutRegistry.Register(shortcuts.NewA2ATaskManagementShortcut(c.config)) } - configDir := c.determineConfigDirectory() + configDir := c.config.GetConfigDir() customShortcutClient := c.createRawSDKClient() if err := c.shortcutRegistry.LoadCustomShortcuts(configDir, customShortcutClient, c.modelService, c.imageService, c.toolService); err != nil { logger.Error("Failed to load custom shortcuts", "error", err, "config_dir", configDir) } } -// determineConfigDirectory returns the directory where configuration and related files should be stored -func (c *ServiceContainer) determineConfigDirectory() string { - configDir := ".infer" - if c.viper != nil { - if configFile := c.viper.ConfigFileUsed(); configFile != "" { - configDir = filepath.Dir(configFile) - } - } - return configDir -} - -func (c *ServiceContainer) GetConfig() *config.Config { - return c.config -} - // Logger returns the logger instance for this container func (c *ServiceContainer) Logger() *zap.Logger { return c.log @@ -432,10 +384,6 @@ func (c *ServiceContainer) GetModelService() domain.ModelService { return c.modelService } -func (c *ServiceContainer) GetChatService() domain.ChatService { - return c.chatService -} - func (c *ServiceContainer) GetToolService() domain.ToolService { return c.toolService } @@ -463,10 +411,6 @@ func (c *ServiceContainer) PricingService() domain.PricingService { return c.pricingService } -func (c *ServiceContainer) GetTheme() domain.Theme { - return c.themeService.GetCurrentTheme() -} - func (c *ServiceContainer) GetThemeService() domain.ThemeService { return c.themeService } @@ -475,12 +419,6 @@ func (c *ServiceContainer) GetShortcutRegistry() *shortcuts.Registry { return c.shortcutRegistry } -// GetA2AAgentService returns the A2A agent service -func (c *ServiceContainer) GetA2AAgentService() domain.A2AAgentService { - return c.a2aAgentService -} - -// New service getters func (c *ServiceContainer) GetStateManager() domain.StateManager { return c.stateManager } @@ -603,42 +541,6 @@ func (c *ServiceContainer) RegisterShortcut(shortcut shortcuts.Shortcut) { c.shortcutRegistry.Register(shortcut) } -// File writer service getters -func (c *ServiceContainer) GetPathValidator() filewriterdomain.PathValidator { - return c.pathValidator -} - -func (c *ServiceContainer) GetBackupManager() filewriterdomain.BackupManager { - return c.backupManager -} - -func (c *ServiceContainer) GetFileWriter() filewriterdomain.FileWriter { - return c.fileWriter -} - -func (c *ServiceContainer) GetChunkManager() filewriterdomain.ChunkManager { - return c.chunkManager -} - -func (c *ServiceContainer) GetParameterExtractor() *tools.ParameterExtractor { - return c.paramExtractor -} - -// GetViper returns the Viper instance -func (c *ServiceContainer) GetViper() *viper.Viper { - return c.viper -} - -// GetConfigService returns the config service -func (c *ServiceContainer) GetConfigService() *services.ConfigService { - return c.configService -} - -// GetTitleGenerator returns the conversation title generator -func (c *ServiceContainer) GetTitleGenerator() *services.ConversationTitleGenerator { - return c.titleGenerator -} - // GetBackgroundJobManager returns the background job manager func (c *ServiceContainer) GetBackgroundJobManager() *services.BackgroundJobManager { return c.backgroundJobManager diff --git a/internal/domain/config_service.go b/internal/domain/config_service.go deleted file mode 100644 index 4904bb16..00000000 --- a/internal/domain/config_service.go +++ /dev/null @@ -1,28 +0,0 @@ -package domain - -import "github.com/inference-gateway/cli/config" - -// ConfigService provides configuration-related functionality -type ConfigService interface { - // Tool approval configuration - IsApprovalRequired(toolName string) bool - IsBashCommandWhitelisted(command string) bool - - // Debug and output configuration - GetOutputDirectory() string - - // Gateway configuration - GetGatewayURL() string - GetAPIKey() string - GetTimeout() int - - // Chat configuration - GetAgentConfig() *config.AgentConfig - - // Sandbox configuration - GetSandboxDirectories() []string - GetProtectedPaths() []string - - // Full configuration access - GetConfig() *config.Config -} diff --git a/internal/handlers/chat_handler.go b/internal/handlers/chat_handler.go index 27e69e82..a0779309 100644 --- a/internal/handlers/chat_handler.go +++ b/internal/handlers/chat_handler.go @@ -31,7 +31,6 @@ type ChatHandler struct { conversationOptimizer domain.ConversationOptimizer sessionRolloverManager *services.SessionRolloverManager modelService domain.ModelService - configService domain.ConfigService toolService domain.ToolService fileService domain.FileService imageService domain.ImageService @@ -63,7 +62,6 @@ func NewChatHandler( conversationOptimizer domain.ConversationOptimizer, sessionRolloverManager *services.SessionRolloverManager, modelService domain.ModelService, - configService domain.ConfigService, toolService domain.ToolService, fileService domain.FileService, imageService domain.ImageService, @@ -82,7 +80,6 @@ func NewChatHandler( conversationOptimizer: conversationOptimizer, sessionRolloverManager: sessionRolloverManager, modelService: modelService, - configService: configService, toolService: toolService, fileService: fileService, imageService: imageService, diff --git a/internal/handlers/chat_handler_test.go b/internal/handlers/chat_handler_test.go index baeaaa08..ed1f6853 100644 --- a/internal/handlers/chat_handler_test.go +++ b/internal/handlers/chat_handler_test.go @@ -547,7 +547,6 @@ func TestFormatMetricsWithoutSessionTokens(t *testing.T) { nil, // conversationOptimizer nil, // sessionRolloverManager nil, // modelService - nil, // configService nil, // toolService nil, // fileService nil, // imageService @@ -624,7 +623,7 @@ func TestChatHandler_Handle(t *testing.T) { type chatHandlerTestCase struct { name string msg tea.Msg - setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *mocks.FakeToolService, *mocks.FakeFileService, *mocks.FakeConfigService) + setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *mocks.FakeToolService, *mocks.FakeFileService, *config.Config) expectedCmd bool validateResult func(*testing.T, tea.Cmd) } @@ -651,7 +650,7 @@ func getUserInputTestCases() []chatHandlerTestCase { msg: domain.UserInputEvent{ Content: "Hello, how are you?", }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { model.GetCurrentModelReturns("test-model") eventCh := make(chan domain.ChatEvent, 1) close(eventCh) @@ -664,7 +663,7 @@ func getUserInputTestCases() []chatHandlerTestCase { msg: domain.UserInputEvent{ Content: "/help", }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -673,7 +672,7 @@ func getUserInputTestCases() []chatHandlerTestCase { msg: domain.UserInputEvent{ Content: "!ls -la", }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { tool.IsToolEnabledReturns(true) }, expectedCmd: true, @@ -683,7 +682,7 @@ func getUserInputTestCases() []chatHandlerTestCase { msg: domain.UserInputEvent{ Content: "!!Read(file_path=\"test.txt\")", }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { tool.IsToolEnabledReturns(true) }, expectedCmd: true, @@ -696,7 +695,7 @@ func getFileSelectionTestCases() []chatHandlerTestCase { { name: "FileSelectionRequestEvent - with files", msg: domain.FileSelectionRequestEvent{}, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { file.ListProjectFilesReturns([]string{"file1.go", "file2.go"}, nil) }, expectedCmd: true, @@ -704,7 +703,7 @@ func getFileSelectionTestCases() []chatHandlerTestCase { { name: "FileSelectionRequestEvent - no files", msg: domain.FileSelectionRequestEvent{}, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { file.ListProjectFilesReturns([]string{}, nil) }, expectedCmd: true, @@ -720,7 +719,7 @@ func getChatEventTestCases() []chatHandlerTestCase { RequestID: "test-123", Timestamp: time.Now(), }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -731,7 +730,7 @@ func getChatEventTestCases() []chatHandlerTestCase { Content: "Response chunk", Timestamp: time.Now(), }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { model.GetCurrentModelReturns("test-model") }, expectedCmd: false, @@ -743,7 +742,7 @@ func getChatEventTestCases() []chatHandlerTestCase { ReasoningContent: "Thinking...", Timestamp: time.Now(), }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -761,7 +760,7 @@ func getChatEventTestCases() []chatHandlerTestCase { }, }, }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -772,7 +771,7 @@ func getChatEventTestCases() []chatHandlerTestCase { Error: assert.AnError, Timestamp: time.Now(), }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -787,7 +786,7 @@ func getToolExecutionTestCases() []chatHandlerTestCase { SessionID: "test-123", TotalTools: 2, }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: true, }, @@ -801,7 +800,7 @@ func getToolExecutionTestCases() []chatHandlerTestCase { Status: "executing", Message: "Read tool executing", }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { }, expectedCmd: false, }, @@ -812,7 +811,7 @@ func getToolExecutionTestCases() []chatHandlerTestCase { TotalExecuted: 2, SuccessCount: 2, }, - setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, config *mocks.FakeConfigService) { + setupMocks: func(agent *mocks.FakeAgentService, model *mocks.FakeModelService, tool *mocks.FakeToolService, file *mocks.FakeFileService, cfg *config.Config) { model.GetCurrentModelReturns("test-model") eventCh := make(chan domain.ChatEvent, 1) close(eventCh) @@ -823,18 +822,15 @@ func getToolExecutionTestCases() []chatHandlerTestCase { } } -func setupTestChatHandler(_ *testing.T, setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *mocks.FakeToolService, *mocks.FakeFileService, *mocks.FakeConfigService), stateManager domain.StateManager) *ChatHandler { +func setupTestChatHandler(_ *testing.T, setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *mocks.FakeToolService, *mocks.FakeFileService, *config.Config), stateManager domain.StateManager) *ChatHandler { mockAgent := &mocks.FakeAgentService{} mockModel := &mocks.FakeModelService{} mockTool := &mocks.FakeToolService{} mockFile := &mocks.FakeFileService{} - mockConfig := &mocks.FakeConfigService{} - - mockConfig.IsApprovalRequiredReturns(false) - mockConfig.GetOutputDirectoryReturns("/tmp") + cfg := config.DefaultConfig() if setupMocks != nil { - setupMocks(mockAgent, mockModel, mockTool, mockFile, mockConfig) + setupMocks(mockAgent, mockModel, mockTool, mockFile, cfg) } conversationRepo := services.NewInMemoryConversationRepository(nil, nil) @@ -847,7 +843,6 @@ func setupTestChatHandler(_ *testing.T, setupMocks func(*mocks.FakeAgentService, nil, // conversationOptimizer nil, // sessionRolloverManager mockModel, - mockConfig, mockTool, mockFile, nil, @@ -858,7 +853,7 @@ func setupTestChatHandler(_ *testing.T, setupMocks func(*mocks.FakeAgentService, nil, nil, nil, - config.DefaultConfig(), + cfg, ) } @@ -928,7 +923,7 @@ func TestChatEventHandler_handleChatComplete(t *testing.T) { withToolCalls bool metricsProvided bool shouldInjectReminder bool - setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *mocks.FakeConfigService) + setupMocks func(*mocks.FakeAgentService, *mocks.FakeModelService, *config.Config) }{ { name: "Complete without tools or metrics", @@ -981,7 +976,6 @@ func TestChatEventHandler_handleChatComplete(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mockAgent := &mocks.FakeAgentService{} mockModel := &mocks.FakeModelService{} - mockConfig := &mocks.FakeConfigService{} mockTool := &mocks.FakeToolService{} mockFile := &mocks.FakeFileService{} @@ -996,7 +990,6 @@ func TestChatEventHandler_handleChatComplete(t *testing.T) { nil, // conversationOptimizer nil, // sessionRolloverManager mockModel, - mockConfig, mockTool, mockFile, nil, diff --git a/internal/handlers/chat_message_processor_test.go b/internal/handlers/chat_message_processor_test.go index 0a886c51..0ba140fa 100644 --- a/internal/handlers/chat_message_processor_test.go +++ b/internal/handlers/chat_message_processor_test.go @@ -84,7 +84,6 @@ func TestChatMessageProcessor_handleUserInput(t *testing.T) { mockFile := &mocks.FakeFileService{} mockAgent := &mocks.FakeAgentService{} mockModel := &mocks.FakeModelService{} - mockConfig := &mocks.FakeConfigService{} mockTool := &mocks.FakeToolService{} if tt.setupMocks != nil { @@ -102,7 +101,6 @@ func TestChatMessageProcessor_handleUserInput(t *testing.T) { nil, // conversationOptimizer nil, // sessionRolloverManager mockModel, - mockConfig, mockTool, mockFile, nil, diff --git a/internal/services/agents.go b/internal/services/agents.go index 763d909e..3a34182f 100644 --- a/internal/services/agents.go +++ b/internal/services/agents.go @@ -15,10 +15,10 @@ import ( ) type A2AAgentService struct { - config *config.Config - agentsConfigSvc *AgentsConfigService - cache map[string]*domain.CachedAgentCard - cacheMutex sync.RWMutex + config *config.Config + agentsPath string + cache map[string]*domain.CachedAgentCard + cacheMutex sync.RWMutex } func NewA2AAgentService(cfg *config.Config) *A2AAgentService { @@ -35,12 +35,10 @@ func NewA2AAgentService(cfg *config.Config) *A2AAgentService { agentsPath = config.DefaultAgentsPath } - agentsConfigSvc := NewAgentsConfigService(agentsPath) - return &A2AAgentService{ - config: cfg, - agentsConfigSvc: agentsConfigSvc, - cache: make(map[string]*domain.CachedAgentCard), + config: cfg, + agentsPath: agentsPath, + cache: make(map[string]*domain.CachedAgentCard), } } @@ -99,7 +97,7 @@ func (s *A2AAgentService) GetConfiguredAgents() []string { return s.config.A2A.Agents } - urls, err := s.agentsConfigSvc.GetAgentURLs() + urls, err := config.GetAgentURLs(s.agentsPath) if err != nil { logger.Error("Failed to load agents from agents.yaml", "error", err) return []string{} diff --git a/internal/services/agents_config.go b/internal/services/agents_config.go deleted file mode 100644 index 02d5f5a2..00000000 --- a/internal/services/agents_config.go +++ /dev/null @@ -1,179 +0,0 @@ -package services - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - - config "github.com/inference-gateway/cli/config" - logger "github.com/inference-gateway/cli/internal/logger" - yaml "gopkg.in/yaml.v3" -) - -// AgentsConfigService manages the agents.yaml configuration -type AgentsConfigService struct { - configPath string -} - -// NewAgentsConfigService creates a new agents config service -func NewAgentsConfigService(configPath string) *AgentsConfigService { - return &AgentsConfigService{ - configPath: configPath, - } -} - -// Load loads the agents configuration from the file -func (s *AgentsConfigService) Load() (*config.AgentsConfig, error) { - if _, err := os.Stat(s.configPath); os.IsNotExist(err) { - logger.Info("Agents config file does not exist, returning default", "path", s.configPath) - return config.DefaultAgentsConfig(), nil - } - - data, err := os.ReadFile(s.configPath) - if err != nil { - return nil, fmt.Errorf("failed to read agents config: %w", err) - } - - expandedData := os.ExpandEnv(string(data)) - - var agentsConfig config.AgentsConfig - if err := yaml.Unmarshal([]byte(expandedData), &agentsConfig); err != nil { - return nil, fmt.Errorf("failed to parse agents config: %w", err) - } - - return &agentsConfig, nil -} - -// Save saves the agents configuration to the file -func (s *AgentsConfigService) Save(cfg *config.AgentsConfig) error { - var buf bytes.Buffer - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal agents config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - data := buf.Bytes() - - configDir := filepath.Dir(s.configPath) - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(s.configPath, data, 0644); err != nil { - return fmt.Errorf("failed to write agents config: %w", err) - } - - logger.Info("Agents config saved", "path", s.configPath) - return nil -} - -// AddAgent adds a new agent to the configuration -func (s *AgentsConfigService) AddAgent(agent config.AgentEntry) error { - cfg, err := s.Load() - if err != nil { - return err - } - - for _, existing := range cfg.Agents { - if existing.Name == agent.Name { - return fmt.Errorf("agent with name '%s' already exists", agent.Name) - } - } - - cfg.Agents = append(cfg.Agents, agent) - return s.Save(cfg) -} - -// UpdateAgent updates an existing agent in the configuration -func (s *AgentsConfigService) UpdateAgent(agent config.AgentEntry) error { - cfg, err := s.Load() - if err != nil { - return err - } - - found := false - for i, existing := range cfg.Agents { - if existing.Name == agent.Name { - cfg.Agents[i] = agent - found = true - break - } - } - - if !found { - return fmt.Errorf("agent with name '%s' not found", agent.Name) - } - - return s.Save(cfg) -} - -// RemoveAgent removes an agent by name -func (s *AgentsConfigService) RemoveAgent(name string) error { - cfg, err := s.Load() - if err != nil { - return err - } - - found := false - newAgents := make([]config.AgentEntry, 0, len(cfg.Agents)) - for _, agent := range cfg.Agents { - if agent.Name != name { - newAgents = append(newAgents, agent) - } else { - found = true - } - } - - if !found { - return fmt.Errorf("agent with name '%s' not found", name) - } - - cfg.Agents = newAgents - return s.Save(cfg) -} - -// ListAgents returns all configured agents -func (s *AgentsConfigService) ListAgents() ([]config.AgentEntry, error) { - cfg, err := s.Load() - if err != nil { - return nil, err - } - return cfg.Agents, nil -} - -// GetAgent returns a specific agent by name -func (s *AgentsConfigService) GetAgent(name string) (*config.AgentEntry, error) { - cfg, err := s.Load() - if err != nil { - return nil, err - } - - for _, agent := range cfg.Agents { - if agent.Name == name { - return &agent, nil - } - } - - return nil, fmt.Errorf("agent with name '%s' not found", name) -} - -// GetAgentURLs returns URLs of all configured agents -func (s *AgentsConfigService) GetAgentURLs() ([]string, error) { - agents, err := s.ListAgents() - if err != nil { - return nil, err - } - - urls := make([]string, 0, len(agents)) - for _, agent := range agents { - urls = append(urls, agent.URL) - } - return urls, nil -} diff --git a/internal/services/agents_test.go b/internal/services/agents_test.go index 1bd1a47e..b776f34e 100644 --- a/internal/services/agents_test.go +++ b/internal/services/agents_test.go @@ -13,13 +13,11 @@ func TestA2AAgentService_GetConfiguredAgents_EnvVarPrecedence(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - agentsConfigSvc := NewAgentsConfigService(agentsPath) - err := agentsConfigSvc.AddAgent(config.AgentEntry{ + require.NoError(t, config.AddAgent(agentsPath, config.AgentEntry{ Name: "yaml-agent", URL: "http://yaml-agent:8080", Run: false, - }) - require.NoError(t, err) + })) t.Run("environment variable takes precedence", func(t *testing.T) { cfg := config.DefaultConfig() @@ -29,9 +27,9 @@ func TestA2AAgentService_GetConfiguredAgents_EnvVarPrecedence(t *testing.T) { } svc := &A2AAgentService{ - config: cfg, - agentsConfigSvc: agentsConfigSvc, - cache: make(map[string]*domain.CachedAgentCard), + config: cfg, + agentsPath: agentsPath, + cache: make(map[string]*domain.CachedAgentCard), } agents := svc.GetConfiguredAgents() @@ -46,9 +44,9 @@ func TestA2AAgentService_GetConfiguredAgents_EnvVarPrecedence(t *testing.T) { cfg.A2A.Agents = []string{} svc := &A2AAgentService{ - config: cfg, - agentsConfigSvc: agentsConfigSvc, - cache: make(map[string]*domain.CachedAgentCard), + config: cfg, + agentsPath: agentsPath, + cache: make(map[string]*domain.CachedAgentCard), } agents := svc.GetConfiguredAgents() @@ -62,9 +60,9 @@ func TestA2AAgentService_GetConfiguredAgents_EnvVarPrecedence(t *testing.T) { cfg.A2A.Agents = nil svc := &A2AAgentService{ - config: cfg, - agentsConfigSvc: agentsConfigSvc, - cache: make(map[string]*domain.CachedAgentCard), + config: cfg, + agentsPath: agentsPath, + cache: make(map[string]*domain.CachedAgentCard), } agents := svc.GetConfiguredAgents() @@ -78,15 +76,13 @@ func TestA2AAgentService_GetConfiguredAgents_NoAgentsConfigured(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - agentsConfigSvc := NewAgentsConfigService(agentsPath) - cfg := config.DefaultConfig() cfg.A2A.Agents = nil svc := &A2AAgentService{ - config: cfg, - agentsConfigSvc: agentsConfigSvc, - cache: make(map[string]*domain.CachedAgentCard), + config: cfg, + agentsPath: agentsPath, + cache: make(map[string]*domain.CachedAgentCard), } agents := svc.GetConfiguredAgents() diff --git a/internal/services/config_service.go b/internal/services/config_service.go deleted file mode 100644 index 307b9e37..00000000 --- a/internal/services/config_service.go +++ /dev/null @@ -1,101 +0,0 @@ -package services - -import ( - "fmt" - - viper "github.com/spf13/viper" - - config "github.com/inference-gateway/cli/config" - utils "github.com/inference-gateway/cli/internal/utils" -) - -// ConfigService handles configuration management and reloading -type ConfigService struct { - viper *viper.Viper - config *config.Config -} - -// NewConfigService creates a new config service -func NewConfigService(v *viper.Viper, cfg *config.Config) *ConfigService { - return &ConfigService{ - viper: v, - config: cfg, - } -} - -// Reload reloads configuration from disk -func (cs *ConfigService) Reload() (*config.Config, error) { - if err := cs.viper.ReadInConfig(); err != nil { - return nil, fmt.Errorf("failed to re-read config file: %w", err) - } - - newConfig := &config.Config{} - if err := cs.viper.Unmarshal(newConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal reloaded config: %w", err) - } - - cs.config = newConfig - - return newConfig, nil -} - -// GetConfig returns the current config -func (cs *ConfigService) GetConfig() *config.Config { - return cs.config -} - -// SetValue sets a configuration value using dot notation and saves it to disk -func (cs *ConfigService) SetValue(key, value string) error { - cs.viper.Set(key, value) - - if err := utils.WriteViperConfigWithIndent(cs.viper, 2); err != nil { - return fmt.Errorf("failed to save config: %w", err) - } - - newConfig, err := cs.Reload() - if err != nil { - return fmt.Errorf("failed to reload config after setting: %w", err) - } - - cs.config = newConfig - - return nil -} - -// Domain ConfigService interface implementation (delegates to underlying config) - -func (cs *ConfigService) IsApprovalRequired(toolName string) bool { - return cs.config.IsApprovalRequired(toolName) -} - -func (cs *ConfigService) IsBashCommandWhitelisted(command string) bool { - return cs.config.IsBashCommandWhitelisted(command) -} - -func (cs *ConfigService) GetOutputDirectory() string { - return cs.config.GetOutputDirectory() -} - -func (cs *ConfigService) GetGatewayURL() string { - return cs.config.Gateway.URL -} - -func (cs *ConfigService) GetAPIKey() string { - return cs.config.Gateway.APIKey -} - -func (cs *ConfigService) GetTimeout() int { - return cs.config.Gateway.Timeout -} - -func (cs *ConfigService) GetAgentConfig() *config.AgentConfig { - return cs.config.GetAgentConfig() -} - -func (cs *ConfigService) GetSandboxDirectories() []string { - return cs.config.GetSandboxDirectories() -} - -func (cs *ConfigService) GetProtectedPaths() []string { - return cs.config.GetProtectedPaths() -} diff --git a/internal/services/keybindings_config.go b/internal/services/keybindings_config.go deleted file mode 100644 index 8be31a0d..00000000 --- a/internal/services/keybindings_config.go +++ /dev/null @@ -1,90 +0,0 @@ -package services - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - - config "github.com/inference-gateway/cli/config" - logger "github.com/inference-gateway/cli/internal/logger" - yaml "gopkg.in/yaml.v3" -) - -// KeybindingsConfigService manages the keybindings.yaml configuration -type KeybindingsConfigService struct { - configPath string -} - -// NewKeybindingsConfigService creates a new keybindings config service -func NewKeybindingsConfigService(configPath string) *KeybindingsConfigService { - return &KeybindingsConfigService{ - configPath: configPath, - } -} - -// Load reads the keybindings configuration from disk. When the file is -// missing it returns the in-code defaults so callers can treat absence as -// "use defaults" without special-casing. -func (s *KeybindingsConfigService) Load() (*config.KeybindingsConfig, error) { - if _, err := os.Stat(s.configPath); os.IsNotExist(err) { - logger.Info("Keybindings config file does not exist, returning defaults", "path", s.configPath) - return defaultKeybindingsConfig(), nil - } - - data, err := os.ReadFile(s.configPath) - if err != nil { - return nil, fmt.Errorf("failed to read keybindings config: %w", err) - } - - var cfg config.KeybindingsConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse keybindings config: %w", err) - } - - return &cfg, nil -} - -// Save writes the keybindings configuration to disk, creating any missing -// parent directories. -func (s *KeybindingsConfigService) Save(cfg *config.KeybindingsConfig) error { - var buf bytes.Buffer - - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal keybindings config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - configDir := filepath.Dir(s.configPath) - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(s.configPath, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write keybindings config: %w", err) - } - - logger.Info("Keybindings config saved", "path", s.configPath) - return nil -} - -func defaultKeybindingsConfig() *config.KeybindingsConfig { - return &config.KeybindingsConfig{ - Enabled: true, - Bindings: config.GetDefaultKeybindings(), - } -} - -// DefaultKeybindingsConfig exposes the default keybindings config used when -// no file exists. Callers (init, reset) use it to seed a fresh file. -func DefaultKeybindingsConfig() *config.KeybindingsConfig { - return defaultKeybindingsConfig() -} diff --git a/internal/services/mcp_config.go b/internal/services/mcp_config.go deleted file mode 100644 index 5924eb56..00000000 --- a/internal/services/mcp_config.go +++ /dev/null @@ -1,218 +0,0 @@ -package services - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - - config "github.com/inference-gateway/cli/config" - logger "github.com/inference-gateway/cli/internal/logger" - yaml "gopkg.in/yaml.v3" -) - -// MCPConfigService manages the mcp.yaml configuration -type MCPConfigService struct { - configPath string -} - -// NewMCPConfigService creates a new MCP config service -func NewMCPConfigService(configPath string) *MCPConfigService { - return &MCPConfigService{ - configPath: configPath, - } -} - -// Load loads the MCP configuration from the file -func (s *MCPConfigService) Load() (*config.MCPConfig, error) { - if _, err := os.Stat(s.configPath); os.IsNotExist(err) { - logger.Info("MCP config file does not exist, returning default", "path", s.configPath) - return config.DefaultMCPConfig(), nil - } - - data, err := os.ReadFile(s.configPath) - if err != nil { - return nil, fmt.Errorf("failed to read MCP config: %w", err) - } - - expandedData := os.ExpandEnv(string(data)) - - var mcpConfig config.MCPConfig - if err := yaml.Unmarshal([]byte(expandedData), &mcpConfig); err != nil { - return nil, fmt.Errorf("failed to parse MCP config: %w", err) - } - - return &mcpConfig, nil -} - -// Save saves the MCP configuration to the file -func (s *MCPConfigService) Save(cfg *config.MCPConfig) error { - var buf bytes.Buffer - - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal MCP config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - data := buf.Bytes() - - configDir := filepath.Dir(s.configPath) - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(s.configPath, data, 0644); err != nil { - return fmt.Errorf("failed to write MCP config: %w", err) - } - - logger.Info("MCP config saved", "path", s.configPath) - return nil -} - -// AddServer adds a new MCP server to the configuration -func (s *MCPConfigService) AddServer(server config.MCPServerEntry) error { - cfg, err := s.Load() - if err != nil { - return err - } - - for _, existing := range cfg.Servers { - if existing.Name == server.Name { - return fmt.Errorf("MCP server with name '%s' already exists", server.Name) - } - } - - cfg.Servers = append(cfg.Servers, server) - return s.Save(cfg) -} - -// UpdateServer updates an existing MCP server in the configuration -func (s *MCPConfigService) UpdateServer(server config.MCPServerEntry) error { - cfg, err := s.Load() - if err != nil { - return err - } - - found := false - for i, existing := range cfg.Servers { - if existing.Name == server.Name { - cfg.Servers[i] = server - found = true - break - } - } - - if !found { - return fmt.Errorf("MCP server with name '%s' not found", server.Name) - } - - return s.Save(cfg) -} - -// RemoveServer removes an MCP server by name -func (s *MCPConfigService) RemoveServer(name string) error { - cfg, err := s.Load() - if err != nil { - return err - } - - found := false - newServers := make([]config.MCPServerEntry, 0, len(cfg.Servers)) - for _, server := range cfg.Servers { - if server.Name != name { - newServers = append(newServers, server) - } else { - found = true - } - } - - if !found { - return fmt.Errorf("MCP server with name '%s' not found", name) - } - - cfg.Servers = newServers - return s.Save(cfg) -} - -// ListServers returns all configured MCP servers -func (s *MCPConfigService) ListServers() ([]config.MCPServerEntry, error) { - cfg, err := s.Load() - if err != nil { - return nil, err - } - return cfg.Servers, nil -} - -// GetServer returns a specific MCP server by name -func (s *MCPConfigService) GetServer(name string) (*config.MCPServerEntry, error) { - cfg, err := s.Load() - if err != nil { - return nil, err - } - - for _, server := range cfg.Servers { - if server.Name == name { - return &server, nil - } - } - - return nil, fmt.Errorf("MCP server with name '%s' not found", name) -} - -// Merge merges the optional mcp.yaml config with the base config -// The optional config takes precedence for global settings -// Servers from both configs are combined (optional servers can override base servers by name) -func Merge(base *config.MCPConfig, optional *config.MCPConfig) *config.MCPConfig { - if optional == nil { - return base - } - - merged := &config.MCPConfig{ - Enabled: optional.Enabled || base.Enabled, - ConnectionTimeout: base.ConnectionTimeout, - DiscoveryTimeout: base.DiscoveryTimeout, - LivenessProbeEnabled: base.LivenessProbeEnabled, - LivenessProbeInterval: base.LivenessProbeInterval, - MaxRetries: base.MaxRetries, - Servers: make([]config.MCPServerEntry, 0), - } - - if optional.ConnectionTimeout > 0 { - merged.ConnectionTimeout = optional.ConnectionTimeout - } - if optional.DiscoveryTimeout > 0 { - merged.DiscoveryTimeout = optional.DiscoveryTimeout - } - if optional.LivenessProbeInterval > 0 { - merged.LivenessProbeInterval = optional.LivenessProbeInterval - } - if optional.MaxRetries > 0 { - merged.MaxRetries = optional.MaxRetries - } - if optional.LivenessProbeEnabled { - merged.LivenessProbeEnabled = true - } - - serverMap := make(map[string]config.MCPServerEntry) - for _, server := range base.Servers { - serverMap[server.Name] = server - } - - for _, server := range optional.Servers { - serverMap[server.Name] = server - } - - for _, server := range serverMap { - merged.Servers = append(merged.Servers, server) - } - - return merged -} diff --git a/internal/services/prompts_config.go b/internal/services/prompts_config.go deleted file mode 100644 index 07cfb594..00000000 --- a/internal/services/prompts_config.go +++ /dev/null @@ -1,83 +0,0 @@ -package services - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - - config "github.com/inference-gateway/cli/config" - logger "github.com/inference-gateway/cli/internal/logger" - yaml "gopkg.in/yaml.v3" -) - -// PromptsConfigService manages the prompts.yaml configuration where every -// LLM-facing prompt the CLI uses lives. It mirrors the keybindings and MCP -// service contracts so callers can treat all sibling YAML configs the same -// way. -type PromptsConfigService struct { - configPath string -} - -func NewPromptsConfigService(configPath string) *PromptsConfigService { - return &PromptsConfigService{configPath: configPath} -} - -// Load reads prompts.yaml from disk. When the file is missing it returns -// the in-code defaults so callers can treat absence as "use defaults" -// without special-casing. -func (s *PromptsConfigService) Load() (*config.PromptsConfig, error) { - if _, err := os.Stat(s.configPath); os.IsNotExist(err) { - logger.Info("Prompts config file does not exist, returning defaults", "path", s.configPath) - return config.DefaultPromptsConfig(), nil - } - - data, err := os.ReadFile(s.configPath) - if err != nil { - return nil, fmt.Errorf("failed to read prompts config: %w", err) - } - - var cfg config.PromptsConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse prompts config: %w", err) - } - - return &cfg, nil -} - -// Save writes the prompts configuration to disk, creating any missing -// parent directories. -func (s *PromptsConfigService) Save(cfg *config.PromptsConfig) error { - var buf bytes.Buffer - - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal prompts config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - configDir := filepath.Dir(s.configPath) - if err := os.MkdirAll(configDir, 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(s.configPath, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write prompts config: %w", err) - } - - logger.Info("Prompts config saved", "path", s.configPath) - return nil -} - -// DefaultPromptsConfig exposes the default prompts config used when -// no file exists. Callers (init, reset) use it to seed a fresh file. -func DefaultPromptsConfig() *config.PromptsConfig { - return config.DefaultPromptsConfig() -} diff --git a/internal/shortcuts/interfaces.go b/internal/shortcuts/interfaces.go index 728b302e..09cc4309 100644 --- a/internal/shortcuts/interfaces.go +++ b/internal/shortcuts/interfaces.go @@ -3,7 +3,6 @@ package shortcuts import ( "context" - config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" ) @@ -94,13 +93,3 @@ type TokenStats struct { TotalTokens int RequestCount int } - -// AgentsConfigService interface for managing agent configurations -type AgentsConfigService interface { - AddAgent(agent config.AgentEntry) error - UpdateAgent(agent config.AgentEntry) error - RemoveAgent(name string) error - ListAgents() ([]config.AgentEntry, error) - GetAgent(name string) (*config.AgentEntry, error) - GetAgentURLs() ([]string, error) -} diff --git a/internal/shortcuts/task_management.go b/internal/shortcuts/task_management.go index 3d26f13f..5e98ba8e 100644 --- a/internal/shortcuts/task_management.go +++ b/internal/shortcuts/task_management.go @@ -8,12 +8,12 @@ import ( // A2ATaskManagementShortcut shows the A2A task management dropdown type A2ATaskManagementShortcut struct { - configService *config.Config + config *config.Config } // NewA2ATaskManagementShortcut creates a new A2A task management shortcut -func NewA2ATaskManagementShortcut(configService *config.Config) *A2ATaskManagementShortcut { - return &A2ATaskManagementShortcut{configService: configService} +func NewA2ATaskManagementShortcut(cfg *config.Config) *A2ATaskManagementShortcut { + return &A2ATaskManagementShortcut{config: cfg} } func (t *A2ATaskManagementShortcut) GetName() string { return "tasks" } @@ -24,7 +24,7 @@ func (t *A2ATaskManagementShortcut) GetUsage() string { return "/ta func (t *A2ATaskManagementShortcut) CanExecute(args []string) bool { return len(args) == 0 } func (t *A2ATaskManagementShortcut) Execute(ctx context.Context, args []string) (ShortcutResult, error) { - if !t.configService.A2A.Enabled { + if !t.config.A2A.Enabled { return ShortcutResult{ Output: "Task management requires A2A to be enabled in configuration.", Success: false, diff --git a/internal/ui/components/input_status_bar.go b/internal/ui/components/input_status_bar.go index cd0e72f8..29e3c2f0 100644 --- a/internal/ui/components/input_status_bar.go +++ b/internal/ui/components/input_status_bar.go @@ -23,7 +23,7 @@ type InputStatusBar struct { modelService domain.ModelService themeService domain.ThemeService stateManager domain.StateManager - configService *config.Config + config *config.Config conversationRepo domain.ConversationRepository toolService domain.ToolService tokenEstimator domain.TokenEstimator @@ -61,9 +61,9 @@ func (isb *InputStatusBar) SetStateManager(stateManager domain.StateManager) { isb.stateManager = stateManager } -// SetConfigService sets the config service -func (isb *InputStatusBar) SetConfigService(configService *config.Config) { - isb.configService = configService +// SetConfig sets the config for the status bar +func (isb *InputStatusBar) SetConfig(cfg *config.Config) { + isb.config = cfg } // SetConversationRepo sets the conversation repository @@ -110,7 +110,7 @@ func (isb *InputStatusBar) SetHeight(height int) { } func (isb *InputStatusBar) Render() string { - if isb.configService != nil && !isb.configService.Chat.StatusBar.Enabled { + if isb.config != nil && !isb.config.Chat.StatusBar.Enabled { return "" } @@ -370,11 +370,11 @@ func (isb *InputStatusBar) buildModelDisplayText(currentModel string) string { // shouldShowIndicator checks if a specific indicator should be shown func (isb *InputStatusBar) shouldShowIndicator(indicator string) bool { - if isb.configService == nil { + if isb.config == nil { return true } - indicators := isb.configService.Chat.StatusBar.Indicators + indicators := isb.config.Chat.StatusBar.Indicators switch indicator { case "model": return indicators.Model @@ -416,10 +416,10 @@ func (isb *InputStatusBar) buildThemeIndicator() string { // buildMaxOutputIndicator builds the max output tokens indicator text func (isb *InputStatusBar) buildMaxOutputIndicator() string { - if isb.configService == nil { + if isb.config == nil { return "" } - maxTokens := isb.configService.Agent.MaxTokens + maxTokens := isb.config.Agent.MaxTokens if maxTokens > 0 { return fmt.Sprintf("Max Output: %d", maxTokens) } @@ -439,7 +439,7 @@ func (isb *InputStatusBar) buildA2AAgentsIndicator() string { // buildMCPIndicator builds the MCP server status indicator text func (isb *InputStatusBar) buildMCPIndicator() string { - if isb.mcpStatus == nil || isb.configService == nil || len(isb.configService.MCP.Servers) == 0 { + if isb.mcpStatus == nil || isb.config == nil || len(isb.config.MCP.Servers) == 0 { return "" } if isb.mcpStatus.TotalTools > 0 { diff --git a/internal/ui/components/input_status_bar_test.go b/internal/ui/components/input_status_bar_test.go index 197c5ab2..8e3d4894 100644 --- a/internal/ui/components/input_status_bar_test.go +++ b/internal/ui/components/input_status_bar_test.go @@ -49,7 +49,7 @@ func TestInputStatusBar_MasterToggle(t *testing.T) { modelService: modelService, themeService: themeService, stateManager: stateManager, - configService: cfg, + config: cfg, } output := statusBar.Render() @@ -127,7 +127,7 @@ func TestInputStatusBar_ShouldShowIndicator(t *testing.T) { } statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, } result := statusBar.shouldShowIndicator(tt.indicator) @@ -141,7 +141,7 @@ func TestInputStatusBar_ShouldShowIndicator(t *testing.T) { func TestInputStatusBar_ShouldShowIndicator_NilConfig(t *testing.T) { statusBar := &InputStatusBar{ - configService: nil, + config: nil, } result := statusBar.shouldShowIndicator("model") @@ -193,7 +193,7 @@ func TestInputStatusBar_BuildMaxOutputIndicator(t *testing.T) { cfg.Agent.MaxTokens = tt.maxTokens statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, } result := statusBar.buildMaxOutputIndicator() @@ -310,8 +310,8 @@ func TestInputStatusBar_BuildMCPIndicator(t *testing.T) { } statusBar := &InputStatusBar{ - mcpStatus: tt.mcpStatus, - configService: cfg, + mcpStatus: tt.mcpStatus, + config: cfg, } result := statusBar.buildMCPIndicator() @@ -412,7 +412,7 @@ func TestInputStatusBar_ShouldShowIndicator_SessionTokens(t *testing.T) { cfg.Chat.StatusBar.Indicators.SessionTokens = tt.configEnabled statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, } result := statusBar.shouldShowIndicator("session_tokens") @@ -438,7 +438,7 @@ func TestInputStatusBar_BuildModelDisplayText_WithSessionTokens(t *testing.T) { themeService.GetCurrentThemeNameReturns("tokyo-night") statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, themeService: themeService, conversationRepo: mockRepo, } @@ -528,7 +528,7 @@ func TestInputStatusBar_ShouldShowIndicator_GitBranch(t *testing.T) { cfg.Chat.StatusBar.Indicators.GitBranch = tt.configEnabled statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, } result := statusBar.shouldShowIndicator("git_branch") @@ -554,7 +554,7 @@ func TestInputStatusBar_BuildModelDisplayText(t *testing.T) { cfg.Chat.StatusBar.Indicators.GitBranch = false statusBar := &InputStatusBar{ - configService: cfg, + config: cfg, } result := statusBar.buildModelDisplayText("test-model") @@ -572,8 +572,8 @@ func TestInputStatusBar_BuildModelDisplayText_AllEnabled(t *testing.T) { themeService.GetCurrentThemeNameReturns("tokyo-night") statusBar := &InputStatusBar{ - configService: cfg, - themeService: themeService, + config: cfg, + themeService: themeService, } result := statusBar.buildModelDisplayText("test-model") diff --git a/internal/ui/components/input_view.go b/internal/ui/components/input_view.go index 3eddf24a..9b91afe5 100644 --- a/internal/ui/components/input_view.go +++ b/internal/ui/components/input_view.go @@ -24,7 +24,7 @@ type InputView struct { modelService domain.ModelService imageService domain.ImageService stateManager domain.StateManager - configService *config.Config + config *config.Config conversationRepo domain.ConversationRepository historyManager *history.HistoryManager disabled bool @@ -80,9 +80,9 @@ func (iv *InputView) SetStateManager(stateManager domain.StateManager) { iv.stateManager = stateManager } -// SetConfigService sets the config service for this input view -func (iv *InputView) SetConfigService(configService *config.Config) { - iv.configService = configService +// SetConfig sets the config for this input view +func (iv *InputView) SetConfig(cfg *config.Config) { + iv.config = cfg } // SetImageService sets the image service for this input view diff --git a/internal/ui/components/model_selection_view.go b/internal/ui/components/model_selection_view.go index 7cb01615..e59ee82b 100644 --- a/internal/ui/components/model_selection_view.go +++ b/internal/ui/components/model_selection_view.go @@ -6,6 +6,7 @@ import ( tea "github.com/charmbracelet/bubbletea" + config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" styles "github.com/inference-gateway/cli/internal/ui/styles" ) @@ -31,14 +32,14 @@ type ModelSelectorImpl struct { cancelled bool modelService domain.ModelService pricingService domain.PricingService - configService domain.ConfigService + config *config.Config searchQuery string searchMode bool currentView ModelViewMode } // NewModelSelector creates a new model selector -func NewModelSelector(models []string, modelService domain.ModelService, pricingService domain.PricingService, configService domain.ConfigService, styleProvider *styles.Provider) *ModelSelectorImpl { +func NewModelSelector(models []string, modelService domain.ModelService, pricingService domain.PricingService, cfg *config.Config, styleProvider *styles.Provider) *ModelSelectorImpl { m := &ModelSelectorImpl{ models: models, filteredModels: make([]string, len(models)), @@ -48,7 +49,7 @@ func NewModelSelector(models []string, modelService domain.ModelService, pricing styleProvider: styleProvider, modelService: modelService, pricingService: pricingService, - configService: configService, + config: cfg, searchQuery: "", searchMode: false, currentView: ModelViewAll, @@ -181,13 +182,10 @@ func (m *ModelSelectorImpl) View() string { accentColor := m.styleProvider.GetThemeColor("accent") b.WriteString(m.styleProvider.RenderWithColor("Select a Model", accentColor)) - if m.configService != nil { - cfg := m.configService.GetConfig() - if cfg.ClaudeCode.Enabled { - successColor := m.styleProvider.GetThemeColor("success") - b.WriteString(" ") - b.WriteString(m.styleProvider.RenderWithColor("● Claude Subscription", successColor)) - } + if m.config != nil && m.config.ClaudeCode.Enabled { + successColor := m.styleProvider.GetThemeColor("success") + b.WriteString(" ") + b.WriteString(m.styleProvider.RenderWithColor("● Claude Subscription", successColor)) } b.WriteString("\n\n") diff --git a/internal/web/pty_manager.go b/internal/web/pty_manager.go index 55fcfc81..3b5c14c8 100644 --- a/internal/web/pty_manager.go +++ b/internal/web/pty_manager.go @@ -12,7 +12,6 @@ import ( pty "github.com/creack/pty" websocket "github.com/gorilla/websocket" - viper "github.com/spf13/viper" config "github.com/inference-gateway/cli/config" logger "github.com/inference-gateway/cli/internal/logger" @@ -30,13 +29,13 @@ type SessionHandler interface { type Session = SessionHandler // CreateSessionHandler creates either a local PTY session or remote SSH session -func CreateSessionHandler(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, cfg *config.Config, v *viper.Viper, sessionID string, sessionManager *SessionManager, progressCh chan<- string) (SessionHandler, error) { +func CreateSessionHandler(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, cfg *config.Config, sessionID string, sessionManager *SessionManager, progressCh chan<- string) (SessionHandler, error) { if serverCfg != nil { return createRemoteSSHSession(webCfg, serverCfg, cfg.Gateway.URL, sessionID, sessionManager, progressCh) } logger.Info("Creating local PTY session") - return NewLocalPTYSession(cfg, v), nil + return NewLocalPTYSession(cfg), nil } // createRemoteSSHSession creates a remote SSH session with optional auto-install @@ -156,17 +155,15 @@ func ensureRemoteConfig(client *SSHClient, serverCfg *config.SSHServerConfig, _ // LocalPTYSession represents a single local terminal session type LocalPTYSession struct { - cfg *config.Config - viper *viper.Viper - pty *os.File - cmd *exec.Cmd - mu sync.Mutex + cfg *config.Config + pty *os.File + cmd *exec.Cmd + mu sync.Mutex } -func NewLocalPTYSession(cfg *config.Config, v *viper.Viper) *LocalPTYSession { +func NewLocalPTYSession(cfg *config.Config) *LocalPTYSession { return &LocalPTYSession{ - cfg: cfg, - viper: v, + cfg: cfg, } } diff --git a/internal/web/server.go b/internal/web/server.go index 61fafcae..53ab6101 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -16,7 +16,6 @@ import ( uuid "github.com/google/uuid" websocket "github.com/gorilla/websocket" - viper "github.com/spf13/viper" config "github.com/inference-gateway/cli/config" logger "github.com/inference-gateway/cli/internal/logger" @@ -30,16 +29,14 @@ var templates embed.FS type WebTerminalServer struct { cfg *config.Config - viper *viper.Viper server *http.Server upgrader websocket.Upgrader sessionManager *SessionManager } -func NewWebTerminalServer(cfg *config.Config, v *viper.Viper) *WebTerminalServer { +func NewWebTerminalServer(cfg *config.Config) *WebTerminalServer { return &WebTerminalServer{ - cfg: cfg, - viper: v, + cfg: cfg, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, @@ -49,7 +46,7 @@ func NewWebTerminalServer(cfg *config.Config, v *viper.Viper) *WebTerminalServer } func (s *WebTerminalServer) Start() error { - s.sessionManager = NewSessionManager(s.cfg, s.viper) + s.sessionManager = NewSessionManager(s.cfg) logger.Info("Checking embedded static files...") if err := fs.WalkDir(staticFiles, ".", func(path string, d fs.DirEntry, err error) error { @@ -315,7 +312,7 @@ func (s *WebTerminalServer) handleWebSocket(w http.ResponseWriter, r *http.Reque done := make(chan struct{}) go func() { defer close(done) - handler, handlerErr = CreateSessionHandler(&s.cfg.Web, serverCfg, s.cfg, s.viper, sessionID, s.sessionManager, progressCh) + handler, handlerErr = CreateSessionHandler(&s.cfg.Web, serverCfg, s.cfg, sessionID, s.sessionManager, progressCh) close(progressCh) }() diff --git a/internal/web/session_manager.go b/internal/web/session_manager.go index c414cb3f..2064117e 100644 --- a/internal/web/session_manager.go +++ b/internal/web/session_manager.go @@ -5,7 +5,6 @@ import ( "time" websocket "github.com/gorilla/websocket" - viper "github.com/spf13/viper" config "github.com/inference-gateway/cli/config" logger "github.com/inference-gateway/cli/internal/logger" @@ -14,7 +13,6 @@ import ( // SessionManager tracks and manages all active sessions type SessionManager struct { cfg *config.Config - viper *viper.Viper sessions map[string]*SessionEntry mu sync.RWMutex done chan struct{} @@ -28,10 +26,9 @@ type SessionEntry struct { mu sync.Mutex } -func NewSessionManager(cfg *config.Config, v *viper.Viper) *SessionManager { +func NewSessionManager(cfg *config.Config) *SessionManager { sm := &SessionManager{ cfg: cfg, - viper: v, sessions: make(map[string]*SessionEntry), done: make(chan struct{}), } @@ -46,7 +43,7 @@ func (sm *SessionManager) CreateSession(sessionID string) Session { sm.mu.Lock() defer sm.mu.Unlock() - session := NewLocalPTYSession(sm.cfg, sm.viper) + session := NewLocalPTYSession(sm.cfg) entry := &SessionEntry{ session: session, lastActive: time.Now(), diff --git a/tests/mocks/domain/fake_config_service.go b/tests/mocks/domain/fake_config_service.go deleted file mode 100644 index 368bac6f..00000000 --- a/tests/mocks/domain/fake_config_service.go +++ /dev/null @@ -1,686 +0,0 @@ -// Code generated by counterfeiter. DO NOT EDIT. -package domain - -import ( - "sync" - - "github.com/inference-gateway/cli/config" - "github.com/inference-gateway/cli/internal/domain" -) - -type FakeConfigService struct { - GetAPIKeyStub func() string - getAPIKeyMutex sync.RWMutex - getAPIKeyArgsForCall []struct { - } - getAPIKeyReturns struct { - result1 string - } - getAPIKeyReturnsOnCall map[int]struct { - result1 string - } - GetAgentConfigStub func() *config.AgentConfig - getAgentConfigMutex sync.RWMutex - getAgentConfigArgsForCall []struct { - } - getAgentConfigReturns struct { - result1 *config.AgentConfig - } - getAgentConfigReturnsOnCall map[int]struct { - result1 *config.AgentConfig - } - GetConfigStub func() *config.Config - getConfigMutex sync.RWMutex - getConfigArgsForCall []struct { - } - getConfigReturns struct { - result1 *config.Config - } - getConfigReturnsOnCall map[int]struct { - result1 *config.Config - } - GetGatewayURLStub func() string - getGatewayURLMutex sync.RWMutex - getGatewayURLArgsForCall []struct { - } - getGatewayURLReturns struct { - result1 string - } - getGatewayURLReturnsOnCall map[int]struct { - result1 string - } - GetOutputDirectoryStub func() string - getOutputDirectoryMutex sync.RWMutex - getOutputDirectoryArgsForCall []struct { - } - getOutputDirectoryReturns struct { - result1 string - } - getOutputDirectoryReturnsOnCall map[int]struct { - result1 string - } - GetProtectedPathsStub func() []string - getProtectedPathsMutex sync.RWMutex - getProtectedPathsArgsForCall []struct { - } - getProtectedPathsReturns struct { - result1 []string - } - getProtectedPathsReturnsOnCall map[int]struct { - result1 []string - } - GetSandboxDirectoriesStub func() []string - getSandboxDirectoriesMutex sync.RWMutex - getSandboxDirectoriesArgsForCall []struct { - } - getSandboxDirectoriesReturns struct { - result1 []string - } - getSandboxDirectoriesReturnsOnCall map[int]struct { - result1 []string - } - GetTimeoutStub func() int - getTimeoutMutex sync.RWMutex - getTimeoutArgsForCall []struct { - } - getTimeoutReturns struct { - result1 int - } - getTimeoutReturnsOnCall map[int]struct { - result1 int - } - IsApprovalRequiredStub func(string) bool - isApprovalRequiredMutex sync.RWMutex - isApprovalRequiredArgsForCall []struct { - arg1 string - } - isApprovalRequiredReturns struct { - result1 bool - } - isApprovalRequiredReturnsOnCall map[int]struct { - result1 bool - } - IsBashCommandWhitelistedStub func(string) bool - isBashCommandWhitelistedMutex sync.RWMutex - isBashCommandWhitelistedArgsForCall []struct { - arg1 string - } - isBashCommandWhitelistedReturns struct { - result1 bool - } - isBashCommandWhitelistedReturnsOnCall map[int]struct { - result1 bool - } - invocations map[string][][]interface{} - invocationsMutex sync.RWMutex -} - -func (fake *FakeConfigService) GetAPIKey() string { - fake.getAPIKeyMutex.Lock() - ret, specificReturn := fake.getAPIKeyReturnsOnCall[len(fake.getAPIKeyArgsForCall)] - fake.getAPIKeyArgsForCall = append(fake.getAPIKeyArgsForCall, struct { - }{}) - stub := fake.GetAPIKeyStub - fakeReturns := fake.getAPIKeyReturns - fake.recordInvocation("GetAPIKey", []interface{}{}) - fake.getAPIKeyMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetAPIKeyCallCount() int { - fake.getAPIKeyMutex.RLock() - defer fake.getAPIKeyMutex.RUnlock() - return len(fake.getAPIKeyArgsForCall) -} - -func (fake *FakeConfigService) GetAPIKeyCalls(stub func() string) { - fake.getAPIKeyMutex.Lock() - defer fake.getAPIKeyMutex.Unlock() - fake.GetAPIKeyStub = stub -} - -func (fake *FakeConfigService) GetAPIKeyReturns(result1 string) { - fake.getAPIKeyMutex.Lock() - defer fake.getAPIKeyMutex.Unlock() - fake.GetAPIKeyStub = nil - fake.getAPIKeyReturns = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetAPIKeyReturnsOnCall(i int, result1 string) { - fake.getAPIKeyMutex.Lock() - defer fake.getAPIKeyMutex.Unlock() - fake.GetAPIKeyStub = nil - if fake.getAPIKeyReturnsOnCall == nil { - fake.getAPIKeyReturnsOnCall = make(map[int]struct { - result1 string - }) - } - fake.getAPIKeyReturnsOnCall[i] = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetAgentConfig() *config.AgentConfig { - fake.getAgentConfigMutex.Lock() - ret, specificReturn := fake.getAgentConfigReturnsOnCall[len(fake.getAgentConfigArgsForCall)] - fake.getAgentConfigArgsForCall = append(fake.getAgentConfigArgsForCall, struct { - }{}) - stub := fake.GetAgentConfigStub - fakeReturns := fake.getAgentConfigReturns - fake.recordInvocation("GetAgentConfig", []interface{}{}) - fake.getAgentConfigMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetAgentConfigCallCount() int { - fake.getAgentConfigMutex.RLock() - defer fake.getAgentConfigMutex.RUnlock() - return len(fake.getAgentConfigArgsForCall) -} - -func (fake *FakeConfigService) GetAgentConfigCalls(stub func() *config.AgentConfig) { - fake.getAgentConfigMutex.Lock() - defer fake.getAgentConfigMutex.Unlock() - fake.GetAgentConfigStub = stub -} - -func (fake *FakeConfigService) GetAgentConfigReturns(result1 *config.AgentConfig) { - fake.getAgentConfigMutex.Lock() - defer fake.getAgentConfigMutex.Unlock() - fake.GetAgentConfigStub = nil - fake.getAgentConfigReturns = struct { - result1 *config.AgentConfig - }{result1} -} - -func (fake *FakeConfigService) GetAgentConfigReturnsOnCall(i int, result1 *config.AgentConfig) { - fake.getAgentConfigMutex.Lock() - defer fake.getAgentConfigMutex.Unlock() - fake.GetAgentConfigStub = nil - if fake.getAgentConfigReturnsOnCall == nil { - fake.getAgentConfigReturnsOnCall = make(map[int]struct { - result1 *config.AgentConfig - }) - } - fake.getAgentConfigReturnsOnCall[i] = struct { - result1 *config.AgentConfig - }{result1} -} - -func (fake *FakeConfigService) GetConfig() *config.Config { - fake.getConfigMutex.Lock() - ret, specificReturn := fake.getConfigReturnsOnCall[len(fake.getConfigArgsForCall)] - fake.getConfigArgsForCall = append(fake.getConfigArgsForCall, struct { - }{}) - stub := fake.GetConfigStub - fakeReturns := fake.getConfigReturns - fake.recordInvocation("GetConfig", []interface{}{}) - fake.getConfigMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetConfigCallCount() int { - fake.getConfigMutex.RLock() - defer fake.getConfigMutex.RUnlock() - return len(fake.getConfigArgsForCall) -} - -func (fake *FakeConfigService) GetConfigCalls(stub func() *config.Config) { - fake.getConfigMutex.Lock() - defer fake.getConfigMutex.Unlock() - fake.GetConfigStub = stub -} - -func (fake *FakeConfigService) GetConfigReturns(result1 *config.Config) { - fake.getConfigMutex.Lock() - defer fake.getConfigMutex.Unlock() - fake.GetConfigStub = nil - fake.getConfigReturns = struct { - result1 *config.Config - }{result1} -} - -func (fake *FakeConfigService) GetConfigReturnsOnCall(i int, result1 *config.Config) { - fake.getConfigMutex.Lock() - defer fake.getConfigMutex.Unlock() - fake.GetConfigStub = nil - if fake.getConfigReturnsOnCall == nil { - fake.getConfigReturnsOnCall = make(map[int]struct { - result1 *config.Config - }) - } - fake.getConfigReturnsOnCall[i] = struct { - result1 *config.Config - }{result1} -} - -func (fake *FakeConfigService) GetGatewayURL() string { - fake.getGatewayURLMutex.Lock() - ret, specificReturn := fake.getGatewayURLReturnsOnCall[len(fake.getGatewayURLArgsForCall)] - fake.getGatewayURLArgsForCall = append(fake.getGatewayURLArgsForCall, struct { - }{}) - stub := fake.GetGatewayURLStub - fakeReturns := fake.getGatewayURLReturns - fake.recordInvocation("GetGatewayURL", []interface{}{}) - fake.getGatewayURLMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetGatewayURLCallCount() int { - fake.getGatewayURLMutex.RLock() - defer fake.getGatewayURLMutex.RUnlock() - return len(fake.getGatewayURLArgsForCall) -} - -func (fake *FakeConfigService) GetGatewayURLCalls(stub func() string) { - fake.getGatewayURLMutex.Lock() - defer fake.getGatewayURLMutex.Unlock() - fake.GetGatewayURLStub = stub -} - -func (fake *FakeConfigService) GetGatewayURLReturns(result1 string) { - fake.getGatewayURLMutex.Lock() - defer fake.getGatewayURLMutex.Unlock() - fake.GetGatewayURLStub = nil - fake.getGatewayURLReturns = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetGatewayURLReturnsOnCall(i int, result1 string) { - fake.getGatewayURLMutex.Lock() - defer fake.getGatewayURLMutex.Unlock() - fake.GetGatewayURLStub = nil - if fake.getGatewayURLReturnsOnCall == nil { - fake.getGatewayURLReturnsOnCall = make(map[int]struct { - result1 string - }) - } - fake.getGatewayURLReturnsOnCall[i] = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetOutputDirectory() string { - fake.getOutputDirectoryMutex.Lock() - ret, specificReturn := fake.getOutputDirectoryReturnsOnCall[len(fake.getOutputDirectoryArgsForCall)] - fake.getOutputDirectoryArgsForCall = append(fake.getOutputDirectoryArgsForCall, struct { - }{}) - stub := fake.GetOutputDirectoryStub - fakeReturns := fake.getOutputDirectoryReturns - fake.recordInvocation("GetOutputDirectory", []interface{}{}) - fake.getOutputDirectoryMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetOutputDirectoryCallCount() int { - fake.getOutputDirectoryMutex.RLock() - defer fake.getOutputDirectoryMutex.RUnlock() - return len(fake.getOutputDirectoryArgsForCall) -} - -func (fake *FakeConfigService) GetOutputDirectoryCalls(stub func() string) { - fake.getOutputDirectoryMutex.Lock() - defer fake.getOutputDirectoryMutex.Unlock() - fake.GetOutputDirectoryStub = stub -} - -func (fake *FakeConfigService) GetOutputDirectoryReturns(result1 string) { - fake.getOutputDirectoryMutex.Lock() - defer fake.getOutputDirectoryMutex.Unlock() - fake.GetOutputDirectoryStub = nil - fake.getOutputDirectoryReturns = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetOutputDirectoryReturnsOnCall(i int, result1 string) { - fake.getOutputDirectoryMutex.Lock() - defer fake.getOutputDirectoryMutex.Unlock() - fake.GetOutputDirectoryStub = nil - if fake.getOutputDirectoryReturnsOnCall == nil { - fake.getOutputDirectoryReturnsOnCall = make(map[int]struct { - result1 string - }) - } - fake.getOutputDirectoryReturnsOnCall[i] = struct { - result1 string - }{result1} -} - -func (fake *FakeConfigService) GetProtectedPaths() []string { - fake.getProtectedPathsMutex.Lock() - ret, specificReturn := fake.getProtectedPathsReturnsOnCall[len(fake.getProtectedPathsArgsForCall)] - fake.getProtectedPathsArgsForCall = append(fake.getProtectedPathsArgsForCall, struct { - }{}) - stub := fake.GetProtectedPathsStub - fakeReturns := fake.getProtectedPathsReturns - fake.recordInvocation("GetProtectedPaths", []interface{}{}) - fake.getProtectedPathsMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetProtectedPathsCallCount() int { - fake.getProtectedPathsMutex.RLock() - defer fake.getProtectedPathsMutex.RUnlock() - return len(fake.getProtectedPathsArgsForCall) -} - -func (fake *FakeConfigService) GetProtectedPathsCalls(stub func() []string) { - fake.getProtectedPathsMutex.Lock() - defer fake.getProtectedPathsMutex.Unlock() - fake.GetProtectedPathsStub = stub -} - -func (fake *FakeConfigService) GetProtectedPathsReturns(result1 []string) { - fake.getProtectedPathsMutex.Lock() - defer fake.getProtectedPathsMutex.Unlock() - fake.GetProtectedPathsStub = nil - fake.getProtectedPathsReturns = struct { - result1 []string - }{result1} -} - -func (fake *FakeConfigService) GetProtectedPathsReturnsOnCall(i int, result1 []string) { - fake.getProtectedPathsMutex.Lock() - defer fake.getProtectedPathsMutex.Unlock() - fake.GetProtectedPathsStub = nil - if fake.getProtectedPathsReturnsOnCall == nil { - fake.getProtectedPathsReturnsOnCall = make(map[int]struct { - result1 []string - }) - } - fake.getProtectedPathsReturnsOnCall[i] = struct { - result1 []string - }{result1} -} - -func (fake *FakeConfigService) GetSandboxDirectories() []string { - fake.getSandboxDirectoriesMutex.Lock() - ret, specificReturn := fake.getSandboxDirectoriesReturnsOnCall[len(fake.getSandboxDirectoriesArgsForCall)] - fake.getSandboxDirectoriesArgsForCall = append(fake.getSandboxDirectoriesArgsForCall, struct { - }{}) - stub := fake.GetSandboxDirectoriesStub - fakeReturns := fake.getSandboxDirectoriesReturns - fake.recordInvocation("GetSandboxDirectories", []interface{}{}) - fake.getSandboxDirectoriesMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetSandboxDirectoriesCallCount() int { - fake.getSandboxDirectoriesMutex.RLock() - defer fake.getSandboxDirectoriesMutex.RUnlock() - return len(fake.getSandboxDirectoriesArgsForCall) -} - -func (fake *FakeConfigService) GetSandboxDirectoriesCalls(stub func() []string) { - fake.getSandboxDirectoriesMutex.Lock() - defer fake.getSandboxDirectoriesMutex.Unlock() - fake.GetSandboxDirectoriesStub = stub -} - -func (fake *FakeConfigService) GetSandboxDirectoriesReturns(result1 []string) { - fake.getSandboxDirectoriesMutex.Lock() - defer fake.getSandboxDirectoriesMutex.Unlock() - fake.GetSandboxDirectoriesStub = nil - fake.getSandboxDirectoriesReturns = struct { - result1 []string - }{result1} -} - -func (fake *FakeConfigService) GetSandboxDirectoriesReturnsOnCall(i int, result1 []string) { - fake.getSandboxDirectoriesMutex.Lock() - defer fake.getSandboxDirectoriesMutex.Unlock() - fake.GetSandboxDirectoriesStub = nil - if fake.getSandboxDirectoriesReturnsOnCall == nil { - fake.getSandboxDirectoriesReturnsOnCall = make(map[int]struct { - result1 []string - }) - } - fake.getSandboxDirectoriesReturnsOnCall[i] = struct { - result1 []string - }{result1} -} - -func (fake *FakeConfigService) GetTimeout() int { - fake.getTimeoutMutex.Lock() - ret, specificReturn := fake.getTimeoutReturnsOnCall[len(fake.getTimeoutArgsForCall)] - fake.getTimeoutArgsForCall = append(fake.getTimeoutArgsForCall, struct { - }{}) - stub := fake.GetTimeoutStub - fakeReturns := fake.getTimeoutReturns - fake.recordInvocation("GetTimeout", []interface{}{}) - fake.getTimeoutMutex.Unlock() - if stub != nil { - return stub() - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) GetTimeoutCallCount() int { - fake.getTimeoutMutex.RLock() - defer fake.getTimeoutMutex.RUnlock() - return len(fake.getTimeoutArgsForCall) -} - -func (fake *FakeConfigService) GetTimeoutCalls(stub func() int) { - fake.getTimeoutMutex.Lock() - defer fake.getTimeoutMutex.Unlock() - fake.GetTimeoutStub = stub -} - -func (fake *FakeConfigService) GetTimeoutReturns(result1 int) { - fake.getTimeoutMutex.Lock() - defer fake.getTimeoutMutex.Unlock() - fake.GetTimeoutStub = nil - fake.getTimeoutReturns = struct { - result1 int - }{result1} -} - -func (fake *FakeConfigService) GetTimeoutReturnsOnCall(i int, result1 int) { - fake.getTimeoutMutex.Lock() - defer fake.getTimeoutMutex.Unlock() - fake.GetTimeoutStub = nil - if fake.getTimeoutReturnsOnCall == nil { - fake.getTimeoutReturnsOnCall = make(map[int]struct { - result1 int - }) - } - fake.getTimeoutReturnsOnCall[i] = struct { - result1 int - }{result1} -} - -func (fake *FakeConfigService) IsApprovalRequired(arg1 string) bool { - fake.isApprovalRequiredMutex.Lock() - ret, specificReturn := fake.isApprovalRequiredReturnsOnCall[len(fake.isApprovalRequiredArgsForCall)] - fake.isApprovalRequiredArgsForCall = append(fake.isApprovalRequiredArgsForCall, struct { - arg1 string - }{arg1}) - stub := fake.IsApprovalRequiredStub - fakeReturns := fake.isApprovalRequiredReturns - fake.recordInvocation("IsApprovalRequired", []interface{}{arg1}) - fake.isApprovalRequiredMutex.Unlock() - if stub != nil { - return stub(arg1) - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) IsApprovalRequiredCallCount() int { - fake.isApprovalRequiredMutex.RLock() - defer fake.isApprovalRequiredMutex.RUnlock() - return len(fake.isApprovalRequiredArgsForCall) -} - -func (fake *FakeConfigService) IsApprovalRequiredCalls(stub func(string) bool) { - fake.isApprovalRequiredMutex.Lock() - defer fake.isApprovalRequiredMutex.Unlock() - fake.IsApprovalRequiredStub = stub -} - -func (fake *FakeConfigService) IsApprovalRequiredArgsForCall(i int) string { - fake.isApprovalRequiredMutex.RLock() - defer fake.isApprovalRequiredMutex.RUnlock() - argsForCall := fake.isApprovalRequiredArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeConfigService) IsApprovalRequiredReturns(result1 bool) { - fake.isApprovalRequiredMutex.Lock() - defer fake.isApprovalRequiredMutex.Unlock() - fake.IsApprovalRequiredStub = nil - fake.isApprovalRequiredReturns = struct { - result1 bool - }{result1} -} - -func (fake *FakeConfigService) IsApprovalRequiredReturnsOnCall(i int, result1 bool) { - fake.isApprovalRequiredMutex.Lock() - defer fake.isApprovalRequiredMutex.Unlock() - fake.IsApprovalRequiredStub = nil - if fake.isApprovalRequiredReturnsOnCall == nil { - fake.isApprovalRequiredReturnsOnCall = make(map[int]struct { - result1 bool - }) - } - fake.isApprovalRequiredReturnsOnCall[i] = struct { - result1 bool - }{result1} -} - -func (fake *FakeConfigService) IsBashCommandWhitelisted(arg1 string) bool { - fake.isBashCommandWhitelistedMutex.Lock() - ret, specificReturn := fake.isBashCommandWhitelistedReturnsOnCall[len(fake.isBashCommandWhitelistedArgsForCall)] - fake.isBashCommandWhitelistedArgsForCall = append(fake.isBashCommandWhitelistedArgsForCall, struct { - arg1 string - }{arg1}) - stub := fake.IsBashCommandWhitelistedStub - fakeReturns := fake.isBashCommandWhitelistedReturns - fake.recordInvocation("IsBashCommandWhitelisted", []interface{}{arg1}) - fake.isBashCommandWhitelistedMutex.Unlock() - if stub != nil { - return stub(arg1) - } - if specificReturn { - return ret.result1 - } - return fakeReturns.result1 -} - -func (fake *FakeConfigService) IsBashCommandWhitelistedCallCount() int { - fake.isBashCommandWhitelistedMutex.RLock() - defer fake.isBashCommandWhitelistedMutex.RUnlock() - return len(fake.isBashCommandWhitelistedArgsForCall) -} - -func (fake *FakeConfigService) IsBashCommandWhitelistedCalls(stub func(string) bool) { - fake.isBashCommandWhitelistedMutex.Lock() - defer fake.isBashCommandWhitelistedMutex.Unlock() - fake.IsBashCommandWhitelistedStub = stub -} - -func (fake *FakeConfigService) IsBashCommandWhitelistedArgsForCall(i int) string { - fake.isBashCommandWhitelistedMutex.RLock() - defer fake.isBashCommandWhitelistedMutex.RUnlock() - argsForCall := fake.isBashCommandWhitelistedArgsForCall[i] - return argsForCall.arg1 -} - -func (fake *FakeConfigService) IsBashCommandWhitelistedReturns(result1 bool) { - fake.isBashCommandWhitelistedMutex.Lock() - defer fake.isBashCommandWhitelistedMutex.Unlock() - fake.IsBashCommandWhitelistedStub = nil - fake.isBashCommandWhitelistedReturns = struct { - result1 bool - }{result1} -} - -func (fake *FakeConfigService) IsBashCommandWhitelistedReturnsOnCall(i int, result1 bool) { - fake.isBashCommandWhitelistedMutex.Lock() - defer fake.isBashCommandWhitelistedMutex.Unlock() - fake.IsBashCommandWhitelistedStub = nil - if fake.isBashCommandWhitelistedReturnsOnCall == nil { - fake.isBashCommandWhitelistedReturnsOnCall = make(map[int]struct { - result1 bool - }) - } - fake.isBashCommandWhitelistedReturnsOnCall[i] = struct { - result1 bool - }{result1} -} - -func (fake *FakeConfigService) Invocations() map[string][][]interface{} { - fake.invocationsMutex.RLock() - defer fake.invocationsMutex.RUnlock() - copiedInvocations := map[string][][]interface{}{} - for key, value := range fake.invocations { - copiedInvocations[key] = value - } - return copiedInvocations -} - -func (fake *FakeConfigService) recordInvocation(key string, args []interface{}) { - fake.invocationsMutex.Lock() - defer fake.invocationsMutex.Unlock() - if fake.invocations == nil { - fake.invocations = map[string][][]interface{}{} - } - if fake.invocations[key] == nil { - fake.invocations[key] = [][]interface{}{} - } - fake.invocations[key] = append(fake.invocations[key], args) -} - -var _ domain.ConfigService = new(FakeConfigService)