Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/root/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (f *debugFlags) runDebugConfigCommand(cmd *cobra.Command, args []string) er
return err
}

cfg, err := config.LoadConfigFrom(ctx, agentSource)
cfg, err := config.Load(ctx, agentSource)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/build/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func DockerImage(ctx context.Context, out Printer, agentFilename string, fs file
return err
}

cfg, err := config.LoadConfigFrom(ctx, agentSource)
cfg, err := config.Load(ctx, agentSource)
if err != nil {
return err
}
Expand Down
32 changes: 20 additions & 12 deletions pkg/config/commands_test.go
Original file line number Diff line number Diff line change
@@ -1,52 +1,60 @@
package config

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/require"
)

func TestV2Commands_AllForms(t *testing.T) {
cfg, err := LoadConfig(t.Context(), "commands_v2.yaml", openRoot(t, "testdata"))
cfg, err := Load(t.Context(), fileSource("testdata/commands_v2.yaml"))
require.NoError(t, err)
// map form

cmdsMap := cfg.Agents["root"].Commands
require.Equal(t, "check disk", cmdsMap["df"])
require.Equal(t, "list files", cmdsMap["ls"])
// list form

cmdsList := cfg.Agents["another_agent"].Commands
require.Equal(t, "check disk", cmdsList["df"])
require.Equal(t, "list files", cmdsList["ls"])
// none

require.Empty(t, cfg.Agents["none_agent"].Commands)
}

func TestMigrate_v1_Commands_AllForms(t *testing.T) {
cfg, err := LoadConfig(t.Context(), "commands_v1.yaml", openRoot(t, "testdata"))
cfg, err := Load(t.Context(), fileSource("testdata/commands_v1.yaml"))
require.NoError(t, err)
// map form

cmdsMap := cfg.Agents["root"].Commands
require.Equal(t, "check disk", cmdsMap["df"])
require.Equal(t, "list files", cmdsMap["ls"])
// list form

cmdsList := cfg.Agents["another_agent"].Commands
require.Equal(t, "check disk", cmdsList["df"])
require.Equal(t, "list files", cmdsList["ls"])
// none

require.Empty(t, cfg.Agents["yet_another_agent"].Commands)
}

func TestMigrate_v0_Commands_AllForms(t *testing.T) {
cfg, err := LoadConfig(t.Context(), "commands_v0.yaml", openRoot(t, "testdata"))
cfg, err := Load(t.Context(), fileSource("testdata/commands_v0.yaml"))
require.NoError(t, err)
// map form

cmdsMap := cfg.Agents["root"].Commands
require.Equal(t, "check disk", cmdsMap["df"])
require.Equal(t, "list files", cmdsMap["ls"])
// list form

cmdsList := cfg.Agents["another_agent"].Commands
require.Equal(t, "check disk", cmdsList["df"])
require.Equal(t, "list files", cmdsList["ls"])
// none

require.Empty(t, cfg.Agents["yet_another_agent"].Commands)
}

type fileSource string

func (s fileSource) Read(context.Context) ([]byte, error) {
return os.ReadFile(string(s))
}
16 changes: 1 addition & 15 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,18 @@ import (

"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/environment"
"github.com/docker/cagent/pkg/filesystem"
)

type Source interface {
Read(ctx context.Context) ([]byte, error)
}

func LoadConfigFrom(ctx context.Context, source Source) (*latest.Config, error) {
func Load(ctx context.Context, source Source) (*latest.Config, error) {
data, err := source.Read(ctx)
if err != nil {
return nil, err
}

return loadConfigBytes(data)
}

func LoadConfig(ctx context.Context, path string, fs filesystem.FS) (*latest.Config, error) {
data, err := fs.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config file %s: %w", path, err)
}

return loadConfigBytes(data)
}

