From 50b49a746ebc9a076a9ebd0407690c58678719c7 Mon Sep 17 00:00:00 2001 From: maxcleme Date: Mon, 11 May 2026 20:05:27 +0200 Subject: [PATCH] feat(a2a): persist sessions to SQLite and resume by contextID Add a --session-db flag to the a2a command and back the A2A server with a SQLite session store. The A2A contextID is reused as the docker-agent session ID so subsequent invocations resume the same conversation. Signed-off-by: maxcleme --- cmd/root/a2a.go | 7 ++++++- e2e/a2a_test.go | 3 ++- pkg/a2a/adapter.go | 42 +++++++++++++++++++++++++++++------------ pkg/a2a/adapter_test.go | 4 ++-- pkg/a2a/server.go | 19 +++++++++++++++---- 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/cmd/root/a2a.go b/cmd/root/a2a.go index c7b8442a9..f9fa46c2b 100644 --- a/cmd/root/a2a.go +++ b/cmd/root/a2a.go @@ -1,17 +1,21 @@ package root import ( + "path/filepath" + "github.com/spf13/cobra" "github.com/docker/docker-agent/pkg/a2a" "github.com/docker/docker-agent/pkg/cli" "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/paths" "github.com/docker/docker-agent/pkg/telemetry" ) type a2aFlags struct { agentName string listenAddr string + sessionDB string runConfig config.RuntimeConfig } @@ -30,6 +34,7 @@ func newA2ACmd() *cobra.Command { cmd.PersistentFlags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to run (defaults to the team's first agent)") cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8082", "Address to listen on") + cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", filepath.Join(paths.GetHomeDir(), ".cagent", "session.db"), "Path to the session database") addRuntimeConfigFlags(cmd, &flags.runConfig) return cmd @@ -52,5 +57,5 @@ func (f *a2aFlags) runA2ACommand(cmd *cobra.Command, args []string) (commandErr defer cleanup() out.Println("Listening on", ln.Addr().String()) - return a2a.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln) + return a2a.Run(ctx, agentFilename, f.agentName, f.sessionDB, &f.runConfig, ln) } diff --git a/e2e/a2a_test.go b/e2e/a2a_test.go index 3e4ed67ec..b2bbbfaf4 100644 --- a/e2e/a2a_test.go +++ b/e2e/a2a_test.go @@ -7,6 +7,7 @@ import ( "io" "net" "net/http" + "path/filepath" "testing" "github.com/a2aproject/a2a-go/a2a" @@ -173,7 +174,7 @@ func startA2AServer(t *testing.T, agentFile string, runConfig *config.RuntimeCon require.NoError(t, err) go func() { - _ = a2aserver.Run(t.Context(), agentFile, "root", runConfig, ln) + _ = a2aserver.Run(t.Context(), agentFile, "root", filepath.Join(t.TempDir(), "session.db"), runConfig, ln) }() port := ln.Addr().(*net.TCPAddr).Port diff --git a/pkg/a2a/adapter.go b/pkg/a2a/adapter.go index 333083dc6..205cf2a0a 100644 --- a/pkg/a2a/adapter.go +++ b/pkg/a2a/adapter.go @@ -5,6 +5,7 @@ import ( "fmt" "iter" "log/slog" + "os" "strings" "go.opentelemetry.io/otel" @@ -22,7 +23,7 @@ import ( // newDockerAgentAdapter creates a new ADK agent adapter from a docker agent team and agent name. // When agentName is empty, the team's default agent (one explicitly named "root" if it // exists, otherwise the first agent declared) is used. -func newDockerAgentAdapter(t *team.Team, agentName string) (agent.Agent, error) { +func newDockerAgentAdapter(t *team.Team, agentName string, sessStore session.Store) (agent.Agent, error) { a, err := t.AgentOrDefault(agentName) if err != nil { return nil, fmt.Errorf("failed to get agent %s: %w", agentName, err) @@ -35,31 +36,48 @@ func newDockerAgentAdapter(t *team.Team, agentName string) (agent.Agent, error) Name: agentName, Description: desc, Run: func(ctx agent.InvocationContext) iter.Seq2[*adksession.Event, error] { - return runDockerAgent(ctx, t, agentName, a) + return runDockerAgent(ctx, t, agentName, a, sessStore) }, }) } // runDockerAgent executes a docker agent and returns ADK session events -func runDockerAgent(ctx agent.InvocationContext, t *team.Team, agentName string, a *dagent.Agent) iter.Seq2[*adksession.Event, error] { +func runDockerAgent(ctx agent.InvocationContext, t *team.Team, agentName string, a *dagent.Agent, sessStore session.Store) iter.Seq2[*adksession.Event, error] { return func(yield func(*adksession.Event, error) bool) { // Extract user message from the ADK context userContent := ctx.UserContent() message := contentToMessage(userContent) - // Create a session - sess := session.New( - session.WithUserMessage(message), - session.WithMaxIterations(a.MaxIterations()), - session.WithMaxConsecutiveToolCalls(a.MaxConsecutiveToolCalls()), - session.WithMaxOldToolCallTokens(a.MaxOldToolCallTokens()), - session.WithToolsApproved(true), - session.WithNonInteractive(true), - ) + // Use the A2A contextID (exposed as the ADK session ID) as the + // docker-agent session ID so subsequent `run --session ` + // invocations can resume the same conversation. + sessionID := ctx.Session().ID() + + var sess *session.Session + if existing, err := sessStore.GetSession(ctx, sessionID); err == nil && existing != nil { + sess = existing + sess.AddMessage(session.UserMessage(message)) + sess.ToolsApproved = true + sess.NonInteractive = true + } else { + workingDir, _ := os.Getwd() + sess = session.New( + session.WithID(sessionID), + session.WithUserMessage(message), + session.WithMaxIterations(a.MaxIterations()), + session.WithMaxConsecutiveToolCalls(a.MaxConsecutiveToolCalls()), + session.WithMaxOldToolCallTokens(a.MaxOldToolCallTokens()), + session.WithToolsApproved(true), + session.WithNonInteractive(true), + session.WithWorkingDir(workingDir), + ) + sess.Title = "A2A Session " + sessionID + } // Create runtime rt, err := runtime.New(t, runtime.WithCurrentAgent(agentName), + runtime.WithSessionStore(sessStore), runtime.WithTracer(otel.Tracer("cagent")), ) if err != nil { diff --git a/pkg/a2a/adapter_test.go b/pkg/a2a/adapter_test.go index 814e2f059..e83624366 100644 --- a/pkg/a2a/adapter_test.go +++ b/pkg/a2a/adapter_test.go @@ -23,7 +23,7 @@ func TestNewDockerAgentAdapter(t *testing.T) { require.NoError(t, team.StopToolSets(t.Context())) }() - adapter, err := newDockerAgentAdapter(team, "root") + adapter, err := newDockerAgentAdapter(team, "root", nil) require.NoError(t, err) assert.Equal(t, "root", adapter.Name()) @@ -42,7 +42,7 @@ func TestNewCAgentAdapter_NonExistent(t *testing.T) { require.NoError(t, team.StopToolSets(t.Context())) }() - _, err = newDockerAgentAdapter(team, "nonexistent") + _, err = newDockerAgentAdapter(team, "nonexistent", nil) assert.Contains(t, err.Error(), "failed to get agent") } diff --git a/pkg/a2a/server.go b/pkg/a2a/server.go index 64cde880a..f0d0712db 100644 --- a/pkg/a2a/server.go +++ b/pkg/a2a/server.go @@ -16,9 +16,11 @@ import ( "github.com/labstack/echo/v4/middleware" "google.golang.org/adk/runner" "google.golang.org/adk/server/adka2a" - "google.golang.org/adk/session" + adksession "google.golang.org/adk/session" "github.com/docker/docker-agent/pkg/config" + pathx "github.com/docker/docker-agent/pkg/path" + "github.com/docker/docker-agent/pkg/session" "github.com/docker/docker-agent/pkg/teamloader" "github.com/docker/docker-agent/pkg/version" ) @@ -36,7 +38,7 @@ func routableAddr(addr string) string { return addr } -func Run(ctx context.Context, agentFilename, agentName string, runConfig *config.RuntimeConfig, ln net.Listener) error { +func Run(ctx context.Context, agentFilename, agentName, sessionDB string, runConfig *config.RuntimeConfig, ln net.Listener) error { slog.DebugContext(ctx, "Starting A2A server", "source", agentFilename, "agent", agentName, "addr", ln.Addr().String()) agentSource, err := config.Resolve(agentFilename, nil) @@ -54,7 +56,16 @@ func Run(ctx context.Context, agentFilename, agentName string, runConfig *config } }() - adkAgent, err := newDockerAgentAdapter(t, agentName) + expandedSessionDB, err := pathx.ExpandHomeDir(sessionDB) + if err != nil { + return fmt.Errorf("failed to expand session db path: %w", err) + } + sessStore, err := session.NewSQLiteSessionStore(expandedSessionDB) + if err != nil { + return fmt.Errorf("failed to open session store: %w", err) + } + + adkAgent, err := newDockerAgentAdapter(t, agentName, sessStore) if err != nil { return fmt.Errorf("failed to create ADK agent adapter: %w", err) } @@ -87,7 +98,7 @@ func Run(ctx context.Context, agentFilename, agentName string, runConfig *config RunnerConfig: runner.Config{ AppName: name, Agent: adkAgent, - SessionService: session.InMemoryService(), + SessionService: adksession.InMemoryService(), }, })