diff --git a/cmd/root/debug.go b/cmd/root/debug.go index 9ec506933..c23e8e0ad 100644 --- a/cmd/root/debug.go +++ b/cmd/root/debug.go @@ -56,7 +56,7 @@ func (f *debugFlags) runDebugConfigCommand(cmd *cobra.Command, args []string) er return err } - cfg, err := config.LoadConfigFrom(ctx, agentSource) + cfg, err := config.Load(ctx, agentSource) if err != nil { return err } diff --git a/pkg/build/build.go b/pkg/build/build.go index 3cabeda4d..d1edfb4f0 100644 --- a/pkg/build/build.go +++ b/pkg/build/build.go @@ -39,7 +39,7 @@ func DockerImage(ctx context.Context, out Printer, agentFilename string, fs file return err } - cfg, err := config.LoadConfigFrom(ctx, agentSource) + cfg, err := config.Load(ctx, agentSource) if err != nil { return err } diff --git a/pkg/config/commands_test.go b/pkg/config/commands_test.go index c0051e73a..be4d41d3b 100644 --- a/pkg/config/commands_test.go +++ b/pkg/config/commands_test.go @@ -1,52 +1,60 @@ package config import ( + "context" + "os" "testing" "github.com/stretchr/testify/require" ) func TestV2Commands_AllForms(t *testing.T) { - cfg, err := LoadConfig(t.Context(), "commands_v2.yaml", openRoot(t, "testdata")) + cfg, err := Load(t.Context(), fileSource("testdata/commands_v2.yaml")) require.NoError(t, err) - // map form + cmdsMap := cfg.Agents["root"].Commands require.Equal(t, "check disk", cmdsMap["df"]) require.Equal(t, "list files", cmdsMap["ls"]) - // list form + cmdsList := cfg.Agents["another_agent"].Commands require.Equal(t, "check disk", cmdsList["df"]) require.Equal(t, "list files", cmdsList["ls"]) - // none + require.Empty(t, cfg.Agents["none_agent"].Commands) } func TestMigrate_v1_Commands_AllForms(t *testing.T) { - cfg, err := LoadConfig(t.Context(), "commands_v1.yaml", openRoot(t, "testdata")) + cfg, err := Load(t.Context(), fileSource("testdata/commands_v1.yaml")) require.NoError(t, err) - // map form + cmdsMap := cfg.Agents["root"].Commands require.Equal(t, "check disk", cmdsMap["df"]) require.Equal(t, "list files", cmdsMap["ls"]) - // list form + cmdsList := cfg.Agents["another_agent"].Commands require.Equal(t, "check disk", cmdsList["df"]) require.Equal(t, "list files", cmdsList["ls"]) - // none + require.Empty(t, cfg.Agents["yet_another_agent"].Commands) } func TestMigrate_v0_Commands_AllForms(t *testing.T) { - cfg, err := LoadConfig(t.Context(), "commands_v0.yaml", openRoot(t, "testdata")) + cfg, err := Load(t.Context(), fileSource("testdata/commands_v0.yaml")) require.NoError(t, err) - // map form + cmdsMap := cfg.Agents["root"].Commands require.Equal(t, "check disk", cmdsMap["df"]) require.Equal(t, "list files", cmdsMap["ls"]) - // list form + cmdsList := cfg.Agents["another_agent"].Commands require.Equal(t, "check disk", cmdsList["df"]) require.Equal(t, "list files", cmdsList["ls"]) - // none + require.Empty(t, cfg.Agents["yet_another_agent"].Commands) } + +type fileSource string + +func (s fileSource) Read(context.Context) ([]byte, error) { + return os.ReadFile(string(s)) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index eacf2ce67..9fad6e617 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,32 +9,18 @@ import ( "github.com/docker/cagent/pkg/config/latest" "github.com/docker/cagent/pkg/environment" - "github.com/docker/cagent/pkg/filesystem" ) type Source interface { Read(ctx context.Context) ([]byte, error) } -func LoadConfigFrom(ctx context.Context, source Source) (*latest.Config, error) { +func Load(ctx context.Context, source Source) (*latest.Config, error) { data, err := source.Read(ctx) if err != nil { return nil, err } - return loadConfigBytes(data) -} - -func LoadConfig(ctx context.Context, path string, fs filesystem.FS) (*latest.Config, error) { - data, err := fs.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("reading config file %s: %w", path, err) - } - - return loadConfigBytes(data) -} - -func loadConfigBytes(data []byte) (*latest.Config, error) { var raw struct { Version string `yaml:"version,omitempty"` } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 2ded43fab..7bb33f1ec 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -15,9 +15,7 @@ import ( func TestAutoRegisterModels(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "autoregister.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/autoregister.yaml")) require.NoError(t, err) assert.Len(t, cfg.Models, 2) @@ -30,9 +28,7 @@ func TestAutoRegisterModels(t *testing.T) { func TestAutoRegisterAlloy(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "autoregister_alloy.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/autoregister_alloy.yaml")) require.NoError(t, err) assert.Len(t, cfg.Models, 2) @@ -45,9 +41,7 @@ func TestAutoRegisterAlloy(t *testing.T) { func TestMigrate_v0_v1_provider(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "provider_v0.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/provider_v0.yaml")) require.NoError(t, err) assert.Equal(t, "openai", cfg.Models["gpt"].Provider) @@ -56,9 +50,7 @@ func TestMigrate_v0_v1_provider(t *testing.T) { func TestMigrate_v1_provider(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "provider_v1.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/provider_v1.yaml")) require.NoError(t, err) assert.Equal(t, "openai", cfg.Models["gpt"].Provider) @@ -67,9 +59,7 @@ func TestMigrate_v1_provider(t *testing.T) { func TestMigrate_v0_v1_todo(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "todo_v0.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/todo_v0.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -81,9 +71,7 @@ func TestMigrate_v0_v1_todo(t *testing.T) { func TestMigrate_v1_todo(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "todo_v1.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/todo_v1.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -95,9 +83,7 @@ func TestMigrate_v1_todo(t *testing.T) { func TestMigrate_v0_v1_shared_todo(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "shared_todo_v0.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/shared_todo_v0.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -109,9 +95,7 @@ func TestMigrate_v0_v1_shared_todo(t *testing.T) { func TestMigrate_v1_shared_todo(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "shared_todo_v1.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/shared_todo_v1.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -123,9 +107,7 @@ func TestMigrate_v1_shared_todo(t *testing.T) { func TestMigrate_v0_v1_think(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "think_v0.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/think_v0.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -136,9 +118,7 @@ func TestMigrate_v0_v1_think(t *testing.T) { func TestMigrate_v1_think(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "think_v1.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/think_v1.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -149,9 +129,7 @@ func TestMigrate_v1_think(t *testing.T) { func TestMigrate_v0_v1_memory(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "memory_v0.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/memory_v0.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -163,9 +141,7 @@ func TestMigrate_v0_v1_memory(t *testing.T) { func TestMigrate_v1_memory(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - cfg, err := LoadConfig(t.Context(), "memory_v1.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/memory_v1.yaml")) require.NoError(t, err) assert.Len(t, cfg.Agents["root"].Toolsets, 2) @@ -177,9 +153,7 @@ func TestMigrate_v1_memory(t *testing.T) { func TestMigrate_v1(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - _, err := LoadConfig(t.Context(), "v1.yaml", root) + _, err := Load(t.Context(), fileSource("testdata/v1.yaml")) require.NoError(t, err) } @@ -257,9 +231,7 @@ func TestCheckRequiredEnvVars(t *testing.T) { t.Run(test.yaml, func(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata/env") - - cfg, err := LoadConfig(t.Context(), test.yaml, root) + cfg, err := Load(t.Context(), fileSource("testdata/env/"+test.yaml)) require.NoError(t, err) err = CheckRequiredEnvVars(t.Context(), cfg, "", &noEnvProvider{}) @@ -277,9 +249,7 @@ func TestCheckRequiredEnvVars(t *testing.T) { func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata/env") - - cfg, err := LoadConfig(t.Context(), "all.yaml", root) + cfg, err := Load(t.Context(), fileSource("testdata/env/all.yaml")) require.NoError(t, err) err = CheckRequiredEnvVars(t.Context(), cfg, "gateway:8080", &noEnvProvider{}) diff --git a/pkg/config/examples_test.go b/pkg/config/examples_test.go index 5ac9e9b8f..ff5421c8b 100644 --- a/pkg/config/examples_test.go +++ b/pkg/config/examples_test.go @@ -12,7 +12,6 @@ import ( "github.com/xeipuuv/gojsonschema" "github.com/docker/cagent/pkg/config/latest" - "github.com/docker/cagent/pkg/filesystem" "github.com/docker/cagent/pkg/modelsdev" ) @@ -43,7 +42,7 @@ func TestParseExamples(t *testing.T) { t.Run(file, func(t *testing.T) { t.Parallel() - cfg, err := LoadConfig(t.Context(), file, filesystem.AllowAll) + cfg, err := Load(t.Context(), fileSource(file)) require.NoError(t, err) require.Equal(t, latest.Version, cfg.Version, "Version should be %d in %s", latest.Version, file) diff --git a/pkg/config/validation_test.go b/pkg/config/validation_test.go index edeedfc05..481eaec1f 100644 --- a/pkg/config/validation_test.go +++ b/pkg/config/validation_test.go @@ -1,13 +1,15 @@ package config import ( + "path/filepath" "testing" "github.com/stretchr/testify/require" ) func TestLoadConfig_InvalidPath(t *testing.T) { - tmp := openRoot(t, t.TempDir()) + tmp := t.TempDir() + tmpRoot := openRoot(t, tmp) validConfig := `version: 1 agents: @@ -15,18 +17,20 @@ agents: model: "openai/gpt-4" ` - err := tmp.WriteFile("valid.yaml", []byte(validConfig), 0o644) + err := tmpRoot.WriteFile("valid.yaml", []byte(validConfig), 0o644) require.NoError(t, err) - cfg, err := LoadConfig(t.Context(), "valid.yaml", tmp) + cfg, err := Load(t.Context(), fileSource(filepath.Join(tmp, "valid.yaml"))) require.NoError(t, err) require.NotNil(t, cfg) - _, err = LoadConfig(t.Context(), "../../../etc/passwd", tmp) + _, err = Load(t.Context(), fileSource(filepath.Join(tmp, "../../../etc/passwd"))) //nolint: gocritic // testing invalid path require.Error(t, err) } func TestValidationErrors(t *testing.T) { + t.Parallel() + tests := []struct { name string path string @@ -49,9 +53,7 @@ func TestValidationErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - root := openRoot(t, "testdata") - - _, err := LoadConfig(t.Context(), tt.path, root) + _, err := Load(t.Context(), fileSource(filepath.Join("testdata", tt.path))) require.Error(t, err) }) } diff --git a/pkg/oci/package.go b/pkg/oci/package.go index ab09a94be..53d6334c2 100644 --- a/pkg/oci/package.go +++ b/pkg/oci/package.go @@ -29,7 +29,7 @@ func PackageFileAsOCIToStore(ctx context.Context, agentFilename, artifactRef str agentFilename = filepath.Clean(agentFilename) source := teamloader.NewFileSource(agentFilename) - cfg, err := config.LoadConfigFrom(ctx, source) + cfg, err := config.Load(ctx, source) if err != nil { return "", fmt.Errorf("loading config: %w", err) } diff --git a/pkg/server/server.go b/pkg/server/server.go index 0d419db31..efce8559f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -195,7 +195,12 @@ func (s *Server) getAgentConfig(c echo.Context) error { agentID := c.Param("id") p := addYamlExt(agentID) - cfg, err := config.LoadConfig(c.Request().Context(), p, s.rootFS) + data, err := s.rootFS.ReadFile(p) + if err != nil { + return fmt.Errorf("reading config file %s: %w", p, err) + } + + cfg, err := config.Load(c.Request().Context(), teamloader.NewBytesSource(data)) if err != nil { slog.Error("Failed to load config", "error", err) return echo.NewHTTPError(http.StatusNotFound, "agent not found") diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 73de0ae0e..0e7ce8142 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -130,7 +130,7 @@ func LoadFrom(ctx context.Context, source AgentSource, runtimeConfig *config.Run } // Load the agent's configuration - cfg, err := config.LoadConfigFrom(ctx, source) + cfg, err := config.Load(ctx, source) if err != nil { return nil, err }