Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cmd/root/a2a.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package root

import (
"path/filepath"

"github.com/spf13/cobra"

"github.com/docker/docker-agent/pkg/a2a"
"github.com/docker/docker-agent/pkg/cli"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/paths"
"github.com/docker/docker-agent/pkg/telemetry"
)

type a2aFlags struct {
agentName string
listenAddr string
sessionDB string
runConfig config.RuntimeConfig
}

Expand All @@ -30,6 +34,7 @@ func newA2ACmd() *cobra.Command {

cmd.PersistentFlags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to run (defaults to the team's first agent)")
cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8082", "Address to listen on")
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database")
addRuntimeConfigFlags(cmd, &flags.runConfig)

return cmd
Expand All @@ -52,5 +57,5 @@ func (f *a2aFlags) runA2ACommand(cmd *cobra.Command, args []string) (commandErr
defer cleanup()

out.Println("Listening on", ln.Addr().String())
return a2a.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln)
return a2a.Run(ctx, agentFilename, f.agentName, f.sessionDB, &f.runConfig, ln)
}
3 changes: 2 additions & 1 deletion e2e/a2a_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net"
"net/http"
"path/filepath"
"testing"

"github.com/a2aproject/a2a-go/a2a"
Expand Down Expand Up @@ -173,7 +174,7 @@ func startA2AServer(t *testing.T, agentFile string, runConfig *config.RuntimeCon
require.NoError(t, err)

go func() {
_ = a2aserver.Run(t.Context(), agentFile, "root", runConfig, ln)
_ = a2aserver.Run(t.Context(), agentFile, "root", filepath.Join(t.TempDir(), "session.db"), runConfig, ln)
}()

port := ln.Addr().(*net.TCPAddr).Port
Expand Down
42 changes: 30 additions & 12 deletions pkg/a2a/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"iter"
"log/slog"
"os"
"strings"

"go.opentelemetry.io/otel"
Expand All @@ -22,7 +23,7 @@ import (
// newDockerAgentAdapter creates a new ADK agent adapter from a docker agent team and agent name.
// When agentName is empty, the team's default agent (one explicitly named "root" if it
// exists, otherwise the first agent declared) is used.
func newDockerAgentAdapter(t *team.Team, agentName string) (agent.Agent, error) {
func newDockerAgentAdapter(t *team.Team, agentName string, sessStore session.Store) (agent.Agent, error) {
a, err := t.AgentOrDefault(agentName)
if err != nil {
return nil, fmt.Errorf("failed to get agent %s: %w", agentName, err)
Expand All @@ -35,31 +36,48 @@ func newDockerAgentAdapter(t *team.Team, agentName string) (agent.Agent, error)
Name: agentName,
Description: desc,
Run: func(ctx agent.InvocationContext) iter.Seq2[*adksession.Event, error] {
return runDockerAgent(ctx, t, agentName, a)
return runDockerAgent(ctx, t, agentName, a, sessStore)
},
})
}

// runDockerAgent executes a docker agent and returns ADK session events
func runDockerAgent(ctx agent.InvocationContext, t *team.Team, agentName string, a *dagent.Agent) iter.Seq2[*adksession.Event, error] {
func runDockerAgent(ctx agent.InvocationContext, t *team.Team, agentName string, a *dagent.Agent, sessStore session.Store) iter.Seq2[*adksession.Event, error] {
return func(yield func(*adksession.Event, error) bool) {
// Extract user message from the ADK context
userContent := ctx.UserContent()
message := contentToMessage(userContent)

// Create a session
sess := session.New(
session.WithUserMessage(message),
session.WithMaxIterations(a.MaxIterations()),
session.WithMaxConsecutiveToolCalls(a.MaxConsecutiveToolCalls()),
session.WithMaxOldToolCallTokens(a.MaxOldToolCallTokens()),
session.WithToolsApproved(true),
session.WithNonInteractive(true),
)
// Use the A2A contextID (exposed as the ADK session ID) as the
// docker-agent session ID so subsequent `run --session <id>`
// invocations can resume the same conversation.
sessionID := ctx.Session().ID()

var sess *session.Session
if existing, err := sessStore.GetSession(ctx, sessionID); err == nil && existing != nil {
sess = existing
sess.AddMessage(session.UserMessage(message))
sess.ToolsApproved = true
sess.NonInteractive = true
} else {
workingDir, _ := os.Getwd()
sess = session.New(
session.WithID(sessionID),
session.WithUserMessage(message),
session.WithMaxIterations(a.MaxIterations()),
session.WithMaxConsecutiveToolCalls(a.MaxConsecutiveToolCalls()),
session.WithMaxOldToolCallTokens(a.MaxOldToolCallTokens()),
session.WithToolsApproved(true),
session.WithNonInteractive(true),
session.WithWorkingDir(workingDir),
)
sess.Title = "A2A Session " + sessionID
}

// Create runtime
rt, err := runtime.New(t,
runtime.WithCurrentAgent(agentName),
runtime.WithSessionStore(sessStore),
runtime.WithTracer(otel.Tracer("cagent")),
)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/a2a/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestNewDockerAgentAdapter(t *testing.T) {
require.NoError(t, team.StopToolSets(t.Context()))
}()

adapter, err := newDockerAgentAdapter(team, "root")
adapter, err := newDockerAgentAdapter(team, "root", nil)

require.NoError(t, err)
assert.Equal(t, "root", adapter.Name())
Expand All @@ -42,7 +42,7 @@ func TestNewCAgentAdapter_NonExistent(t *testing.T) {
require.NoError(t, team.StopToolSets(t.Context()))
}()

_, err = newDockerAgentAdapter(team, "nonexistent")
_, err = newDockerAgentAdapter(team, "nonexistent", nil)

assert.Contains(t, err.Error(), "failed to get agent")
}
Expand Down
19 changes: 15 additions & 4 deletions pkg/a2a/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ import (
"github.com/labstack/echo/v4/middleware"
"google.golang.org/adk/runner"
"google.golang.org/adk/server/adka2a"
"google.golang.org/adk/session"
adksession "google.golang.org/adk/session"

"github.com/docker/docker-agent/pkg/config"
pathx "github.com/docker/docker-agent/pkg/path"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/teamloader"
"github.com/docker/docker-agent/pkg/version"
)
Expand All @@ -36,7 +38,7 @@ func routableAddr(addr string) string {
return addr
}

func Run(ctx context.Context, agentFilename, agentName string, runConfig *config.RuntimeConfig, ln net.Listener) error {
func Run(ctx context.Context, agentFilename, agentName, sessionDB string, runConfig *config.RuntimeConfig, ln net.Listener) error {
slog.DebugContext(ctx, "Starting A2A server", "source", agentFilename, "agent", agentName, "addr", ln.Addr().String())

agentSource, err := config.Resolve(agentFilename, nil)
Expand All @@ -54,7 +56,16 @@ func Run(ctx context.Context, agentFilename, agentName string, runConfig *config
}
}()

adkAgent, err := newDockerAgentAdapter(t, agentName)
expandedSessionDB, err := pathx.ExpandHomeDir(sessionDB)
if err != nil {
return fmt.Errorf("failed to expand session db path: %w", err)
}
sessStore, err := session.NewSQLiteSessionStore(expandedSessionDB)
if err != nil {
return fmt.Errorf("failed to open session store: %w", err)
}

adkAgent, err := newDockerAgentAdapter(t, agentName, sessStore)
if err != nil {
return fmt.Errorf("failed to create ADK agent adapter: %w", err)
}
Expand Down Expand Up @@ -87,7 +98,7 @@ func Run(ctx context.Context, agentFilename, agentName string, runConfig *config
RunnerConfig: runner.Config{
AppName: name,
Agent: adkAgent,
SessionService: session.InMemoryService(),
SessionService: adksession.InMemoryService(),
},
})

Expand Down
Loading