diff --git a/pkg/gateway/catalog.go b/pkg/gateway/catalog.go index aa054f88f..290635dae 100644 --- a/pkg/gateway/catalog.go +++ b/pkg/gateway/catalog.go @@ -2,42 +2,36 @@ package gateway import ( "context" + "encoding/json" "fmt" - "io" "net/http" "strings" - "github.com/goccy/go-yaml" + "github.com/docker/cagent/pkg/sync" ) const DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml" -func ParseServerRef(ref string) string { - return strings.TrimPrefix(ref, "docker:") -} - -func RequiredEnvVars(ctx context.Context, serverName, catalogURL string) ([]Secret, error) { - catalog, err := readCatalog(ctx, catalogURL) +func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) { + catalog, err := readCatalogOnce() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to fetch MCP catalog: %w", err) } server, ok := catalog[serverName] if !ok { - return nil, fmt.Errorf("MCP server %q not found in catalog %q", serverName, catalogURL) + return nil, fmt.Errorf("MCP server %q not found in MCP catalog", serverName) } return server.Secrets, nil } -// TODO(dga): cache the catalog. -func readCatalog(ctx context.Context, url string) (Catalog, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) - if err != nil { - return nil, err - } +// Read the MCP Catalog only once and cache the result. +var readCatalogOnce = sync.OnceErr(func() (Catalog, error) { + // Use the JSON version because it's 3x time faster to parse than YAML. + url := strings.Replace(DockerCatalogURL, ".yaml", ".json", 1) - resp, err := http.DefaultClient.Do(req) + resp, err := http.Get(url) if err != nil { return nil, err } @@ -47,15 +41,10 @@ func readCatalog(ctx context.Context, url string) (Catalog, error) { return nil, fmt.Errorf("failed to fetch URL: %s, status: %s", url, resp.Status) } - buf, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - var topLevel topLevel - if err := yaml.Unmarshal(buf, &topLevel); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&topLevel); err != nil { return nil, err } return topLevel.Catalog, nil -} +}) diff --git a/pkg/gateway/catalog_test.go b/pkg/gateway/catalog_test.go new file mode 100644 index 000000000..1319290f2 --- /dev/null +++ b/pkg/gateway/catalog_test.go @@ -0,0 +1,17 @@ +package gateway + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRequiredEnvVars(t *testing.T) { + secrets, err := RequiredEnvVars(t.Context(), "github-official") + require.NoError(t, err) + + assert.Len(t, secrets, 1) + assert.Equal(t, "GITHUB_PERSONAL_ACCESS_TOKEN", secrets[0].Env) + assert.Equal(t, "github.personal_access_token", secrets[0].Name) +} diff --git a/pkg/gateway/servers.go b/pkg/gateway/servers.go new file mode 100644 index 000000000..458891777 --- /dev/null +++ b/pkg/gateway/servers.go @@ -0,0 +1,9 @@ +package gateway + +import ( + "strings" +) + +func ParseServerRef(ref string) string { + return strings.TrimPrefix(ref, "docker:") +} diff --git a/pkg/gateway/types.go b/pkg/gateway/types.go index 3c6102aca..44d178e42 100644 --- a/pkg/gateway/types.go +++ b/pkg/gateway/types.go @@ -1,17 +1,17 @@ package gateway type topLevel struct { - Catalog Catalog `json:"registry" yaml:"registry"` + Catalog Catalog `json:"registry"` } type Catalog map[string]Server type Server struct { - Secrets []Secret `json:"secrets,omitempty" yaml:"secrets,omitempty"` + Secrets []Secret `json:"secrets,omitempty"` } type Secret struct { - Name string `json:"name" yaml:"name"` - Env string `json:"env" yaml:"env"` - Example string `json:"example" yaml:"example"` + Name string `json:"name"` + Env string `json:"env"` + Example string `json:"example"` } diff --git a/pkg/secrets/gather.go b/pkg/secrets/gather.go index ec1f2b390..643d52165 100644 --- a/pkg/secrets/gather.go +++ b/pkg/secrets/gather.go @@ -79,7 +79,7 @@ func GatherEnvVarsForTools(ctx context.Context, cfg *latest.Config) ([]string, e for _, ref := range gatherMCPServerReferences(cfg) { mcpServerName := gateway.ParseServerRef(ref) - secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName, gateway.DockerCatalogURL) + secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName) if err != nil { return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err) } diff --git a/pkg/sync/oncerr.go b/pkg/sync/oncerr.go new file mode 100644 index 000000000..1959db6aa --- /dev/null +++ b/pkg/sync/oncerr.go @@ -0,0 +1,16 @@ +package sync + +import "sync" + +func OnceErr[T any](fn func() (T, error)) func() (T, error) { + var once sync.Once + var result T + var err error + + return func() (T, error) { + once.Do(func() { + result, err = fn() + }) + return result, err + } +} diff --git a/pkg/sync/oncerr_test.go b/pkg/sync/oncerr_test.go new file mode 100644 index 000000000..cf1c0e00a --- /dev/null +++ b/pkg/sync/oncerr_test.go @@ -0,0 +1,52 @@ +package sync + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOnceErr(t *testing.T) { + t.Parallel() + + called := 0 + fn := func() (int, error) { + called++ + return 42, nil + } + + memoizedFn := OnceErr(fn) + + value, err := memoizedFn() + require.NoError(t, err) + require.Equal(t, 42, value) + require.Equal(t, 1, called) + + value, err = memoizedFn() + require.NoError(t, err) + require.Equal(t, 42, value) + require.Equal(t, 1, called) // Didn't have to call the inner fn +} + +func TestOnceErr_Error(t *testing.T) { + t.Parallel() + + called := 0 + fn := func() (int, error) { + called++ + return 1337, errors.New("test error") + } + + memoizedFn := OnceErr(fn) + + value, err := memoizedFn() + require.Error(t, err) + require.Equal(t, 1337, value) + require.Equal(t, 1, called) + + value, err = memoizedFn() + require.Error(t, err) + require.Equal(t, 1337, value) + require.Equal(t, 1, called) // Didn't have to call the inner fn +} diff --git a/pkg/teamloader/teamloader_test.go b/pkg/teamloader/teamloader_test.go index 61c0f3b27..8e9102aaa 100644 --- a/pkg/teamloader/teamloader_test.go +++ b/pkg/teamloader/teamloader_test.go @@ -100,9 +100,11 @@ func TestLoadExamples(t *testing.T) { // Collect the missing env vars. missingEnvs := map[string]bool{} + var runtimeConfig config.RuntimeConfig + for _, file := range collectExamples(t) { t.Run(file, func(t *testing.T) { - _, err := Load(t.Context(), file, config.RuntimeConfig{}) + _, err := Load(t.Context(), file, runtimeConfig) if err != nil { envErr := &environment.RequiredEnvError{} require.ErrorAs(t, err, &envErr) @@ -123,7 +125,7 @@ func TestLoadExamples(t *testing.T) { t.Run(file, func(t *testing.T) { t.Parallel() - teams, err := Load(t.Context(), file, config.RuntimeConfig{}) + teams, err := Load(t.Context(), file, runtimeConfig) require.NoError(t, err) require.NotEmpty(t, teams) }) diff --git a/pkg/tools/mcp/gateway.go b/pkg/tools/mcp/gateway.go index 1d97b63f5..34447b46b 100644 --- a/pkg/tools/mcp/gateway.go +++ b/pkg/tools/mcp/gateway.go @@ -35,10 +35,11 @@ func NewGatewayToolset(mcpServerName string, config any, toolFilter []string, en slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName, "toolFilter", toolFilter) return &GatewayToolset{ - mcpServerName: mcpServerName, - config: config, - toolFilter: toolFilter, - envProvider: envProvider, + mcpServerName: mcpServerName, + config: config, + toolFilter: toolFilter, + envProvider: envProvider, + cleanUpConfig: func() error { return nil }, cleanUpSecrets: func() error { return nil }, } @@ -50,7 +51,7 @@ func (t *GatewayToolset) Instructions() string { func (t *GatewayToolset) configureOnce(ctx context.Context) error { // Check which secrets (env vars) are required by the MCP server. - secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName, gateway.DockerCatalogURL) + secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName) if err != nil { return fmt.Errorf("reading which secrets the MCP server needs: %w", err) }