From 8cf491811f20f81cbd2adea3804ff9ff3cae5fad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 27 Oct 2025 11:29:56 +0000 Subject: [PATCH 1/2] feat: mark interceptions as completed --- api.go | 7 ++++++ bridge_integration_test.go | 49 +++++++++++++++++++++++++++++++++----- interception.go | 6 +++-- mcp/client_info.go | 2 +- mcp/client_info_test.go | 2 +- recorder.go | 31 ++++++++++++++++++++++++ 6 files changed, 87 insertions(+), 10 deletions(-) diff --git a/api.go b/api.go index 94fcba2..5c152e9 100644 --- a/api.go +++ b/api.go @@ -16,6 +16,11 @@ type InterceptionRecord struct { StartedAt time.Time } +type InterceptionRecordEnded struct { + ID string + EndedAt time.Time +} + type TokenUsageRecord struct { InterceptionID string MsgID string @@ -48,6 +53,8 @@ type ToolUsageRecord struct { type Recorder interface { // RecordInterception records metadata about an interception with an upstream AI provider. RecordInterception(ctx context.Context, req *InterceptionRecord) error + // RecordInterceptionEnded records that given interception has completed. + RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error // RecordTokenUsage records the tokens used in an interception with an upstream AI provider. RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error // RecordPromptUsage records the prompts used in an interception with an upstream AI provider. diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b0de5f8..d0adb16 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "net/http/httptest" + "slices" "sync" "sync/atomic" "testing" @@ -174,6 +175,8 @@ func TestAnthropicMessages(t *testing.T) { require.Len(t, recorderClient.userPrompts, 1) assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt) + + recorderClient.verifyAllInterceptionsEnded(t) }) } }) @@ -273,6 +276,8 @@ func TestOpenAIChatCompletions(t *testing.T) { require.Len(t, recorderClient.userPrompts, 1) assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt) + + recorderClient.verifyAllInterceptionsEnded(t) }) } }) @@ -437,6 +442,8 @@ func TestSimple(t *testing.T) { require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1) require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID) + + recorderClient.verifyAllInterceptionsEnded(t) }) } }) @@ -574,8 +581,10 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier { return map[string]mcp.ServerProxier{proxy.Name(): proxy} } -type configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) -type createRequestFunc func(*testing.T, string, []byte) *http.Request +type ( + configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) + createRequestFunc func(*testing.T, string, []byte) *http.Request +) func TestAnthropicInjectedTools(t *testing.T) { t.Parallel() @@ -953,6 +962,7 @@ func TestErrorHandling(t *testing.T) { require.NoError(t, err) tc.responseHandlerFn(streaming, resp) + recorderClient.verifyAllInterceptionsEnded(t) }) } }) @@ -1097,10 +1107,11 @@ var _ aibridge.Recorder = &mockRecorderClient{} type mockRecorderClient struct { mu sync.Mutex - interceptions []*aibridge.InterceptionRecord - tokenUsages []*aibridge.TokenUsageRecord - userPrompts []*aibridge.PromptUsageRecord - toolUsages []*aibridge.ToolUsageRecord + interceptions []*aibridge.InterceptionRecord + tokenUsages []*aibridge.TokenUsageRecord + userPrompts []*aibridge.PromptUsageRecord + toolUsages []*aibridge.ToolUsageRecord + interceptionsEnd map[string]time.Time } func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error { @@ -1110,6 +1121,19 @@ func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibrid return nil } +func (m *mockRecorderClient) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error { + m.mu.Lock() + defer m.mu.Unlock() + if m.interceptionsEnd == nil { + m.interceptionsEnd = make(map[string]time.Time) + } + if !slices.ContainsFunc(m.interceptions, func(intc *aibridge.InterceptionRecord) bool { return intc.ID == req.ID }) { + return fmt.Errorf("id not found") + } + m.interceptionsEnd[req.ID] = req.EndedAt + return nil +} + func (m *mockRecorderClient) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error { m.mu.Lock() defer m.mu.Unlock() @@ -1131,6 +1155,19 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge. return nil } +// verify all interceptions has been marked as completed +func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { + m.mu.Lock() + defer m.mu.Unlock() + if len(m.interceptions) == 0 { + t.Errorf("HMMM") + } + require.Equal(t, len(m.interceptions), len(m.interceptionsEnd)) + for _, intc := range m.interceptions { + require.Contains(t, m.interceptionsEnd, intc.ID) + } +} + const mockToolName = "coder_list_workspaces" func createMockMCPSrv(t *testing.T) http.Handler { diff --git a/interception.go b/interception.go index ef871b9..9422d01 100644 --- a/interception.go +++ b/interception.go @@ -1,7 +1,6 @@ package aibridge import ( - "context" "errors" "fmt" "net/http" @@ -67,10 +66,13 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder, log := logger.With(slog.F("route", r.URL.Path), slog.F("provider", p.Name()), slog.F("interception_id", interceptor.ID())) - log.Debug(context.Background(), "started interception") + log.Debug(r.Context(), "interception started") if err := interceptor.ProcessRequest(w, r); err != nil { log.Warn(r.Context(), "interception failed", slog.Error(err)) + } else { + log.Debug(r.Context(), "interception ended") } + asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()}) // Ensure all recording have completed before completing request. asyncRecorder.Wait() diff --git a/mcp/client_info.go b/mcp/client_info.go index c573621..84b33d0 100644 --- a/mcp/client_info.go +++ b/mcp/client_info.go @@ -12,4 +12,4 @@ func GetClientInfo() mcp.Implementation { Name: "coder/aibridge", Version: buildinfo.Version(), } -} \ No newline at end of file +} diff --git a/mcp/client_info_test.go b/mcp/client_info_test.go index 730b832..d273d10 100644 --- a/mcp/client_info_test.go +++ b/mcp/client_info_test.go @@ -14,4 +14,4 @@ func TestGetClientInfo(t *testing.T) { assert.NotEmpty(t, info.Version) // Version will either be a git revision, a semantic version, or a combination assert.NotEqual(t, "", info.Version) -} \ No newline at end of file +} diff --git a/recorder.go b/recorder.go index 6d9a2e6..cf28387 100644 --- a/recorder.go +++ b/recorder.go @@ -33,6 +33,21 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept return err } +func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { + client, err := r.clientFn() + if err != nil { + return fmt.Errorf("acquire client: %w", err) + } + + req.EndedAt = time.Now().UTC() + if err = client.RecordInterceptionEnded(ctx, req); err == nil { + return nil + } + + r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err), slog.F("interception_id", req.ID)) + return err +} + func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error { client, err := r.clientFn() if err != nil { @@ -103,6 +118,22 @@ func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *Interceptio panic("RecordInterception must not be called asynchronously") } +func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error { + a.wg.Add(1) + go func() { + defer a.wg.Done() + timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout) + defer cancel() + + err := a.wrapped.RecordInterceptionEnded(timedCtx, req) + if err != nil { + a.logger.Warn(timedCtx, "failed to record interception end", slog.F("type", "prompt"), slog.Error(err), slog.F("payload", req)) + } + }() + + return nil // Caller is not interested in error. +} + func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error { a.wg.Add(1) go func() { From bb3c8d67931f4557f60c2452a6d70d1f8ce412b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Banaszewski?= Date: Mon, 27 Oct 2025 11:48:14 +0000 Subject: [PATCH 2/2] review 1 --- bridge_integration_test.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index d0adb16..0f81c8b 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1155,16 +1155,15 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge. return nil } -// verify all interceptions has been marked as completed +// verify all recorded interceptions has been marked as completed func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { + t.Helper() + m.mu.Lock() defer m.mu.Unlock() - if len(m.interceptions) == 0 { - t.Errorf("HMMM") - } - require.Equal(t, len(m.interceptions), len(m.interceptionsEnd)) + require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions)) for _, intc := range m.interceptions { - require.Contains(t, m.interceptionsEnd, intc.ID) + require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID) } }