From 785748cacdcd9f10ff14ba9e01dce6f47c379a10 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 21 Oct 2025 21:42:31 +0200 Subject: [PATCH 1/5] chore: req logging Signed-off-by: Danny Kopping --- config.go | 2 + intercept_anthropic_messages_base.go | 4 +- intercept_anthropic_messages_blocking.go | 11 +++-- intercept_anthropic_messages_streaming.go | 11 +++-- intercept_openai_chat_base.go | 4 +- intercept_openai_chat_blocking.go | 11 +++-- intercept_openai_chat_streaming.go | 11 +++-- provider_anthropic.go | 51 ++++++++++++++++++----- provider_openai.go | 51 ++++++++++++++++++----- 9 files changed, 108 insertions(+), 48 deletions(-) diff --git a/config.go b/config.go index 4c99eb6..34e9198 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,8 @@ package aibridge type ProviderConfig struct { BaseURL, Key string + // EnableUpstreamLogging enables logging of upstream API requests and responses to /tmp/$provider.log + EnableUpstreamLogging bool } type Config struct { diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 56c9744..beff307 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,8 +14,8 @@ type AnthropicMessagesInterceptionBase struct { id uuid.UUID req *MessageNewParamsWrapper - baseURL, key string - logger slog.Logger + cfg ProviderConfig + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index d80a25d..89436d7 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -22,12 +22,11 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg ProviderConfig) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -58,7 +57,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client := newAnthropicClient(i.baseURL, i.key, opts...) + client := newAnthropicClient(i.cfg, i.id.String(), opts...) messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index c2ad7e0..1d28467 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -25,12 +25,11 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg ProviderConfig) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -95,7 +94,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - client := newAnthropicClient(i.baseURL, i.key) + client := newAnthropicClient(i.cfg, i.id.String()) messages := i.req.MessageNewParams // Accumulate usage across the entire streaming interaction (including tool reinvocations). diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 36b8ff0..4a1a487 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -17,8 +17,8 @@ type OpenAIChatInterceptionBase struct { id uuid.UUID req *ChatCompletionNewParamsWrapper - baseURL, key string - logger slog.Logger + cfg ProviderConfig + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 3b1fa7e..b10de4c 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -23,12 +23,11 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg ProviderConfig) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.baseURL, i.key) + client := newOpenAIClient(i.cfg, i.id.String()) logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 0c5f554..f3a5107 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -25,12 +25,11 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg ProviderConfig) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -65,7 +64,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.baseURL, i.key) + client := newOpenAIClient(i.cfg, i.id.String()) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/provider_anthropic.go b/provider_anthropic.go index 192b230..9391b66 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "io" + "log" "net/http" + "net/http/httputil" "os" "github.com/anthropics/anthropic-sdk-go" @@ -19,7 +21,7 @@ var _ Provider = &AnthropicProvider{} // AnthropicProvider allows for interactions with the Anthropic API. type AnthropicProvider struct { - baseURL, key string + cfg ProviderConfig } const ( @@ -37,8 +39,7 @@ func NewAnthropicProvider(cfg ProviderConfig) *AnthropicProvider { } return &AnthropicProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + cfg: cfg, } } @@ -74,17 +75,17 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req } if req.Stream { - return NewAnthropicMessagesStreamingInterception(id, &req, p.baseURL, p.key), nil + return NewAnthropicMessagesStreamingInterception(id, &req, p.cfg), nil } - return NewAnthropicMessagesBlockingInterception(id, &req, p.baseURL, p.key), nil + return NewAnthropicMessagesBlockingInterception(id, &req, p.cfg), nil } return nil, UnknownRoute } func (p *AnthropicProvider) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL } func (p *AnthropicProvider) AuthHeader() string { @@ -96,12 +97,42 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), p.key) + headers.Set(p.AuthHeader(), p.cfg.Key) } -func newAnthropicClient(baseURL, key string, opts ...option.RequestOption) anthropic.Client { - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) +func newAnthropicClient(cfg ProviderConfig, id string, opts ...option.RequestOption) anthropic.Client { + opts = append(opts, option.WithAPIKey(cfg.Key)) + opts = append(opts, option.WithBaseURL(cfg.BaseURL)) + + if cfg.EnableUpstreamLogging { + reqLogFile, err := os.OpenFile("/tmp/anthropic-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + reqLogger := log.New(reqLogFile, "", log.LstdFlags) + + resLogFile, err := os.OpenFile("/tmp/anthropic-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + resLogger := log.New(resLogFile, "", log.LstdFlags) + + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + if reqDump, err := httputil.DumpRequest(req, true); err == nil { + reqLogger.Printf("[req] [%s] %s", id, reqDump) + } + + resp, err := next(req) + if err != nil { + resLogger.Printf("[res] [%s] Error: %v", id, err) + return resp, err + } + + if respDump, err := httputil.DumpResponse(resp, true); err == nil { + resLogger.Printf("[res] [%s] %s", id, respDump) + } + + return resp, err + })) + } + } + } return anthropic.NewClient(opts...) } diff --git a/provider_openai.go b/provider_openai.go index 5779e05..297e72f 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -4,7 +4,9 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" + "net/http/httputil" "os" "github.com/google/uuid" @@ -16,7 +18,7 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + cfg ProviderConfig } const ( @@ -35,8 +37,7 @@ func NewOpenAIProvider(cfg ProviderConfig) *OpenAIProvider { } return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + cfg: cfg, } } @@ -76,9 +77,9 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIStreamingChatInterception(id, &req, p.cfg), nil } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIBlockingChatInterception(id, &req, p.cfg), nil } } @@ -86,7 +87,7 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } func (p *OpenAIProvider) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL } func (p *OpenAIProvider) AuthHeader() string { @@ -98,13 +99,43 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), "Bearer "+p.key) + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) } -func newOpenAIClient(baseURL, key string) openai.Client { +func newOpenAIClient(cfg ProviderConfig, id string) openai.Client { var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) + opts = append(opts, option.WithAPIKey(cfg.Key)) + opts = append(opts, option.WithBaseURL(cfg.BaseURL)) + + if cfg.EnableUpstreamLogging { + reqLogFile, err := os.OpenFile("/tmp/openai-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + reqLogger := log.New(reqLogFile, "", log.LstdFlags) + + resLogFile, err := os.OpenFile("/tmp/openai-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err == nil { + resLogger := log.New(resLogFile, "", log.LstdFlags) + + opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + if reqDump, err := httputil.DumpRequest(req, true); err == nil { + reqLogger.Printf("[req] [%s] %s", id, reqDump) + } + + resp, err := next(req) + if err != nil { + resLogger.Printf("[res] [%s] Error: %v", id, err) + return resp, err + } + + if respDump, err := httputil.DumpResponse(resp, true); err == nil { + resLogger.Printf("[res] [%s] %s", id, respDump) + } + + return resp, err + })) + } + } + } return openai.NewClient(opts...) } From a8b8db23533e707d31795e668ddeca018c6bf57b Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 21 Oct 2025 22:52:11 +0200 Subject: [PATCH 2/5] chore: enable runtime request logging Signed-off-by: Danny Kopping --- config.go | 15 +++-- intercept_anthropic_messages_base.go | 2 +- intercept_anthropic_messages_blocking.go | 4 +- intercept_anthropic_messages_streaming.go | 4 +- intercept_openai_chat_base.go | 2 +- intercept_openai_chat_blocking.go | 4 +- intercept_openai_chat_streaming.go | 4 +- provider_anthropic.go | 38 ++---------- provider_openai.go | 38 ++---------- request_logger.go | 74 +++++++++++++++++++++++ 10 files changed, 107 insertions(+), 78 deletions(-) create mode 100644 request_logger.go diff --git a/config.go b/config.go index 34e9198..32be6fa 100644 --- a/config.go +++ b/config.go @@ -1,12 +1,19 @@ package aibridge +import "sync/atomic" + type ProviderConfig struct { BaseURL, Key string // EnableUpstreamLogging enables logging of upstream API requests and responses to /tmp/$provider.log - EnableUpstreamLogging bool + enableUpstreamLogging atomic.Bool +} + +// SetEnableUpstreamLogging enables or disables upstream logging at runtime. +func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) { + c.enableUpstreamLogging.Store(enabled) } -type Config struct { - OpenAI ProviderConfig - Anthropic ProviderConfig +// EnableUpstreamLogging returns whether upstream logging is currently enabled. +func (c *ProviderConfig) EnableUpstreamLogging() bool { + return c.enableUpstreamLogging.Load() } diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index beff307..35e8642 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,7 +14,7 @@ type AnthropicMessagesInterceptionBase struct { id uuid.UUID req *MessageNewParamsWrapper - cfg ProviderConfig + cfg *ProviderConfig logger slog.Logger recorder Recorder diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index 89436d7..fb4907d 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -22,7 +22,7 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg ProviderConfig) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, @@ -57,7 +57,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client := newAnthropicClient(i.cfg, i.id.String(), opts...) + client := newAnthropicClient(i.cfg, i.id.String(), i.Model(), opts...) messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 1d28467..5cc368a 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -25,7 +25,7 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg ProviderConfig) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ id: id, req: req, @@ -94,7 +94,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - client := newAnthropicClient(i.cfg, i.id.String()) + client := newAnthropicClient(i.cfg, i.id.String(), i.Model()) messages := i.req.MessageNewParams // Accumulate usage across the entire streaming interaction (including tool reinvocations). diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 4a1a487..4be1c77 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -17,7 +17,7 @@ type OpenAIChatInterceptionBase struct { id uuid.UUID req *ChatCompletionNewParamsWrapper - cfg ProviderConfig + cfg *ProviderConfig logger slog.Logger recorder Recorder diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index b10de4c..bdde017 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -23,7 +23,7 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg ProviderConfig) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, @@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.cfg, i.id.String()) + client := newOpenAIClient(i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index f3a5107..1146fe2 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -25,7 +25,7 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg ProviderConfig) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ id: id, req: req, @@ -64,7 +64,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.cfg, i.id.String()) + client := newOpenAIClient(i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/provider_anthropic.go b/provider_anthropic.go index 9391b66..b44ae59 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -5,9 +5,7 @@ import ( "errors" "fmt" "io" - "log" "net/http" - "net/http/httputil" "os" "github.com/anthropics/anthropic-sdk-go" @@ -21,7 +19,7 @@ var _ Provider = &AnthropicProvider{} // AnthropicProvider allows for interactions with the Anthropic API. type AnthropicProvider struct { - cfg ProviderConfig + cfg *ProviderConfig } const ( @@ -30,7 +28,7 @@ const ( routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages ) -func NewAnthropicProvider(cfg ProviderConfig) *AnthropicProvider { +func NewAnthropicProvider(cfg *ProviderConfig) *AnthropicProvider { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.anthropic.com/" } @@ -100,37 +98,13 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } -func newAnthropicClient(cfg ProviderConfig, id string, opts ...option.RequestOption) anthropic.Client { +func newAnthropicClient(cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client { opts = append(opts, option.WithAPIKey(cfg.Key)) opts = append(opts, option.WithBaseURL(cfg.BaseURL)) - if cfg.EnableUpstreamLogging { - reqLogFile, err := os.OpenFile("/tmp/anthropic-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - reqLogger := log.New(reqLogFile, "", log.LstdFlags) - - resLogFile, err := os.OpenFile("/tmp/anthropic-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - resLogger := log.New(resLogFile, "", log.LstdFlags) - - opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - if reqDump, err := httputil.DumpRequest(req, true); err == nil { - reqLogger.Printf("[req] [%s] %s", id, reqDump) - } - - resp, err := next(req) - if err != nil { - resLogger.Printf("[res] [%s] Error: %v", id, err) - return resp, err - } - - if respDump, err := httputil.DumpResponse(resp, true); err == nil { - resLogger.Printf("[res] [%s] %s", id, respDump) - } - - return resp, err - })) - } + if cfg.EnableUpstreamLogging() { + if middleware := createLoggingMiddleware("anthropic", id, model); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/provider_openai.go b/provider_openai.go index 297e72f..94b8d08 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -4,9 +4,7 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" - "net/http/httputil" "os" "github.com/google/uuid" @@ -18,7 +16,7 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - cfg ProviderConfig + cfg *ProviderConfig } const ( @@ -27,7 +25,7 @@ const ( routeChatCompletions = "/openai/v1/chat/completions" // https://platform.openai.com/docs/api-reference/chat ) -func NewOpenAIProvider(cfg ProviderConfig) *OpenAIProvider { +func NewOpenAIProvider(cfg *ProviderConfig) *OpenAIProvider { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.openai.com/v1/" } @@ -102,38 +100,14 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) } -func newOpenAIClient(cfg ProviderConfig, id string) openai.Client { +func newOpenAIClient(cfg *ProviderConfig, id, model string) openai.Client { var opts []option.RequestOption opts = append(opts, option.WithAPIKey(cfg.Key)) opts = append(opts, option.WithBaseURL(cfg.BaseURL)) - if cfg.EnableUpstreamLogging { - reqLogFile, err := os.OpenFile("/tmp/openai-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - reqLogger := log.New(reqLogFile, "", log.LstdFlags) - - resLogFile, err := os.OpenFile("/tmp/openai-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) - if err == nil { - resLogger := log.New(resLogFile, "", log.LstdFlags) - - opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { - if reqDump, err := httputil.DumpRequest(req, true); err == nil { - reqLogger.Printf("[req] [%s] %s", id, reqDump) - } - - resp, err := next(req) - if err != nil { - resLogger.Printf("[res] [%s] Error: %v", id, err) - return resp, err - } - - if respDump, err := httputil.DumpResponse(resp, true); err == nil { - resLogger.Printf("[res] [%s] %s", id, respDump) - } - - return resp, err - })) - } + if cfg.EnableUpstreamLogging() { + if middleware := createLoggingMiddleware("openai", id, model); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/request_logger.go b/request_logger.go new file mode 100644 index 0000000..16ab999 --- /dev/null +++ b/request_logger.go @@ -0,0 +1,74 @@ +package aibridge + +import ( + "log" + "net/http" + "net/http/httputil" + "os" +) + +// logUpstreamRequest logs an HTTP request with the given ID and model name. +// The prefix format is: [req] [id] [model] +func logUpstreamRequest(logger *log.Logger, id, model string, req *http.Request) { + if logger == nil { + return + } + + if reqDump, err := httputil.DumpRequest(req, true); err == nil { + logger.Printf("[req] [%s] [%s] %s", id, model, reqDump) + } +} + +// logUpstreamResponse logs an HTTP response with the given ID and model name. +// The prefix format is: [res] [id] [model] +func logUpstreamResponse(logger *log.Logger, id, model string, resp *http.Response) { + if logger == nil { + return + } + + if respDump, err := httputil.DumpResponse(resp, true); err == nil { + logger.Printf("[res] [%s] [%s] %s", id, model, respDump) + } +} + +// logUpstreamError logs an error that occurred during request/response processing. +// The prefix format is: [res] [id] [model] Error: +func logUpstreamError(logger *log.Logger, id, model string, err error) { + if logger == nil { + return + } + + logger.Printf("[res] [%s] [%s] Error: %v", id, model, err) +} + +// createLoggingMiddleware creates a middleware function that logs requests and responses. +// Returns nil if logging setup fails. +func createLoggingMiddleware(provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { + reqLogFile, err := os.OpenFile("/tmp/"+provider+"-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + return nil + } + + resLogFile, err := os.OpenFile("/tmp/"+provider+"-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + reqLogFile.Close() + return nil + } + + reqLogger := log.New(reqLogFile, "", log.LstdFlags) + resLogger := log.New(resLogFile, "", log.LstdFlags) + + return func(req *http.Request, next func(*http.Request) (*http.Response, error)) (*http.Response, error) { + logUpstreamRequest(reqLogger, id, model, req) + + resp, err := next(req) + if err != nil { + logUpstreamError(resLogger, id, model, err) + return resp, err + } + + logUpstreamResponse(resLogger, id, model, resp) + + return resp, err + } +} From 0452dc848c991d8646ebd8a9276c3838fad8c25b Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 22 Oct 2025 14:31:25 +0200 Subject: [PATCH 3/5] chore: clean up Signed-off-by: Danny Kopping --- intercept_anthropic_messages_blocking.go | 2 +- intercept_anthropic_messages_streaming.go | 2 +- intercept_openai_chat_blocking.go | 2 +- intercept_openai_chat_streaming.go | 2 +- provider_anthropic.go | 5 ++- provider_openai.go | 5 ++- request_logger.go | 47 +++++++++++++++++++++-- 7 files changed, 53 insertions(+), 12 deletions(-) diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index fb4907d..ccfff75 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -57,7 +57,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client := newAnthropicClient(i.cfg, i.id.String(), i.Model(), opts...) + client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model(), opts...) messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index 5cc368a..4437a44 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -94,7 +94,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - client := newAnthropicClient(i.cfg, i.id.String(), i.Model()) + client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model()) messages := i.req.MessageNewParams // Accumulate usage across the entire streaming interaction (including tool reinvocations). diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index bdde017..f7321b9 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.cfg, i.id.String(), i.Model()) + client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 1146fe2..b134930 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -64,7 +64,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.cfg, i.id.String(), i.Model()) + client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/provider_anthropic.go b/provider_anthropic.go index b44ae59..9d883fc 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -8,6 +8,7 @@ import ( "net/http" "os" + "cdr.dev/slog" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/shared" @@ -98,12 +99,12 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), p.cfg.Key) } -func newAnthropicClient(cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client { +func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client { opts = append(opts, option.WithAPIKey(cfg.Key)) opts = append(opts, option.WithBaseURL(cfg.BaseURL)) if cfg.EnableUpstreamLogging() { - if middleware := createLoggingMiddleware("anthropic", id, model); middleware != nil { + if middleware := createLoggingMiddleware(logger, "anthropic", id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/provider_openai.go b/provider_openai.go index 94b8d08..ff66a7e 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,6 +7,7 @@ import ( "net/http" "os" + "cdr.dev/slog" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -100,13 +101,13 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) } -func newOpenAIClient(cfg *ProviderConfig, id, model string) openai.Client { +func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string) openai.Client { var opts []option.RequestOption opts = append(opts, option.WithAPIKey(cfg.Key)) opts = append(opts, option.WithBaseURL(cfg.BaseURL)) if cfg.EnableUpstreamLogging() { - if middleware := createLoggingMiddleware("openai", id, model); middleware != nil { + if middleware := createLoggingMiddleware(logger, "openai", id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/request_logger.go b/request_logger.go index 16ab999..3670a0a 100644 --- a/request_logger.go +++ b/request_logger.go @@ -1,12 +1,35 @@ package aibridge import ( + "context" + "fmt" "log" "net/http" "net/http/httputil" "os" + "path/filepath" + "strings" + + "cdr.dev/slog" ) +// sanitizeModelName makes a model name safe for use as a directory name. +// Replaces filesystem-unsafe characters with underscores. +func sanitizeModelName(model string) string { + replacer := strings.NewReplacer( + "/", "_", + "\\", "_", + ":", "_", + "*", "_", + "?", "_", + "\"", "_", + "<", "_", + ">", "_", + "|", "_", + ) + return replacer.Replace(model) +} + // logUpstreamRequest logs an HTTP request with the given ID and model name. // The prefix format is: [req] [id] [model] func logUpstreamRequest(logger *log.Logger, id, model string, req *http.Request) { @@ -42,16 +65,32 @@ func logUpstreamError(logger *log.Logger, id, model string, err error) { } // createLoggingMiddleware creates a middleware function that logs requests and responses. -// Returns nil if logging setup fails. -func createLoggingMiddleware(provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { - reqLogFile, err := os.OpenFile("/tmp/"+provider+"-req.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) +// Logs are written to $TMPDIR/$provider/$model/$id.req.log and $TMPDIR/$provider/$model/$id.res.log +// Returns nil if logging setup fails, logging errors via the provided logger. +func createLoggingMiddleware(logger slog.Logger, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { + ctx := context.Background() + safeModel := sanitizeModelName(model) + logDir := filepath.Join(os.TempDir(), provider, safeModel) + + // Create the directory structure if it doesn't exist + if err := os.MkdirAll(logDir, 0755); err != nil { + logger.Warn(ctx, "failed to create log directory", slog.Error(err), slog.F("dir", logDir)) + return nil + } + + reqLogPath := filepath.Join(logDir, fmt.Sprintf("%s.req.log", id)) + resLogPath := filepath.Join(logDir, fmt.Sprintf("%s.res.log", id)) + + reqLogFile, err := os.OpenFile(reqLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { + logger.Warn(ctx, "failed to open request log file", slog.Error(err), slog.F("path", reqLogPath)) return nil } - resLogFile, err := os.OpenFile("/tmp/"+provider+"-res.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + resLogFile, err := os.OpenFile(resLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { reqLogFile.Close() + logger.Warn(ctx, "failed to open response log file", slog.Error(err), slog.F("path", resLogPath)) return nil } From a4c07afcf42be44e65b361ff780324cba82c0be3 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 22 Oct 2025 16:14:19 +0200 Subject: [PATCH 4/5] chore: refactoring, tests Signed-off-by: Danny Kopping --- bridge_integration_test.go | 4 +- config.go | 6 +- provider_anthropic.go | 2 +- provider_openai.go | 2 +- request_logger.go | 38 ++++---- request_logger_test.go | 175 +++++++++++++++++++++++++++++++++++++ 6 files changed, 207 insertions(+), 20 deletions(-) create mode 100644 request_logger_test.go diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b0de5f8..a1bbb0c 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1152,8 +1152,8 @@ func createMockMCPSrv(t *testing.T) http.Handler { return server.NewStreamableHTTPServer(s) } -func cfg(url, key string) aibridge.ProviderConfig { - return aibridge.ProviderConfig{ +func cfg(url, key string) *aibridge.ProviderConfig { + return &aibridge.ProviderConfig{ BaseURL: url, Key: key, } diff --git a/config.go b/config.go index 32be6fa..da14e61 100644 --- a/config.go +++ b/config.go @@ -4,7 +4,11 @@ import "sync/atomic" type ProviderConfig struct { BaseURL, Key string - // EnableUpstreamLogging enables logging of upstream API requests and responses to /tmp/$provider.log + // UpstreamLoggingDir specifies the base directory for upstream logging. + // If empty, os.TempDir() will be used. + // Logs are written to $UpstreamLoggingDir/$provider/$model/$id.{req,res}.log + UpstreamLoggingDir string + // enableUpstreamLogging enables logging of upstream API requests and responses. enableUpstreamLogging atomic.Bool } diff --git a/provider_anthropic.go b/provider_anthropic.go index 9d883fc..7d590b1 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -104,7 +104,7 @@ func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model strin opts = append(opts, option.WithBaseURL(cfg.BaseURL)) if cfg.EnableUpstreamLogging() { - if middleware := createLoggingMiddleware(logger, "anthropic", id, model); middleware != nil { + if middleware := createLoggingMiddleware(logger, cfg, ProviderAnthropic, id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/provider_openai.go b/provider_openai.go index ff66a7e..3952e7f 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -107,7 +107,7 @@ func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string) opts = append(opts, option.WithBaseURL(cfg.BaseURL)) if cfg.EnableUpstreamLogging() { - if middleware := createLoggingMiddleware(logger, "openai", id, model); middleware != nil { + if middleware := createLoggingMiddleware(logger, cfg, ProviderOpenAI, id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } } diff --git a/request_logger.go b/request_logger.go index 3670a0a..168721f 100644 --- a/request_logger.go +++ b/request_logger.go @@ -13,19 +13,20 @@ import ( "cdr.dev/slog" ) -// sanitizeModelName makes a model name safe for use as a directory name. +// SanitizeModelName makes a model name safe for use as a directory name. // Replaces filesystem-unsafe characters with underscores. -func sanitizeModelName(model string) string { +func SanitizeModelName(model string) string { + repl := "_" replacer := strings.NewReplacer( - "/", "_", - "\\", "_", - ":", "_", - "*", "_", - "?", "_", - "\"", "_", - "<", "_", - ">", "_", - "|", "_", + "/", repl, + "\\", repl, + ":", repl, + "*", repl, + "?", repl, + "\"", repl, + "<", repl, + ">", repl, + "|", repl, ) return replacer.Replace(model) } @@ -65,12 +66,19 @@ func logUpstreamError(logger *log.Logger, id, model string, err error) { } // createLoggingMiddleware creates a middleware function that logs requests and responses. -// Logs are written to $TMPDIR/$provider/$model/$id.req.log and $TMPDIR/$provider/$model/$id.res.log +// Logs are written to $baseDir/$provider/$model/$id.req.log and $baseDir/$provider/$model/$id.res.log +// where baseDir is from cfg.UpstreamLoggingDir or os.TempDir() if not specified. // Returns nil if logging setup fails, logging errors via the provided logger. -func createLoggingMiddleware(logger slog.Logger, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { +func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { ctx := context.Background() - safeModel := sanitizeModelName(model) - logDir := filepath.Join(os.TempDir(), provider, safeModel) + safeModel := SanitizeModelName(model) + + baseDir := cfg.UpstreamLoggingDir + if baseDir == "" { + baseDir = os.TempDir() + } + + logDir := filepath.Join(baseDir, provider, safeModel) // Create the directory structure if it doesn't exist if err := os.MkdirAll(logDir, 0755); err != nil { diff --git a/request_logger_test.go b/request_logger_test.go new file mode 100644 index 0000000..fa78509 --- /dev/null +++ b/request_logger_test.go @@ -0,0 +1,175 @@ +package aibridge_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/stretchr/testify/require" + "golang.org/x/tools/txtar" +) + +func TestRequestLogging(t *testing.T) { + t.Parallel() + + testCases := []struct { + provider string + fixture []byte + route string + createProvider func(*aibridge.ProviderConfig) aibridge.Provider + }{ + { + provider: aibridge.ProviderAnthropic, + fixture: antSimple, + route: "/anthropic/v1/messages", + createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider { + return aibridge.NewAnthropicProvider(cfg) + }, + }, + { + provider: aibridge.ProviderOpenAI, + fixture: oaiSimple, + route: "/openai/v1/chat/completions", + createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider { + return aibridge.NewOpenAIProvider(cfg) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.provider, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Use a temp dir for this test + tmpDir := t.TempDir() + + // Parse fixture + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + // Create mock server + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + t.Cleanup(srv.Close) + + cfg := aibridge.ProviderConfig{ + BaseURL: srv.URL, + Key: apiKey, + UpstreamLoggingDir: tmpDir, + } + cfg.SetEnableUpstreamLogging(true) + + provider := tc.createProvider(&cfg) + client := &mockRecorderClient{} + mcpProxy := mcp.NewServerProxyManager(nil) + + bridge, err := aibridge.NewRequestBridge(context.Background(), []aibridge.Provider{provider}, logger, client, mcpProxy) + require.NoError(t, err) + t.Cleanup(func() { + _ = bridge.Shutdown(context.Background()) + }) + + // Make a request + req, err := http.NewRequestWithContext(t.Context(), "POST", tc.route, strings.NewReader(string(files[fixtureRequest]))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(aibridge.AsActor(req.Context(), userID, nil)) + rec := httptest.NewRecorder() + bridge.ServeHTTP(rec, req) + require.Equal(t, 200, rec.Code) + + // Check that log files were created + // Parse the request to get the model name + var reqData map[string]any + require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqData)) + model := reqData["model"].(string) + + logDir := filepath.Join(tmpDir, tc.provider, model) + entries, err := os.ReadDir(logDir) + require.NoError(t, err, "log directory should exist") + require.NotEmpty(t, entries, "log directory should contain files") + + // Should have at least one .req.log and one .res.log file + var hasReq, hasRes bool + for _, entry := range entries { + name := entry.Name() + if strings.HasSuffix(name, ".req.log") { + hasReq = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "request log should have content") + require.Contains(t, string(content), "POST") + } else if strings.HasSuffix(name, ".res.log") { + hasRes = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "response log should have content") + require.Contains(t, string(content), "200") + } + } + require.True(t, hasReq, "should have request log file") + require.True(t, hasRes, "should have response log file") + }) + } +} + +func TestSanitizeModelName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple model", + input: "gpt-4o", + expected: "gpt-4o", + }, + { + name: "model with slash", + input: "gpt-4o/mini", + expected: "gpt-4o_mini", + }, + { + name: "model with colon", + input: "o1:2024-12-17", + expected: "o1_2024-12-17", + }, + { + name: "model with backslash", + input: "model\\name", + expected: "model_name", + }, + { + name: "model with multiple special chars", + input: "model:name/version?", + expected: "model_name_version_", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := aibridge.SanitizeModelName(tt.input) + require.Equal(t, tt.expected, result) + }) + } +} From 7a3c4bf0a1964a85cf41e77c09c887e1d5a21afa Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 22 Oct 2025 20:12:23 +0200 Subject: [PATCH 5/5] chore: ensure all configs are concurrency-safe Signed-off-by: Danny Kopping --- bridge_integration_test.go | 55 ++++++++++++++++++++++++++++---------- config.go | 55 +++++++++++++++++++++++++++++++------- go.mod | 1 + go.sum | 2 ++ provider_anthropic.go | 26 ++++++++++-------- provider_openai.go | 26 ++++++++++-------- request_logger.go | 2 +- request_logger_test.go | 15 +++++------ 8 files changed, 127 insertions(+), 55 deletions(-) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index a1bbb0c..b8f0435 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -126,7 +126,9 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey)) + require.NoError(t, err) + b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil)) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -229,7 +231,9 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey)) + require.NoError(t, err) + b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil)) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -294,7 +298,11 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -332,7 +340,11 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -470,7 +482,8 @@ func TestFallthrough(t *testing.T) { fixture: antFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + require.NoError(t, err) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -481,7 +494,8 @@ func TestFallthrough(t *testing.T) { fixture: oaiFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + require.NoError(t, err) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -586,7 +600,11 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -667,7 +685,11 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -851,7 +873,11 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) }, responseHandlerFn: func(streaming bool, resp *http.Response) { if streaming { @@ -876,7 +902,11 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) }, responseHandlerFn: func(streaming bool, resp *http.Response) { if streaming { @@ -1153,8 +1183,5 @@ func createMockMCPSrv(t *testing.T) http.Handler { } func cfg(url, key string) *aibridge.ProviderConfig { - return &aibridge.ProviderConfig{ - BaseURL: url, - Key: key, - } + return aibridge.NewProviderConfig(url, key, "") } diff --git a/config.go b/config.go index da14e61..60c08aa 100644 --- a/config.go +++ b/config.go @@ -1,23 +1,60 @@ package aibridge -import "sync/atomic" +import "go.uber.org/atomic" type ProviderConfig struct { - BaseURL, Key string - // UpstreamLoggingDir specifies the base directory for upstream logging. - // If empty, os.TempDir() will be used. - // Logs are written to $UpstreamLoggingDir/$provider/$model/$id.{req,res}.log - UpstreamLoggingDir string - // enableUpstreamLogging enables logging of upstream API requests and responses. + baseURL, key atomic.String + upstreamLoggingDir atomic.String enableUpstreamLogging atomic.Bool } +// NewProviderConfig creates a new ProviderConfig with the given values. +func NewProviderConfig(baseURL, key, upstreamLoggingDir string) *ProviderConfig { + cfg := &ProviderConfig{} + cfg.baseURL.Store(baseURL) + cfg.key.Store(key) + cfg.upstreamLoggingDir.Store(upstreamLoggingDir) + return cfg +} + +// BaseURL returns the base URL for the provider. +func (c *ProviderConfig) BaseURL() string { + return c.baseURL.Load() +} + +// SetBaseURL sets the base URL for the provider. +func (c *ProviderConfig) SetBaseURL(baseURL string) { + c.baseURL.Store(baseURL) +} + +// Key returns the API key for the provider. +func (c *ProviderConfig) Key() string { + return c.key.Load() +} + +// SetKey sets the API key for the provider. +func (c *ProviderConfig) SetKey(key string) { + c.key.Store(key) +} + +// UpstreamLoggingDir returns the base directory for upstream logging. +// If empty, the OS's tempdir will be used. +// Logs are written to $UpstreamLoggingDir/$provider/$model/$interceptionID.{req,res}.log +func (c *ProviderConfig) UpstreamLoggingDir() string { + return c.upstreamLoggingDir.Load() +} + +// SetUpstreamLoggingDir sets the base directory for upstream logging. +func (c *ProviderConfig) SetUpstreamLoggingDir(dir string) { + c.upstreamLoggingDir.Store(dir) +} + // SetEnableUpstreamLogging enables or disables upstream logging at runtime. func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) { c.enableUpstreamLogging.Store(enabled) } -// EnableUpstreamLogging returns whether upstream logging is currently enabled. -func (c *ProviderConfig) EnableUpstreamLogging() bool { +// IsUpstreamLoggingEnabled returns whether upstream logging is currently enabled. +func (c *ProviderConfig) IsUpstreamLoggingEnabled() bool { return c.enableUpstreamLogging.Load() } diff --git a/go.mod b/go.mod index 6c241fe..66c925f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/sjson v1.2.5 + go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b diff --git a/go.sum b/go.sum index 7815785..119563d 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,8 @@ go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiM go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/provider_anthropic.go b/provider_anthropic.go index 7d590b1..e67eeb7 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -29,17 +29,21 @@ const ( routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages ) -func NewAnthropicProvider(cfg *ProviderConfig) *AnthropicProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.anthropic.com/" +func NewAnthropicProvider(cfg *ProviderConfig) (*AnthropicProvider, error) { + if cfg == nil { + return nil, fmt.Errorf("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("ANTHROPIC_API_KEY") + + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.anthropic.com/") + } + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("ANTHROPIC_API_KEY")) } return &AnthropicProvider{ cfg: cfg, - } + }, nil } func (p *AnthropicProvider) Name() string { @@ -84,7 +88,7 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req } func (p *AnthropicProvider) BaseURL() string { - return p.cfg.BaseURL + return p.cfg.BaseURL() } func (p *AnthropicProvider) AuthHeader() string { @@ -96,14 +100,14 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), p.cfg.Key) + headers.Set(p.AuthHeader(), p.cfg.Key()) } func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client { - opts = append(opts, option.WithAPIKey(cfg.Key)) - opts = append(opts, option.WithBaseURL(cfg.BaseURL)) + opts = append(opts, option.WithAPIKey(cfg.Key())) + opts = append(opts, option.WithBaseURL(cfg.BaseURL())) - if cfg.EnableUpstreamLogging() { + if cfg.IsUpstreamLoggingEnabled() { if middleware := createLoggingMiddleware(logger, cfg, ProviderAnthropic, id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } diff --git a/provider_openai.go b/provider_openai.go index 3952e7f..8b645ca 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -26,18 +26,22 @@ const ( routeChatCompletions = "/openai/v1/chat/completions" // https://platform.openai.com/docs/api-reference/chat ) -func NewOpenAIProvider(cfg *ProviderConfig) *OpenAIProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.openai.com/v1/" +func NewOpenAIProvider(cfg *ProviderConfig) (*OpenAIProvider, error) { + if cfg == nil { + return nil, fmt.Errorf("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("OPENAI_API_KEY") + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.openai.com/v1/") + } + + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("OPENAI_API_KEY")) } return &OpenAIProvider{ cfg: cfg, - } + }, nil } func (p *OpenAIProvider) Name() string { @@ -86,7 +90,7 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } func (p *OpenAIProvider) BaseURL() string { - return p.cfg.BaseURL + return p.cfg.BaseURL() } func (p *OpenAIProvider) AuthHeader() string { @@ -98,15 +102,15 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key()) } func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string) openai.Client { var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(cfg.Key)) - opts = append(opts, option.WithBaseURL(cfg.BaseURL)) + opts = append(opts, option.WithAPIKey(cfg.Key())) + opts = append(opts, option.WithBaseURL(cfg.BaseURL())) - if cfg.EnableUpstreamLogging() { + if cfg.IsUpstreamLoggingEnabled() { if middleware := createLoggingMiddleware(logger, cfg, ProviderOpenAI, id, model); middleware != nil { opts = append(opts, option.WithMiddleware(middleware)) } diff --git a/request_logger.go b/request_logger.go index 168721f..0a71c4a 100644 --- a/request_logger.go +++ b/request_logger.go @@ -73,7 +73,7 @@ func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider, ctx := context.Background() safeModel := SanitizeModelName(model) - baseDir := cfg.UpstreamLoggingDir + baseDir := cfg.UpstreamLoggingDir() if baseDir == "" { baseDir = os.TempDir() } diff --git a/request_logger_test.go b/request_logger_test.go index fa78509..d451435 100644 --- a/request_logger_test.go +++ b/request_logger_test.go @@ -25,13 +25,13 @@ func TestRequestLogging(t *testing.T) { provider string fixture []byte route string - createProvider func(*aibridge.ProviderConfig) aibridge.Provider + createProvider func(*aibridge.ProviderConfig) (aibridge.Provider, error) }{ { provider: aibridge.ProviderAnthropic, fixture: antSimple, route: "/anthropic/v1/messages", - createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider { + createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) { return aibridge.NewAnthropicProvider(cfg) }, }, @@ -39,7 +39,7 @@ func TestRequestLogging(t *testing.T) { provider: aibridge.ProviderOpenAI, fixture: oaiSimple, route: "/openai/v1/chat/completions", - createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider { + createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) { return aibridge.NewOpenAIProvider(cfg) }, }, @@ -68,14 +68,11 @@ func TestRequestLogging(t *testing.T) { })) t.Cleanup(srv.Close) - cfg := aibridge.ProviderConfig{ - BaseURL: srv.URL, - Key: apiKey, - UpstreamLoggingDir: tmpDir, - } + cfg := aibridge.NewProviderConfig(srv.URL, apiKey, tmpDir) cfg.SetEnableUpstreamLogging(true) - provider := tc.createProvider(&cfg) + provider, err := tc.createProvider(cfg) + require.NoError(t, err) client := &mockRecorderClient{} mcpProxy := mcp.NewServerProxyManager(nil)