Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ type InterceptionRecord struct {
StartedAt time.Time
}

type InterceptionRecordEnded struct {
ID string
EndedAt time.Time
}

type TokenUsageRecord struct {
InterceptionID string
MsgID string
Expand Down Expand Up @@ -48,6 +53,8 @@ type ToolUsageRecord struct {
type Recorder interface {
// RecordInterception records metadata about an interception with an upstream AI provider.
RecordInterception(ctx context.Context, req *InterceptionRecord) error
// RecordInterceptionEnded records that given interception has completed.
RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error
// RecordTokenUsage records the tokens used in an interception with an upstream AI provider.
RecordTokenUsage(ctx context.Context, req *TokenUsageRecord) error
// RecordPromptUsage records the prompts used in an interception with an upstream AI provider.
Expand Down
48 changes: 42 additions & 6 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"slices"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -174,6 +175,8 @@ func TestAnthropicMessages(t *testing.T) {

require.Len(t, recorderClient.userPrompts, 1)
assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt)

recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
Expand Down Expand Up @@ -273,6 +276,8 @@ func TestOpenAIChatCompletions(t *testing.T) {

require.Len(t, recorderClient.userPrompts, 1)
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)

recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
Expand Down Expand Up @@ -437,6 +442,8 @@ func TestSimple(t *testing.T) {

require.GreaterOrEqual(t, len(recorderClient.tokenUsages), 1)
require.Equal(t, recorderClient.tokenUsages[0].MsgID, tc.expectedMsgID)

recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
Expand Down Expand Up @@ -574,8 +581,10 @@ func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
return map[string]mcp.ServerProxier{proxy.Name(): proxy}
}

type configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
type createRequestFunc func(*testing.T, string, []byte) *http.Request
type (
configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
createRequestFunc func(*testing.T, string, []byte) *http.Request
)

func TestAnthropicInjectedTools(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -953,6 +962,7 @@ func TestErrorHandling(t *testing.T) {
require.NoError(t, err)

tc.responseHandlerFn(streaming, resp)
recorderClient.verifyAllInterceptionsEnded(t)
})
}
})
Expand Down Expand Up @@ -1097,10 +1107,11 @@ var _ aibridge.Recorder = &mockRecorderClient{}
type mockRecorderClient struct {
mu sync.Mutex

interceptions []*aibridge.InterceptionRecord
tokenUsages []*aibridge.TokenUsageRecord
userPrompts []*aibridge.PromptUsageRecord
toolUsages []*aibridge.ToolUsageRecord
interceptions []*aibridge.InterceptionRecord
tokenUsages []*aibridge.TokenUsageRecord
userPrompts []*aibridge.PromptUsageRecord
toolUsages []*aibridge.ToolUsageRecord
interceptionsEnd map[string]time.Time
}

func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibridge.InterceptionRecord) error {
Expand All @@ -1110,6 +1121,19 @@ func (m *mockRecorderClient) RecordInterception(ctx context.Context, req *aibrid
return nil
}

func (m *mockRecorderClient) RecordInterceptionEnded(ctx context.Context, req *aibridge.InterceptionRecordEnded) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.interceptionsEnd == nil {
m.interceptionsEnd = make(map[string]time.Time)
}
if !slices.ContainsFunc(m.interceptions, func(intc *aibridge.InterceptionRecord) bool { return intc.ID == req.ID }) {
return fmt.Errorf("id not found")
}
m.interceptionsEnd[req.ID] = req.EndedAt
return nil
}

func (m *mockRecorderClient) RecordPromptUsage(ctx context.Context, req *aibridge.PromptUsageRecord) error {
m.mu.Lock()
defer m.mu.Unlock()
Expand All @@ -1131,6 +1155,18 @@ func (m *mockRecorderClient) RecordToolUsage(ctx context.Context, req *aibridge.
return nil
}

// verify all recorded interceptions has been marked as completed
func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
t.Helper()

m.mu.Lock()
defer m.mu.Unlock()
require.Equalf(t, len(m.interceptions), len(m.interceptionsEnd), "got %v interception ended calls, want: %v", len(m.interceptionsEnd), len(m.interceptions))
for _, intc := range m.interceptions {
require.Containsf(t, m.interceptionsEnd, intc.ID, "interception with id: %v has not been ended", intc.ID)
}
}

const mockToolName = "coder_list_workspaces"

func createMockMCPSrv(t *testing.T) http.Handler {
Expand Down
6 changes: 4 additions & 2 deletions interception.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package aibridge

import (
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -67,10 +66,13 @@ func newInterceptionProcessor(p Provider, logger slog.Logger, recorder Recorder,

log := logger.With(slog.F("route", r.URL.Path), slog.F("provider", p.Name()), slog.F("interception_id", interceptor.ID()))

log.Debug(context.Background(), "started interception")
log.Debug(r.Context(), "interception started")
if err := interceptor.ProcessRequest(w, r); err != nil {
log.Warn(r.Context(), "interception failed", slog.Error(err))
} else {
log.Debug(r.Context(), "interception ended")
}
asyncRecorder.RecordInterceptionEnded(r.Context(), &InterceptionRecordEnded{ID: interceptor.ID().String()})

// Ensure all recording have completed before completing request.
asyncRecorder.Wait()
Expand Down
2 changes: 1 addition & 1 deletion mcp/client_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ func GetClientInfo() mcp.Implementation {
Name: "coder/aibridge",
Version: buildinfo.Version(),
}
}
}
2 changes: 1 addition & 1 deletion mcp/client_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ func TestGetClientInfo(t *testing.T) {
assert.NotEmpty(t, info.Version)
// Version will either be a git revision, a semantic version, or a combination
assert.NotEqual(t, "", info.Version)
}
}
31 changes: 31 additions & 0 deletions recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ func (r *RecorderWrapper) RecordInterception(ctx context.Context, req *Intercept
return err
}

func (r *RecorderWrapper) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error {
client, err := r.clientFn()
if err != nil {
return fmt.Errorf("acquire client: %w", err)
}

req.EndedAt = time.Now().UTC()
if err = client.RecordInterceptionEnded(ctx, req); err == nil {
return nil
}

r.logger.Warn(ctx, "failed to record that interception ended", slog.Error(err), slog.F("interception_id", req.ID))
return err
}

func (r *RecorderWrapper) RecordPromptUsage(ctx context.Context, req *PromptUsageRecord) error {
client, err := r.clientFn()
if err != nil {
Expand Down Expand Up @@ -103,6 +118,22 @@ func (a *AsyncRecorder) RecordInterception(ctx context.Context, req *Interceptio
panic("RecordInterception must not be called asynchronously")
}

func (a *AsyncRecorder) RecordInterceptionEnded(ctx context.Context, req *InterceptionRecordEnded) error {
a.wg.Add(1)
go func() {
defer a.wg.Done()
timedCtx, cancel := context.WithTimeout(context.Background(), a.timeout)
defer cancel()

err := a.wrapped.RecordInterceptionEnded(timedCtx, req)
if err != nil {
a.logger.Warn(timedCtx, "failed to record interception end", slog.F("type", "prompt"), slog.Error(err), slog.F("payload", req))
}
}()

return nil // Caller is not interested in error.
}

func (a *AsyncRecorder) RecordPromptUsage(_ context.Context, req *PromptUsageRecord) error {
a.wg.Add(1)
go func() {
Expand Down