diff --git a/.infer/agents.yaml b/.infer/agents.yaml index f05f9992..5347d41d 100644 --- a/.infer/agents.yaml +++ b/.infer/agents.yaml @@ -1,3 +1,4 @@ +--- agents: - name: mock-agent url: http://localhost:8081 diff --git a/.infer/config.yaml b/.infer/config.yaml index ecb0835f..acb40288 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -1,3 +1,4 @@ +--- container_runtime: type: docker gateway: diff --git a/cmd/agents.go b/cmd/agents.go index 9c97f15f..8cb27375 100644 --- a/cmd/agents.go +++ b/cmd/agents.go @@ -289,7 +289,11 @@ func addAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bool, Environment: environment, } - if err := config.AddAgent(path, agent); err != nil { + cfg, err := config.LoadAgents(path) + if err != nil { + return err + } + if err := cfg.CreateEntry(agent); err != nil { return err } @@ -320,7 +324,11 @@ func updateAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bo return err } - existing, err := config.GetAgent(path, name) + cfg, err := config.LoadAgents(path) + if err != nil { + return err + } + existing, err := cfg.ReadEntry(name) if err != nil { return err } @@ -349,7 +357,7 @@ func updateAgent(cmd *cobra.Command, name, url, artifactsURL, oci string, run bo return fmt.Errorf("--model is required when --run is enabled. Specify a model in the format provider/model (e.g., openai/gpt-5, anthropic/claude-4-5-sonnet)") } - if err := config.UpdateAgent(path, agent); err != nil { + if err := cfg.UpdateEntry(agent); err != nil { return err } @@ -380,7 +388,11 @@ func removeAgent(cmd *cobra.Command, name string) error { return err } - if err := config.RemoveAgent(path, name); err != nil { + cfg, err := config.LoadAgents(path) + if err != nil { + return err + } + if err := cfg.DeleteEntry(name); err != nil { return err } @@ -394,10 +406,11 @@ func listAgents(cmd *cobra.Command, args []string) error { return err } - localAgents, err := config.ListAgents(path) + cfg, err := config.LoadAgents(path) if err != nil { return err } + localAgents := cfg.ListEntries() externalAgents := extractExternalAgents(Cfg) @@ -490,7 +503,11 @@ func showAgent(cmd *cobra.Command, name string) error { return err } - agent, err := config.GetAgent(path, name) + cfg, err := config.LoadAgents(path) + if err != nil { + return err + } + agent, err := cfg.ReadEntry(name) if err != nil { return err } @@ -579,13 +596,17 @@ func enableAgent(cmd *cobra.Command, name string) error { return err } - agent, err := config.GetAgent(path, name) + cfg, err := config.LoadAgents(path) + if err != nil { + return fmt.Errorf("failed to load agents: %w", err) + } + agent, err := cfg.ReadEntry(name) if err != nil { return fmt.Errorf("failed to find agent: %w", err) } agent.Enabled = true - if err := config.UpdateAgent(path, *agent); err != nil { + if err := cfg.UpdateEntry(*agent); err != nil { return fmt.Errorf("failed to enable agent: %w", err) } @@ -602,13 +623,17 @@ func disableAgent(cmd *cobra.Command, name string) error { return err } - agent, err := config.GetAgent(path, name) + cfg, err := config.LoadAgents(path) + if err != nil { + return fmt.Errorf("failed to load agents: %w", err) + } + agent, err := cfg.ReadEntry(name) if err != nil { return fmt.Errorf("failed to find agent: %w", err) } agent.Enabled = false - if err := config.UpdateAgent(path, *agent); err != nil { + if err := cfg.UpdateEntry(*agent); err != nil { return fmt.Errorf("failed to disable agent: %w", err) } diff --git a/cmd/config.go b/cmd/config.go index 83b1bb62..9917c646 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -17,6 +17,7 @@ import ( sdk "github.com/inference-gateway/sdk" config "github.com/inference-gateway/cli/config" + configutils "github.com/inference-gateway/cli/config/utils" container "github.com/inference-gateway/cli/internal/container" formatting "github.com/inference-gateway/cli/internal/formatting" logger "github.com/inference-gateway/cli/internal/logger" @@ -108,11 +109,7 @@ For complete project initialization, use 'infer init' instead.`, } } - if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := writeConfigAsYAMLWithIndent(configPath, 2); err != nil { + if err := configutils.SaveYAML(configPath, "config", config.DefaultConfig()); err != nil { return fmt.Errorf("failed to create config file: %w", err) } diff --git a/cmd/init.go b/cmd/init.go index 5a88a452..13a5b6a3 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -1,15 +1,14 @@ package cmd import ( - "bytes" "fmt" "os" "path/filepath" cobra "github.com/spf13/cobra" - yaml "gopkg.in/yaml.v3" config "github.com/inference-gateway/cli/config" + utils "github.com/inference-gateway/cli/config/utils" icons "github.com/inference-gateway/cli/internal/ui/styles/icons" ) @@ -40,7 +39,7 @@ func initializeProject(cmd *cobra.Command) error { //nolint:funlen userspace, _ := cmd.Flags().GetBool("userspace") skipMigrations, _ := cmd.Flags().GetBool("skip-migrations") - var configPath, gitignorePath, scmShortcutsPath, gitShortcutsPath, mcpShortcutsPath, shellsShortcutsPath, exportShortcutsPath, a2aShortcutsPath, mcpPath, keybindingsPath, promptsPath, channelsPath string + var configPath, gitignorePath, scmShortcutsPath, gitShortcutsPath, mcpShortcutsPath, shellsShortcutsPath, exportShortcutsPath, a2aShortcutsPath, mcpPath, keybindingsPath, promptsPath, channelsPath, agentsPath string if userspace { homeDir, err := os.UserHomeDir() @@ -59,6 +58,7 @@ func initializeProject(cmd *cobra.Command) error { //nolint:funlen keybindingsPath = filepath.Join(homeDir, config.ConfigDirName, config.KeybindingsFileName) promptsPath = filepath.Join(homeDir, config.ConfigDirName, config.PromptsFileName) channelsPath = filepath.Join(homeDir, config.ConfigDirName, config.ChannelsFileName) + agentsPath = filepath.Join(homeDir, config.ConfigDirName, config.AgentsFileName) } else { configPath = config.DefaultConfigPath gitignorePath = filepath.Join(config.ConfigDirName, config.GitignoreFileName) @@ -72,15 +72,16 @@ func initializeProject(cmd *cobra.Command) error { //nolint:funlen keybindingsPath = config.DefaultKeybindingsPath promptsPath = config.DefaultPromptsPath channelsPath = config.DefaultChannelsPath + agentsPath = config.DefaultAgentsPath } if !overwrite { - if err := validateFilesNotExist(configPath, gitignorePath, scmShortcutsPath, gitShortcutsPath, mcpShortcutsPath, shellsShortcutsPath, exportShortcutsPath, a2aShortcutsPath, mcpPath, keybindingsPath, promptsPath, channelsPath); err != nil { + if err := validateFilesNotExist(configPath, gitignorePath, scmShortcutsPath, gitShortcutsPath, mcpShortcutsPath, shellsShortcutsPath, exportShortcutsPath, a2aShortcutsPath, mcpPath, keybindingsPath, promptsPath, channelsPath, agentsPath); err != nil { return err } } - if err := writeConfigAsYAMLWithIndent(configPath, 2); err != nil { + if err := utils.SaveYAML(configPath, "config", config.DefaultConfig()); err != nil { return fmt.Errorf("failed to create config file: %w", err) } @@ -139,6 +140,10 @@ tmp/ return fmt.Errorf("failed to create channels config file: %w", err) } + if err := createAgentsConfigFile(agentsPath); err != nil { + return fmt.Errorf("failed to create agents config file: %w", err) + } + var scopeDesc string if userspace { scopeDesc = "userspace" @@ -159,6 +164,7 @@ tmp/ fmt.Printf(" Created: %s\n", keybindingsPath) fmt.Printf(" Created: %s\n", promptsPath) fmt.Printf(" Created: %s\n", channelsPath) + fmt.Printf(" Created: %s\n", agentsPath) if migrated { fmt.Printf("\n%s Migrated legacy `channels:` block from config.yaml into %s.\n", icons.CheckMarkStyle.Render(icons.CheckMark), channelsPath) fmt.Printf(" You can now remove the `channels:` block from %s.\n", configPath) @@ -184,29 +190,6 @@ tmp/ return nil } -// writeConfigAsYAMLWithIndent writes the default configuration to a YAML file with specified indentation -func writeConfigAsYAMLWithIndent(filename string, indent int) error { - defaultConfig := config.DefaultConfig() - - if err := os.MkdirAll(filepath.Dir(filename), 0755); err != nil { - return fmt.Errorf("failed to create directory: %w", err) - } - - var buf bytes.Buffer - yamlEncoder := yaml.NewEncoder(&buf) - yamlEncoder.SetIndent(indent) - - if err := yamlEncoder.Encode(defaultConfig); err != nil { - return fmt.Errorf("failed to marshal config to YAML: %w", err) - } - - if err := yamlEncoder.Close(); err != nil { - return fmt.Errorf("failed to close YAML encoder: %w", err) - } - - return os.WriteFile(filename, buf.Bytes(), 0644) -} - // checkFileExists checks if a file exists and returns an error if it does func checkFileExists(path, description string) error { if _, err := os.Stat(path); err == nil { @@ -466,6 +449,16 @@ func createChannelsConfigFile(path string) (bool, error) { return migrated, nil } +// createAgentsConfigFile writes a fresh agents.yaml seeded from the in-code +// defaults so users can manage A2A agents via `infer agents` commands. +func createAgentsConfigFile(path string) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + return config.SaveAgents(path, config.DefaultAgentsConfig()) +} + // createMCPConfigFile creates the MCP configuration YAML file func createMCPConfigFile(path string) error { if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { diff --git a/cmd/init_test.go b/cmd/init_test.go index 07b9d7bd..3042b1f1 100644 --- a/cmd/init_test.go +++ b/cmd/init_test.go @@ -7,6 +7,9 @@ import ( "testing" cobra "github.com/spf13/cobra" + + config "github.com/inference-gateway/cli/config" + configutils "github.com/inference-gateway/cli/config/utils" ) func TestInitializeProject(t *testing.T) { @@ -96,7 +99,7 @@ func TestInitializeProject(t *testing.T) { } } -func TestWriteConfigAsYAMLWithIndent(t *testing.T) { +func TestInitWritesConfigYAMLWithDocMarker(t *testing.T) { tmpDir, err := os.MkdirTemp("", "infer-config-test-*") if err != nil { t.Fatalf("failed to create temp dir: %v", err) @@ -105,23 +108,22 @@ func TestWriteConfigAsYAMLWithIndent(t *testing.T) { configPath := tmpDir + "/.infer/config.yaml" - err = writeConfigAsYAMLWithIndent(configPath, 2) - if err != nil { - t.Errorf("writeConfigAsYAMLWithIndent() error = %v", err) - return + if err := configutils.SaveYAML(configPath, "config", config.DefaultConfig()); err != nil { + t.Fatalf("SaveYAML() error = %v", err) } if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Errorf("expected config file to be created") - return + t.Fatal("expected config file to be created") } content, err := os.ReadFile(configPath) if err != nil { - t.Errorf("failed to read config file: %v", err) - return + t.Fatalf("failed to read config file: %v", err) } + if !strings.HasPrefix(string(content), "---\n") { + t.Errorf("config file should start with `---\\n`, got %q", string(content[:min(8, len(content))])) + } if !strings.Contains(string(content), "gateway:") { t.Errorf("config file does not contain expected gateway section") } diff --git a/cmd/mcp.go b/cmd/mcp.go index 5ee0ace7..bd11fde0 100644 --- a/cmd/mcp.go +++ b/cmd/mcp.go @@ -314,7 +314,11 @@ func addMCPServer(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - if err := config.AddMCPServer(configPath, server); err != nil { + cfg, err := config.LoadMCP(configPath) + if err != nil { + return fmt.Errorf("failed to load MCP config: %w", err) + } + if err := cfg.CreateEntry(server); err != nil { return fmt.Errorf("failed to add MCP server: %w", err) } @@ -341,7 +345,11 @@ func removeMCPServer(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - if err := config.RemoveMCPServer(configPath, name); err != nil { + cfg, err := config.LoadMCP(configPath) + if err != nil { + return fmt.Errorf("failed to load MCP config: %w", err) + } + if err := cfg.DeleteEntry(name); err != nil { return fmt.Errorf("failed to remove MCP server: %w", err) } @@ -358,7 +366,11 @@ func updateMCPServer(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - existing, err := config.GetMCPServer(configPath, name) + cfg, err := config.LoadMCP(configPath) + if err != nil { + return fmt.Errorf("failed to load MCP config: %w", err) + } + existing, err := cfg.ReadEntry(name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } @@ -399,7 +411,7 @@ func updateMCPServer(cmd *cobra.Command, args []string) error { existing.Enabled = enabled } - if err := config.UpdateMCPServer(configPath, *existing); err != nil { + if err := cfg.UpdateEntry(*existing); err != nil { return fmt.Errorf("failed to update MCP server: %w", err) } @@ -416,14 +428,18 @@ func enableMCPServer(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - server, err := config.GetMCPServer(configPath, name) + cfg, err := config.LoadMCP(configPath) + if err != nil { + return fmt.Errorf("failed to load MCP config: %w", err) + } + server, err := cfg.ReadEntry(name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } server.Enabled = true - if err := config.UpdateMCPServer(configPath, *server); err != nil { + if err := cfg.UpdateEntry(*server); err != nil { return fmt.Errorf("failed to enable MCP server: %w", err) } @@ -441,14 +457,18 @@ func disableMCPServer(cmd *cobra.Command, args []string) error { configPath := getMCPConfigPath(cmd) - server, err := config.GetMCPServer(configPath, name) + cfg, err := config.LoadMCP(configPath) + if err != nil { + return fmt.Errorf("failed to load MCP config: %w", err) + } + server, err := cfg.ReadEntry(name) if err != nil { return fmt.Errorf("failed to get MCP server: %w", err) } server.Enabled = false - if err := config.UpdateMCPServer(configPath, *server); err != nil { + if err := cfg.UpdateEntry(*server); err != nil { return fmt.Errorf("failed to disable MCP server: %w", err) } diff --git a/config/agents.go b/config/agents.go index 3fa7753c..46aa1728 100644 --- a/config/agents.go +++ b/config/agents.go @@ -1,20 +1,20 @@ package config import ( - "bytes" - "fmt" - "os" - "path/filepath" "strings" - yaml "gopkg.in/yaml.v3" + utils "github.com/inference-gateway/cli/config/utils" ) // AgentsConfig represents the agents.yaml configuration file type AgentsConfig struct { Agents []AgentEntry `yaml:"agents" mapstructure:"agents"` + + path string } +var _ CollectionConfig[AgentEntry] = (*AgentsConfig)(nil) + // AgentEntry represents a single A2A agent configuration type AgentEntry struct { Name string `yaml:"name" mapstructure:"name"` @@ -73,142 +73,69 @@ func (a *AgentEntry) GetEnvironmentWithModel() map[string]string { // returns the in-code defaults so callers can treat absence as "use // defaults" without special-casing. func LoadAgents(path string) (*AgentsConfig, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return DefaultAgentsConfig(), nil - } - - data, err := os.ReadFile(path) + cfg, err := utils.LoadYAML(path, "agents", DefaultAgentsConfig) if err != nil { - return nil, fmt.Errorf("failed to read agents config: %w", err) - } - - expandedData := os.ExpandEnv(string(data)) - - var cfg AgentsConfig - if err := yaml.Unmarshal([]byte(expandedData), &cfg); err != nil { - return nil, fmt.Errorf("failed to parse agents config: %w", err) + return nil, err } - - return &cfg, nil + cfg.path = path + return cfg, nil } // SaveAgents writes the agents configuration to disk, creating any // missing parent directories. func SaveAgents(path string, cfg *AgentsConfig) error { - var buf bytes.Buffer - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal agents config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write agents config: %w", err) - } - - return nil + return utils.SaveYAML(path, "agents", cfg) } -// AddAgent adds a new agent to the configuration file. -func AddAgent(path string, agent AgentEntry) error { - cfg, err := LoadAgents(path) - if err != nil { - return err - } - - for _, existing := range cfg.Agents { - if existing.Name == agent.Name { - return fmt.Errorf("agent with name '%s' already exists", agent.Name) - } - } +func agentName(e AgentEntry) string { return e.Name } - cfg.Agents = append(cfg.Agents, agent) - return SaveAgents(path, cfg) -} +const agentKind = "agent" -// UpdateAgent updates an existing agent by name. -func UpdateAgent(path string, agent AgentEntry) error { - cfg, err := LoadAgents(path) +// CreateEntry implements CollectionConfig. +func (c *AgentsConfig) CreateEntry(entry AgentEntry) error { + next, err := appendEntry(c.Agents, entry, entry.Name, agentName, agentKind) if err != nil { return err } + c.Agents = next + return SaveAgents(c.path, c) +} - for i, existing := range cfg.Agents { - if existing.Name == agent.Name { - cfg.Agents[i] = agent - return SaveAgents(path, cfg) - } - } - - return fmt.Errorf("agent with name '%s' not found", agent.Name) +// ReadEntry implements CollectionConfig. +func (c *AgentsConfig) ReadEntry(name string) (*AgentEntry, error) { + return findEntry(c.Agents, name, agentName, agentKind) } -// RemoveAgent removes an agent by name. -func RemoveAgent(path, name string) error { - cfg, err := LoadAgents(path) +// UpdateEntry implements CollectionConfig. +func (c *AgentsConfig) UpdateEntry(entry AgentEntry) error { + next, err := replaceEntry(c.Agents, entry, entry.Name, agentName, agentKind) if err != nil { return err } - - found := false - newAgents := make([]AgentEntry, 0, len(cfg.Agents)) - for _, agent := range cfg.Agents { - if agent.Name != name { - newAgents = append(newAgents, agent) - } else { - found = true - } - } - - if !found { - return fmt.Errorf("agent with name '%s' not found", name) - } - - cfg.Agents = newAgents - return SaveAgents(path, cfg) + c.Agents = next + return SaveAgents(c.path, c) } -// ListAgents returns all configured agents. -func ListAgents(path string) ([]AgentEntry, error) { - cfg, err := LoadAgents(path) +// DeleteEntry implements CollectionConfig. +func (c *AgentsConfig) DeleteEntry(name string) error { + next, err := removeEntry(c.Agents, name, agentName, agentKind) if err != nil { - return nil, err + return err } - return cfg.Agents, nil + c.Agents = next + return SaveAgents(c.path, c) } -// GetAgent returns a single agent entry by name. -func GetAgent(path, name string) (*AgentEntry, error) { - cfg, err := LoadAgents(path) - if err != nil { - return nil, err - } - - for _, agent := range cfg.Agents { - if agent.Name == name { - return &agent, nil - } - } - - return nil, fmt.Errorf("agent with name '%s' not found", name) -} +// ListEntries implements CollectionConfig. +func (c *AgentsConfig) ListEntries() []AgentEntry { return c.Agents } // GetAgentURLs returns URLs of all configured agents. func GetAgentURLs(path string) ([]string, error) { - agents, err := ListAgents(path) + cfg, err := LoadAgents(path) if err != nil { return nil, err } - + agents := cfg.ListEntries() urls := make([]string, 0, len(agents)) for _, agent := range agents { urls = append(urls, agent.URL) diff --git a/config/agents_persistence_test.go b/config/agents_test.go similarity index 54% rename from config/agents_persistence_test.go rename to config/agents_test.go index b91c09db..3dcaca76 100644 --- a/config/agents_persistence_test.go +++ b/config/agents_test.go @@ -9,7 +9,7 @@ import ( require "github.com/stretchr/testify/require" ) -func TestAddAgent(t *testing.T) { +func TestCreateEntry_Agent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") @@ -24,170 +24,138 @@ func TestAddAgent(t *testing.T) { }, } - if err := config.AddAgent(agentsPath, agent); err != nil { - t.Fatalf("Failed to add agent: %v", err) - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, cfg.CreateEntry(agent)) + if _, err := os.Stat(agentsPath); os.IsNotExist(err) { t.Fatal("Agents config file was not created") } - cfg, err := config.LoadAgents(agentsPath) - if err != nil { - t.Fatalf("Failed to load config: %v", err) - } - if len(cfg.Agents) != 1 { - t.Fatalf("Expected 1 agent, got %d", len(cfg.Agents)) - } - if cfg.Agents[0].Name != agent.Name { - t.Errorf("Expected name %s, got %s", agent.Name, cfg.Agents[0].Name) - } - if cfg.Agents[0].URL != agent.URL { - t.Errorf("Expected URL %s, got %s", agent.URL, cfg.Agents[0].URL) - } - if cfg.Agents[0].OCI != agent.OCI { - t.Errorf("Expected OCI %s, got %s", agent.OCI, cfg.Agents[0].OCI) - } - if cfg.Agents[0].Run != agent.Run { - t.Errorf("Expected Run %v, got %v", agent.Run, cfg.Agents[0].Run) - } - if len(cfg.Agents[0].Environment) != 2 { - t.Errorf("Expected 2 environment variables, got %d", len(cfg.Agents[0].Environment)) - } + reloaded, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.Len(t, reloaded.Agents, 1) + require.Equal(t, agent.Name, reloaded.Agents[0].Name) + require.Equal(t, agent.URL, reloaded.Agents[0].URL) + require.Equal(t, agent.OCI, reloaded.Agents[0].OCI) + require.Equal(t, agent.Run, reloaded.Agents[0].Run) + require.Len(t, reloaded.Agents[0].Environment, 2) } -func TestAddAgent_Duplicate(t *testing.T) { +func TestCreateEntry_Agent_Duplicate(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") agent := config.AgentEntry{Name: "test-agent", URL: "https://agent.example.com"} - if err := config.AddAgent(agentsPath, agent); err != nil { - t.Fatalf("Failed to add agent: %v", err) - } - if err := config.AddAgent(agentsPath, agent); err == nil { - t.Fatal("Expected error when adding duplicate agent, got nil") - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, cfg.CreateEntry(agent)) + require.Error(t, cfg.CreateEntry(agent), "expected duplicate-name error") } -func TestRemoveAgent(t *testing.T) { +func TestDeleteEntry_Agent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") agent1 := config.AgentEntry{Name: "agent1", URL: "https://agent1.example.com"} agent2 := config.AgentEntry{Name: "agent2", URL: "https://agent2.example.com"} - require.NoError(t, config.AddAgent(agentsPath, agent1)) - require.NoError(t, config.AddAgent(agentsPath, agent2)) - - if err := config.RemoveAgent(agentsPath, "agent1"); err != nil { - t.Fatalf("Failed to remove agent: %v", err) - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, cfg.CreateEntry(agent1)) + require.NoError(t, cfg.CreateEntry(agent2)) + require.NoError(t, cfg.DeleteEntry("agent1")) - agents, err := config.ListAgents(agentsPath) - if err != nil { - t.Fatalf("Failed to list agents: %v", err) - } - if len(agents) != 1 { - t.Fatalf("Expected 1 agent, got %d", len(agents)) - } - if agents[0].Name != "agent2" { - t.Errorf("Expected agent2 to remain, got %s", agents[0].Name) - } + reloaded, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + agents := reloaded.ListEntries() + require.Len(t, agents, 1) + require.Equal(t, "agent2", agents[0].Name) } -func TestRemoveAgent_Nonexistent(t *testing.T) { +func TestDeleteEntry_Agent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - if err := config.RemoveAgent(agentsPath, "nonexistent"); err == nil { - t.Fatal("Expected error when removing nonexistent agent, got nil") - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.Error(t, cfg.DeleteEntry("nonexistent")) } -func TestListAgents(t *testing.T) { +func TestListEntries_Agents(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - agents := []config.AgentEntry{ + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + for _, agent := range []config.AgentEntry{ {Name: "agent1", URL: "https://agent1.example.com"}, {Name: "agent2", URL: "https://agent2.example.com"}, {Name: "agent3", URL: "https://agent3.example.com"}, - } - for _, agent := range agents { - require.NoError(t, config.AddAgent(agentsPath, agent)) + } { + require.NoError(t, cfg.CreateEntry(agent)) } - listed, err := config.ListAgents(agentsPath) - if err != nil { - t.Fatalf("Failed to list agents: %v", err) - } - if len(listed) != 3 { - t.Fatalf("Expected 3 agents, got %d", len(listed)) - } + reloaded, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.Len(t, reloaded.ListEntries(), 3) } -func TestGetAgent(t *testing.T) { +func TestReadEntry_Agent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") agent := config.AgentEntry{Name: "test-agent", URL: "https://agent.example.com"} - require.NoError(t, config.AddAgent(agentsPath, agent)) - retrieved, err := config.GetAgent(agentsPath, "test-agent") - if err != nil { - t.Fatalf("Failed to get agent: %v", err) - } - if retrieved.Name != agent.Name { - t.Errorf("Expected name %s, got %s", agent.Name, retrieved.Name) - } - if retrieved.URL != agent.URL { - t.Errorf("Expected URL %s, got %s", agent.URL, retrieved.URL) - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, cfg.CreateEntry(agent)) + + retrieved, err := cfg.ReadEntry("test-agent") + require.NoError(t, err) + require.Equal(t, agent.Name, retrieved.Name) + require.Equal(t, agent.URL, retrieved.URL) } -func TestGetAgent_Nonexistent(t *testing.T) { +func TestReadEntry_Agent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - if _, err := config.GetAgent(agentsPath, "nonexistent"); err == nil { - t.Fatal("Expected error when getting nonexistent agent, got nil") - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + _, err = cfg.ReadEntry("nonexistent") + require.Error(t, err) } func TestGetAgentURLs(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - agents := []config.AgentEntry{ + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + for _, agent := range []config.AgentEntry{ {Name: "agent1", URL: "https://agent1.example.com"}, {Name: "agent2", URL: "https://agent2.example.com"}, - } - for _, agent := range agents { - require.NoError(t, config.AddAgent(agentsPath, agent)) + } { + require.NoError(t, cfg.CreateEntry(agent)) } urls, err := config.GetAgentURLs(agentsPath) - if err != nil { - t.Fatalf("Failed to get agent URLs: %v", err) - } - if len(urls) != 2 { - t.Fatalf("Expected 2 URLs, got %d", len(urls)) - } + require.NoError(t, err) + require.Len(t, urls, 2) expectedURLs := map[string]bool{ "https://agent1.example.com": false, "https://agent2.example.com": false, } for _, url := range urls { - if _, exists := expectedURLs[url]; !exists { - t.Errorf("Unexpected URL: %s", url) - } + _, exists := expectedURLs[url] + require.True(t, exists, "unexpected URL: %s", url) expectedURLs[url] = true } for url, found := range expectedURLs { - if !found { - t.Errorf("Expected URL not found: %s", url) - } + require.True(t, found, "expected URL not found: %s", url) } } @@ -196,15 +164,11 @@ func TestLoadAgents_NonexistentFile(t *testing.T) { agentsPath := filepath.Join(tmpDir, "nonexistent.yaml") cfg, err := config.LoadAgents(agentsPath) - if err != nil { - t.Fatalf("Expected no error for nonexistent file, got: %v", err) - } - if len(cfg.Agents) != 0 { - t.Errorf("Expected empty agents list, got %d agents", len(cfg.Agents)) - } + require.NoError(t, err) + require.Empty(t, cfg.Agents) } -func TestUpdateAgent(t *testing.T) { +func TestUpdateEntry_Agent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") @@ -218,9 +182,11 @@ func TestUpdateAgent(t *testing.T) { "API_KEY": "secret", }, } - require.NoError(t, config.AddAgent(agentsPath, agent)) + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, cfg.CreateEntry(agent)) - updatedAgent := config.AgentEntry{ + updated := config.AgentEntry{ Name: "test-agent", URL: "https://new-agent.example.com", OCI: "ghcr.io/org/test-agent:v2", @@ -230,42 +196,30 @@ func TestUpdateAgent(t *testing.T) { "DEBUG": "true", }, } - require.NoError(t, config.UpdateAgent(agentsPath, updatedAgent)) + require.NoError(t, cfg.UpdateEntry(updated)) - retrieved, err := config.GetAgent(agentsPath, "test-agent") + reloaded, err := config.LoadAgents(agentsPath) require.NoError(t, err) - - if retrieved.URL != updatedAgent.URL { - t.Errorf("Expected URL %s, got %s", updatedAgent.URL, retrieved.URL) - } - if retrieved.OCI != updatedAgent.OCI { - t.Errorf("Expected OCI %s, got %s", updatedAgent.OCI, retrieved.OCI) - } - if retrieved.Run != updatedAgent.Run { - t.Errorf("Expected Run %v, got %v", updatedAgent.Run, retrieved.Run) - } - if retrieved.Model != updatedAgent.Model { - t.Errorf("Expected Model %s, got %s", updatedAgent.Model, retrieved.Model) - } - if len(retrieved.Environment) != 1 || retrieved.Environment["DEBUG"] != "true" { - t.Errorf("Expected Environment to be updated, got %v", retrieved.Environment) - } - - agents, err := config.ListAgents(agentsPath) + retrieved, err := reloaded.ReadEntry("test-agent") require.NoError(t, err) - if len(agents) != 1 { - t.Errorf("Expected 1 agent after update, got %d", len(agents)) - } + + require.Equal(t, updated.URL, retrieved.URL) + require.Equal(t, updated.OCI, retrieved.OCI) + require.Equal(t, updated.Run, retrieved.Run) + require.Equal(t, updated.Model, retrieved.Model) + require.Len(t, retrieved.Environment, 1) + require.Equal(t, "true", retrieved.Environment["DEBUG"]) + require.Len(t, reloaded.ListEntries(), 1) } -func TestUpdateAgent_Nonexistent(t *testing.T) { +func TestUpdateEntry_Agent_Nonexistent(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") agent := config.AgentEntry{Name: "nonexistent", URL: "https://agent.example.com"} - if err := config.UpdateAgent(agentsPath, agent); err == nil { - t.Fatal("Expected error when updating nonexistent agent, got nil") - } + cfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.Error(t, cfg.UpdateEntry(agent)) } func TestLoadAgents_EnvironmentVariableExpansion(t *testing.T) { diff --git a/config/channels.go b/config/channels.go index fd7d3ada..9d70ac64 100644 --- a/config/channels.go +++ b/config/channels.go @@ -1,12 +1,7 @@ package config import ( - "bytes" - "fmt" - "os" - "path/filepath" - - yaml "gopkg.in/yaml.v3" + utils "github.com/inference-gateway/cli/config/utils" ) const ( @@ -46,23 +41,7 @@ func DefaultChannelsConfig() *ChannelsConfig { // os.ExpandEnv so `${BOT_TOKEN}`-style references resolve from the // environment. func LoadChannels(path string) (*ChannelsConfig, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return DefaultChannelsConfig(), nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read channels config: %w", err) - } - - expandedData := os.ExpandEnv(string(data)) - - var cfg ChannelsConfig - if err := yaml.Unmarshal([]byte(expandedData), &cfg); err != nil { - return nil, fmt.Errorf("failed to parse channels config: %w", err) - } - - return &cfg, nil + return utils.LoadYAML(path, "channels", DefaultChannelsConfig) } // SaveChannels writes the channels configuration to disk, creating any @@ -70,27 +49,5 @@ func LoadChannels(path string) (*ChannelsConfig, error) { // so callers should ensure it is also listed in // tools.sandbox.protected_paths. func SaveChannels(path string, cfg *ChannelsConfig) error { - var buf bytes.Buffer - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal channels config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write channels config: %w", err) - } - - return nil + return utils.SaveYAML(path, "channels", cfg) } diff --git a/config/collection.go b/config/collection.go new file mode 100644 index 00000000..90c7beea --- /dev/null +++ b/config/collection.go @@ -0,0 +1,73 @@ +package config + +import "fmt" + +// CollectionConfig is implemented by file-backed configs whose payload is a +// named, mutable collection of entries. MCPConfig (collection of +// MCPServerEntry) and AgentsConfig (collection of AgentEntry) implement it. +// +// Methods mutate the file on disk; the in-memory receiver stays in sync. +// Implementations must remember the file path they were loaded from so +// each call can persist back to the same location. +type CollectionConfig[E any] interface { + CreateEntry(entry E) error + ReadEntry(name string) (*E, error) + UpdateEntry(entry E) error + DeleteEntry(name string) error + ListEntries() []E +} + +// appendEntry returns slice with entry appended, after rejecting duplicates +// by name. kind is used in the duplicate-error message (e.g. "MCP server"). +func appendEntry[E any](slice []E, entry E, name string, nameOf func(E) string, kind string) ([]E, error) { + for _, existing := range slice { + if nameOf(existing) == name { + return nil, fmt.Errorf("%s with name '%s' already exists", kind, name) + } + } + return append(slice, entry), nil +} + +// replaceEntry returns slice with the entry whose name matches replaced +// by the new entry. Returns an error if no entry with that name exists. +func replaceEntry[E any](slice []E, entry E, name string, nameOf func(E) string, kind string) ([]E, error) { + for i, existing := range slice { + if nameOf(existing) == name { + next := make([]E, len(slice)) + copy(next, slice) + next[i] = entry + return next, nil + } + } + return nil, fmt.Errorf("%s with name '%s' not found", kind, name) +} + +// removeEntry returns slice with the named entry removed. Returns an error +// if no entry with that name exists. +func removeEntry[E any](slice []E, name string, nameOf func(E) string, kind string) ([]E, error) { + out := make([]E, 0, len(slice)) + found := false + for _, existing := range slice { + if nameOf(existing) == name { + found = true + continue + } + out = append(out, existing) + } + if !found { + return nil, fmt.Errorf("%s with name '%s' not found", kind, name) + } + return out, nil +} + +// findEntry returns a copy of the entry whose name matches, or an error if +// no entry with that name exists. +func findEntry[E any](slice []E, name string, nameOf func(E) string, kind string) (*E, error) { + for _, existing := range slice { + if nameOf(existing) == name { + e := existing + return &e, nil + } + } + return nil, fmt.Errorf("%s with name '%s' not found", kind, name) +} diff --git a/config/collection_test.go b/config/collection_test.go new file mode 100644 index 00000000..d7f38957 --- /dev/null +++ b/config/collection_test.go @@ -0,0 +1,91 @@ +package config + +import ( + "strings" + "testing" +) + +type collectionFixture struct{ Name, Value string } + +func nameOfFixture(e collectionFixture) string { return e.Name } + +const fixtureKind = "fixture" + +func TestAppendEntry_AppendsNew(t *testing.T) { + in := []collectionFixture{{Name: "a"}} + out, err := appendEntry(in, collectionFixture{Name: "b"}, "b", nameOfFixture, fixtureKind) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if len(out) != 2 || out[1].Name != "b" { + t.Errorf("got %+v", out) + } +} + +func TestAppendEntry_RejectsDuplicate(t *testing.T) { + in := []collectionFixture{{Name: "a"}} + _, err := appendEntry(in, collectionFixture{Name: "a"}, "a", nameOfFixture, fixtureKind) + if err == nil { + t.Fatal("expected duplicate error") + } + if !strings.Contains(err.Error(), "already exists") || !strings.Contains(err.Error(), fixtureKind) { + t.Errorf("error should mention duplicate and kind: %v", err) + } +} + +func TestReplaceEntry_ReplacesExisting(t *testing.T) { + in := []collectionFixture{{Name: "a", Value: "1"}, {Name: "b", Value: "2"}} + out, err := replaceEntry(in, collectionFixture{Name: "b", Value: "X"}, "b", nameOfFixture, fixtureKind) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if out[1].Value != "X" { + t.Errorf("expected replacement, got %+v", out) + } + if in[1].Value != "2" { + t.Errorf("input slice mutated: got %+v", in) + } +} + +func TestReplaceEntry_NotFound(t *testing.T) { + _, err := replaceEntry([]collectionFixture{{Name: "a"}}, collectionFixture{Name: "b"}, "b", nameOfFixture, fixtureKind) + if err == nil { + t.Fatal("expected not-found error") + } +} + +func TestRemoveEntry_RemovesExisting(t *testing.T) { + in := []collectionFixture{{Name: "a"}, {Name: "b"}, {Name: "c"}} + out, err := removeEntry(in, "b", nameOfFixture, fixtureKind) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if len(out) != 2 || out[0].Name != "a" || out[1].Name != "c" { + t.Errorf("got %+v", out) + } +} + +func TestRemoveEntry_NotFound(t *testing.T) { + _, err := removeEntry([]collectionFixture{{Name: "a"}}, "z", nameOfFixture, fixtureKind) + if err == nil { + t.Fatal("expected not-found error") + } +} + +func TestFindEntry_FindsExisting(t *testing.T) { + in := []collectionFixture{{Name: "a", Value: "1"}, {Name: "b", Value: "2"}} + got, err := findEntry(in, "b", nameOfFixture, fixtureKind) + if err != nil { + t.Fatalf("unexpected: %v", err) + } + if got.Value != "2" { + t.Errorf("got %+v", got) + } +} + +func TestFindEntry_NotFound(t *testing.T) { + _, err := findEntry([]collectionFixture{{Name: "a"}}, "z", nameOfFixture, fixtureKind) + if err == nil { + t.Fatal("expected not-found error") + } +} diff --git a/config/keybindings.go b/config/keybindings.go index 509e7287..cd450382 100644 --- a/config/keybindings.go +++ b/config/keybindings.go @@ -1,12 +1,7 @@ package config import ( - "bytes" - "fmt" - "os" - "path/filepath" - - yaml "gopkg.in/yaml.v3" + utils "github.com/inference-gateway/cli/config/utils" ) const ( @@ -25,51 +20,17 @@ func DefaultKeybindingsConfig() *KeybindingsConfig { // LoadKeybindings reads keybindings.yaml from disk. When the file is // missing it returns the in-code defaults so callers can treat absence -// as "use defaults" without special-casing. +// as "use defaults" without special-casing. The file body is run through +// os.ExpandEnv — any literal `${…}` token in a customised binding must be +// escaped as `$$…`. func LoadKeybindings(path string) (*KeybindingsConfig, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return DefaultKeybindingsConfig(), nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read keybindings config: %w", err) - } - - var cfg KeybindingsConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse keybindings config: %w", err) - } - - return &cfg, nil + return utils.LoadYAML(path, "keybindings", DefaultKeybindingsConfig) } // SaveKeybindings writes the keybindings configuration to disk, creating // any missing parent directories. func SaveKeybindings(path string, cfg *KeybindingsConfig) error { - var buf bytes.Buffer - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal keybindings config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write keybindings config: %w", err) - } - - return nil + return utils.SaveYAML(path, "keybindings", cfg) } // GetDefaultKeybindings returns the default keybinding configuration diff --git a/config/keybindings_persistence_test.go b/config/keybindings_test.go similarity index 100% rename from config/keybindings_persistence_test.go rename to config/keybindings_test.go diff --git a/config/mcp.go b/config/mcp.go index c7699cee..2eabbc57 100644 --- a/config/mcp.go +++ b/config/mcp.go @@ -1,13 +1,10 @@ package config import ( - "bytes" "fmt" - "os" - "path/filepath" "strings" - yaml "gopkg.in/yaml.v3" + utils "github.com/inference-gateway/cli/config/utils" ) // MCPConfig represents the mcp.yaml configuration file @@ -19,8 +16,12 @@ type MCPConfig struct { LivenessProbeInterval int `yaml:"liveness_probe_interval,omitempty" mapstructure:"liveness_probe_interval,omitempty"` MaxRetries int `yaml:"max_retries,omitempty" mapstructure:"max_retries,omitempty"` Servers []MCPServerEntry `yaml:"servers" mapstructure:"servers"` + + path string } +var _ CollectionConfig[MCPServerEntry] = (*MCPConfig)(nil) + // MCPServerEntry represents a single MCP server configuration type MCPServerEntry struct { Name string `yaml:"name" mapstructure:"name"` @@ -162,136 +163,61 @@ const ( // the in-code defaults so callers can treat absence as "use defaults" // without special-casing. func LoadMCP(path string) (*MCPConfig, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return DefaultMCPConfig(), nil - } - - data, err := os.ReadFile(path) + cfg, err := utils.LoadYAML(path, "MCP", DefaultMCPConfig) if err != nil { - return nil, fmt.Errorf("failed to read MCP config: %w", err) - } - - expandedData := os.ExpandEnv(string(data)) - - var cfg MCPConfig - if err := yaml.Unmarshal([]byte(expandedData), &cfg); err != nil { - return nil, fmt.Errorf("failed to parse MCP config: %w", err) + return nil, err } - - return &cfg, nil + cfg.path = path + return cfg, nil } // SaveMCP writes the MCP configuration to disk, creating any missing // parent directories. func SaveMCP(path string, cfg *MCPConfig) error { - var buf bytes.Buffer - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal MCP config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write MCP config: %w", err) - } - - return nil + return utils.SaveYAML(path, "MCP", cfg) } -// AddMCPServer adds a new MCP server entry to the configuration file. -func AddMCPServer(path string, server MCPServerEntry) error { - cfg, err := LoadMCP(path) - if err != nil { - return err - } - - for _, existing := range cfg.Servers { - if existing.Name == server.Name { - return fmt.Errorf("MCP server with name '%s' already exists", server.Name) - } - } +func mcpServerName(e MCPServerEntry) string { return e.Name } - cfg.Servers = append(cfg.Servers, server) - return SaveMCP(path, cfg) -} +const mcpServerKind = "MCP server" -// UpdateMCPServer updates an existing MCP server entry by name. -func UpdateMCPServer(path string, server MCPServerEntry) error { - cfg, err := LoadMCP(path) +// CreateEntry implements CollectionConfig. +func (c *MCPConfig) CreateEntry(entry MCPServerEntry) error { + next, err := appendEntry(c.Servers, entry, entry.Name, mcpServerName, mcpServerKind) if err != nil { return err } + c.Servers = next + return SaveMCP(c.path, c) +} - for i, existing := range cfg.Servers { - if existing.Name == server.Name { - cfg.Servers[i] = server - return SaveMCP(path, cfg) - } - } - - return fmt.Errorf("MCP server with name '%s' not found", server.Name) +// ReadEntry implements CollectionConfig. +func (c *MCPConfig) ReadEntry(name string) (*MCPServerEntry, error) { + return findEntry(c.Servers, name, mcpServerName, mcpServerKind) } -// RemoveMCPServer removes an MCP server entry by name. -func RemoveMCPServer(path, name string) error { - cfg, err := LoadMCP(path) +// UpdateEntry implements CollectionConfig. +func (c *MCPConfig) UpdateEntry(entry MCPServerEntry) error { + next, err := replaceEntry(c.Servers, entry, entry.Name, mcpServerName, mcpServerKind) if err != nil { return err } - - found := false - newServers := make([]MCPServerEntry, 0, len(cfg.Servers)) - for _, server := range cfg.Servers { - if server.Name != name { - newServers = append(newServers, server) - } else { - found = true - } - } - - if !found { - return fmt.Errorf("MCP server with name '%s' not found", name) - } - - cfg.Servers = newServers - return SaveMCP(path, cfg) + c.Servers = next + return SaveMCP(c.path, c) } -// ListMCPServers returns all configured MCP servers. -func ListMCPServers(path string) ([]MCPServerEntry, error) { - cfg, err := LoadMCP(path) +// DeleteEntry implements CollectionConfig. +func (c *MCPConfig) DeleteEntry(name string) error { + next, err := removeEntry(c.Servers, name, mcpServerName, mcpServerKind) if err != nil { - return nil, err + return err } - return cfg.Servers, nil + c.Servers = next + return SaveMCP(c.path, c) } -// GetMCPServer returns a single MCP server entry by name. -func GetMCPServer(path, name string) (*MCPServerEntry, error) { - cfg, err := LoadMCP(path) - if err != nil { - return nil, err - } - - for _, server := range cfg.Servers { - if server.Name == name { - return &server, nil - } - } - - return nil, fmt.Errorf("MCP server with name '%s' not found", name) -} +// ListEntries implements CollectionConfig. +func (c *MCPConfig) ListEntries() []MCPServerEntry { return c.Servers } // MergeMCP merges an optional mcp.yaml config on top of a base config. // Optional values take precedence; servers from both are combined and diff --git a/config/mcp_persistence_test.go b/config/mcp_persistence_test.go deleted file mode 100644 index 2e8634ab..00000000 --- a/config/mcp_persistence_test.go +++ /dev/null @@ -1,398 +0,0 @@ -package config_test - -import ( - "os" - "path/filepath" - "testing" - - config "github.com/inference-gateway/cli/config" -) - -func TestLoadMCP_NonExistentFile(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "non-existent.yaml") - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() should not error for non-existent file, got: %v", err) - } - if cfg == nil { - t.Fatal("LoadMCP() returned nil config") - } - - defaultCfg := config.DefaultMCPConfig() - if cfg.Enabled != defaultCfg.Enabled { - t.Errorf("Expected Enabled=%v, got %v", defaultCfg.Enabled, cfg.Enabled) - } -} - -func TestLoadMCP_ValidYAML(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - yamlContent := `enabled: true -connection_timeout: 60 -discovery_timeout: 45 -servers: - - name: test-server - scheme: http - host: localhost - ports: - - "3000:8080" - path: /sse - enabled: true - timeout: 30 - description: Test MCP server - include_tools: - - tool1 - - tool2 - exclude_tools: - - tool3 -` - - if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() failed: %v", err) - } - - if !cfg.Enabled { - t.Error("Expected Enabled to be true") - } - if cfg.ConnectionTimeout != 60 { - t.Errorf("Expected ConnectionTimeout=60, got %d", cfg.ConnectionTimeout) - } - if cfg.DiscoveryTimeout != 45 { - t.Errorf("Expected DiscoveryTimeout=45, got %d", cfg.DiscoveryTimeout) - } - if len(cfg.Servers) != 1 { - t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) - } - - server := cfg.Servers[0] - if server.Name != "test-server" { - t.Errorf("Expected server name 'test-server', got %q", server.Name) - } - - expectedURL := "http://localhost:3000/sse" - if server.GetURL() != expectedURL { - t.Errorf("Expected server URL %q, got %q", expectedURL, server.GetURL()) - } - if !server.Enabled { - t.Error("Expected server to be enabled") - } - if server.Timeout != 30 { - t.Errorf("Expected server timeout=30, got %d", server.Timeout) - } - if len(server.IncludeTools) != 2 { - t.Errorf("Expected 2 include tools, got %d", len(server.IncludeTools)) - } - if len(server.ExcludeTools) != 1 { - t.Errorf("Expected 1 exclude tool, got %d", len(server.ExcludeTools)) - } -} - -func TestLoadMCP_EnvironmentVariableExpansion(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - t.Setenv("TEST_MCP_URL", "http://env-server:8080/sse") - - yamlContent := `enabled: true -servers: - - name: env-server - scheme: http - host: env-server - ports: - - "8080:8080" - path: /sse - enabled: true -` - - if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() failed: %v", err) - } - if len(cfg.Servers) != 1 { - t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) - } - - server := cfg.Servers[0] - expectedURL := "http://env-server:8080/sse" - if server.GetURL() != expectedURL { - t.Errorf("Expected URL %q, got %q", expectedURL, server.GetURL()) - } -} - -func TestSaveMCP(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "subdir", "mcp.yaml") - - cfg := &config.MCPConfig{ - Enabled: true, - ConnectionTimeout: 60, - DiscoveryTimeout: 45, - Servers: []config.MCPServerEntry{ - { - Name: "test-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - Timeout: 30, - Description: "Test server", - }, - }, - } - - if err := config.SaveMCP(configPath, cfg); err != nil { - t.Fatalf("SaveMCP() failed: %v", err) - } - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Fatal("Config file was not created") - } - - loadedCfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() after SaveMCP() failed: %v", err) - } - if loadedCfg.Enabled != cfg.Enabled { - t.Errorf("Expected Enabled=%v, got %v", cfg.Enabled, loadedCfg.Enabled) - } - if len(loadedCfg.Servers) != len(cfg.Servers) { - t.Errorf("Expected %d servers, got %d", len(cfg.Servers), len(loadedCfg.Servers)) - } -} - -func TestAddMCPServer(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - newServer := config.MCPServerEntry{ - Name: "new-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"4000:8080"}, - Path: "/sse", - Enabled: true, - Description: "New server", - } - - if err := config.AddMCPServer(configPath, newServer); err != nil { - t.Fatalf("AddMCPServer() failed: %v", err) - } - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() after AddMCPServer() failed: %v", err) - } - if len(cfg.Servers) != 1 { - t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) - } - if cfg.Servers[0].Name != newServer.Name { - t.Errorf("Expected server name %q, got %q", newServer.Name, cfg.Servers[0].Name) - } -} - -func TestAddMCPServer_DuplicateName(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - server := config.MCPServerEntry{ - Name: "duplicate-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - } - - if err := config.AddMCPServer(configPath, server); err != nil { - t.Fatalf("First AddMCPServer() failed: %v", err) - } - if err := config.AddMCPServer(configPath, server); err == nil { - t.Fatal("Expected error when adding duplicate server, got nil") - } -} - -func TestUpdateMCPServer(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - initialServer := config.MCPServerEntry{ - Name: "update-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - } - if err := config.AddMCPServer(configPath, initialServer); err != nil { - t.Fatalf("AddMCPServer() failed: %v", err) - } - - updatedServer := config.MCPServerEntry{ - Name: "update-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"5000:8080"}, - Path: "/sse", - Enabled: false, - Description: "Updated description", - } - if err := config.UpdateMCPServer(configPath, updatedServer); err != nil { - t.Fatalf("UpdateMCPServer() failed: %v", err) - } - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() after UpdateMCPServer() failed: %v", err) - } - if len(cfg.Servers) != 1 { - t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) - } - - server := cfg.Servers[0] - if server.GetURL() != updatedServer.GetURL() { - t.Errorf("Expected URL %q, got %q", updatedServer.GetURL(), server.GetURL()) - } - if server.Enabled != updatedServer.Enabled { - t.Errorf("Expected Enabled=%v, got %v", updatedServer.Enabled, server.Enabled) - } - if server.Description != updatedServer.Description { - t.Errorf("Expected Description %q, got %q", updatedServer.Description, server.Description) - } -} - -func TestUpdateMCPServer_NotFound(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - server := config.MCPServerEntry{ - Name: "nonexistent-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - } - if err := config.UpdateMCPServer(configPath, server); err == nil { - t.Fatal("Expected error when updating non-existent server, got nil") - } -} - -func TestRemoveMCPServer(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} - server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: true} - - if err := config.AddMCPServer(configPath, server1); err != nil { - t.Fatalf("Failed to add server1: %v", err) - } - if err := config.AddMCPServer(configPath, server2); err != nil { - t.Fatalf("Failed to add server2: %v", err) - } - - if err := config.RemoveMCPServer(configPath, "server1"); err != nil { - t.Fatalf("RemoveMCPServer() failed: %v", err) - } - - cfg, err := config.LoadMCP(configPath) - if err != nil { - t.Fatalf("LoadMCP() after RemoveMCPServer() failed: %v", err) - } - if len(cfg.Servers) != 1 { - t.Fatalf("Expected 1 server after removal, got %d", len(cfg.Servers)) - } - if cfg.Servers[0].Name != "server2" { - t.Errorf("Expected remaining server to be 'server2', got %q", cfg.Servers[0].Name) - } -} - -func TestRemoveMCPServer_NotFound(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - if err := config.RemoveMCPServer(configPath, "nonexistent-server"); err == nil { - t.Fatal("Expected error when removing non-existent server, got nil") - } -} - -func TestListMCPServers(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - servers, err := config.ListMCPServers(configPath) - if err != nil { - t.Fatalf("ListMCPServers() failed: %v", err) - } - if len(servers) != 0 { - t.Errorf("Expected 0 servers initially, got %d", len(servers)) - } - - server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} - server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: false} - - if err := config.AddMCPServer(configPath, server1); err != nil { - t.Fatalf("Failed to add server1: %v", err) - } - if err := config.AddMCPServer(configPath, server2); err != nil { - t.Fatalf("Failed to add server2: %v", err) - } - - servers, err = config.ListMCPServers(configPath) - if err != nil { - t.Fatalf("ListMCPServers() failed: %v", err) - } - if len(servers) != 2 { - t.Fatalf("Expected 2 servers, got %d", len(servers)) - } -} - -func TestGetMCPServer(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - expectedServer := config.MCPServerEntry{ - Name: "get-server", - Scheme: "http", - Host: "localhost", - Ports: []string{"3000:8080"}, - Path: "/sse", - Enabled: true, - Description: "Test server", - } - if err := config.AddMCPServer(configPath, expectedServer); err != nil { - t.Fatalf("Failed to add server: %v", err) - } - - server, err := config.GetMCPServer(configPath, "get-server") - if err != nil { - t.Fatalf("GetMCPServer() failed: %v", err) - } - if server.Name != expectedServer.Name { - t.Errorf("Expected name %q, got %q", expectedServer.Name, server.Name) - } - if server.GetURL() != expectedServer.GetURL() { - t.Errorf("Expected URL %q, got %q", expectedServer.GetURL(), server.GetURL()) - } -} - -func TestGetMCPServer_NotFound(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "mcp.yaml") - - if _, err := config.GetMCPServer(configPath, "nonexistent-server"); err == nil { - t.Fatal("Expected error when getting non-existent server, got nil") - } -} diff --git a/config/mcp_test.go b/config/mcp_test.go index 3efe4554..8a6b1650 100644 --- a/config/mcp_test.go +++ b/config/mcp_test.go @@ -1,11 +1,26 @@ -package config +package config_test import ( + "os" + "path/filepath" "testing" + + config "github.com/inference-gateway/cli/config" ) +func TestMCPConstants(t *testing.T) { + if config.MCPFileName != "mcp.yaml" { + t.Errorf("Expected MCPFileName to be 'mcp.yaml', got %q", config.MCPFileName) + } + + expectedPath := config.ConfigDirName + "/" + config.MCPFileName + if config.DefaultMCPPath != expectedPath { + t.Errorf("Expected DefaultMCPPath to be %q, got %q", expectedPath, config.DefaultMCPPath) + } +} + func TestDefaultMCPConfig(t *testing.T) { - cfg := DefaultMCPConfig() + cfg := config.DefaultMCPConfig() if cfg == nil { t.Fatal("DefaultMCPConfig() returned nil") @@ -35,13 +50,13 @@ func TestDefaultMCPConfig(t *testing.T) { func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { tests := []struct { name string - entry MCPServerEntry + entry config.MCPServerEntry toolName string shouldInclude bool }{ { name: "no include or exclude lists - include all", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", }, toolName: "any_tool", @@ -49,7 +64,7 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "include list only - tool in list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", IncludeTools: []string{"tool1", "tool2", "tool3"}, }, @@ -58,7 +73,7 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "include list only - tool not in list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", IncludeTools: []string{"tool1", "tool2", "tool3"}, }, @@ -67,7 +82,7 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "exclude list only - tool in list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", ExcludeTools: []string{"dangerous_tool", "risky_tool"}, }, @@ -76,7 +91,7 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "exclude list only - tool not in list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", ExcludeTools: []string{"dangerous_tool", "risky_tool"}, }, @@ -85,17 +100,17 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "both lists - tool in include list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", IncludeTools: []string{"tool1", "tool2"}, - ExcludeTools: []string{"tool2", "tool3"}, // tool2 in both - include takes precedence + ExcludeTools: []string{"tool2", "tool3"}, }, toolName: "tool2", shouldInclude: true, }, { name: "both lists - tool not in include list", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", IncludeTools: []string{"tool1", "tool2"}, ExcludeTools: []string{"tool3"}, @@ -105,7 +120,7 @@ func TestMCPServerEntry_ShouldIncludeTool(t *testing.T) { }, { name: "empty lists - include all", - entry: MCPServerEntry{ + entry: config.MCPServerEntry{ Name: "test-server", IncludeTools: []string{}, ExcludeTools: []string{}, @@ -166,7 +181,7 @@ func TestMCPServerEntry_GetTimeout(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - entry := MCPServerEntry{ + entry := config.MCPServerEntry{ Timeout: tt.serverTimeout, } result := entry.GetTimeout(tt.globalTimeout) @@ -177,13 +192,393 @@ func TestMCPServerEntry_GetTimeout(t *testing.T) { } } -func TestMCPConstants(t *testing.T) { - if MCPFileName != "mcp.yaml" { - t.Errorf("Expected MCPFileName to be 'mcp.yaml', got %q", MCPFileName) +func TestLoadMCP_NonExistentFile(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "non-existent.yaml") + + cfg, err := config.LoadMCP(configPath) + if err != nil { + t.Fatalf("LoadMCP() should not error for non-existent file, got: %v", err) + } + if cfg == nil { + t.Fatal("LoadMCP() returned nil config") + } + + defaultCfg := config.DefaultMCPConfig() + if cfg.Enabled != defaultCfg.Enabled { + t.Errorf("Expected Enabled=%v, got %v", defaultCfg.Enabled, cfg.Enabled) + } +} + +func TestLoadMCP_ValidYAML(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + yamlContent := `enabled: true +connection_timeout: 60 +discovery_timeout: 45 +servers: + - name: test-server + scheme: http + host: localhost + ports: + - "3000:8080" + path: /sse + enabled: true + timeout: 30 + description: Test MCP server + include_tools: + - tool1 + - tool2 + exclude_tools: + - tool3 +` + + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := config.LoadMCP(configPath) + if err != nil { + t.Fatalf("LoadMCP() failed: %v", err) + } + + if !cfg.Enabled { + t.Error("Expected Enabled to be true") + } + if cfg.ConnectionTimeout != 60 { + t.Errorf("Expected ConnectionTimeout=60, got %d", cfg.ConnectionTimeout) + } + if cfg.DiscoveryTimeout != 45 { + t.Errorf("Expected DiscoveryTimeout=45, got %d", cfg.DiscoveryTimeout) + } + if len(cfg.Servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) + } + + server := cfg.Servers[0] + if server.Name != "test-server" { + t.Errorf("Expected server name 'test-server', got %q", server.Name) + } + + expectedURL := "http://localhost:3000/sse" + if server.GetURL() != expectedURL { + t.Errorf("Expected server URL %q, got %q", expectedURL, server.GetURL()) + } + if !server.Enabled { + t.Error("Expected server to be enabled") + } + if server.Timeout != 30 { + t.Errorf("Expected server timeout=30, got %d", server.Timeout) + } + if len(server.IncludeTools) != 2 { + t.Errorf("Expected 2 include tools, got %d", len(server.IncludeTools)) + } + if len(server.ExcludeTools) != 1 { + t.Errorf("Expected 1 exclude tool, got %d", len(server.ExcludeTools)) + } +} + +func TestLoadMCP_EnvironmentVariableExpansion(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + t.Setenv("TEST_MCP_URL", "http://env-server:8080/sse") + + yamlContent := `enabled: true +servers: + - name: env-server + scheme: http + host: env-server + ports: + - "8080:8080" + path: /sse + enabled: true +` + + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := config.LoadMCP(configPath) + if err != nil { + t.Fatalf("LoadMCP() failed: %v", err) + } + if len(cfg.Servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(cfg.Servers)) + } + + server := cfg.Servers[0] + expectedURL := "http://env-server:8080/sse" + if server.GetURL() != expectedURL { + t.Errorf("Expected URL %q, got %q", expectedURL, server.GetURL()) + } +} + +func TestSaveMCP(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "subdir", "mcp.yaml") + + cfg := &config.MCPConfig{ + Enabled: true, + ConnectionTimeout: 60, + DiscoveryTimeout: 45, + Servers: []config.MCPServerEntry{ + { + Name: "test-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"3000:8080"}, + Path: "/sse", + Enabled: true, + Timeout: 30, + Description: "Test server", + }, + }, + } + + if err := config.SaveMCP(configPath, cfg); err != nil { + t.Fatalf("SaveMCP() failed: %v", err) + } + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Fatal("Config file was not created") + } + + loadedCfg, err := config.LoadMCP(configPath) + if err != nil { + t.Fatalf("LoadMCP() after SaveMCP() failed: %v", err) + } + if loadedCfg.Enabled != cfg.Enabled { + t.Errorf("Expected Enabled=%v, got %v", cfg.Enabled, loadedCfg.Enabled) + } + if len(loadedCfg.Servers) != len(cfg.Servers) { + t.Errorf("Expected %d servers, got %d", len(cfg.Servers), len(loadedCfg.Servers)) + } +} + +func loadMCPForTest(t *testing.T, path string) *config.MCPConfig { + t.Helper() + cfg, err := config.LoadMCP(path) + if err != nil { + t.Fatalf("LoadMCP() failed: %v", err) + } + return cfg +} + +func TestCreateEntry_MCPServer(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + newServer := config.MCPServerEntry{ + Name: "new-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"4000:8080"}, + Path: "/sse", + Enabled: true, + Description: "New server", + } + + cfg := loadMCPForTest(t, configPath) + if err := cfg.CreateEntry(newServer); err != nil { + t.Fatalf("CreateEntry() failed: %v", err) } - expectedPath := ConfigDirName + "/" + MCPFileName - if DefaultMCPPath != expectedPath { - t.Errorf("Expected DefaultMCPPath to be %q, got %q", expectedPath, DefaultMCPPath) + reloaded := loadMCPForTest(t, configPath) + if len(reloaded.Servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(reloaded.Servers)) + } + if reloaded.Servers[0].Name != newServer.Name { + t.Errorf("Expected server name %q, got %q", newServer.Name, reloaded.Servers[0].Name) + } +} + +func TestCreateEntry_MCPServer_DuplicateName(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + server := config.MCPServerEntry{ + Name: "duplicate-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"3000:8080"}, + Path: "/sse", + Enabled: true, + } + + cfg := loadMCPForTest(t, configPath) + if err := cfg.CreateEntry(server); err != nil { + t.Fatalf("First CreateEntry() failed: %v", err) + } + if err := cfg.CreateEntry(server); err == nil { + t.Fatal("Expected error when adding duplicate server, got nil") + } +} + +func TestUpdateEntry_MCPServer(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + initialServer := config.MCPServerEntry{ + Name: "update-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"3000:8080"}, + Path: "/sse", + Enabled: true, + } + cfg := loadMCPForTest(t, configPath) + if err := cfg.CreateEntry(initialServer); err != nil { + t.Fatalf("CreateEntry() failed: %v", err) + } + + updatedServer := config.MCPServerEntry{ + Name: "update-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"5000:8080"}, + Path: "/sse", + Enabled: false, + Description: "Updated description", + } + if err := cfg.UpdateEntry(updatedServer); err != nil { + t.Fatalf("UpdateEntry() failed: %v", err) + } + + reloaded := loadMCPForTest(t, configPath) + if len(reloaded.Servers) != 1 { + t.Fatalf("Expected 1 server, got %d", len(reloaded.Servers)) + } + + server := reloaded.Servers[0] + if server.GetURL() != updatedServer.GetURL() { + t.Errorf("Expected URL %q, got %q", updatedServer.GetURL(), server.GetURL()) + } + if server.Enabled != updatedServer.Enabled { + t.Errorf("Expected Enabled=%v, got %v", updatedServer.Enabled, server.Enabled) + } + if server.Description != updatedServer.Description { + t.Errorf("Expected Description %q, got %q", updatedServer.Description, server.Description) + } +} + +func TestUpdateEntry_MCPServer_NotFound(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + server := config.MCPServerEntry{ + Name: "nonexistent-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"3000:8080"}, + Path: "/sse", + Enabled: true, + } + cfg := loadMCPForTest(t, configPath) + if err := cfg.UpdateEntry(server); err == nil { + t.Fatal("Expected error when updating non-existent server, got nil") + } +} + +func TestDeleteEntry_MCPServer(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} + server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: true} + + cfg := loadMCPForTest(t, configPath) + if err := cfg.CreateEntry(server1); err != nil { + t.Fatalf("Failed to add server1: %v", err) + } + if err := cfg.CreateEntry(server2); err != nil { + t.Fatalf("Failed to add server2: %v", err) + } + + if err := cfg.DeleteEntry("server1"); err != nil { + t.Fatalf("DeleteEntry() failed: %v", err) + } + + reloaded := loadMCPForTest(t, configPath) + if len(reloaded.Servers) != 1 { + t.Fatalf("Expected 1 server after removal, got %d", len(reloaded.Servers)) + } + if reloaded.Servers[0].Name != "server2" { + t.Errorf("Expected remaining server to be 'server2', got %q", reloaded.Servers[0].Name) + } +} + +func TestDeleteEntry_MCPServer_NotFound(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + cfg := loadMCPForTest(t, configPath) + if err := cfg.DeleteEntry("nonexistent-server"); err == nil { + t.Fatal("Expected error when removing non-existent server, got nil") + } +} + +func TestListEntries_MCPServers(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + cfg := loadMCPForTest(t, configPath) + if servers := cfg.ListEntries(); len(servers) != 0 { + t.Errorf("Expected 0 servers initially, got %d", len(servers)) + } + + server1 := config.MCPServerEntry{Name: "server1", Scheme: "http", Host: "localhost", Ports: []string{"3000:8080"}, Path: "/sse", Enabled: true} + server2 := config.MCPServerEntry{Name: "server2", Scheme: "http", Host: "localhost", Ports: []string{"4000:8080"}, Path: "/sse", Enabled: false} + + if err := cfg.CreateEntry(server1); err != nil { + t.Fatalf("Failed to add server1: %v", err) + } + if err := cfg.CreateEntry(server2); err != nil { + t.Fatalf("Failed to add server2: %v", err) + } + + reloaded := loadMCPForTest(t, configPath) + if servers := reloaded.ListEntries(); len(servers) != 2 { + t.Fatalf("Expected 2 servers, got %d", len(servers)) + } +} + +func TestReadEntry_MCPServer(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + expectedServer := config.MCPServerEntry{ + Name: "get-server", + Scheme: "http", + Host: "localhost", + Ports: []string{"3000:8080"}, + Path: "/sse", + Enabled: true, + Description: "Test server", + } + cfg := loadMCPForTest(t, configPath) + if err := cfg.CreateEntry(expectedServer); err != nil { + t.Fatalf("Failed to add server: %v", err) + } + + server, err := cfg.ReadEntry("get-server") + if err != nil { + t.Fatalf("ReadEntry() failed: %v", err) + } + if server.Name != expectedServer.Name { + t.Errorf("Expected name %q, got %q", expectedServer.Name, server.Name) + } + if server.GetURL() != expectedServer.GetURL() { + t.Errorf("Expected URL %q, got %q", expectedServer.GetURL(), server.GetURL()) + } +} + +func TestReadEntry_MCPServer_NotFound(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "mcp.yaml") + + cfg := loadMCPForTest(t, configPath) + if _, err := cfg.ReadEntry("nonexistent-server"); err == nil { + t.Fatal("Expected error when getting non-existent server, got nil") } } diff --git a/config/prompts.go b/config/prompts.go index fa5e7a36..28cfffa4 100644 --- a/config/prompts.go +++ b/config/prompts.go @@ -1,12 +1,7 @@ package config import ( - "bytes" - "fmt" - "os" - "path/filepath" - - yaml "gopkg.in/yaml.v3" + utils "github.com/inference-gateway/cli/config/utils" ) const ( @@ -16,51 +11,17 @@ const ( // LoadPrompts reads prompts.yaml from disk. When the file is missing it // returns the in-code defaults so callers can treat absence as "use -// defaults" without special-casing. +// defaults" without special-casing. The file body is run through +// os.ExpandEnv — any literal `${…}` token in a customised prompt must be +// escaped as `$$…`. func LoadPrompts(path string) (*PromptsConfig, error) { - if _, err := os.Stat(path); os.IsNotExist(err) { - return DefaultPromptsConfig(), nil - } - - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("failed to read prompts config: %w", err) - } - - var cfg PromptsConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, fmt.Errorf("failed to parse prompts config: %w", err) - } - - return &cfg, nil + return utils.LoadYAML(path, "prompts", DefaultPromptsConfig) } // SavePrompts writes the prompts configuration to disk, creating any // missing parent directories. func SavePrompts(path string, cfg *PromptsConfig) error { - var buf bytes.Buffer - buf.WriteString("---\n") - - encoder := yaml.NewEncoder(&buf) - encoder.SetIndent(2) - - if err := encoder.Encode(cfg); err != nil { - return fmt.Errorf("failed to marshal prompts config: %w", err) - } - - if err := encoder.Close(); err != nil { - return fmt.Errorf("failed to close encoder: %w", err) - } - - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - return fmt.Errorf("failed to create config directory: %w", err) - } - - if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { - return fmt.Errorf("failed to write prompts config: %w", err) - } - - return nil + return utils.SaveYAML(path, "prompts", cfg) } // PromptsConfig holds every customisable LLM prompt the CLI ships with. diff --git a/config/prompts_persistence_test.go b/config/prompts_persistence_test.go deleted file mode 100644 index 860c4a91..00000000 --- a/config/prompts_persistence_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package config_test - -import ( - "os" - "path/filepath" - "strings" - "testing" - - config "github.com/inference-gateway/cli/config" -) - -func TestLoadPrompts_NonExistentFile(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "non-existent.yaml") - - cfg, err := config.LoadPrompts(configPath) - if err != nil { - t.Fatalf("LoadPrompts() should not error for non-existent file, got: %v", err) - } - if cfg == nil { - t.Fatal("LoadPrompts() returned nil config") - } - if cfg.Agent.SystemPrompt == "" { - t.Error("Default prompts config should populate agent.system_prompt") - } - if cfg.Agent.SystemPromptPlan == "" { - t.Error("Default prompts config should populate agent.system_prompt_plan") - } - if cfg.Git.CommitMessage.SystemPrompt == "" { - t.Error("Default prompts config should populate git.commit_message.system_prompt") - } -} - -func TestLoadPrompts_ValidYAML(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "prompts.yaml") - - yamlContent := `--- -agent: - system_prompt: custom agent prompt - system_prompt_plan: custom plan prompt -git: - commit_message: - system_prompt: custom commit prompt -init: - prompt: custom init prompt -` - - if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - cfg, err := config.LoadPrompts(configPath) - if err != nil { - t.Fatalf("LoadPrompts() failed: %v", err) - } - if cfg.Agent.SystemPrompt != "custom agent prompt" { - t.Errorf("Expected custom system_prompt, got %q", cfg.Agent.SystemPrompt) - } - if cfg.Agent.SystemPromptPlan != "custom plan prompt" { - t.Errorf("Expected custom plan prompt, got %q", cfg.Agent.SystemPromptPlan) - } - if cfg.Git.CommitMessage.SystemPrompt != "custom commit prompt" { - t.Errorf("Expected custom commit prompt, got %q", cfg.Git.CommitMessage.SystemPrompt) - } - if cfg.Init.Prompt != "custom init prompt" { - t.Errorf("Expected custom init prompt, got %q", cfg.Init.Prompt) - } -} - -func TestLoadPrompts_PartialYAML(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "prompts.yaml") - - yamlContent := `--- -agent: - system_prompt: only this field is set -` - - if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { - t.Fatalf("Failed to write test config file: %v", err) - } - - cfg, err := config.LoadPrompts(configPath) - if err != nil { - t.Fatalf("LoadPrompts() failed: %v", err) - } - if cfg.Agent.SystemPrompt != "only this field is set" { - t.Errorf("Expected partial value, got %q", cfg.Agent.SystemPrompt) - } - // Other prompt fields stay zero — the runtime overlay (cmd applies it - // on top of LoadPrompts) is what fills them back in from defaults. - if cfg.Agent.SystemPromptPlan != "" { - t.Errorf("Expected unset plan prompt to be empty, got %q", cfg.Agent.SystemPromptPlan) - } - if cfg.Git.CommitMessage.SystemPrompt != "" { - t.Errorf("Expected unset commit prompt to be empty, got %q", cfg.Git.CommitMessage.SystemPrompt) - } -} - -func TestSavePrompts_RoundTrip(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "prompts.yaml") - - original := &config.PromptsConfig{ - Agent: config.PromptsAgentConfig{ - SystemPrompt: "round trip system prompt", - SystemReminders: config.PromptsAgentRemindersConfig{ - ReminderText: "round trip reminder", - }, - }, - Git: config.PromptsGitConfig{ - CommitMessage: config.PromptsGitCommitMessageConfig{ - SystemPrompt: "round trip commit prompt", - }, - }, - } - - if err := config.SavePrompts(configPath, original); err != nil { - t.Fatalf("SavePrompts() failed: %v", err) - } - - loaded, err := config.LoadPrompts(configPath) - if err != nil { - t.Fatalf("LoadPrompts() after save failed: %v", err) - } - if loaded.Agent.SystemPrompt != original.Agent.SystemPrompt { - t.Errorf("agent.system_prompt not preserved, got %q", loaded.Agent.SystemPrompt) - } - if loaded.Agent.SystemReminders.ReminderText != original.Agent.SystemReminders.ReminderText { - t.Errorf("reminder_text not preserved, got %q", loaded.Agent.SystemReminders.ReminderText) - } - if loaded.Git.CommitMessage.SystemPrompt != original.Git.CommitMessage.SystemPrompt { - t.Errorf("git.commit_message.system_prompt not preserved, got %q", loaded.Git.CommitMessage.SystemPrompt) - } -} - -func TestSavePrompts_CreatesParentDirectory(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "nested", "deep", "prompts.yaml") - - if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { - t.Fatalf("SavePrompts() failed to create nested dirs: %v", err) - } - if _, err := os.Stat(configPath); err != nil { - t.Fatalf("File not created at nested path: %v", err) - } -} - -func TestSavePrompts_StartsWithYAMLDocumentMarker(t *testing.T) { - tempDir := t.TempDir() - configPath := filepath.Join(tempDir, "prompts.yaml") - - if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { - t.Fatalf("SavePrompts() failed: %v", err) - } - - data, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } - if !strings.HasPrefix(string(data), "---\n") { - t.Errorf("Saved file should start with YAML document marker, got: %q", string(data[:min(20, len(data))])) - } -} diff --git a/config/prompts_test.go b/config/prompts_test.go index 81bb6003..9e2a0adc 100644 --- a/config/prompts_test.go +++ b/config/prompts_test.go @@ -1,12 +1,19 @@ -package config +package config_test -import "testing" +import ( + "os" + "path/filepath" + "strings" + "testing" + + config "github.com/inference-gateway/cli/config" +) // Guards against accidental deletions of default prompts. Every prompt // field surfaced through prompts.yaml must ship a non-empty default so // the runtime overlay can fall back to it when a user blanks a key. func TestDefaultPromptsConfig_AllPromptsPopulated(t *testing.T) { - cfg := DefaultPromptsConfig() + cfg := config.DefaultPromptsConfig() cases := map[string]string{ "agent.system_prompt": cfg.Agent.SystemPrompt, @@ -29,9 +36,164 @@ func TestDefaultPromptsConfig_AllPromptsPopulated(t *testing.T) { // opt-in. This guards it in the opposite direction so a future "fill in // a default" change is intentional. func TestDefaultPromptsConfig_OptionalPromptsBlank(t *testing.T) { - cfg := DefaultPromptsConfig() + cfg := config.DefaultPromptsConfig() if cfg.Agent.CustomInstructions != "" { t.Errorf("agent.custom_instructions should ship empty, got %q", cfg.Agent.CustomInstructions) } } + +func TestLoadPrompts_NonExistentFile(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "non-existent.yaml") + + cfg, err := config.LoadPrompts(configPath) + if err != nil { + t.Fatalf("LoadPrompts() should not error for non-existent file, got: %v", err) + } + if cfg == nil { + t.Fatal("LoadPrompts() returned nil config") + } + if cfg.Agent.SystemPrompt == "" { + t.Error("Default prompts config should populate agent.system_prompt") + } + if cfg.Agent.SystemPromptPlan == "" { + t.Error("Default prompts config should populate agent.system_prompt_plan") + } + if cfg.Git.CommitMessage.SystemPrompt == "" { + t.Error("Default prompts config should populate git.commit_message.system_prompt") + } +} + +func TestLoadPrompts_ValidYAML(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "prompts.yaml") + + yamlContent := `--- +agent: + system_prompt: custom agent prompt + system_prompt_plan: custom plan prompt +git: + commit_message: + system_prompt: custom commit prompt +init: + prompt: custom init prompt +` + + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := config.LoadPrompts(configPath) + if err != nil { + t.Fatalf("LoadPrompts() failed: %v", err) + } + if cfg.Agent.SystemPrompt != "custom agent prompt" { + t.Errorf("Expected custom system_prompt, got %q", cfg.Agent.SystemPrompt) + } + if cfg.Agent.SystemPromptPlan != "custom plan prompt" { + t.Errorf("Expected custom plan prompt, got %q", cfg.Agent.SystemPromptPlan) + } + if cfg.Git.CommitMessage.SystemPrompt != "custom commit prompt" { + t.Errorf("Expected custom commit prompt, got %q", cfg.Git.CommitMessage.SystemPrompt) + } + if cfg.Init.Prompt != "custom init prompt" { + t.Errorf("Expected custom init prompt, got %q", cfg.Init.Prompt) + } +} + +func TestLoadPrompts_PartialYAML(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "prompts.yaml") + + yamlContent := `--- +agent: + system_prompt: only this field is set +` + + if err := os.WriteFile(configPath, []byte(yamlContent), 0644); err != nil { + t.Fatalf("Failed to write test config file: %v", err) + } + + cfg, err := config.LoadPrompts(configPath) + if err != nil { + t.Fatalf("LoadPrompts() failed: %v", err) + } + if cfg.Agent.SystemPrompt != "only this field is set" { + t.Errorf("Expected partial value, got %q", cfg.Agent.SystemPrompt) + } + // Other prompt fields stay zero — the runtime overlay (cmd applies it + // on top of LoadPrompts) is what fills them back in from defaults. + if cfg.Agent.SystemPromptPlan != "" { + t.Errorf("Expected unset plan prompt to be empty, got %q", cfg.Agent.SystemPromptPlan) + } + if cfg.Git.CommitMessage.SystemPrompt != "" { + t.Errorf("Expected unset commit prompt to be empty, got %q", cfg.Git.CommitMessage.SystemPrompt) + } +} + +func TestSavePrompts_RoundTrip(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "prompts.yaml") + + original := &config.PromptsConfig{ + Agent: config.PromptsAgentConfig{ + SystemPrompt: "round trip system prompt", + SystemReminders: config.PromptsAgentRemindersConfig{ + ReminderText: "round trip reminder", + }, + }, + Git: config.PromptsGitConfig{ + CommitMessage: config.PromptsGitCommitMessageConfig{ + SystemPrompt: "round trip commit prompt", + }, + }, + } + + if err := config.SavePrompts(configPath, original); err != nil { + t.Fatalf("SavePrompts() failed: %v", err) + } + + loaded, err := config.LoadPrompts(configPath) + if err != nil { + t.Fatalf("LoadPrompts() after save failed: %v", err) + } + if loaded.Agent.SystemPrompt != original.Agent.SystemPrompt { + t.Errorf("agent.system_prompt not preserved, got %q", loaded.Agent.SystemPrompt) + } + if loaded.Agent.SystemReminders.ReminderText != original.Agent.SystemReminders.ReminderText { + t.Errorf("reminder_text not preserved, got %q", loaded.Agent.SystemReminders.ReminderText) + } + if loaded.Git.CommitMessage.SystemPrompt != original.Git.CommitMessage.SystemPrompt { + t.Errorf("git.commit_message.system_prompt not preserved, got %q", loaded.Git.CommitMessage.SystemPrompt) + } +} + +func TestSavePrompts_CreatesParentDirectory(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "nested", "deep", "prompts.yaml") + + if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { + t.Fatalf("SavePrompts() failed to create nested dirs: %v", err) + } + if _, err := os.Stat(configPath); err != nil { + t.Fatalf("File not created at nested path: %v", err) + } +} + +func TestSavePrompts_StartsWithYAMLDocumentMarker(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "prompts.yaml") + + if err := config.SavePrompts(configPath, config.DefaultPromptsConfig()); err != nil { + t.Fatalf("SavePrompts() failed: %v", err) + } + + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !strings.HasPrefix(string(data), "---\n") { + t.Errorf("Saved file should start with YAML document marker, got: %q", string(data[:min(20, len(data))])) + } +} diff --git a/config/utils/yamlfile.go b/config/utils/yamlfile.go new file mode 100644 index 00000000..6b04233d --- /dev/null +++ b/config/utils/yamlfile.go @@ -0,0 +1,69 @@ +// Package utils provides generic file-IO helpers shared by every sub-config +// in package config. They are deliberately domain-agnostic - anything that +// knows about a specific config type belongs in package config itself. +package utils + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + yaml "gopkg.in/yaml.v3" +) + +// LoadYAML reads path. If the file does not exist, defaults() is returned so +// callers can treat absence as "use defaults" without special-casing. The +// file body is run through os.ExpandEnv so ${VAR} references resolve from +// the environment before unmarshalling — any future content that needs a +// literal `${…}` token must escape it as `$$…`. +// +// label scopes error messages, e.g. "channels" produces +// "failed to read channels config: …". +func LoadYAML[T any](path, label string, defaults func() *T) (*T, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return defaults(), nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read %s config: %w", label, err) + } + + expanded := os.ExpandEnv(string(data)) + + cfg := new(T) + if err := yaml.Unmarshal([]byte(expanded), cfg); err != nil { + return nil, fmt.Errorf("failed to parse %s config: %w", label, err) + } + + return cfg, nil +} + +// SaveYAML writes cfg to path, creating any missing parent directories. +// It always emits the YAML document marker `---\n` and uses 2-space indent. +func SaveYAML[T any](path, label string, cfg *T) error { + var buf bytes.Buffer + buf.WriteString("---\n") + + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(2) + + if err := encoder.Encode(cfg); err != nil { + return fmt.Errorf("failed to marshal %s config: %w", label, err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("failed to close %s encoder: %w", label, err) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return fmt.Errorf("failed to create %s config directory: %w", label, err) + } + + if err := os.WriteFile(path, buf.Bytes(), 0644); err != nil { + return fmt.Errorf("failed to write %s config: %w", label, err) + } + + return nil +} diff --git a/config/utils/yamlfile_test.go b/config/utils/yamlfile_test.go new file mode 100644 index 00000000..217a9fed --- /dev/null +++ b/config/utils/yamlfile_test.go @@ -0,0 +1,120 @@ +package utils_test + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" + + utils "github.com/inference-gateway/cli/config/utils" +) + +type yamlFileFixture struct { + Name string `yaml:"name"` + Value string `yaml:"value"` +} + +func defaultsForFixture() *yamlFileFixture { + return &yamlFileFixture{Name: "default", Value: "x"} +} + +func TestLoadYAML_MissingFileReturnsDefaults(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "missing.yaml") + + cfg, err := utils.LoadYAML(path, "fixture", defaultsForFixture) + if err != nil { + t.Fatalf("LoadYAML() error = %v", err) + } + if cfg.Name != "default" || cfg.Value != "x" { + t.Errorf("expected defaults, got %+v", cfg) + } +} + +func TestLoadYAML_ParsesYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f.yaml") + if err := os.WriteFile(path, []byte("name: hello\nvalue: world\n"), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := utils.LoadYAML(path, "fixture", defaultsForFixture) + if err != nil { + t.Fatalf("LoadYAML() error = %v", err) + } + if cfg.Name != "hello" || cfg.Value != "world" { + t.Errorf("got %+v", cfg) + } +} + +func TestLoadYAML_ExpandsEnv(t *testing.T) { + t.Setenv("YAMLFILE_TEST_VAR", "expanded") + dir := t.TempDir() + path := filepath.Join(dir, "f.yaml") + if err := os.WriteFile(path, []byte("name: ${YAMLFILE_TEST_VAR}\nvalue: v\n"), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := utils.LoadYAML(path, "fixture", defaultsForFixture) + if err != nil { + t.Fatalf("LoadYAML() error = %v", err) + } + if cfg.Name != "expanded" { + t.Errorf("expected env-expanded name, got %q", cfg.Name) + } +} + +func TestLoadYAML_MalformedYAML(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f.yaml") + if err := os.WriteFile(path, []byte("not: valid: yaml: ["), 0644); err != nil { + t.Fatal(err) + } + + _, err := utils.LoadYAML(path, "fixture", defaultsForFixture) + if err == nil { + t.Fatal("expected parse error") + } + if !strings.Contains(err.Error(), "fixture") { + t.Errorf("error should be label-scoped, got %v", err) + } +} + +func TestSaveYAML_WritesDocMarkerAndRoundTrips(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "nested", "f.yaml") + original := &yamlFileFixture{Name: "n", Value: "v"} + + if err := utils.SaveYAML(path, "fixture", original); err != nil { + t.Fatalf("SaveYAML() error = %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read after save: %v", err) + } + if !bytes.HasPrefix(data, []byte("---\n")) { + t.Errorf("expected `---\\n` doc marker prefix, got %q", string(data[:min(8, len(data))])) + } + + loaded, err := utils.LoadYAML(path, "fixture", defaultsForFixture) + if err != nil { + t.Fatalf("load after save: %v", err) + } + if *loaded != *original { + t.Errorf("round-trip mismatch: got %+v want %+v", loaded, original) + } +} + +func TestSaveYAML_CreatesParentDirectory(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "a", "b", "c", "f.yaml") + + if err := utils.SaveYAML(path, "fixture", &yamlFileFixture{Name: "n"}); err != nil { + t.Fatalf("SaveYAML() error = %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Errorf("file not created: %v", err) + } +} diff --git a/internal/services/agents_test.go b/internal/services/agents_test.go index b776f34e..6d29e0c3 100644 --- a/internal/services/agents_test.go +++ b/internal/services/agents_test.go @@ -13,7 +13,9 @@ func TestA2AAgentService_GetConfiguredAgents_EnvVarPrecedence(t *testing.T) { tmpDir := t.TempDir() agentsPath := filepath.Join(tmpDir, "agents.yaml") - require.NoError(t, config.AddAgent(agentsPath, config.AgentEntry{ + agentsCfg, err := config.LoadAgents(agentsPath) + require.NoError(t, err) + require.NoError(t, agentsCfg.CreateEntry(config.AgentEntry{ Name: "yaml-agent", URL: "http://yaml-agent:8080", Run: false,