Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ func (u *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error
r2 := req.Clone(req.Context())
maps.Copy(r2.Header, u.httpOptions.Header)

// Forward the agent session ID only on gateway-bound calls. The
// gating on `X-Cagent-Forward` keeps the identifier out of direct
// provider requests and unrelated outbound HTTP made through this
// transport, even though `SessionIDFromContext` is populated for
// every call originating in the run loop.
if r2.Header.Get("X-Cagent-Forward") != "" {
if sid := SessionIDFromContext(r2.Context()); sid != "" {
r2.Header.Set("X-Cagent-Session-Id", sid)
}
}

if u.httpOptions.Query != nil {
q := r2.URL.Query()
for k, vs := range u.httpOptions.Query {
Expand Down
69 changes: 67 additions & 2 deletions pkg/httpclient/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package httpclient

import (
"context"
"net/http"
"net/http/httptest"
"testing"
Expand Down Expand Up @@ -68,15 +69,23 @@ func TestHeaders(t *testing.T) {
// to a test server, and returns the headers the server received.
func doRequest(t *testing.T, opts ...Opt) http.Header {
t.Helper()
return doRequestWithCtx(t, t.Context(), opts...)
}

// doRequestWithCtx is like doRequest but uses the supplied context for
// the outbound request, so callers can exercise context-derived header
// injection (e.g. session ID propagation).
func doRequestWithCtx(t *testing.T, ctx context.Context, opts ...Opt) http.Header {
t.Helper()

var capturedHeaders http.Header
srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
capturedHeaders = r.Header
}))
defer srv.Close()

client := NewHTTPClient(t.Context(), opts...)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, srv.URL, http.NoBody)
client := NewHTTPClient(ctx, opts...)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
Expand All @@ -85,3 +94,59 @@ func doRequest(t *testing.T, opts ...Opt) http.Header {

return capturedHeaders
}

func TestSessionIDHeader_GatewayBoundOnly(t *testing.T) {
t.Parallel()

tests := []struct {
name string
ctxSessionID string
opts []Opt
wantHeaderSent bool
}{
{
name: "session ID present, gateway-bound (X-Cagent-Forward set) → header sent",
ctxSessionID: "sess-abc",
opts: []Opt{WithProxiedBaseURL("https://gateway.example/v1")},
wantHeaderSent: true,
},
{
name: "session ID present, no X-Cagent-Forward → header skipped",
ctxSessionID: "sess-abc",
opts: nil,
wantHeaderSent: false,
},
{
name: "no session ID on context, gateway-bound → header skipped",
ctxSessionID: "",
opts: []Opt{WithProxiedBaseURL("https://gateway.example/v1")},
wantHeaderSent: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx := t.Context()
if tt.ctxSessionID != "" {
ctx = ContextWithSessionID(ctx, tt.ctxSessionID)
}
headers := doRequestWithCtx(t, ctx, tt.opts...)

if tt.wantHeaderSent {
assert.Equal(t, tt.ctxSessionID, headers.Get("X-Cagent-Session-Id"))
} else {
assert.Empty(t, headers.Get("X-Cagent-Session-Id"))
}
})
}
}

func TestContextWithSessionID_RoundTrip(t *testing.T) {
t.Parallel()

assert.Empty(t, SessionIDFromContext(t.Context()), "empty context yields empty session ID")
ctx := ContextWithSessionID(t.Context(), "sess-xyz")
assert.Equal(t, "sess-xyz", SessionIDFromContext(ctx))
}
21 changes: 21 additions & 0 deletions pkg/httpclient/session.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package httpclient

import "context"

type sessionIDKey struct{}

// ContextWithSessionID returns a new context carrying the given session ID.
// When set, [userAgentTransport.RoundTrip] forwards it as the
// `X-Cagent-Session-Id` header — but only on gateway-bound requests
// (those already carrying `X-Cagent-Forward`), to keep the identifier
// out of direct provider calls and unrelated outbound HTTP.
func ContextWithSessionID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, sessionIDKey{}, id)
}

// SessionIDFromContext returns the session ID stored on ctx by
// [ContextWithSessionID], or the empty string if none is set.
func SessionIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(sessionIDKey{}).(string)
return id
}
6 changes: 6 additions & 0 deletions pkg/runtime/loop.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/docker/docker-agent/pkg/agent"
"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/compaction"
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/runtime/toolexec"
Expand Down Expand Up @@ -166,6 +167,11 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
// goroutine so it has a real name in stack traces and is easier to navigate
// in editors.
func (r *LocalRuntime) runStreamLoop(ctx context.Context, sess *session.Session, events chan Event) {
// Seed the cagent session ID at the run-loop boundary so any
// gateway-bound HTTP call originating from this loop can correlate
// back to the originating session. Plumbing happens in
// pkg/httpclient/userAgentTransport, gated on `X-Cagent-Forward`.
ctx = httpclient.ContextWithSessionID(ctx, sess.ID)
r.telemetry.RecordSessionStart(ctx, r.CurrentAgentName(), sess.ID)

ctx, sessionSpan := r.startSpan(ctx, "runtime.session", trace.WithAttributes(
Expand Down
7 changes: 7 additions & 0 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/docker/docker-agent/pkg/config/types"
"github.com/docker/docker-agent/pkg/hooks"
"github.com/docker/docker-agent/pkg/hooks/builtins"
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/modelsdev"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/sessiontitle"
Expand Down Expand Up @@ -1176,6 +1177,12 @@ func (r *LocalRuntime) Summarize(ctx context.Context, sess *session.Session, add
// hooks then fire inside [LocalRuntime.doCompact] with [Input.CompactionReason]
// set to the canonical reason; they may veto or supply a custom summary.
func (r *LocalRuntime) compactWithReason(ctx context.Context, sess *session.Session, additionalPrompt, reason string, events chan Event) {
// Stamp the session ID on ctx so the compaction LLM call carries
// `X-Cagent-Session-Id` to the gateway. Manual compaction
// (via `Summarize` from the App) bypasses `runStreamLoop`'s seed;
// internal callers (proactive threshold, overflow recovery) already
// run with a stamped ctx, but re-stamping is idempotent.
ctx = httpclient.ContextWithSessionID(ctx, sess.ID)
a := r.resolveSessionAgent(sess)

source := preCompactSourceFor(reason)
Expand Down
7 changes: 7 additions & 0 deletions pkg/sessiontitle/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/httpclient"
"github.com/docker/docker-agent/pkg/model/provider"
"github.com/docker/docker-agent/pkg/model/provider/options"
)
Expand Down Expand Up @@ -60,6 +61,12 @@ func (g *Generator) Generate(ctx context.Context, sessionID string, userMessages
return "", nil
}

// Title generation runs outside the run loop, so the session ID
// is not yet on ctx. Stamp it here so the gateway-bound LLM calls
// below carry `X-Cagent-Session-Id` and remain attributable to
// the originating session.
ctx = httpclient.ContextWithSessionID(ctx, sessionID)

// Apply timeout to prevent hanging on slow or unresponsive models.
ctx, cancel := context.WithTimeout(ctx, titleGenerationTimeout)
defer cancel()
Expand Down
Loading