func loadConfigBytes(data []byte) (*latest.Config, error) {
var raw struct {
Version string `yaml:"version,omitempty"`
}
Expand Down
60 changes: 15 additions & 45 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ import (
func TestAutoRegisterModels(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "autoregister.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/autoregister.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Models, 2)
Expand All @@ -30,9 +28,7 @@ func TestAutoRegisterModels(t *testing.T) {
func TestAutoRegisterAlloy(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "autoregister_alloy.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/autoregister_alloy.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Models, 2)
Expand All @@ -45,9 +41,7 @@ func TestAutoRegisterAlloy(t *testing.T) {
func TestMigrate_v0_v1_provider(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "provider_v0.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/provider_v0.yaml"))
require.NoError(t, err)

assert.Equal(t, "openai", cfg.Models["gpt"].Provider)
Expand All @@ -56,9 +50,7 @@ func TestMigrate_v0_v1_provider(t *testing.T) {
func TestMigrate_v1_provider(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "provider_v1.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/provider_v1.yaml"))
require.NoError(t, err)

assert.Equal(t, "openai", cfg.Models["gpt"].Provider)
Expand All @@ -67,9 +59,7 @@ func TestMigrate_v1_provider(t *testing.T) {
func TestMigrate_v0_v1_todo(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "todo_v0.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/todo_v0.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -81,9 +71,7 @@ func TestMigrate_v0_v1_todo(t *testing.T) {
func TestMigrate_v1_todo(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "todo_v1.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/todo_v1.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -95,9 +83,7 @@ func TestMigrate_v1_todo(t *testing.T) {
func TestMigrate_v0_v1_shared_todo(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "shared_todo_v0.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/shared_todo_v0.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -109,9 +95,7 @@ func TestMigrate_v0_v1_shared_todo(t *testing.T) {
func TestMigrate_v1_shared_todo(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "shared_todo_v1.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/shared_todo_v1.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -123,9 +107,7 @@ func TestMigrate_v1_shared_todo(t *testing.T) {
func TestMigrate_v0_v1_think(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "think_v0.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/think_v0.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -136,9 +118,7 @@ func TestMigrate_v0_v1_think(t *testing.T) {
func TestMigrate_v1_think(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "think_v1.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/think_v1.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -149,9 +129,7 @@ func TestMigrate_v1_think(t *testing.T) {
func TestMigrate_v0_v1_memory(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "memory_v0.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/memory_v0.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -163,9 +141,7 @@ func TestMigrate_v0_v1_memory(t *testing.T) {
func TestMigrate_v1_memory(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

cfg, err := LoadConfig(t.Context(), "memory_v1.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/memory_v1.yaml"))
require.NoError(t, err)

assert.Len(t, cfg.Agents["root"].Toolsets, 2)
Expand All @@ -177,9 +153,7 @@ func TestMigrate_v1_memory(t *testing.T) {
func TestMigrate_v1(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

_, err := LoadConfig(t.Context(), "v1.yaml", root)
_, err := Load(t.Context(), fileSource("testdata/v1.yaml"))
require.NoError(t, err)
}

Expand Down Expand Up @@ -257,9 +231,7 @@ func TestCheckRequiredEnvVars(t *testing.T) {
t.Run(test.yaml, func(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata/env")

cfg, err := LoadConfig(t.Context(), test.yaml, root)
cfg, err := Load(t.Context(), fileSource("testdata/env/"+test.yaml))
require.NoError(t, err)

err = CheckRequiredEnvVars(t.Context(), cfg, "", &noEnvProvider{})
Expand All @@ -277,9 +249,7 @@ func TestCheckRequiredEnvVars(t *testing.T) {
func TestCheckRequiredEnvVarsWithModelGateway(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata/env")

cfg, err := LoadConfig(t.Context(), "all.yaml", root)
cfg, err := Load(t.Context(), fileSource("testdata/env/all.yaml"))
require.NoError(t, err)

err = CheckRequiredEnvVars(t.Context(), cfg, "gateway:8080", &noEnvProvider{})
Expand Down
3 changes: 1 addition & 2 deletions pkg/config/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/xeipuuv/gojsonschema"

"github.com/docker/cagent/pkg/config/latest"
"github.com/docker/cagent/pkg/filesystem"
"github.com/docker/cagent/pkg/modelsdev"
)

Expand Down Expand Up @@ -43,7 +42,7 @@ func TestParseExamples(t *testing.T) {
t.Run(file, func(t *testing.T) {
t.Parallel()

cfg, err := LoadConfig(t.Context(), file, filesystem.AllowAll)
cfg, err := Load(t.Context(), fileSource(file))

require.NoError(t, err)
require.Equal(t, latest.Version, cfg.Version, "Version should be %d in %s", latest.Version, file)
Expand Down
16 changes: 9 additions & 7 deletions pkg/config/validation_test.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
package config

import (
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestLoadConfig_InvalidPath(t *testing.T) {
tmp := openRoot(t, t.TempDir())
tmp := t.TempDir()
tmpRoot := openRoot(t, tmp)

validConfig := `version: 1
agents:
root:
model: "openai/gpt-4"
`

err := tmp.WriteFile("valid.yaml", []byte(validConfig), 0o644)
err := tmpRoot.WriteFile("valid.yaml", []byte(validConfig), 0o644)
require.NoError(t, err)

cfg, err := LoadConfig(t.Context(), "valid.yaml", tmp)
cfg, err := Load(t.Context(), fileSource(filepath.Join(tmp, "valid.yaml")))
require.NoError(t, err)
require.NotNil(t, cfg)

_, err = LoadConfig(t.Context(), "../../../etc/passwd", tmp)
_, err = Load(t.Context(), fileSource(filepath.Join(tmp, "../../../etc/passwd"))) //nolint: gocritic // testing invalid path
require.Error(t, err)
}

func TestValidationErrors(t *testing.T) {
t.Parallel()

tests := []struct {
name string
path string
Expand All @@ -49,9 +53,7 @@ func TestValidationErrors(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

root := openRoot(t, "testdata")

_, err := LoadConfig(t.Context(), tt.path, root)
_, err := Load(t.Context(), fileSource(filepath.Join("testdata", tt.path)))
require.Error(t, err)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/oci/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func PackageFileAsOCIToStore(ctx context.Context, agentFilename, artifactRef str

agentFilename = filepath.Clean(agentFilename)
source := teamloader.NewFileSource(agentFilename)
cfg, err := config.LoadConfigFrom(ctx, source)
cfg, err := config.Load(ctx, source)
if err != nil {
return "", fmt.Errorf("loading config: %w", err)
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,12 @@ func (s *Server) getAgentConfig(c echo.Context) error {
agentID := c.Param("id")
p := addYamlExt(agentID)

cfg, err := config.LoadConfig(c.Request().Context(), p, s.rootFS)
data, err := s.rootFS.ReadFile(p)
if err != nil {
return fmt.Errorf("reading config file %s: %w", p, err)
}

cfg, err := config.Load(c.Request().Context(), teamloader.NewBytesSource(data))
if err != nil {
slog.Error("Failed to load config", "error", err)
return echo.NewHTTPError(http.StatusNotFound, "agent not found")
Expand Down
2 changes: 1 addition & 1 deletion pkg/teamloader/teamloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func LoadFrom(ctx context.Context, source AgentSource, runtimeConfig *config.Run
}

// Load the agent's configuration
cfg, err := config.LoadConfigFrom(ctx, source)
cfg, err := config.Load(ctx, source)
if err != nil {
return nil, err
}
Expand Down
Loading