Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions pkg/gateway/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,36 @@ package gateway

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/goccy/go-yaml"
"github.com/docker/cagent/pkg/sync"
)

const DockerCatalogURL = "https://desktop.docker.com/mcp/catalog/v3/catalog.yaml"

func ParseServerRef(ref string) string {
return strings.TrimPrefix(ref, "docker:")
}

func RequiredEnvVars(ctx context.Context, serverName, catalogURL string) ([]Secret, error) {
catalog, err := readCatalog(ctx, catalogURL)
func RequiredEnvVars(ctx context.Context, serverName string) ([]Secret, error) {
catalog, err := readCatalogOnce()
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to fetch MCP catalog: %w", err)
}

server, ok := catalog[serverName]
if !ok {
return nil, fmt.Errorf("MCP server %q not found in catalog %q", serverName, catalogURL)
return nil, fmt.Errorf("MCP server %q not found in MCP catalog", serverName)
}

return server.Secrets, nil
}

// TODO(dga): cache the catalog.
func readCatalog(ctx context.Context, url string) (Catalog, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, err
}
// Read the MCP Catalog only once and cache the result.
var readCatalogOnce = sync.OnceErr(func() (Catalog, error) {
// Use the JSON version because it's 3x time faster to parse than YAML.
url := strings.Replace(DockerCatalogURL, ".yaml", ".json", 1)

resp, err := http.DefaultClient.Do(req)
resp, err := http.Get(url)
if err != nil {
return nil, err
}
Expand All @@ -47,15 +41,10 @@ func readCatalog(ctx context.Context, url string) (Catalog, error) {
return nil, fmt.Errorf("failed to fetch URL: %s, status: %s", url, resp.Status)
}

buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

var topLevel topLevel
if err := yaml.Unmarshal(buf, &topLevel); err != nil {
if err := json.NewDecoder(resp.Body).Decode(&topLevel); err != nil {
return nil, err
}

return topLevel.Catalog, nil
}
})
17 changes: 17 additions & 0 deletions pkg/gateway/catalog_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package gateway

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRequiredEnvVars(t *testing.T) {
secrets, err := RequiredEnvVars(t.Context(), "github-official")
require.NoError(t, err)

assert.Len(t, secrets, 1)
assert.Equal(t, "GITHUB_PERSONAL_ACCESS_TOKEN", secrets[0].Env)
assert.Equal(t, "github.personal_access_token", secrets[0].Name)
}
9 changes: 9 additions & 0 deletions pkg/gateway/servers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package gateway

import (
"strings"
)

func ParseServerRef(ref string) string {
return strings.TrimPrefix(ref, "docker:")
}
10 changes: 5 additions & 5 deletions pkg/gateway/types.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package gateway

type topLevel struct {
Catalog Catalog `json:"registry" yaml:"registry"`
Catalog Catalog `json:"registry"`
}

type Catalog map[string]Server

type Server struct {
Secrets []Secret `json:"secrets,omitempty" yaml:"secrets,omitempty"`
Secrets []Secret `json:"secrets,omitempty"`
}

type Secret struct {
Name string `json:"name" yaml:"name"`
Env string `json:"env" yaml:"env"`
Example string `json:"example" yaml:"example"`
Name string `json:"name"`
Env string `json:"env"`
Example string `json:"example"`
}
2 changes: 1 addition & 1 deletion pkg/secrets/gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func GatherEnvVarsForTools(ctx context.Context, cfg *latest.Config) ([]string, e
for _, ref := range gatherMCPServerReferences(cfg) {
mcpServerName := gateway.ParseServerRef(ref)

secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName, gateway.DockerCatalogURL)
secrets, err := gateway.RequiredEnvVars(ctx, mcpServerName)
if err != nil {
return nil, fmt.Errorf("reading which secrets the MCP server needs: %w", err)
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/sync/oncerr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package sync

import "sync"

func OnceErr[T any](fn func() (T, error)) func() (T, error) {
var once sync.Once
var result T
var err error

return func() (T, error) {
once.Do(func() {
result, err = fn()
})
return result, err
}
}
52 changes: 52 additions & 0 deletions pkg/sync/oncerr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package sync

import (
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func TestOnceErr(t *testing.T) {
t.Parallel()

called := 0
fn := func() (int, error) {
called++
return 42, nil
}

memoizedFn := OnceErr(fn)

value, err := memoizedFn()
require.NoError(t, err)
require.Equal(t, 42, value)
require.Equal(t, 1, called)

value, err = memoizedFn()
require.NoError(t, err)
require.Equal(t, 42, value)
require.Equal(t, 1, called) // Didn't have to call the inner fn
}

func TestOnceErr_Error(t *testing.T) {
t.Parallel()

called := 0
fn := func() (int, error) {
called++
return 1337, errors.New("test error")
}

memoizedFn := OnceErr(fn)

value, err := memoizedFn()
require.Error(t, err)
require.Equal(t, 1337, value)
require.Equal(t, 1, called)

value, err = memoizedFn()
require.Error(t, err)
require.Equal(t, 1337, value)
require.Equal(t, 1, called) // Didn't have to call the inner fn
}
6 changes: 4 additions & 2 deletions pkg/teamloader/teamloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ func TestLoadExamples(t *testing.T) {
// Collect the missing env vars.
missingEnvs := map[string]bool{}

var runtimeConfig config.RuntimeConfig

for _, file := range collectExamples(t) {
t.Run(file, func(t *testing.T) {
_, err := Load(t.Context(), file, config.RuntimeConfig{})
_, err := Load(t.Context(), file, runtimeConfig)
if err != nil {
envErr := &environment.RequiredEnvError{}
require.ErrorAs(t, err, &envErr)
Expand All @@ -123,7 +125,7 @@ func TestLoadExamples(t *testing.T) {
t.Run(file, func(t *testing.T) {
t.Parallel()

teams, err := Load(t.Context(), file, config.RuntimeConfig{})
teams, err := Load(t.Context(), file, runtimeConfig)
require.NoError(t, err)
require.NotEmpty(t, teams)
})
Expand Down
11 changes: 6 additions & 5 deletions pkg/tools/mcp/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ func NewGatewayToolset(mcpServerName string, config any, toolFilter []string, en
slog.Debug("Creating MCP Gateway toolset", "name", mcpServerName, "toolFilter", toolFilter)

return &GatewayToolset{
mcpServerName: mcpServerName,
config: config,
toolFilter: toolFilter,
envProvider: envProvider,
mcpServerName: mcpServerName,
config: config,
toolFilter: toolFilter,
envProvider: envProvider,

cleanUpConfig: func() error { return nil },
cleanUpSecrets: func() error { return nil },
}
Expand All @@ -50,7 +51,7 @@ func (t *GatewayToolset) Instructions() string {

func (t *GatewayToolset) configureOnce(ctx context.Context) error {
// Check which secrets (env vars) are required by the MCP server.
secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName, gateway.DockerCatalogURL)
secrets, err := gateway.RequiredEnvVars(ctx, t.mcpServerName)
if err != nil {
return fmt.Errorf("reading which secrets the MCP server needs: %w", err)
}
Expand Down
Loading