From b6ab1af817dbcb081463f2ac46531cebe64a6cde Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:13:56 +0000 Subject: [PATCH 1/8] test(proxy): add interception path coverage for ReverseProxy refactor Add 15 tests exercising handleConnectWithInterception through the full CONNECT+TLS interception path. These form the behavioral baseline for the upcoming ReverseProxy refactor that adds WebSocket support. Coverage: credential injection, canonical log fields, multi-request keepalive, network policy denial, transport errors, extra/remove headers, request body forwarding, large responses, status codes, response headers, X-Request-Id injection, Proxy-Authorization stripping, and HTTP methods. --- ...026-04-22-websocket-reverseproxy-design.md | 134 +++++ proxy/intercept_test.go | 472 ++++++++++++++++++ 2 files changed, 606 insertions(+) create mode 100644 docs/2026-04-22-websocket-reverseproxy-design.md create mode 100644 proxy/intercept_test.go diff --git a/docs/2026-04-22-websocket-reverseproxy-design.md b/docs/2026-04-22-websocket-reverseproxy-design.md new file mode 100644 index 0000000..5917bb4 --- /dev/null +++ b/docs/2026-04-22-websocket-reverseproxy-design.md @@ -0,0 +1,134 @@ +# WebSocket Support via ReverseProxy Refactor + +**Date:** 2026-04-22 +**Status:** Approved +**Scope:** `proxy/proxy.go` — `handleConnectWithInterception` + +## Problem + +Gatekeeper's TLS interception path manually reads HTTP requests in a loop (`http.ReadRequest` → `transport.RoundTrip` → `resp.Write`). After a WebSocket upgrade (HTTP 101 Switching Protocols), the client sends binary WebSocket frames which `http.ReadRequest` cannot parse, causing `"malformed HTTP request"` errors and connection drops. + +## Solution + +Replace the manual request loop in `handleConnectWithInterception` with an `http.Server` serving on the client-side TLS connection, using `httputil.ReverseProxy` as the handler. Go 1.25's `ReverseProxy` natively handles WebSocket upgrades — it detects `Upgrade` headers, preserves them through hop-by-hop removal, hijacks both sides on a `101` response, and does bidirectional `io.Copy`. + +## Architecture + +``` +Client ←TLS→ http.Server(tlsClientConn) → ReverseProxy → upstream +``` + +### Flow + +1. CONNECT arrives, proxy hijacks, sends `200 Connection Established` (unchanged) +2. TLS handshake with client using generated cert (unchanged) +3. **New:** Create a single-connection `http.Server` with `httputil.ReverseProxy` as handler +4. `http.Server.Serve()` manages the request loop (replaces manual `for` + `http.ReadRequest`) +5. For normal HTTP: `ReverseProxy` forwards via `Transport.RoundTrip`, credential injection in `Rewrite` +6. For WebSocket: `ReverseProxy` detects `101`, hijacks, bidirectional copy — no custom code needed + +### Feature Mapping + +Every feature in the current manual loop maps to a `ReverseProxy` hook: + +| Feature | Current location | New location | +|---|---|---| +| Network policy check | Loop body | Wrapping handler (before ReverseProxy) | +| Keep HTTP policy | Loop body | Wrapping handler (before ReverseProxy) | +| Credential injection (`injectCredentials`) | Loop body | `Rewrite` on `ProxyRequest.Out` | +| MCP credential injection | Loop body | `Rewrite` | +| Extra headers / remove headers | Loop body | `Rewrite` | +| Token substitution | Loop body | `Rewrite` | +| Request ID generation | Loop body | `Rewrite` | +| Host gateway IP rewrite | Loop body, modifies dial target | `Rewrite` (rewrite URL host) or custom `Transport.DialContext` | +| Proxy-Authorization stripping | Loop body | `Rewrite` (read from `ProxyRequest.In` before hop-by-hop removal) | +| Credential resolver (token-exchange) | Loop body | `Rewrite` (read subject from `In.Header`, resolve, set on `Out`) | +| LLM gateway policy | Loop body, post-response | `ModifyResponse` | +| Response transformers | Loop body, post-response | `ModifyResponse` | +| Body capture for logging | Loop body | `ModifyResponse` (response) and `Rewrite` (request) | +| Canonical log line | Loop body | `ModifyResponse` + `ErrorHandler` | +| OTel span/metrics | Loop body via callbacks | Wrapping handler or `ModifyResponse` | +| Transport error → 502 | Loop body | `ErrorHandler` | +| WebSocket upgrade | **Not supported** | Built-in `ReverseProxy.handleUpgradeResponse` | + +### Key Design Decisions + +**Proxy-Authorization before hop-by-hop removal:** `ReverseProxy` strips hop-by-hop headers (including `Proxy-Authorization`) before calling `Rewrite`. For `subject_from: proxy-auth` token exchange, the subject identity must be extracted from `ProxyRequest.In` (which preserves original headers) rather than `ProxyRequest.Out`. + +**Single-connection http.Server:** The `http.Server` serves on a `net.Listener` wrapping the single TLS connection. When the connection closes, `Serve` returns. This replaces the manual `for` loop and gets HTTP keepalive, pipelining, and protocol upgrade handling from the stdlib. + +**Per-connection transport:** The `http.Transport` is created per-CONNECT connection (same as today). `ForceAttemptHTTP2` remains disabled — the intercepted connection reads HTTP/1.1. + +**No behavioral changes:** All external APIs (`Proxy`, `RunContextData`, config) remain identical. This is purely an internal refactor of one function. + +### Handling Policy Denials in Rewrite + +The current loop writes error responses (407, 403, 502) directly to the TLS connection and continues the loop. With `ReverseProxy`, the `Rewrite` function cannot write responses directly. Two options: + +**Option A — Wrapping handler:** A handler that runs policy checks before delegating to `ReverseProxy`. On denial, it writes the error response itself and does not call `ReverseProxy.ServeHTTP`. This is the cleanest approach. + +**Option B — Rewrite sets a sentinel, ErrorHandler acts on it.** `Rewrite` stores a denial in the request context, `ModifyResponse` or a custom `RoundTripper` wrapper checks for it. More complex, less readable. + +**Decision:** Option A. The wrapping handler pattern is idiomatic and keeps policy logic separate from forwarding logic. + +```go +func (p *Proxy) interceptHandler(host string, rc *RunContextData, transport *http.Transport) http.Handler { + rp := &httputil.ReverseProxy{ + Rewrite: p.rewriteIntercepted(host, rc), + Transport: transport, + ModifyResponse: p.modifyInterceptedResponse(host, rc), + ErrorHandler: p.interceptErrorHandler(host, rc), + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Network policy, Keep HTTP policy checks here + // On denial: write error response, return + // On allow: rp.ServeHTTP(w, r) + }) +} +``` + +## Testing Strategy + +Tests are written first (TDD) against the current code to establish behavioral baselines, then the refactor must keep them passing. + +### New tests to add before refactor + +1. **Normal HTTPS through interception** — credential injection verified on upstream request +2. **WebSocket upgrade through interception** — upgrade succeeds, bidirectional frame exchange works (will fail against current code, pass after refactor) +3. **Multi-request keepalive** — multiple requests over single CONNECT tunnel +4. **Network policy denial on inner request** — 407 returned, connection stays alive +5. **Transport error** — unreachable upstream, 502 returned, canonical log line emitted +6. **Credential resolver via CONNECT** — token-exchange with `subject_from: proxy-auth` +7. **Host gateway through interception** — gateway hostname rewritten to actual IP + +### Existing tests that must keep passing + +All tests in `proxy/proxy_test.go`, particularly: +- `TestProxy_CanonicalLogLine_ConnectTransportError` +- `TestProxy_CanonicalLogLine_ConnectBlocked` +- All credential injection, policy, and logging tests + +## Implementation Plan + +### Phase 1: Test baseline (TDD) +Write the new tests listed above against the current code. All should pass except the WebSocket test. + +### Phase 2: Extract helpers +Extract the inline policy/credential/logging logic from the current loop into named methods that can be called from both the old loop and the new handler. This is a refactor-only step — no behavioral changes. + +### Phase 3: Build the ReverseProxy handler +Implement `interceptHandler` with `Rewrite`, `ModifyResponse`, `ErrorHandler`, and the wrapping handler for policy checks. Wire it into `handleConnectWithInterception` replacing the manual loop. + +### Phase 4: WebSocket test passes +The WebSocket upgrade test should now pass with zero additional code. + +### Phase 5: Verify and clean up +Run full test suite, remove dead code from the old loop, verify OTel instrumentation. + +## Out of Scope + +- Changing the non-interception tunnel path (`handleConnectTunnel`) +- Changing the HTTP relay path (`handleHTTP`) +- Changing the MCP relay handler +- Config schema changes +- New config options for WebSocket-specific behavior diff --git a/proxy/intercept_test.go b/proxy/intercept_test.go new file mode 100644 index 0000000..213c269 --- /dev/null +++ b/proxy/intercept_test.go @@ -0,0 +1,472 @@ +package proxy + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" +) + +// interceptTestSetup creates a proxy with TLS interception enabled and an HTTPS +// backend server. The proxy is configured to trust the backend's TLS cert and +// the returned client trusts the proxy's interception CA. +type interceptTestSetup struct { + Proxy *Proxy + ProxyServer *httptest.Server + Backend *httptest.Server + Client *http.Client + CA *CA + BackendHost string // hostname only (e.g., 127.0.0.1) — for credential matching + BackendHostPort string // host:port (e.g., 127.0.0.1:12345) — for extra/remove header matching +} + +func newInterceptTestSetup(t *testing.T, backendHandler http.Handler) *interceptTestSetup { + t.Helper() + + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + backend := httptest.NewTLSServer(backendHandler) + + // Build a CA pool that trusts the backend's TLS cert. + upstreamCAs := x509.NewCertPool() + upstreamCAs.AddCert(backend.Certificate()) + + p := NewProxy() + p.SetCA(ca) + p.SetUpstreamCAs(upstreamCAs) + + proxyServer := httptest.NewServer(p) + + // Client trusts the interception CA and routes through the proxy. + clientCAs := x509.NewCertPool() + clientCAs.AppendCertsFromPEM(ca.certPEM) + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(mustParseURL(proxyServer.URL)), + TLSClientConfig: &tls.Config{RootCAs: clientCAs}, + }, + } + + backendHost := mustParseURL(backend.URL).Host // host:port for extra header matching (uses r.Host) + backendHostname := mustParseURL(backend.URL).Hostname() // hostname only for credential matching + + t.Cleanup(func() { + proxyServer.Close() + backend.Close() + }) + + return &interceptTestSetup{ + Proxy: p, + ProxyServer: proxyServer, + Backend: backend, + Client: client, + CA: ca, + BackendHost: backendHostname, + BackendHostPort: backendHost, + } +} + +func TestIntercept_CredentialInjection(t *testing.T) { + var receivedAuth string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Write([]byte("ok")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer test-token-123", "test-grant") + + resp, err := setup.Client.Get(setup.Backend.URL + "/api/data") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedAuth != "Bearer test-token-123" { + t.Errorf("Authorization = %q, want %q", receivedAuth, "Bearer test-token-123") + } +} + +func TestIntercept_CredentialInjectionCanonicalLog(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer granted-token", "my-grant") + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + resp, err := setup.Client.Get(setup.Backend.URL + "/resource") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if !logged.AuthInjected { + t.Error("expected AuthInjected=true") + } + if len(logged.Grants) == 0 || logged.Grants[0] != "my-grant" { + t.Errorf("Grants = %v, want [my-grant]", logged.Grants) + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } + if logged.RequestID == "" { + t.Error("expected non-empty RequestID") + } +} + +func TestIntercept_MultiRequestKeepalive(t *testing.T) { + var requestCount int + var mu sync.Mutex + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + requestCount++ + mu.Unlock() + w.Write([]byte("ok")) + })) + + for i := 0; i < 5; i++ { + resp, err := setup.Client.Get(setup.Backend.URL + "/ping") + if err != nil { + t.Fatalf("request %d: %v", i, err) + } + resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("request %d: status = %d, want 200", i, resp.StatusCode) + } + } + + mu.Lock() + defer mu.Unlock() + if requestCount != 5 { + t.Errorf("requestCount = %d, want 5", requestCount) + } +} + +func TestIntercept_NetworkPolicyDenial(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("backend should not be reached on denied request") + })) + + // Strict policy with no allows — denies everything at the inner request level. + setup.Proxy.SetNetworkPolicy("strict", nil, nil) + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + // The CONNECT itself will be denied before TLS interception. + resp, err := setup.Client.Get(setup.Backend.URL + "/blocked") + if err == nil { + resp.Body.Close() + // Under strict policy with no allows, CONNECT is denied with 407. + if resp.StatusCode != http.StatusProxyAuthRequired { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusProxyAuthRequired) + } + } + // The client may get a transport error if CONNECT is blocked. + // Either way, the request should be denied. + if !logged.Denied { + t.Error("expected Denied=true in log") + } +} + +func TestIntercept_TransportError502(t *testing.T) { + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + p := NewProxy() + p.SetCA(ca) + + var logged RequestLogData + p.SetLogger(func(data RequestLogData) { + logged = data + }) + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(ca.certPEM) + + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(mustParseURL(proxyServer.URL)), + TLSClientConfig: &tls.Config{RootCAs: caCertPool}, + }, + } + + // Connect to a port nothing listens on. + resp, err := client.Get("https://127.0.0.1:1/nope") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway) + } + if logged.Err == nil { + t.Error("expected error in canonical log") + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } +} + +func TestIntercept_CanonicalLogFields(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte("hello")) + })) + + setup.Proxy.SetCredentialWithGrant(setup.BackendHost, "Authorization", "Bearer tok", "test-grant") + + var logged RequestLogData + setup.Proxy.SetLogger(func(data RequestLogData) { + logged = data + }) + + resp, err := setup.Client.Get(setup.Backend.URL + "/some/path") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if logged.Method != "GET" { + t.Errorf("Method = %q, want GET", logged.Method) + } + backendHostname := mustParseURL(setup.Backend.URL).Hostname() + if logged.Host != backendHostname { + t.Errorf("Host = %q, want %q", logged.Host, backendHostname) + } + if logged.Path != "/some/path" { + t.Errorf("Path = %q, want /some/path", logged.Path) + } + if logged.StatusCode != 200 { + t.Errorf("StatusCode = %d, want 200", logged.StatusCode) + } + if logged.RequestType != "connect" { + t.Errorf("RequestType = %q, want connect", logged.RequestType) + } + if !logged.AuthInjected { + t.Error("expected AuthInjected=true") + } +} + +func TestIntercept_ExtraHeaders(t *testing.T) { + var receivedHeaders http.Header + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.Write([]byte("ok")) + })) + + setup.Proxy.AddExtraHeader(setup.BackendHost, "X-Custom-Header", "custom-value") + + resp, err := setup.Client.Get(setup.Backend.URL + "/test") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedHeaders.Get("X-Custom-Header") != "custom-value" { + t.Errorf("X-Custom-Header = %q, want custom-value", receivedHeaders.Get("X-Custom-Header")) + } +} + +func TestIntercept_RemoveHeaders(t *testing.T) { + var receivedAPIKey string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAPIKey = r.Header.Get("X-Api-Key") + w.Write([]byte("ok")) + })) + + setup.Proxy.RemoveRequestHeader(setup.BackendHost, "X-Api-Key") + + req, _ := http.NewRequest("GET", setup.Backend.URL+"/test", nil) + req.Header.Set("X-Api-Key", "stale-key") + resp, err := setup.Client.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedAPIKey != "" { + t.Errorf("X-Api-Key should be removed, got %q", receivedAPIKey) + } +} + +func TestIntercept_RequestBodyForwarded(t *testing.T) { + var receivedBody string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + receivedBody = string(body) + w.Write([]byte("ok")) + })) + + reqBody := `{"key": "value"}` + resp, err := setup.Client.Post(setup.Backend.URL+"/submit", "application/json", strings.NewReader(reqBody)) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + if receivedBody != reqBody { + t.Errorf("body = %q, want %q", receivedBody, reqBody) + } +} + +func TestIntercept_LargeResponseBody(t *testing.T) { + // 1MB response body to verify streaming works. + largeBody := bytes.Repeat([]byte("x"), 1<<20) + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(largeBody) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/large") + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body: %v", err) + } + if len(body) != len(largeBody) { + t.Errorf("body length = %d, want %d", len(body), len(largeBody)) + } +} + +func TestIntercept_ResponseStatusCodes(t *testing.T) { + codes := []int{200, 201, 204, 301, 400, 404, 500} + + for _, code := range codes { + t.Run(http.StatusText(code), func(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/status") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != code { + t.Errorf("status = %d, want %d", resp.StatusCode, code) + } + }) + } +} + +func TestIntercept_ResponseHeaders(t *testing.T) { + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend-Header", "backend-value") + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{}`)) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/headers") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if resp.Header.Get("X-Backend-Header") != "backend-value" { + t.Errorf("X-Backend-Header = %q, want backend-value", resp.Header.Get("X-Backend-Header")) + } + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("Content-Type = %q, want application/json", resp.Header.Get("Content-Type")) + } +} + +func TestIntercept_XRequestIdInjected(t *testing.T) { + var receivedRequestID string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedRequestID = r.Header.Get("X-Request-Id") + w.Write([]byte("ok")) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/rid") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedRequestID == "" { + t.Error("expected X-Request-Id to be injected") + } + if !strings.HasPrefix(receivedRequestID, "req_") { + t.Errorf("X-Request-Id = %q, expected req_ prefix", receivedRequestID) + } +} + +func TestIntercept_ProxyAuthorizationStripped(t *testing.T) { + var receivedProxyAuth string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedProxyAuth = r.Header.Get("Proxy-Authorization") + w.Write([]byte("ok")) + })) + + resp, err := setup.Client.Get(setup.Backend.URL + "/strip") + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + // Proxy-Authorization should be stripped before forwarding upstream. + if receivedProxyAuth != "" { + t.Errorf("Proxy-Authorization should be stripped, got %q", receivedProxyAuth) + } +} + +func TestIntercept_HTTPMethods(t *testing.T) { + methods := []string{"GET", "POST", "PUT", "DELETE", "PATCH"} + + for _, method := range methods { + t.Run(method, func(t *testing.T) { + var receivedMethod string + setup := newInterceptTestSetup(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + w.Write([]byte("ok")) + })) + + req, _ := http.NewRequest(method, setup.Backend.URL+"/method", nil) + resp, err := setup.Client.Do(req) + if err != nil { + t.Fatalf("request: %v", err) + } + resp.Body.Close() + + if receivedMethod != method { + t.Errorf("method = %q, want %q", receivedMethod, method) + } + }) + } +} From ca5d29e0e327b2f46f6503139d21ede758b83af0 Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:20:29 +0000 Subject: [PATCH 2/8] test(proxy): add WebSocket upgrade test (expected to fail) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The test sends a WebSocket upgrade through the CONNECT+TLS interception path and verifies bidirectional byte exchange after the 101 response. Currently hangs because the manual request loop cannot handle protocol upgrades — resp.Write blocks on the 101 response body (the persistent WebSocket connection). --- proxy/intercept_test.go | 134 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/proxy/intercept_test.go b/proxy/intercept_test.go index 213c269..a3757bc 100644 --- a/proxy/intercept_test.go +++ b/proxy/intercept_test.go @@ -1,10 +1,13 @@ package proxy import ( + "bufio" "bytes" "crypto/tls" "crypto/x509" + "fmt" "io" + "net" "net/http" "net/http/httptest" "strings" @@ -470,3 +473,134 @@ func TestIntercept_HTTPMethods(t *testing.T) { }) } } + +func TestIntercept_WebSocketUpgrade(t *testing.T) { + // Backend that accepts WebSocket upgrades and echoes raw bytes. + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "websocket" { + http.Error(w, "expected websocket upgrade", 400) + return + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + + hijacker, ok := w.(http.Hijacker) + if !ok { + return + } + conn, brw, err := hijacker.Hijack() + if err != nil { + return + } + defer conn.Close() + brw.Flush() + + // Echo: read up to 1024 bytes, write them back. + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + })) + defer backend.Close() + + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + upstreamCAs := x509.NewCertPool() + upstreamCAs.AddCert(backend.Certificate()) + + p := NewProxy() + p.SetCA(ca) + p.SetUpstreamCAs(upstreamCAs) + + backendHost := mustParseURL(backend.URL).Hostname() + p.SetCredential(backendHost, "Bearer ws-token") + + var logged RequestLogData + p.SetLogger(func(data RequestLogData) { + logged = data + }) + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + // Dial through the proxy using raw CONNECT. + proxyURL := mustParseURL(proxyServer.URL) + proxyConn, err := net.Dial("tcp", proxyURL.Host) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer proxyConn.Close() + + // Send CONNECT. + backendAddr := mustParseURL(backend.URL).Host + fmt.Fprintf(proxyConn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", backendAddr, backendAddr) + br := bufio.NewReader(proxyConn) + connectResp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read CONNECT response: %v", err) + } + if connectResp.StatusCode != 200 { + t.Fatalf("CONNECT status = %d, want 200", connectResp.StatusCode) + } + + // TLS handshake with the proxy's interception cert. + clientCAs := x509.NewCertPool() + clientCAs.AppendCertsFromPEM(ca.certPEM) + tlsConn := tls.Client(proxyConn, &tls.Config{ + RootCAs: clientCAs, + ServerName: backendHost, + }) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("TLS handshake: %v", err) + } + defer tlsConn.Close() + + // Send WebSocket upgrade request. + upgradeReq := "GET /ws HTTP/1.1\r\n" + + "Host: " + backendAddr + "\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "\r\n" + if _, err := tlsConn.Write([]byte(upgradeReq)); err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + // Read the 101 response. + tlsBr := bufio.NewReader(tlsConn) + upgradeResp, err := http.ReadResponse(tlsBr, nil) + if err != nil { + t.Fatalf("read upgrade response: %v", err) + } + if upgradeResp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("upgrade status = %d, want 101", upgradeResp.StatusCode) + } + + // Send a raw message through the upgraded connection. + testMsg := []byte("hello websocket") + if _, err := tlsConn.Write(testMsg); err != nil { + t.Fatalf("write message: %v", err) + } + + // Read echoed message back. + echoBuf := make([]byte, len(testMsg)) + if _, err := io.ReadFull(tlsBr, echoBuf); err != nil { + t.Fatalf("read echo: %v", err) + } + if string(echoBuf) != string(testMsg) { + t.Errorf("echo = %q, want %q", echoBuf, testMsg) + } + + // Verify credential was injected on the upgrade request. + if !logged.AuthInjected { + t.Error("expected credential injection on upgrade request") + } +} From 940057ff4fbe433f07d3627bec9468918f75a334 Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:28:06 +0000 Subject: [PATCH 3/8] feat(proxy): replace interception loop with ReverseProxy for WebSocket support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the manual for { http.ReadRequest → transport.RoundTrip → resp.Write } loop in handleConnectWithInterception with http.Server + httputil.ReverseProxy. ReverseProxy natively handles WebSocket upgrades (101 Switching Protocols) by hijacking both sides and doing bidirectional io.Copy via its built-in switchProtocolCopier. All existing behaviors preserved: - Credential injection (Rewrite hook) - Network policy and Keep policy (wrapping handler) - LLM gateway policy (ModifyResponse via evaluateAndReplaceLLMResponse) - Response transformers (ModifyResponse) - Canonical log lines (ModifyResponse + ErrorHandler) - X-Request-Id injection (Rewrite) - Extra/remove headers, token substitution (Rewrite) - Host gateway rewrite (Rewrite) - Proxy-Authorization preserved from In request for token-exchange Also addresses PR #21 review feedback: - Fix unreachable 407 assertion in NetworkPolicyDenial test - Remove unused BackendHostPort field from test setup - Handle http.NewRequest errors in tests --- .../2026-04-22-websocket-reverseproxy.md | 627 ++++++++++++++++++ proxy/intercept_test.go | 53 +- proxy/proxy.go | 601 ++++++++--------- 3 files changed, 950 insertions(+), 331 deletions(-) create mode 100644 docs/plans/2026-04-22-websocket-reverseproxy.md diff --git a/docs/plans/2026-04-22-websocket-reverseproxy.md b/docs/plans/2026-04-22-websocket-reverseproxy.md new file mode 100644 index 0000000..51e66bb --- /dev/null +++ b/docs/plans/2026-04-22-websocket-reverseproxy.md @@ -0,0 +1,627 @@ +# WebSocket Support via ReverseProxy Refactor — Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace the manual HTTP request loop in `handleConnectWithInterception` with `httputil.ReverseProxy`, enabling WebSocket upgrade support while preserving all existing behaviors. + +**Architecture:** After CONNECT hijack and TLS handshake (unchanged), create a single-connection `http.Server` with an `httputil.ReverseProxy` handler wrapped in a policy-checking middleware. The wrapping handler performs network policy, Keep policy, and credential resolution before delegating to `ReverseProxy`. Credential injection happens in `Rewrite`, response processing in `ModifyResponse`, transport errors in `ErrorHandler`. WebSocket upgrades work automatically via `ReverseProxy.handleUpgradeResponse`. + +**Tech Stack:** Go stdlib `net/http/httputil.ReverseProxy`, `net/http.Server`, `crypto/tls` + +--- + +## File Structure + +| File | Action | Responsibility | +|---|---|---| +| `proxy/proxy.go` | Modify | Replace `handleConnectWithInterception` loop (lines 1812-2177) with `http.Server` + `ReverseProxy` | +| `proxy/intercept_test.go` | Modify | Add WebSocket upgrade test | +| `proxy/proxy_test.go` | Verify | Existing tests must keep passing | + +The refactor is contained to one function in one file. No new files needed — the handler, rewrite, and modify-response logic are methods on `*Proxy` defined inline or as closures within the existing file. + +--- + +## Task 1: Add WebSocket upgrade test (will fail against current code) + +This test establishes the target behavior. It will fail now and pass after the refactor. + +**Files:** +- Modify: `proxy/intercept_test.go` + +- [ ] **Step 1: Write the WebSocket upgrade test** + +Add to `proxy/intercept_test.go`: + +```go +func TestIntercept_WebSocketUpgrade(t *testing.T) { + // Backend that accepts WebSocket upgrades and echoes messages. + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "websocket" { + http.Error(w, "expected websocket upgrade", 400) + return + } + // Minimal WebSocket handshake. + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + + // Hijack and echo bytes back. + hijacker, ok := w.(http.Hijacker) + if !ok { + return + } + conn, brw, err := hijacker.Hijack() + if err != nil { + return + } + defer conn.Close() + brw.Flush() + + // Simple echo: read up to 1024 bytes, write them back. + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + })) + defer backend.Close() + + ca, err := generateCA() + if err != nil { + t.Fatal(err) + } + + upstreamCAs := x509.NewCertPool() + upstreamCAs.AddCert(backend.Certificate()) + + p := NewProxy() + p.SetCA(ca) + p.SetUpstreamCAs(upstreamCAs) + + // Set credential to verify injection on the upgrade request. + backendHost := mustParseURL(backend.URL).Hostname() + p.SetCredential(backendHost, "Bearer ws-token") + + proxyServer := httptest.NewServer(p) + defer proxyServer.Close() + + // Dial through the proxy using CONNECT. + proxyURL := mustParseURL(proxyServer.URL) + proxyConn, err := net.Dial("tcp", proxyURL.Host) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer proxyConn.Close() + + // Send CONNECT. + backendAddr := mustParseURL(backend.URL).Host + fmt.Fprintf(proxyConn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", backendAddr, backendAddr) + br := bufio.NewReader(proxyConn) + connectResp, err := http.ReadResponse(br, nil) + if err != nil { + t.Fatalf("read CONNECT response: %v", err) + } + if connectResp.StatusCode != 200 { + t.Fatalf("CONNECT status = %d, want 200", connectResp.StatusCode) + } + + // TLS handshake with the proxy's interception cert. + clientCAs := x509.NewCertPool() + clientCAs.AppendCertsFromPEM(ca.certPEM) + tlsConn := tls.Client(proxyConn, &tls.Config{ + RootCAs: clientCAs, + ServerName: backendHost, + }) + if err := tlsConn.Handshake(); err != nil { + t.Fatalf("TLS handshake: %v", err) + } + defer tlsConn.Close() + + // Send WebSocket upgrade request. + upgradeReq := "GET /ws HTTP/1.1\r\n" + + "Host: " + backendAddr + "\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + + "Sec-WebSocket-Version: 13\r\n" + + "\r\n" + if _, err := tlsConn.Write([]byte(upgradeReq)); err != nil { + t.Fatalf("write upgrade request: %v", err) + } + + // Read the 101 response. + tlsBr := bufio.NewReader(tlsConn) + upgradeResp, err := http.ReadResponse(tlsBr, nil) + if err != nil { + t.Fatalf("read upgrade response: %v", err) + } + if upgradeResp.StatusCode != http.StatusSwitchingProtocols { + t.Fatalf("upgrade status = %d, want 101", upgradeResp.StatusCode) + } + + // Send a raw message through the WebSocket tunnel. + testMsg := []byte("hello websocket") + if _, err := tlsConn.Write(testMsg); err != nil { + t.Fatalf("write message: %v", err) + } + + // Read echoed message back. + echoBuf := make([]byte, len(testMsg)) + if _, err := io.ReadFull(tlsBr, echoBuf); err != nil { + t.Fatalf("read echo: %v", err) + } + if string(echoBuf) != string(testMsg) { + t.Errorf("echo = %q, want %q", echoBuf, testMsg) + } +} +``` + +- [ ] **Step 2: Run test to verify it fails against current code** + +Run: `go test -run TestIntercept_WebSocketUpgrade -v -count=1 ./proxy/` +Expected: FAIL — the current code will either hang or error with "malformed HTTP request" after the 101. + +- [ ] **Step 3: Commit the failing test** + +```bash +git add proxy/intercept_test.go +git commit -m "test(proxy): add WebSocket upgrade test (expected to fail)" +``` + +--- + +## Task 2: Refactor handleConnectWithInterception to use ReverseProxy + +This is the core change. Replace lines 1812-2177 (the manual `for` loop) with an `http.Server` + `httputil.ReverseProxy`. + +**Files:** +- Modify: `proxy/proxy.go` (lines 1749-2178) + +- [ ] **Step 1: Replace the request loop with http.Server + ReverseProxy** + +Replace the code from line 1812 (`clientReader := bufio.NewReader(tlsClientConn)`) through line 2177 (closing `}` of the for loop) with: + +```go + // Create a reverse proxy that handles request forwarding, including + // WebSocket upgrades via the stdlib's built-in protocol switch support. + reverseProxy := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + // Preserve the original Proxy-Authorization from In before + // ReverseProxy strips hop-by-hop headers. + // token-exchange subject_from: proxy-auth needs this. + proxyAuth := pr.In.Header.Get("Proxy-Authorization") + + pr.Out.URL.Scheme = "https" + connectHost := r.Host + if rc := getRunContext(r); rc != nil && rc.HostGatewayIP != "" && isHostGateway(rc, host) { + connectHost = rewriteHostPort(r.Host, rc.HostGatewayIP) + } + pr.Out.URL.Host = connectHost + pr.Out.Host = pr.In.Host + + // Restore Proxy-Authorization so credential resolver can read it. + if proxyAuth != "" { + pr.Out.Header.Set("Proxy-Authorization", proxyAuth) + } + + // MCP credential injection. + p.injectMCPCredentialsWithContext(r, pr.Out) + + // Credential injection. + creds, credErr := p.getCredentialsForRequest(r, pr.Out, host) + if credErr != nil { + // Store error in context for ErrorHandler to pick up. + *pr.Out = *pr.Out.WithContext(context.WithValue(pr.Out.Context(), interceptCredErrKey{}, credErr)) + return + } + credResult := injectCredentials(pr.Out, creds, host, pr.Out.Method, pr.Out.URL.Path) + + // Store credential result in context for ModifyResponse/logging. + ctx := pr.Out.Context() + ctx = context.WithValue(ctx, interceptCredResultKey{}, credResult) + *pr.Out = *pr.Out.WithContext(ctx) + + // Extra headers. + mergeExtraHeaders(pr.Out, r.Host, p.getExtraHeadersForRequest(r, r.Host)) + + // Strip proxy headers. + pr.Out.Header.Del("Proxy-Connection") + pr.Out.Header.Del("Proxy-Authorization") + + // Remove configured headers (but not injected credential headers). + for _, headerName := range p.getRemoveHeadersForRequest(r, host) { + if credResult.InjectedHeaders[strings.ToLower(headerName)] { + continue + } + pr.Out.Header.Del(headerName) + } + + // Token substitution. + if sub := p.getTokenSubstitutionForRequest(r, host); sub != nil { + p.applyTokenSubstitution(pr.Out, sub) + } + + // Request ID. + if pr.Out.Header.Get("X-Request-Id") == "" { + pr.Out.Header.Set("X-Request-Id", newRequestID()) + } + }, + Transport: transport, + ModifyResponse: func(resp *http.Response) error { + req := resp.Request + + // LLM gateway policy evaluation (Anthropic API only). + if resp.StatusCode == http.StatusOK && host == "api.anthropic.com" { + if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { + if eng, ok := rc.KeepEngines["llm-gateway"]; ok { + p.evaluateAndReplaceLLMResponse(r, req, resp, eng) + } + } + } + + // Response transformers. + if transformers := p.getResponseTransformersForRequest(r, host); len(transformers) > 0 { + for _, transformer := range transformers { + if newRespInterface, transformed := transformer(req, resp); transformed { + if newResp, ok := newRespInterface.(*http.Response); ok { + *resp = *newResp + } + break + } + } + } + + // Canonical log line. + credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + var respBody []byte + respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) + var reqBody []byte + // Request body was already consumed by the transport; capture from context if available. + _ = reqBody + _ = respBody + + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: resp.StatusCode, + Duration: time.Since(reqStartFromContext(req.Context())), + RequestHeaders: req.Header.Clone(), + ResponseHeaders: resp.Header.Clone(), + ResponseBody: respBody, + RequestSize: req.ContentLength, + ResponseSize: resp.ContentLength, + AuthInjected: len(credResult.InjectedHeaders) > 0, + InjectedHeaders: credResult.InjectedHeaders, + Grants: credResult.Grants, + }) + + return nil + }, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + // Check for credential resolution error from Rewrite. + if credErr, ok := req.Context().Value(interceptCredErrKey{}).(error); ok { + http.Error(rw, "credential resolution failed\n", http.StatusBadGateway) + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Err: credErr, + }) + return + } + + rw.WriteHeader(http.StatusBadGateway) + credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Err: err, + AuthInjected: len(credResult.InjectedHeaders) > 0, + InjectedHeaders: credResult.InjectedHeaders, + Grants: credResult.Grants, + }) + }, + } + + // Wrapping handler: policy checks before ReverseProxy. + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Network policy check. + if !p.checkNetworkPolicyForRequest(r, host, connectPort, req.Method, req.URL.Path) { + innerReqID := req.Header.Get("X-Request-Id") + if innerReqID == "" { + innerReqID = newRequestID() + } + p.logRequest(r, RequestLogData{ + RequestID: innerReqID, + Method: req.Method, + URL: "https://" + r.Host + req.URL.Path, + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusProxyAuthRequired, + RequestSize: req.ContentLength, + ResponseSize: -1, + Denied: true, + DenyReason: "Request blocked by network policy: " + req.Method + " " + host + req.URL.Path, + }) + p.logPolicy(r, "network", "http.request", "", req.Method+" "+host+req.URL.Path) + w.Header().Set("X-Moat-Blocked", "request-rule") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusProxyAuthRequired) + fmt.Fprintf(w, "Moat: request blocked by network policy.\nHost: %s\nTo allow this request, update network.rules in moat.yaml.\n", host) + return + } + + // Keep HTTP policy check. + if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { + if eng, ok := rc.KeepEngines["http"]; ok { + call := keeplib.NewHTTPCall(req.Method, host, req.URL.Path) + call.Context.Scope = "http-" + host + result, evalErr := keeplib.SafeEvaluate(eng, call, "http") + if evalErr != nil { + innerReqID := req.Header.Get("X-Request-Id") + if innerReqID == "" { + innerReqID = newRequestID() + } + p.logRequest(r, RequestLogData{ + RequestID: innerReqID, + Method: req.Method, + URL: "https://" + r.Host + req.URL.Path, + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusForbidden, + RequestSize: req.ContentLength, + ResponseSize: -1, + Denied: true, + DenyReason: "Keep policy evaluation error", + Err: evalErr, + }) + p.logPolicy(r, "http", "http.request", "evaluation-error", "Policy evaluation failed") + w.Header().Set("X-Moat-Blocked", "keep-policy") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, "Moat: request blocked — policy evaluation error.\nHost: %s\n", host) + return + } + if result.Decision == keeplib.Deny { + innerReqID := req.Header.Get("X-Request-Id") + if innerReqID == "" { + innerReqID = newRequestID() + } + p.logRequest(r, RequestLogData{ + RequestID: innerReqID, + Method: req.Method, + URL: "https://" + r.Host + req.URL.Path, + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusForbidden, + RequestSize: req.ContentLength, + ResponseSize: -1, + Denied: true, + DenyReason: "Keep policy denied: " + result.Rule + " " + result.Message, + }) + p.logPolicy(r, "http", "http.request", result.Rule, result.Message) + w.Header().Set("X-Moat-Blocked", "keep-policy") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusForbidden) + msg := fmt.Sprintf("Moat: request blocked by Keep policy.\nHost: %s\n", host) + if result.Message != "" { + msg += result.Message + "\n" + } + fmt.Fprint(w, msg) + return + } + } + } + + // Store request start time in context for duration calculation. + ctx := context.WithValue(req.Context(), interceptReqStartKey{}, time.Now()) + reverseProxy.ServeHTTP(w, req.WithContext(ctx)) + }) + + // Serve on a single-connection listener wrapping the TLS connection. + srv := &http.Server{ + Handler: handler, + IdleTimeout: 120 * time.Second, + ErrorLog: log.New(io.Discard, "", 0), // Suppress server-level errors (we handle them in ErrorHandler). + } + srv.Serve(newSingleConnListener(tlsClientConn)) +``` + +This requires several supporting types. Add before the function: + +```go +// Context keys for passing data between ReverseProxy hooks. +type interceptCredResultKey struct{} +type interceptCredErrKey struct{} +type interceptReqStartKey struct{} + +func reqStartFromContext(ctx context.Context) time.Time { + if t, ok := ctx.Value(interceptReqStartKey{}).(time.Time); ok { + return t + } + return time.Now() +} + +// singleConnListener wraps a single net.Conn as a net.Listener. +// Accept returns the connection once, then blocks until Close is called. +type singleConnListener struct { + conn net.Conn + once sync.Once + ch chan net.Conn +} + +func newSingleConnListener(conn net.Conn) *singleConnListener { + ch := make(chan net.Conn, 1) + ch <- conn + return &singleConnListener{conn: conn, ch: ch} +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + conn, ok := <-l.ch + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (l *singleConnListener) Close() error { + l.once.Do(func() { close(l.ch) }) + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.conn.LocalAddr() +} +``` + +Also add `evaluateAndReplaceLLMResponse` as a method that encapsulates the LLM policy logic currently inline in the loop (lines 2024-2106). This keeps ModifyResponse readable: + +```go +// evaluateAndReplaceLLMResponse evaluates LLM gateway policy and replaces +// the response in-place if denied. Called from ModifyResponse. +func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Request, resp *http.Response, eng *keeplib.Engine) { + respBodyBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, maxLLMResponseSize+1)) + resp.Body.Close() + if readErr != nil { + p.logPolicy(ctxReq, "llm-gateway", "llm.read_error", "read-error", "Failed to read response body for policy evaluation") + errorBody := buildPolicyDeniedResponse("read-error", "Failed to read response body for policy evaluation.") + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + return + } + if int64(len(respBodyBytes)) > maxLLMResponseSize { + p.logPolicy(ctxReq, "llm-gateway", "llm.response_too_large", "size-limit", "Response too large for policy evaluation") + errorBody := buildPolicyDeniedResponse("size-limit", "Response too large for policy evaluation.") + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + return + } + result := evaluateLLMResponse(eng, respBodyBytes, resp) + if result.Denied { + p.logPolicy(ctxReq, "llm-gateway", "llm.tool_use", result.Rule, result.Message) + errorBody := buildPolicyDeniedResponse(result.Rule, result.Message) + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + } else if result.Events != nil { + var buf bytes.Buffer + for _, ev := range result.Events { + if ev.ID != "" { + fmt.Fprintf(&buf, "id: %s\n", ev.ID) + } + if ev.Type != "" { + fmt.Fprintf(&buf, "event: %s\n", ev.Type) + } + lines := strings.Split(ev.Data, "\n") + for _, line := range lines { + fmt.Fprintf(&buf, "data: %s\n", line) + } + buf.WriteByte('\n') + } + resp.Header.Del("Content-Encoding") + resp.Body = io.NopCloser(&buf) + resp.ContentLength = int64(buf.Len()) + } else { + resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes)) + resp.ContentLength = int64(len(respBodyBytes)) + } +} +``` + +- [ ] **Step 2: Add required imports** + +Add to the import block in `proxy/proxy.go`: +- `"log"` (for `log.New` in http.Server ErrorLog) +- `"net/http/httputil"` (for ReverseProxy) + +- [ ] **Step 3: Verify compilation** + +Run: `go build ./proxy/` +Expected: compiles cleanly + +- [ ] **Step 4: Run the full test suite** + +Run: `go test -count=1 ./proxy/` +Expected: All existing `TestIntercept_*` and `TestProxy_*` tests pass + +- [ ] **Step 5: Run the WebSocket test** + +Run: `go test -run TestIntercept_WebSocketUpgrade -v -count=1 ./proxy/` +Expected: PASS + +- [ ] **Step 6: Run go vet** + +Run: `go vet ./...` +Expected: clean + +- [ ] **Step 7: Commit** + +```bash +git add proxy/proxy.go +git commit -m "feat(proxy): replace interception loop with ReverseProxy for WebSocket support + +Replace the manual for { http.ReadRequest → transport.RoundTrip → resp.Write } +loop in handleConnectWithInterception with http.Server + httputil.ReverseProxy. + +ReverseProxy natively handles WebSocket upgrades (101 Switching Protocols) +by hijacking both sides and doing bidirectional io.Copy. + +All existing behaviors preserved: credential injection, network policy, +Keep policy, LLM gateway policy, response transformers, canonical log +lines, X-Request-Id, extra/remove headers, token substitution, host +gateway rewrite." +``` + +--- + +## Task 3: Full verification + +- [ ] **Step 1: Run the complete test suite with race detector** + +Run: `go test -race -count=1 ./...` +Expected: All tests pass, no data races + +- [ ] **Step 2: Run go vet** + +Run: `go vet ./...` +Expected: clean + +- [ ] **Step 3: Clean up any dead code** + +Remove the `bufio` import from proxy.go if no longer used (the manual `bufio.NewReader` loop is gone). Check for any other dead code. + +- [ ] **Step 4: Final commit if cleanup needed** + +```bash +git add -A +git commit -m "refactor(proxy): remove dead code from interception loop replacement" +``` diff --git a/proxy/intercept_test.go b/proxy/intercept_test.go index a3757bc..2d0cf61 100644 --- a/proxy/intercept_test.go +++ b/proxy/intercept_test.go @@ -19,13 +19,12 @@ import ( // backend server. The proxy is configured to trust the backend's TLS cert and // the returned client trusts the proxy's interception CA. type interceptTestSetup struct { - Proxy *Proxy - ProxyServer *httptest.Server - Backend *httptest.Server - Client *http.Client - CA *CA - BackendHost string // hostname only (e.g., 127.0.0.1) — for credential matching - BackendHostPort string // host:port (e.g., 127.0.0.1:12345) — for extra/remove header matching + Proxy *Proxy + ProxyServer *httptest.Server + Backend *httptest.Server + Client *http.Client + CA *CA + BackendHost string // hostname only (e.g., 127.0.0.1) — for credential and header matching } func newInterceptTestSetup(t *testing.T, backendHandler http.Handler) *interceptTestSetup { @@ -59,8 +58,7 @@ func newInterceptTestSetup(t *testing.T, backendHandler http.Handler) *intercept }, } - backendHost := mustParseURL(backend.URL).Host // host:port for extra header matching (uses r.Host) - backendHostname := mustParseURL(backend.URL).Hostname() // hostname only for credential matching + backendHostname := mustParseURL(backend.URL).Hostname() t.Cleanup(func() { proxyServer.Close() @@ -68,13 +66,12 @@ func newInterceptTestSetup(t *testing.T, backendHandler http.Handler) *intercept }) return &interceptTestSetup{ - Proxy: p, - ProxyServer: proxyServer, - Backend: backend, - Client: client, - CA: ca, - BackendHost: backendHostname, - BackendHostPort: backendHost, + Proxy: p, + ProxyServer: proxyServer, + Backend: backend, + Client: client, + CA: ca, + BackendHost: backendHostname, } } @@ -175,17 +172,17 @@ func TestIntercept_NetworkPolicyDenial(t *testing.T) { logged = data }) - // The CONNECT itself will be denied before TLS interception. + // Strict policy denies the CONNECT request itself with 407. + // Go's transport returns this as an error (non-200 CONNECT response). resp, err := setup.Client.Get(setup.Backend.URL + "/blocked") if err == nil { resp.Body.Close() - // Under strict policy with no allows, CONNECT is denied with 407. - if resp.StatusCode != http.StatusProxyAuthRequired { - t.Errorf("status = %d, want %d", resp.StatusCode, http.StatusProxyAuthRequired) - } + t.Fatal("expected transport error from denied CONNECT, got nil") + } + // Verify the error message references the 407 status text. + if !strings.Contains(err.Error(), "407") && !strings.Contains(err.Error(), "Proxy Authentication Required") { + t.Errorf("expected 407/Proxy Authentication Required in error, got: %v", err) } - // The client may get a transport error if CONNECT is blocked. - // Either way, the request should be denied. if !logged.Denied { t.Error("expected Denied=true in log") } @@ -308,7 +305,10 @@ func TestIntercept_RemoveHeaders(t *testing.T) { setup.Proxy.RemoveRequestHeader(setup.BackendHost, "X-Api-Key") - req, _ := http.NewRequest("GET", setup.Backend.URL+"/test", nil) + req, err := http.NewRequest("GET", setup.Backend.URL+"/test", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } req.Header.Set("X-Api-Key", "stale-key") resp, err := setup.Client.Do(req) if err != nil { @@ -460,7 +460,10 @@ func TestIntercept_HTTPMethods(t *testing.T) { w.Write([]byte("ok")) })) - req, _ := http.NewRequest(method, setup.Backend.URL+"/method", nil) + req, err := http.NewRequest(method, setup.Backend.URL+"/method", nil) + if err != nil { + t.Fatalf("NewRequest: %v", err) + } resp, err := setup.Client.Do(req) if err != nil { t.Fatalf("request: %v", err) diff --git a/proxy/proxy.go b/proxy/proxy.go index cac0512..87e719d 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -29,7 +29,6 @@ package proxy import ( - "bufio" "bytes" "context" "crypto/subtle" @@ -38,9 +37,11 @@ import ( "encoding/base64" "fmt" "io" + "log" "log/slog" "net" "net/http" + "net/http/httputil" "net/url" "strconv" "strings" @@ -1746,6 +1747,110 @@ func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) { }() } +// Context keys for passing data between ReverseProxy hooks in the interception path. +type interceptCredResultKey struct{} +type interceptCredErrKey struct{} +type interceptReqStartKey struct{} + +func reqStartFromContext(ctx context.Context) time.Time { + if t, ok := ctx.Value(interceptReqStartKey{}).(time.Time); ok { + return t + } + return time.Now() +} + +// singleConnListener wraps a single net.Conn as a net.Listener. +// Accept returns the connection once, then blocks until Close is called. +type singleConnListener struct { + conn net.Conn + once sync.Once + ch chan net.Conn +} + +func newSingleConnListener(conn net.Conn) *singleConnListener { + ch := make(chan net.Conn, 1) + ch <- conn + return &singleConnListener{conn: conn, ch: ch} +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + conn, ok := <-l.ch + if !ok { + return nil, io.EOF + } + return conn, nil +} + +func (l *singleConnListener) Close() error { + l.once.Do(func() { close(l.ch) }) + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.conn.LocalAddr() +} + +// evaluateAndReplaceLLMResponse evaluates LLM gateway policy and replaces +// the response in-place if denied. Called from ModifyResponse. +func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Request, resp *http.Response, eng *keeplib.Engine) { + respBodyBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, maxLLMResponseSize+1)) + resp.Body.Close() + if readErr != nil { + p.logPolicy(ctxReq, "llm-gateway", "llm.read_error", "read-error", "Failed to read response body for policy evaluation") + errorBody := buildPolicyDeniedResponse("read-error", "Failed to read response body for policy evaluation.") + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + return + } + if int64(len(respBodyBytes)) > maxLLMResponseSize { + p.logPolicy(ctxReq, "llm-gateway", "llm.response_too_large", "size-limit", "Response too large for policy evaluation") + errorBody := buildPolicyDeniedResponse("size-limit", "Response too large for policy evaluation.") + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + return + } + result := evaluateLLMResponse(eng, respBodyBytes, resp) + if result.Denied { + p.logPolicy(ctxReq, "llm-gateway", "llm.tool_use", result.Rule, result.Message) + errorBody := buildPolicyDeniedResponse(result.Rule, result.Message) + resp.StatusCode = http.StatusBadRequest + resp.Header = make(http.Header) + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Moat-Blocked", "llm-policy") + resp.ContentLength = int64(len(errorBody)) + resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + } else if result.Events != nil { + var buf bytes.Buffer + for _, ev := range result.Events { + if ev.ID != "" { + fmt.Fprintf(&buf, "id: %s\n", ev.ID) + } + if ev.Type != "" { + fmt.Fprintf(&buf, "event: %s\n", ev.Type) + } + lines := strings.Split(ev.Data, "\n") + for _, line := range lines { + fmt.Fprintf(&buf, "data: %s\n", line) + } + buf.WriteByte('\n') + } + resp.Header.Del("Content-Encoding") + resp.Body = io.NopCloser(&buf) + resp.ContentLength = int64(buf.Len()) + } else { + resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes)) + resp.ContentLength = int64(len(respBodyBytes)) + } +} + func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Request, host string) { hijacker, ok := w.(http.Hijacker) if !ok { @@ -1809,370 +1914,254 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req } } - clientReader := bufio.NewReader(tlsClientConn) - for { - req, err := http.ReadRequest(clientReader) - if err != nil { - if err != io.EOF { - slog.Debug("failed to read request from intercepted connection", - "subsystem", "proxy", "host", host, "error", err) + // Create a reverse proxy that handles request forwarding, including + // WebSocket upgrades via the stdlib's built-in protocol switch support. + reverseProxy := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + // Preserve the original Proxy-Authorization from In before + // ReverseProxy strips hop-by-hop headers. + // token-exchange subject_from: proxy-auth needs this. + proxyAuth := pr.In.Header.Get("Proxy-Authorization") + + pr.Out.URL.Scheme = "https" + connectHost := r.Host + if rc := getRunContext(r); rc != nil && rc.HostGatewayIP != "" && isHostGateway(rc, host) { + connectHost = rewriteHostPort(r.Host, rc.HostGatewayIP) + } + pr.Out.URL.Host = connectHost + pr.Out.Host = pr.In.Host + pr.Out.RequestURI = "" + + // Restore Proxy-Authorization so credential resolver can read it. + if proxyAuth != "" { + pr.Out.Header.Set("Proxy-Authorization", proxyAuth) + } + + // MCP credential injection. + p.injectMCPCredentialsWithContext(r, pr.Out) + + // Credential injection. + creds, credErr := p.getCredentialsForRequest(r, pr.Out, host) + if credErr != nil { + // Store error in context for ErrorHandler to pick up. + *pr.Out = *pr.Out.WithContext(context.WithValue(pr.Out.Context(), interceptCredErrKey{}, credErr)) + return + } + credResult := injectCredentials(pr.Out, creds, host, pr.Out.Method, pr.Out.URL.Path) + + // Store credential result in context for ModifyResponse/logging. + ctx := context.WithValue(pr.Out.Context(), interceptCredResultKey{}, credResult) + *pr.Out = *pr.Out.WithContext(ctx) + + // Extra headers. + mergeExtraHeaders(pr.Out, r.Host, p.getExtraHeadersForRequest(r, r.Host)) + + // Strip proxy headers. + pr.Out.Header.Del("Proxy-Connection") + pr.Out.Header.Del("Proxy-Authorization") + + // Remove configured headers (but not injected credential headers). + for _, headerName := range p.getRemoveHeadersForRequest(r, host) { + if credResult.InjectedHeaders[strings.ToLower(headerName)] { + continue + } + pr.Out.Header.Del(headerName) + } + + // Token substitution. + if sub := p.getTokenSubstitutionForRequest(r, host); sub != nil { + p.applyTokenSubstitution(pr.Out, sub) + } + + // Request ID. + if pr.Out.Header.Get("X-Request-Id") == "" { + pr.Out.Header.Set("X-Request-Id", newRequestID()) + } + }, + Transport: transport, + ModifyResponse: func(resp *http.Response) error { + req := resp.Request + + // LLM gateway policy evaluation (Anthropic API only). + if resp.StatusCode == http.StatusOK && host == "api.anthropic.com" { + if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { + if eng, ok := rc.KeepEngines["llm-gateway"]; ok { + p.evaluateAndReplaceLLMResponse(r, req, resp, eng) + } + } + } + + // Response transformers. + if transformers := p.getResponseTransformersForRequest(r, host); len(transformers) > 0 { + for _, transformer := range transformers { + if newRespInterface, transformed := transformer(req, resp); transformed { + if newResp, ok := newRespInterface.(*http.Response); ok { + *resp = *newResp + } + break + } + } + } + + // Canonical log line. + credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + var respBody []byte + respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) + + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: resp.StatusCode, + Duration: time.Since(reqStartFromContext(req.Context())), + RequestHeaders: req.Header.Clone(), + ResponseHeaders: resp.Header.Clone(), + ResponseBody: respBody, + RequestSize: req.ContentLength, + ResponseSize: resp.ContentLength, + AuthInjected: len(credResult.InjectedHeaders) > 0, + InjectedHeaders: credResult.InjectedHeaders, + Grants: credResult.Grants, + }) + + return nil + }, + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + // Check for credential resolution error from Rewrite. + if credErr, ok := req.Context().Value(interceptCredErrKey{}).(error); ok { + rw.Header().Set("Content-Type", "text/plain") + rw.WriteHeader(http.StatusBadGateway) + fmt.Fprint(rw, "credential resolution failed\n") + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Err: credErr, + }) + return } - return - } + rw.WriteHeader(http.StatusBadGateway) + credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + p.logRequest(r, RequestLogData{ + RequestID: req.Header.Get("X-Request-Id"), + Method: req.Method, + URL: req.URL.String(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Err: err, + AuthInjected: len(credResult.InjectedHeaders) > 0, + InjectedHeaders: credResult.InjectedHeaders, + Grants: credResult.Grants, + }) + }, + } + + // Wrapping handler: policy checks before ReverseProxy. + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { innerReqID := req.Header.Get("X-Request-Id") if innerReqID == "" { innerReqID = newRequestID() } - reqStart := time.Now() - req.URL.Scheme = "https" - // Rewrite synthetic host-gateway hostname to actual IP for forwarding. - connectHost := r.Host - if rc := getRunContext(r); rc != nil && rc.HostGatewayIP != "" && isHostGateway(rc, host) { - connectHost = rewriteHostPort(r.Host, rc.HostGatewayIP) - } - req.URL.Host = connectHost - req.RequestURI = "" - // Check request-level rules (method + path) for the inner HTTP request. - // The CONNECT request r carries the per-run context for rule lookup. + // Network policy check. if !p.checkNetworkPolicyForRequest(r, host, connectPort, req.Method, req.URL.Path) { p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: req.URL.String(), + URL: "https://" + r.Host + req.URL.Path, Host: host, Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusProxyAuthRequired, - Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, DenyReason: "Request blocked by network policy: " + req.Method + " " + host + req.URL.Path, }) p.logPolicy(r, "network", "http.request", "", req.Method+" "+host+req.URL.Path) - body := "Moat: request blocked by network policy.\nHost: " + host + "\nTo allow this request, update network.rules in moat.yaml.\n" - blockedResp := &http.Response{ - StatusCode: http.StatusProxyAuthRequired, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(body)), - Body: io.NopCloser(strings.NewReader(body)), - } - blockedResp.Header.Set("X-Moat-Blocked", "request-rule") - blockedResp.Header.Set("Content-Type", "text/plain") - _ = blockedResp.Write(tlsClientConn) - continue + w.Header().Set("X-Moat-Blocked", "request-rule") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusProxyAuthRequired) + fmt.Fprintf(w, "Moat: request blocked by network policy.\nHost: %s\nTo allow this request, update network.rules in moat.yaml.\n", host) + return } - // Evaluate Keep policy for the inner HTTP request. - // Uses the global "http" engine from network.keep_policy. + // Keep HTTP policy check. if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { - scope := "http" - if eng, ok := rc.KeepEngines[scope]; ok { + if eng, ok := rc.KeepEngines["http"]; ok { call := keeplib.NewHTTPCall(req.Method, host, req.URL.Path) call.Context.Scope = "http-" + host - result, evalErr := keeplib.SafeEvaluate(eng, call, scope) + result, evalErr := keeplib.SafeEvaluate(eng, call, "http") if evalErr != nil { p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: req.URL.String(), + URL: "https://" + r.Host + req.URL.Path, Host: host, Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusForbidden, - Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, DenyReason: "Keep policy evaluation error", Err: evalErr, }) - p.logPolicy(r, scope, "http.request", "evaluation-error", "Policy evaluation failed") - msg := "Moat: request blocked — policy evaluation error.\nHost: " + host + "\n" - blockedResp := &http.Response{ - StatusCode: http.StatusForbidden, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(msg)), - Body: io.NopCloser(strings.NewReader(msg)), - } - blockedResp.Header.Set("X-Moat-Blocked", "keep-policy") - blockedResp.Header.Set("Content-Type", "text/plain") - _ = blockedResp.Write(tlsClientConn) - continue - } else if result.Decision == keeplib.Deny { + p.logPolicy(r, "http", "http.request", "evaluation-error", "Policy evaluation failed") + w.Header().Set("X-Moat-Blocked", "keep-policy") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, "Moat: request blocked — policy evaluation error.\nHost: %s\n", host) + return + } + if result.Decision == keeplib.Deny { p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: req.URL.String(), + URL: "https://" + r.Host + req.URL.Path, Host: host, Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusForbidden, - Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, DenyReason: "Keep policy denied: " + result.Rule + " " + result.Message, }) - p.logPolicy(r, scope, "http.request", result.Rule, result.Message) - msg := "Moat: request blocked by Keep policy.\nHost: " + host + "\n" + p.logPolicy(r, "http", "http.request", result.Rule, result.Message) + w.Header().Set("X-Moat-Blocked", "keep-policy") + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusForbidden) + msg := fmt.Sprintf("Moat: request blocked by Keep policy.\nHost: %s\n", host) if result.Message != "" { msg += result.Message + "\n" } - blockedResp := &http.Response{ - StatusCode: http.StatusForbidden, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(msg)), - Body: io.NopCloser(strings.NewReader(msg)), - } - blockedResp.Header.Set("X-Moat-Blocked", "keep-policy") - blockedResp.Header.Set("Content-Type", "text/plain") - _ = blockedResp.Write(tlsClientConn) - continue - } - } - } - - // Inject MCP credentials if this is an MCP request. - // Use the CONNECT request r for context lookups since inner - // requests from the TLS stream don't carry the request context. - p.injectMCPCredentialsWithContext(r, req) - - creds, credErr := p.getCredentialsForRequest(r, req, host) - if credErr != nil { - body := "credential resolution failed\n" - errResp := &http.Response{ - StatusCode: http.StatusBadGateway, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(body)), - Body: io.NopCloser(strings.NewReader(body)), - } - errResp.Header.Set("Content-Type", "text/plain") - _ = errResp.Write(tlsClientConn) - p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: req.URL.String(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusBadGateway, - Duration: time.Since(reqStart), - RequestSize: req.ContentLength, - ResponseSize: -1, - Err: credErr, - }) - continue - } - - // Capture request body and headers after credential resolution - // so that resolver side effects (e.g., subject header stripping) - // are reflected and sensitive headers are not logged. - var reqBody []byte - reqBody, req.Body = captureBody(req.Body, req.Header.Get("Content-Type")) - originalReqHeaders := req.Header.Clone() - - credResult := injectCredentials(req, creds, host, req.Method, req.URL.Path) - - // Inject any additional headers configured for this host. - // Merges with existing values (comma-separated) to preserve client - // headers like anthropic-beta that support multiple flags. - mergeExtraHeaders(req, r.Host, p.getExtraHeadersForRequest(r, r.Host)) - req.Header.Del("Proxy-Connection") - req.Header.Del("Proxy-Authorization") - - // Remove headers that should be stripped for this host, but never - // remove a credential header the proxy just injected (see comment - // in handleHTTP for the multi-grant conflict scenario). - for _, headerName := range p.getRemoveHeadersForRequest(r, host) { - if credResult.InjectedHeaders[strings.ToLower(headerName)] { - continue - } - req.Header.Del(headerName) - } - // Apply token substitution if configured for this host. - // Capture the URL before substitution so logs don't contain real tokens. - logURL := req.URL.String() - if sub := p.getTokenSubstitutionForRequest(r, host); sub != nil { - p.applyTokenSubstitution(req, sub) - } - - if req.Header.Get("X-Request-Id") == "" { - req.Header.Set("X-Request-Id", innerReqID) - } - - resp, err := transport.RoundTrip(req) - - // Track LLM policy denials for the canonical log line. - var llmDenied bool - var llmDenyReason string - - // Evaluate LLM gateway policy on Anthropic API responses. - // NOTE: Only applies to the default Anthropic endpoint. Custom - // ANTHROPIC_BASE_URL endpoints bypass policy evaluation — this is - // mutually exclusive with llm-gateway (see config validation). - if resp != nil && resp.StatusCode == http.StatusOK && host == "api.anthropic.com" { - if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { - if eng, ok := rc.KeepEngines["llm-gateway"]; ok { - respBodyBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, maxLLMResponseSize+1)) - resp.Body.Close() - if readErr != nil { - p.logPolicy(r, "llm-gateway", "llm.read_error", "read-error", "Failed to read response body for policy evaluation") - llmDenied = true - llmDenyReason = "LLM policy read error" - errorBody := buildPolicyDeniedResponse("read-error", "Failed to read response body for policy evaluation.") - resp = &http.Response{ - StatusCode: http.StatusBadRequest, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(errorBody)), - Body: io.NopCloser(bytes.NewReader(errorBody)), - } - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - } else if int64(len(respBodyBytes)) > maxLLMResponseSize { - p.logPolicy(r, "llm-gateway", "llm.response_too_large", "size-limit", "Response too large for policy evaluation") - llmDenied = true - llmDenyReason = "LLM policy response too large" - errorBody := buildPolicyDeniedResponse("size-limit", "Response too large for policy evaluation.") - resp = &http.Response{ - StatusCode: http.StatusBadRequest, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(errorBody)), - Body: io.NopCloser(bytes.NewReader(errorBody)), - } - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - } else { - result := evaluateLLMResponse(eng, respBodyBytes, resp) - if result.Denied { - p.logPolicy(r, "llm-gateway", "llm.tool_use", result.Rule, result.Message) - llmDenied = true - llmDenyReason = "LLM policy denied: " + result.Rule + " " + result.Message - errorBody := buildPolicyDeniedResponse(result.Rule, result.Message) - resp = &http.Response{ - StatusCode: http.StatusBadRequest, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - ContentLength: int64(len(errorBody)), - Body: io.NopCloser(bytes.NewReader(errorBody)), - } - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - } else if result.Events != nil { - // SSE response allowed — re-serialize evaluated events. - // Events were decompressed for evaluation, so the - // re-serialized body is plaintext — remove Content-Encoding. - var buf bytes.Buffer - for _, ev := range result.Events { - if ev.ID != "" { - fmt.Fprintf(&buf, "id: %s\n", ev.ID) - } - if ev.Type != "" { - fmt.Fprintf(&buf, "event: %s\n", ev.Type) - } - // Per SSE spec, multi-line data needs a `data:` prefix per line. - lines := strings.Split(ev.Data, "\n") - for _, line := range lines { - fmt.Fprintf(&buf, "data: %s\n", line) - } - buf.WriteByte('\n') // Event terminator. - } - resp.Header.Del("Content-Encoding") - resp.Body = io.NopCloser(&buf) - resp.ContentLength = int64(buf.Len()) - } else { - // JSON response allowed — restore original body. - resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes)) - resp.ContentLength = int64(len(respBodyBytes)) - } - } + fmt.Fprint(w, msg) + return } } } - // Capture response - var respBody []byte - var respHeaders http.Header - statusCode := http.StatusBadGateway - var responseSize int64 = -1 - if resp != nil { - respHeaders = resp.Header.Clone() - statusCode = resp.StatusCode - responseSize = resp.ContentLength - - // Apply response transformers BEFORE capturing body - // so transformer can read the original response body. - // Only the first transformer that returns true is applied (transformers are not chained). - if transformers := p.getResponseTransformersForRequest(r, host); len(transformers) > 0 { - for _, transformer := range transformers { - if newRespInterface, transformed := transformer(req, resp); transformed { - if newResp, ok := newRespInterface.(*http.Response); ok { - resp = newResp - statusCode = resp.StatusCode - respHeaders = resp.Header.Clone() - } - break // Only apply first matching transformer - } - } - } - - // Capture body AFTER transformation - respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) - } - - p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: logURL, - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: statusCode, - Duration: time.Since(reqStart), - Err: err, - RequestHeaders: originalReqHeaders, - ResponseHeaders: respHeaders, - RequestBody: reqBody, - ResponseBody: respBody, - RequestSize: req.ContentLength, - ResponseSize: responseSize, - InjectedHeaders: credResult.InjectedHeaders, - Grants: credResult.Grants, - Denied: llmDenied, - DenyReason: llmDenyReason, - }) - - if err != nil { - errResp := &http.Response{ - StatusCode: http.StatusBadGateway, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - _ = errResp.Write(tlsClientConn) - continue - } - - _ = resp.Write(tlsClientConn) - resp.Body.Close() + // Store request start time in context for duration calculation. + ctx := context.WithValue(req.Context(), interceptReqStartKey{}, time.Now()) + reverseProxy.ServeHTTP(w, req.WithContext(ctx)) + }) - if resp.Close || req.Close { - return - } + // Serve on a single-connection listener wrapping the TLS connection. + srv := &http.Server{ + Handler: handler, + IdleTimeout: 120 * time.Second, + ErrorLog: log.New(io.Discard, "", 0), // Suppress server-level errors; handled in ErrorHandler. } + _ = srv.Serve(newSingleConnListener(tlsClientConn)) } From ca0b3e8c8622ae64cb37e62c6d1b1d165fc04d9a Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:30:35 +0000 Subject: [PATCH 4/8] docs: add v0.9.0 changelog entry, remove design docs --- CHANGELOG.md | 11 + ...026-04-22-websocket-reverseproxy-design.md | 134 ---- .../2026-04-22-websocket-reverseproxy.md | 627 ------------------ 3 files changed, 11 insertions(+), 761 deletions(-) delete mode 100644 docs/2026-04-22-websocket-reverseproxy-design.md delete mode 100644 docs/plans/2026-04-22-websocket-reverseproxy.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 21dc1f2..afaeb1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,17 @@ Gatekeeper is a standalone credential-injecting TLS-intercepting proxy. It trans Gatekeeper is pre-1.0. The configuration schema and credential source interface may change between minor versions. +## v0.9.0 — 2026-04-22 + +### Added + +- **WebSocket support through TLS interception** — WebSocket upgrades (101 Switching Protocols) now work through CONNECT+TLS intercepted connections; credentials are injected on the upgrade request, then the proxy switches to bidirectional byte tunneling for WebSocket frames ([#22](https://github.com/majorcontext/gatekeeper/pull/22)) + +### Changed + +- **Refactored `handleConnectWithInterception`** — replaced the manual `http.ReadRequest` → `transport.RoundTrip` → `resp.Write` loop with `httputil.ReverseProxy` served via a single-connection `http.Server`; all existing behaviors (credential injection, network/Keep policy, LLM gateway policy, response transformers, canonical logging) are preserved through `Rewrite`, `ModifyResponse`, and `ErrorHandler` hooks ([#22](https://github.com/majorcontext/gatekeeper/pull/22)) +- **Extracted `evaluateAndReplaceLLMResponse`** — LLM gateway policy evaluation logic moved from inline in the request loop to a standalone method for readability ([#22](https://github.com/majorcontext/gatekeeper/pull/22)) + ## v0.8.0 — 2026-04-22 ### Added diff --git a/docs/2026-04-22-websocket-reverseproxy-design.md b/docs/2026-04-22-websocket-reverseproxy-design.md deleted file mode 100644 index 5917bb4..0000000 --- a/docs/2026-04-22-websocket-reverseproxy-design.md +++ /dev/null @@ -1,134 +0,0 @@ -# WebSocket Support via ReverseProxy Refactor - -**Date:** 2026-04-22 -**Status:** Approved -**Scope:** `proxy/proxy.go` — `handleConnectWithInterception` - -## Problem - -Gatekeeper's TLS interception path manually reads HTTP requests in a loop (`http.ReadRequest` → `transport.RoundTrip` → `resp.Write`). After a WebSocket upgrade (HTTP 101 Switching Protocols), the client sends binary WebSocket frames which `http.ReadRequest` cannot parse, causing `"malformed HTTP request"` errors and connection drops. - -## Solution - -Replace the manual request loop in `handleConnectWithInterception` with an `http.Server` serving on the client-side TLS connection, using `httputil.ReverseProxy` as the handler. Go 1.25's `ReverseProxy` natively handles WebSocket upgrades — it detects `Upgrade` headers, preserves them through hop-by-hop removal, hijacks both sides on a `101` response, and does bidirectional `io.Copy`. - -## Architecture - -``` -Client ←TLS→ http.Server(tlsClientConn) → ReverseProxy → upstream -``` - -### Flow - -1. CONNECT arrives, proxy hijacks, sends `200 Connection Established` (unchanged) -2. TLS handshake with client using generated cert (unchanged) -3. **New:** Create a single-connection `http.Server` with `httputil.ReverseProxy` as handler -4. `http.Server.Serve()` manages the request loop (replaces manual `for` + `http.ReadRequest`) -5. For normal HTTP: `ReverseProxy` forwards via `Transport.RoundTrip`, credential injection in `Rewrite` -6. For WebSocket: `ReverseProxy` detects `101`, hijacks, bidirectional copy — no custom code needed - -### Feature Mapping - -Every feature in the current manual loop maps to a `ReverseProxy` hook: - -| Feature | Current location | New location | -|---|---|---| -| Network policy check | Loop body | Wrapping handler (before ReverseProxy) | -| Keep HTTP policy | Loop body | Wrapping handler (before ReverseProxy) | -| Credential injection (`injectCredentials`) | Loop body | `Rewrite` on `ProxyRequest.Out` | -| MCP credential injection | Loop body | `Rewrite` | -| Extra headers / remove headers | Loop body | `Rewrite` | -| Token substitution | Loop body | `Rewrite` | -| Request ID generation | Loop body | `Rewrite` | -| Host gateway IP rewrite | Loop body, modifies dial target | `Rewrite` (rewrite URL host) or custom `Transport.DialContext` | -| Proxy-Authorization stripping | Loop body | `Rewrite` (read from `ProxyRequest.In` before hop-by-hop removal) | -| Credential resolver (token-exchange) | Loop body | `Rewrite` (read subject from `In.Header`, resolve, set on `Out`) | -| LLM gateway policy | Loop body, post-response | `ModifyResponse` | -| Response transformers | Loop body, post-response | `ModifyResponse` | -| Body capture for logging | Loop body | `ModifyResponse` (response) and `Rewrite` (request) | -| Canonical log line | Loop body | `ModifyResponse` + `ErrorHandler` | -| OTel span/metrics | Loop body via callbacks | Wrapping handler or `ModifyResponse` | -| Transport error → 502 | Loop body | `ErrorHandler` | -| WebSocket upgrade | **Not supported** | Built-in `ReverseProxy.handleUpgradeResponse` | - -### Key Design Decisions - -**Proxy-Authorization before hop-by-hop removal:** `ReverseProxy` strips hop-by-hop headers (including `Proxy-Authorization`) before calling `Rewrite`. For `subject_from: proxy-auth` token exchange, the subject identity must be extracted from `ProxyRequest.In` (which preserves original headers) rather than `ProxyRequest.Out`. - -**Single-connection http.Server:** The `http.Server` serves on a `net.Listener` wrapping the single TLS connection. When the connection closes, `Serve` returns. This replaces the manual `for` loop and gets HTTP keepalive, pipelining, and protocol upgrade handling from the stdlib. - -**Per-connection transport:** The `http.Transport` is created per-CONNECT connection (same as today). `ForceAttemptHTTP2` remains disabled — the intercepted connection reads HTTP/1.1. - -**No behavioral changes:** All external APIs (`Proxy`, `RunContextData`, config) remain identical. This is purely an internal refactor of one function. - -### Handling Policy Denials in Rewrite - -The current loop writes error responses (407, 403, 502) directly to the TLS connection and continues the loop. With `ReverseProxy`, the `Rewrite` function cannot write responses directly. Two options: - -**Option A — Wrapping handler:** A handler that runs policy checks before delegating to `ReverseProxy`. On denial, it writes the error response itself and does not call `ReverseProxy.ServeHTTP`. This is the cleanest approach. - -**Option B — Rewrite sets a sentinel, ErrorHandler acts on it.** `Rewrite` stores a denial in the request context, `ModifyResponse` or a custom `RoundTripper` wrapper checks for it. More complex, less readable. - -**Decision:** Option A. The wrapping handler pattern is idiomatic and keeps policy logic separate from forwarding logic. - -```go -func (p *Proxy) interceptHandler(host string, rc *RunContextData, transport *http.Transport) http.Handler { - rp := &httputil.ReverseProxy{ - Rewrite: p.rewriteIntercepted(host, rc), - Transport: transport, - ModifyResponse: p.modifyInterceptedResponse(host, rc), - ErrorHandler: p.interceptErrorHandler(host, rc), - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Network policy, Keep HTTP policy checks here - // On denial: write error response, return - // On allow: rp.ServeHTTP(w, r) - }) -} -``` - -## Testing Strategy - -Tests are written first (TDD) against the current code to establish behavioral baselines, then the refactor must keep them passing. - -### New tests to add before refactor - -1. **Normal HTTPS through interception** — credential injection verified on upstream request -2. **WebSocket upgrade through interception** — upgrade succeeds, bidirectional frame exchange works (will fail against current code, pass after refactor) -3. **Multi-request keepalive** — multiple requests over single CONNECT tunnel -4. **Network policy denial on inner request** — 407 returned, connection stays alive -5. **Transport error** — unreachable upstream, 502 returned, canonical log line emitted -6. **Credential resolver via CONNECT** — token-exchange with `subject_from: proxy-auth` -7. **Host gateway through interception** — gateway hostname rewritten to actual IP - -### Existing tests that must keep passing - -All tests in `proxy/proxy_test.go`, particularly: -- `TestProxy_CanonicalLogLine_ConnectTransportError` -- `TestProxy_CanonicalLogLine_ConnectBlocked` -- All credential injection, policy, and logging tests - -## Implementation Plan - -### Phase 1: Test baseline (TDD) -Write the new tests listed above against the current code. All should pass except the WebSocket test. - -### Phase 2: Extract helpers -Extract the inline policy/credential/logging logic from the current loop into named methods that can be called from both the old loop and the new handler. This is a refactor-only step — no behavioral changes. - -### Phase 3: Build the ReverseProxy handler -Implement `interceptHandler` with `Rewrite`, `ModifyResponse`, `ErrorHandler`, and the wrapping handler for policy checks. Wire it into `handleConnectWithInterception` replacing the manual loop. - -### Phase 4: WebSocket test passes -The WebSocket upgrade test should now pass with zero additional code. - -### Phase 5: Verify and clean up -Run full test suite, remove dead code from the old loop, verify OTel instrumentation. - -## Out of Scope - -- Changing the non-interception tunnel path (`handleConnectTunnel`) -- Changing the HTTP relay path (`handleHTTP`) -- Changing the MCP relay handler -- Config schema changes -- New config options for WebSocket-specific behavior diff --git a/docs/plans/2026-04-22-websocket-reverseproxy.md b/docs/plans/2026-04-22-websocket-reverseproxy.md deleted file mode 100644 index 51e66bb..0000000 --- a/docs/plans/2026-04-22-websocket-reverseproxy.md +++ /dev/null @@ -1,627 +0,0 @@ -# WebSocket Support via ReverseProxy Refactor — Implementation Plan - -> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Replace the manual HTTP request loop in `handleConnectWithInterception` with `httputil.ReverseProxy`, enabling WebSocket upgrade support while preserving all existing behaviors. - -**Architecture:** After CONNECT hijack and TLS handshake (unchanged), create a single-connection `http.Server` with an `httputil.ReverseProxy` handler wrapped in a policy-checking middleware. The wrapping handler performs network policy, Keep policy, and credential resolution before delegating to `ReverseProxy`. Credential injection happens in `Rewrite`, response processing in `ModifyResponse`, transport errors in `ErrorHandler`. WebSocket upgrades work automatically via `ReverseProxy.handleUpgradeResponse`. - -**Tech Stack:** Go stdlib `net/http/httputil.ReverseProxy`, `net/http.Server`, `crypto/tls` - ---- - -## File Structure - -| File | Action | Responsibility | -|---|---|---| -| `proxy/proxy.go` | Modify | Replace `handleConnectWithInterception` loop (lines 1812-2177) with `http.Server` + `ReverseProxy` | -| `proxy/intercept_test.go` | Modify | Add WebSocket upgrade test | -| `proxy/proxy_test.go` | Verify | Existing tests must keep passing | - -The refactor is contained to one function in one file. No new files needed — the handler, rewrite, and modify-response logic are methods on `*Proxy` defined inline or as closures within the existing file. - ---- - -## Task 1: Add WebSocket upgrade test (will fail against current code) - -This test establishes the target behavior. It will fail now and pass after the refactor. - -**Files:** -- Modify: `proxy/intercept_test.go` - -- [ ] **Step 1: Write the WebSocket upgrade test** - -Add to `proxy/intercept_test.go`: - -```go -func TestIntercept_WebSocketUpgrade(t *testing.T) { - // Backend that accepts WebSocket upgrades and echoes messages. - backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Upgrade") != "websocket" { - http.Error(w, "expected websocket upgrade", 400) - return - } - // Minimal WebSocket handshake. - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - w.WriteHeader(http.StatusSwitchingProtocols) - - // Hijack and echo bytes back. - hijacker, ok := w.(http.Hijacker) - if !ok { - return - } - conn, brw, err := hijacker.Hijack() - if err != nil { - return - } - defer conn.Close() - brw.Flush() - - // Simple echo: read up to 1024 bytes, write them back. - buf := make([]byte, 1024) - n, err := conn.Read(buf) - if err != nil { - return - } - conn.Write(buf[:n]) - })) - defer backend.Close() - - ca, err := generateCA() - if err != nil { - t.Fatal(err) - } - - upstreamCAs := x509.NewCertPool() - upstreamCAs.AddCert(backend.Certificate()) - - p := NewProxy() - p.SetCA(ca) - p.SetUpstreamCAs(upstreamCAs) - - // Set credential to verify injection on the upgrade request. - backendHost := mustParseURL(backend.URL).Hostname() - p.SetCredential(backendHost, "Bearer ws-token") - - proxyServer := httptest.NewServer(p) - defer proxyServer.Close() - - // Dial through the proxy using CONNECT. - proxyURL := mustParseURL(proxyServer.URL) - proxyConn, err := net.Dial("tcp", proxyURL.Host) - if err != nil { - t.Fatalf("dial proxy: %v", err) - } - defer proxyConn.Close() - - // Send CONNECT. - backendAddr := mustParseURL(backend.URL).Host - fmt.Fprintf(proxyConn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", backendAddr, backendAddr) - br := bufio.NewReader(proxyConn) - connectResp, err := http.ReadResponse(br, nil) - if err != nil { - t.Fatalf("read CONNECT response: %v", err) - } - if connectResp.StatusCode != 200 { - t.Fatalf("CONNECT status = %d, want 200", connectResp.StatusCode) - } - - // TLS handshake with the proxy's interception cert. - clientCAs := x509.NewCertPool() - clientCAs.AppendCertsFromPEM(ca.certPEM) - tlsConn := tls.Client(proxyConn, &tls.Config{ - RootCAs: clientCAs, - ServerName: backendHost, - }) - if err := tlsConn.Handshake(); err != nil { - t.Fatalf("TLS handshake: %v", err) - } - defer tlsConn.Close() - - // Send WebSocket upgrade request. - upgradeReq := "GET /ws HTTP/1.1\r\n" + - "Host: " + backendAddr + "\r\n" + - "Upgrade: websocket\r\n" + - "Connection: Upgrade\r\n" + - "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + - "Sec-WebSocket-Version: 13\r\n" + - "\r\n" - if _, err := tlsConn.Write([]byte(upgradeReq)); err != nil { - t.Fatalf("write upgrade request: %v", err) - } - - // Read the 101 response. - tlsBr := bufio.NewReader(tlsConn) - upgradeResp, err := http.ReadResponse(tlsBr, nil) - if err != nil { - t.Fatalf("read upgrade response: %v", err) - } - if upgradeResp.StatusCode != http.StatusSwitchingProtocols { - t.Fatalf("upgrade status = %d, want 101", upgradeResp.StatusCode) - } - - // Send a raw message through the WebSocket tunnel. - testMsg := []byte("hello websocket") - if _, err := tlsConn.Write(testMsg); err != nil { - t.Fatalf("write message: %v", err) - } - - // Read echoed message back. - echoBuf := make([]byte, len(testMsg)) - if _, err := io.ReadFull(tlsBr, echoBuf); err != nil { - t.Fatalf("read echo: %v", err) - } - if string(echoBuf) != string(testMsg) { - t.Errorf("echo = %q, want %q", echoBuf, testMsg) - } -} -``` - -- [ ] **Step 2: Run test to verify it fails against current code** - -Run: `go test -run TestIntercept_WebSocketUpgrade -v -count=1 ./proxy/` -Expected: FAIL — the current code will either hang or error with "malformed HTTP request" after the 101. - -- [ ] **Step 3: Commit the failing test** - -```bash -git add proxy/intercept_test.go -git commit -m "test(proxy): add WebSocket upgrade test (expected to fail)" -``` - ---- - -## Task 2: Refactor handleConnectWithInterception to use ReverseProxy - -This is the core change. Replace lines 1812-2177 (the manual `for` loop) with an `http.Server` + `httputil.ReverseProxy`. - -**Files:** -- Modify: `proxy/proxy.go` (lines 1749-2178) - -- [ ] **Step 1: Replace the request loop with http.Server + ReverseProxy** - -Replace the code from line 1812 (`clientReader := bufio.NewReader(tlsClientConn)`) through line 2177 (closing `}` of the for loop) with: - -```go - // Create a reverse proxy that handles request forwarding, including - // WebSocket upgrades via the stdlib's built-in protocol switch support. - reverseProxy := &httputil.ReverseProxy{ - Rewrite: func(pr *httputil.ProxyRequest) { - // Preserve the original Proxy-Authorization from In before - // ReverseProxy strips hop-by-hop headers. - // token-exchange subject_from: proxy-auth needs this. - proxyAuth := pr.In.Header.Get("Proxy-Authorization") - - pr.Out.URL.Scheme = "https" - connectHost := r.Host - if rc := getRunContext(r); rc != nil && rc.HostGatewayIP != "" && isHostGateway(rc, host) { - connectHost = rewriteHostPort(r.Host, rc.HostGatewayIP) - } - pr.Out.URL.Host = connectHost - pr.Out.Host = pr.In.Host - - // Restore Proxy-Authorization so credential resolver can read it. - if proxyAuth != "" { - pr.Out.Header.Set("Proxy-Authorization", proxyAuth) - } - - // MCP credential injection. - p.injectMCPCredentialsWithContext(r, pr.Out) - - // Credential injection. - creds, credErr := p.getCredentialsForRequest(r, pr.Out, host) - if credErr != nil { - // Store error in context for ErrorHandler to pick up. - *pr.Out = *pr.Out.WithContext(context.WithValue(pr.Out.Context(), interceptCredErrKey{}, credErr)) - return - } - credResult := injectCredentials(pr.Out, creds, host, pr.Out.Method, pr.Out.URL.Path) - - // Store credential result in context for ModifyResponse/logging. - ctx := pr.Out.Context() - ctx = context.WithValue(ctx, interceptCredResultKey{}, credResult) - *pr.Out = *pr.Out.WithContext(ctx) - - // Extra headers. - mergeExtraHeaders(pr.Out, r.Host, p.getExtraHeadersForRequest(r, r.Host)) - - // Strip proxy headers. - pr.Out.Header.Del("Proxy-Connection") - pr.Out.Header.Del("Proxy-Authorization") - - // Remove configured headers (but not injected credential headers). - for _, headerName := range p.getRemoveHeadersForRequest(r, host) { - if credResult.InjectedHeaders[strings.ToLower(headerName)] { - continue - } - pr.Out.Header.Del(headerName) - } - - // Token substitution. - if sub := p.getTokenSubstitutionForRequest(r, host); sub != nil { - p.applyTokenSubstitution(pr.Out, sub) - } - - // Request ID. - if pr.Out.Header.Get("X-Request-Id") == "" { - pr.Out.Header.Set("X-Request-Id", newRequestID()) - } - }, - Transport: transport, - ModifyResponse: func(resp *http.Response) error { - req := resp.Request - - // LLM gateway policy evaluation (Anthropic API only). - if resp.StatusCode == http.StatusOK && host == "api.anthropic.com" { - if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { - if eng, ok := rc.KeepEngines["llm-gateway"]; ok { - p.evaluateAndReplaceLLMResponse(r, req, resp, eng) - } - } - } - - // Response transformers. - if transformers := p.getResponseTransformersForRequest(r, host); len(transformers) > 0 { - for _, transformer := range transformers { - if newRespInterface, transformed := transformer(req, resp); transformed { - if newResp, ok := newRespInterface.(*http.Response); ok { - *resp = *newResp - } - break - } - } - } - - // Canonical log line. - credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) - var respBody []byte - respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) - var reqBody []byte - // Request body was already consumed by the transport; capture from context if available. - _ = reqBody - _ = respBody - - p.logRequest(r, RequestLogData{ - RequestID: req.Header.Get("X-Request-Id"), - Method: req.Method, - URL: req.URL.String(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: resp.StatusCode, - Duration: time.Since(reqStartFromContext(req.Context())), - RequestHeaders: req.Header.Clone(), - ResponseHeaders: resp.Header.Clone(), - ResponseBody: respBody, - RequestSize: req.ContentLength, - ResponseSize: resp.ContentLength, - AuthInjected: len(credResult.InjectedHeaders) > 0, - InjectedHeaders: credResult.InjectedHeaders, - Grants: credResult.Grants, - }) - - return nil - }, - ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - // Check for credential resolution error from Rewrite. - if credErr, ok := req.Context().Value(interceptCredErrKey{}).(error); ok { - http.Error(rw, "credential resolution failed\n", http.StatusBadGateway) - p.logRequest(r, RequestLogData{ - RequestID: req.Header.Get("X-Request-Id"), - Method: req.Method, - URL: req.URL.String(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusBadGateway, - Err: credErr, - }) - return - } - - rw.WriteHeader(http.StatusBadGateway) - credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) - p.logRequest(r, RequestLogData{ - RequestID: req.Header.Get("X-Request-Id"), - Method: req.Method, - URL: req.URL.String(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusBadGateway, - Err: err, - AuthInjected: len(credResult.InjectedHeaders) > 0, - InjectedHeaders: credResult.InjectedHeaders, - Grants: credResult.Grants, - }) - }, - } - - // Wrapping handler: policy checks before ReverseProxy. - handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // Network policy check. - if !p.checkNetworkPolicyForRequest(r, host, connectPort, req.Method, req.URL.Path) { - innerReqID := req.Header.Get("X-Request-Id") - if innerReqID == "" { - innerReqID = newRequestID() - } - p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusProxyAuthRequired, - RequestSize: req.ContentLength, - ResponseSize: -1, - Denied: true, - DenyReason: "Request blocked by network policy: " + req.Method + " " + host + req.URL.Path, - }) - p.logPolicy(r, "network", "http.request", "", req.Method+" "+host+req.URL.Path) - w.Header().Set("X-Moat-Blocked", "request-rule") - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusProxyAuthRequired) - fmt.Fprintf(w, "Moat: request blocked by network policy.\nHost: %s\nTo allow this request, update network.rules in moat.yaml.\n", host) - return - } - - // Keep HTTP policy check. - if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { - if eng, ok := rc.KeepEngines["http"]; ok { - call := keeplib.NewHTTPCall(req.Method, host, req.URL.Path) - call.Context.Scope = "http-" + host - result, evalErr := keeplib.SafeEvaluate(eng, call, "http") - if evalErr != nil { - innerReqID := req.Header.Get("X-Request-Id") - if innerReqID == "" { - innerReqID = newRequestID() - } - p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusForbidden, - RequestSize: req.ContentLength, - ResponseSize: -1, - Denied: true, - DenyReason: "Keep policy evaluation error", - Err: evalErr, - }) - p.logPolicy(r, "http", "http.request", "evaluation-error", "Policy evaluation failed") - w.Header().Set("X-Moat-Blocked", "keep-policy") - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "Moat: request blocked — policy evaluation error.\nHost: %s\n", host) - return - } - if result.Decision == keeplib.Deny { - innerReqID := req.Header.Get("X-Request-Id") - if innerReqID == "" { - innerReqID = newRequestID() - } - p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusForbidden, - RequestSize: req.ContentLength, - ResponseSize: -1, - Denied: true, - DenyReason: "Keep policy denied: " + result.Rule + " " + result.Message, - }) - p.logPolicy(r, "http", "http.request", result.Rule, result.Message) - w.Header().Set("X-Moat-Blocked", "keep-policy") - w.Header().Set("Content-Type", "text/plain") - w.WriteHeader(http.StatusForbidden) - msg := fmt.Sprintf("Moat: request blocked by Keep policy.\nHost: %s\n", host) - if result.Message != "" { - msg += result.Message + "\n" - } - fmt.Fprint(w, msg) - return - } - } - } - - // Store request start time in context for duration calculation. - ctx := context.WithValue(req.Context(), interceptReqStartKey{}, time.Now()) - reverseProxy.ServeHTTP(w, req.WithContext(ctx)) - }) - - // Serve on a single-connection listener wrapping the TLS connection. - srv := &http.Server{ - Handler: handler, - IdleTimeout: 120 * time.Second, - ErrorLog: log.New(io.Discard, "", 0), // Suppress server-level errors (we handle them in ErrorHandler). - } - srv.Serve(newSingleConnListener(tlsClientConn)) -``` - -This requires several supporting types. Add before the function: - -```go -// Context keys for passing data between ReverseProxy hooks. -type interceptCredResultKey struct{} -type interceptCredErrKey struct{} -type interceptReqStartKey struct{} - -func reqStartFromContext(ctx context.Context) time.Time { - if t, ok := ctx.Value(interceptReqStartKey{}).(time.Time); ok { - return t - } - return time.Now() -} - -// singleConnListener wraps a single net.Conn as a net.Listener. -// Accept returns the connection once, then blocks until Close is called. -type singleConnListener struct { - conn net.Conn - once sync.Once - ch chan net.Conn -} - -func newSingleConnListener(conn net.Conn) *singleConnListener { - ch := make(chan net.Conn, 1) - ch <- conn - return &singleConnListener{conn: conn, ch: ch} -} - -func (l *singleConnListener) Accept() (net.Conn, error) { - conn, ok := <-l.ch - if !ok { - return nil, io.EOF - } - return conn, nil -} - -func (l *singleConnListener) Close() error { - l.once.Do(func() { close(l.ch) }) - return nil -} - -func (l *singleConnListener) Addr() net.Addr { - return l.conn.LocalAddr() -} -``` - -Also add `evaluateAndReplaceLLMResponse` as a method that encapsulates the LLM policy logic currently inline in the loop (lines 2024-2106). This keeps ModifyResponse readable: - -```go -// evaluateAndReplaceLLMResponse evaluates LLM gateway policy and replaces -// the response in-place if denied. Called from ModifyResponse. -func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Request, resp *http.Response, eng *keeplib.Engine) { - respBodyBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, maxLLMResponseSize+1)) - resp.Body.Close() - if readErr != nil { - p.logPolicy(ctxReq, "llm-gateway", "llm.read_error", "read-error", "Failed to read response body for policy evaluation") - errorBody := buildPolicyDeniedResponse("read-error", "Failed to read response body for policy evaluation.") - resp.StatusCode = http.StatusBadRequest - resp.Header = make(http.Header) - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - resp.ContentLength = int64(len(errorBody)) - resp.Body = io.NopCloser(bytes.NewReader(errorBody)) - return - } - if int64(len(respBodyBytes)) > maxLLMResponseSize { - p.logPolicy(ctxReq, "llm-gateway", "llm.response_too_large", "size-limit", "Response too large for policy evaluation") - errorBody := buildPolicyDeniedResponse("size-limit", "Response too large for policy evaluation.") - resp.StatusCode = http.StatusBadRequest - resp.Header = make(http.Header) - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - resp.ContentLength = int64(len(errorBody)) - resp.Body = io.NopCloser(bytes.NewReader(errorBody)) - return - } - result := evaluateLLMResponse(eng, respBodyBytes, resp) - if result.Denied { - p.logPolicy(ctxReq, "llm-gateway", "llm.tool_use", result.Rule, result.Message) - errorBody := buildPolicyDeniedResponse(result.Rule, result.Message) - resp.StatusCode = http.StatusBadRequest - resp.Header = make(http.Header) - resp.Header.Set("Content-Type", "application/json") - resp.Header.Set("X-Moat-Blocked", "llm-policy") - resp.ContentLength = int64(len(errorBody)) - resp.Body = io.NopCloser(bytes.NewReader(errorBody)) - } else if result.Events != nil { - var buf bytes.Buffer - for _, ev := range result.Events { - if ev.ID != "" { - fmt.Fprintf(&buf, "id: %s\n", ev.ID) - } - if ev.Type != "" { - fmt.Fprintf(&buf, "event: %s\n", ev.Type) - } - lines := strings.Split(ev.Data, "\n") - for _, line := range lines { - fmt.Fprintf(&buf, "data: %s\n", line) - } - buf.WriteByte('\n') - } - resp.Header.Del("Content-Encoding") - resp.Body = io.NopCloser(&buf) - resp.ContentLength = int64(buf.Len()) - } else { - resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes)) - resp.ContentLength = int64(len(respBodyBytes)) - } -} -``` - -- [ ] **Step 2: Add required imports** - -Add to the import block in `proxy/proxy.go`: -- `"log"` (for `log.New` in http.Server ErrorLog) -- `"net/http/httputil"` (for ReverseProxy) - -- [ ] **Step 3: Verify compilation** - -Run: `go build ./proxy/` -Expected: compiles cleanly - -- [ ] **Step 4: Run the full test suite** - -Run: `go test -count=1 ./proxy/` -Expected: All existing `TestIntercept_*` and `TestProxy_*` tests pass - -- [ ] **Step 5: Run the WebSocket test** - -Run: `go test -run TestIntercept_WebSocketUpgrade -v -count=1 ./proxy/` -Expected: PASS - -- [ ] **Step 6: Run go vet** - -Run: `go vet ./...` -Expected: clean - -- [ ] **Step 7: Commit** - -```bash -git add proxy/proxy.go -git commit -m "feat(proxy): replace interception loop with ReverseProxy for WebSocket support - -Replace the manual for { http.ReadRequest → transport.RoundTrip → resp.Write } -loop in handleConnectWithInterception with http.Server + httputil.ReverseProxy. - -ReverseProxy natively handles WebSocket upgrades (101 Switching Protocols) -by hijacking both sides and doing bidirectional io.Copy. - -All existing behaviors preserved: credential injection, network policy, -Keep policy, LLM gateway policy, response transformers, canonical log -lines, X-Request-Id, extra/remove headers, token substitution, host -gateway rewrite." -``` - ---- - -## Task 3: Full verification - -- [ ] **Step 1: Run the complete test suite with race detector** - -Run: `go test -race -count=1 ./...` -Expected: All tests pass, no data races - -- [ ] **Step 2: Run go vet** - -Run: `go vet ./...` -Expected: clean - -- [ ] **Step 3: Clean up any dead code** - -Remove the `bufio` import from proxy.go if no longer used (the manual `bufio.NewReader` loop is gone). Check for any other dead code. - -- [ ] **Step 4: Final commit if cleanup needed** - -```bash -git add -A -git commit -m "refactor(proxy): remove dead code from interception loop replacement" -``` From d722d4b9fb2890dbcbc0a7452a8dc66c086920b0 Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 19:50:59 +0000 Subject: [PATCH 5/8] fix(proxy): address code review feedback on ReverseProxy refactor 1. Fix goroutine/FD leak: singleConnListener now uses ConnState to detect StateClosed/StateHijacked and close the listener, allowing Serve to exit. For WebSocket (hijacked), defer skips closing the underlying connections since ReverseProxy owns them. 2. Fix credential error silently dropped: moved getCredentialsForRequest and injectMCPCredentialsWithContext to the wrapping handler (before ReverseProxy.ServeHTTP) so errors get an early 502 return. Resolved credentials are passed to Rewrite via context. 3. Fix token-substituted URLs in logs: capture pre-substitution URL in Rewrite via interceptLogURLKey context key; ModifyResponse and ErrorHandler use it for canonical log lines instead of req.URL which contains real tokens after substitution. 4. Fix zero Duration in policy denial logs: moved time.Now() to the top of the wrapping handler, before any policy checks. --- proxy/proxy.go | 151 +++++++++++++++++++++++++++++++------------------ 1 file changed, 95 insertions(+), 56 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index 87e719d..d9f81eb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -46,6 +46,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" keeplib "github.com/majorcontext/keep" @@ -1749,8 +1750,9 @@ func (p *Proxy) handleConnectTunnel(w http.ResponseWriter, r *http.Request) { // Context keys for passing data between ReverseProxy hooks in the interception path. type interceptCredResultKey struct{} -type interceptCredErrKey struct{} +type interceptCredsKey struct{} type interceptReqStartKey struct{} +type interceptLogURLKey struct{} func reqStartFromContext(ctx context.Context) time.Time { if t, ok := ctx.Value(interceptReqStartKey{}).(time.Time); ok { @@ -1761,28 +1763,34 @@ func reqStartFromContext(ctx context.Context) time.Time { // singleConnListener wraps a single net.Conn as a net.Listener. // Accept returns the connection once, then blocks until Close is called. +// This keeps http.Server.Serve alive for the lifetime of the connection. type singleConnListener struct { - conn net.Conn - once sync.Once - ch chan net.Conn + conn net.Conn + connCh chan net.Conn + closeCh chan struct{} } func newSingleConnListener(conn net.Conn) *singleConnListener { ch := make(chan net.Conn, 1) ch <- conn - return &singleConnListener{conn: conn, ch: ch} + return &singleConnListener{conn: conn, connCh: ch, closeCh: make(chan struct{})} } func (l *singleConnListener) Accept() (net.Conn, error) { - conn, ok := <-l.ch - if !ok { - return nil, io.EOF + select { + case conn := <-l.connCh: + return conn, nil + case <-l.closeCh: + return nil, net.ErrClosed } - return conn, nil } func (l *singleConnListener) Close() error { - l.once.Do(func() { close(l.ch) }) + select { + case <-l.closeCh: + default: + close(l.closeCh) + } return nil } @@ -1863,7 +1871,16 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req http.Error(w, err.Error(), http.StatusInternalServerError) return } - defer clientConn.Close() + + // Track whether the inner http.Server's connection was hijacked + // (e.g., for WebSocket upgrade). If hijacked, ReverseProxy owns the + // TLS conn and will close it; we must not close clientConn ourselves. + var hijacked atomic.Bool + defer func() { + if !hijacked.Load() { + clientConn.Close() + } + }() _, _ = clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) @@ -1884,7 +1901,11 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req "subsystem", "proxy", "host", host, "error", err) return } - defer tlsClientConn.Close() + defer func() { + if !hijacked.Load() { + tlsClientConn.Close() + } + }() transport := &http.Transport{ Proxy: nil, @@ -1918,11 +1939,6 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req // WebSocket upgrades via the stdlib's built-in protocol switch support. reverseProxy := &httputil.ReverseProxy{ Rewrite: func(pr *httputil.ProxyRequest) { - // Preserve the original Proxy-Authorization from In before - // ReverseProxy strips hop-by-hop headers. - // token-exchange subject_from: proxy-auth needs this. - proxyAuth := pr.In.Header.Get("Proxy-Authorization") - pr.Out.URL.Scheme = "https" connectHost := r.Host if rc := getRunContext(r); rc != nil && rc.HostGatewayIP != "" && isHostGateway(rc, host) { @@ -1932,21 +1948,8 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req pr.Out.Host = pr.In.Host pr.Out.RequestURI = "" - // Restore Proxy-Authorization so credential resolver can read it. - if proxyAuth != "" { - pr.Out.Header.Set("Proxy-Authorization", proxyAuth) - } - - // MCP credential injection. - p.injectMCPCredentialsWithContext(r, pr.Out) - - // Credential injection. - creds, credErr := p.getCredentialsForRequest(r, pr.Out, host) - if credErr != nil { - // Store error in context for ErrorHandler to pick up. - *pr.Out = *pr.Out.WithContext(context.WithValue(pr.Out.Context(), interceptCredErrKey{}, credErr)) - return - } + // Credentials were resolved in the wrapping handler and passed via context. + creds, _ := pr.Out.Context().Value(interceptCredsKey{}).([]credentialHeader) credResult := injectCredentials(pr.Out, creds, host, pr.Out.Method, pr.Out.URL.Path) // Store credential result in context for ModifyResponse/logging. @@ -1968,6 +1971,11 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req pr.Out.Header.Del(headerName) } + // Capture URL before token substitution so logs don't contain real tokens. + logURL := pr.Out.URL.String() + ctx = context.WithValue(pr.Out.Context(), interceptLogURLKey{}, logURL) + *pr.Out = *pr.Out.WithContext(ctx) + // Token substitution. if sub := p.getTokenSubstitutionForRequest(r, host); sub != nil { p.applyTokenSubstitution(pr.Out, sub) @@ -2005,13 +2013,18 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req // Canonical log line. credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + // Use pre-substitution URL so logs don't contain real tokens. + logURL, _ := req.Context().Value(interceptLogURLKey{}).(string) + if logURL == "" { + logURL = req.URL.String() + } var respBody []byte respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) p.logRequest(r, RequestLogData{ RequestID: req.Header.Get("X-Request-Id"), Method: req.Method, - URL: req.URL.String(), + URL: logURL, Host: host, Path: req.URL.Path, RequestType: "connect", @@ -2030,34 +2043,21 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req return nil }, ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { - // Check for credential resolution error from Rewrite. - if credErr, ok := req.Context().Value(interceptCredErrKey{}).(error); ok { - rw.Header().Set("Content-Type", "text/plain") - rw.WriteHeader(http.StatusBadGateway) - fmt.Fprint(rw, "credential resolution failed\n") - p.logRequest(r, RequestLogData{ - RequestID: req.Header.Get("X-Request-Id"), - Method: req.Method, - URL: req.URL.String(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusBadGateway, - Err: credErr, - }) - return - } - rw.WriteHeader(http.StatusBadGateway) credResult, _ := req.Context().Value(interceptCredResultKey{}).(credentialInjectionResult) + logURL, _ := req.Context().Value(interceptLogURLKey{}).(string) + if logURL == "" { + logURL = req.URL.String() + } p.logRequest(r, RequestLogData{ RequestID: req.Header.Get("X-Request-Id"), Method: req.Method, - URL: req.URL.String(), + URL: logURL, Host: host, Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusBadGateway, + Duration: time.Since(reqStartFromContext(req.Context())), Err: err, AuthInjected: len(credResult.InjectedHeaders) > 0, InjectedHeaders: credResult.InjectedHeaders, @@ -2066,8 +2066,10 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req }, } - // Wrapping handler: policy checks before ReverseProxy. + // Wrapping handler: policy checks and credential resolution before ReverseProxy. handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + reqStart := time.Now() + innerReqID := req.Header.Get("X-Request-Id") if innerReqID == "" { innerReqID = newRequestID() @@ -2083,6 +2085,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusProxyAuthRequired, + Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, @@ -2111,6 +2114,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusForbidden, + Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, @@ -2133,6 +2137,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req Path: req.URL.Path, RequestType: "connect", StatusCode: http.StatusForbidden, + Duration: time.Since(reqStart), RequestSize: req.ContentLength, ResponseSize: -1, Denied: true, @@ -2152,16 +2157,50 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req } } - // Store request start time in context for duration calculation. - ctx := context.WithValue(req.Context(), interceptReqStartKey{}, time.Now()) + // MCP credential injection. + p.injectMCPCredentialsWithContext(r, req) + + // Resolve credentials before forwarding so errors are caught early. + creds, credErr := p.getCredentialsForRequest(r, req, host) + if credErr != nil { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusBadGateway) + fmt.Fprint(w, "credential resolution failed\n") + p.logRequest(r, RequestLogData{ + RequestID: innerReqID, + Method: req.Method, + URL: "https://" + r.Host + req.URL.Path, + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Duration: time.Since(reqStart), + Err: credErr, + }) + return + } + + // Pass resolved credentials and start time to Rewrite via context. + ctx := req.Context() + ctx = context.WithValue(ctx, interceptReqStartKey{}, reqStart) + ctx = context.WithValue(ctx, interceptCredsKey{}, creds) reverseProxy.ServeHTTP(w, req.WithContext(ctx)) }) // Serve on a single-connection listener wrapping the TLS connection. + ln := newSingleConnListener(tlsClientConn) srv := &http.Server{ Handler: handler, IdleTimeout: 120 * time.Second, ErrorLog: log.New(io.Discard, "", 0), // Suppress server-level errors; handled in ErrorHandler. + ConnState: func(conn net.Conn, state http.ConnState) { + if state == http.StateHijacked { + hijacked.Store(true) + } + if state == http.StateClosed || state == http.StateHijacked { + ln.Close() + } + }, } - _ = srv.Serve(newSingleConnListener(tlsClientConn)) + _ = srv.Serve(ln) } From c990bc6faad1ea002d0d13ac0137d8bc805eae97 Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:00:10 +0000 Subject: [PATCH 6/8] fix(proxy): address third round of review feedback 1. Security: snapshot pre-injection headers in Rewrite so credential values don't appear in canonical log RequestHeaders. Uses interceptPreInjHeadersKey context key. 2. Bug: evaluateAndReplaceLLMResponse now returns (denied, reason) so ModifyResponse can set Denied/DenyReason in the canonical log line. Restores parity with the old loop's llmDenied/llmDenyReason tracking. 3. Bug: capture RequestBody in wrapping handler via captureBody before ReverseProxy consumes it. Passed to ModifyResponse via interceptReqBodyKey context. 4. Minor: denial/error log URLs now use req.URL.RequestURI() instead of req.URL.Path to preserve query parameters. --- proxy/proxy.go | 56 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index d9f81eb..27f9a91 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1753,6 +1753,8 @@ type interceptCredResultKey struct{} type interceptCredsKey struct{} type interceptReqStartKey struct{} type interceptLogURLKey struct{} +type interceptPreInjHeadersKey struct{} +type interceptReqBodyKey struct{} func reqStartFromContext(ctx context.Context) time.Time { if t, ok := ctx.Value(interceptReqStartKey{}).(time.Time); ok { @@ -1799,8 +1801,8 @@ func (l *singleConnListener) Addr() net.Addr { } // evaluateAndReplaceLLMResponse evaluates LLM gateway policy and replaces -// the response in-place if denied. Called from ModifyResponse. -func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Request, resp *http.Response, eng *keeplib.Engine) { +// the response in-place if denied. Returns whether a denial occurred and the reason. +func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Request, resp *http.Response, eng *keeplib.Engine) (denied bool, reason string) { respBodyBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, maxLLMResponseSize+1)) resp.Body.Close() if readErr != nil { @@ -1812,7 +1814,7 @@ func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Re resp.Header.Set("X-Moat-Blocked", "llm-policy") resp.ContentLength = int64(len(errorBody)) resp.Body = io.NopCloser(bytes.NewReader(errorBody)) - return + return true, "LLM policy read error" } if int64(len(respBodyBytes)) > maxLLMResponseSize { p.logPolicy(ctxReq, "llm-gateway", "llm.response_too_large", "size-limit", "Response too large for policy evaluation") @@ -1823,7 +1825,7 @@ func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Re resp.Header.Set("X-Moat-Blocked", "llm-policy") resp.ContentLength = int64(len(errorBody)) resp.Body = io.NopCloser(bytes.NewReader(errorBody)) - return + return true, "LLM policy response too large" } result := evaluateLLMResponse(eng, respBodyBytes, resp) if result.Denied { @@ -1835,6 +1837,7 @@ func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Re resp.Header.Set("X-Moat-Blocked", "llm-policy") resp.ContentLength = int64(len(errorBody)) resp.Body = io.NopCloser(bytes.NewReader(errorBody)) + return true, "LLM policy denied: " + result.Rule + " " + result.Message } else if result.Events != nil { var buf bytes.Buffer for _, ev := range result.Events { @@ -1857,6 +1860,7 @@ func (p *Proxy) evaluateAndReplaceLLMResponse(ctxReq *http.Request, req *http.Re resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes)) resp.ContentLength = int64(len(respBodyBytes)) } + return false, "" } func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Request, host string) { @@ -1950,10 +1954,17 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req // Credentials were resolved in the wrapping handler and passed via context. creds, _ := pr.Out.Context().Value(interceptCredsKey{}).([]credentialHeader) + + // Snapshot headers before credential injection so logs don't + // contain raw credential values (CLAUDE.md: never log credential values). + preInjectionHeaders := pr.Out.Header.Clone() + credResult := injectCredentials(pr.Out, creds, host, pr.Out.Method, pr.Out.URL.Path) - // Store credential result in context for ModifyResponse/logging. - ctx := context.WithValue(pr.Out.Context(), interceptCredResultKey{}, credResult) + // Store credential result and pre-injection headers in context. + ctx := pr.Out.Context() + ctx = context.WithValue(ctx, interceptCredResultKey{}, credResult) + ctx = context.WithValue(ctx, interceptPreInjHeadersKey{}, preInjectionHeaders) *pr.Out = *pr.Out.WithContext(ctx) // Extra headers. @@ -1990,11 +2001,15 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req ModifyResponse: func(resp *http.Response) error { req := resp.Request + // Track LLM policy denials for the canonical log line. + var llmDenied bool + var llmDenyReason string + // LLM gateway policy evaluation (Anthropic API only). if resp.StatusCode == http.StatusOK && host == "api.anthropic.com" { if rc := getRunContext(r); rc != nil && rc.KeepEngines != nil { if eng, ok := rc.KeepEngines["llm-gateway"]; ok { - p.evaluateAndReplaceLLMResponse(r, req, resp, eng) + llmDenied, llmDenyReason = p.evaluateAndReplaceLLMResponse(r, req, resp, eng) } } } @@ -2021,6 +2036,13 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req var respBody []byte respBody, resp.Body = captureBody(resp.Body, resp.Header.Get("Content-Type")) + // Use pre-injection headers so credential values don't appear in logs. + preHeaders, _ := req.Context().Value(interceptPreInjHeadersKey{}).(http.Header) + if preHeaders == nil { + preHeaders = req.Header.Clone() + } + reqBody, _ := req.Context().Value(interceptReqBodyKey{}).([]byte) + p.logRequest(r, RequestLogData{ RequestID: req.Header.Get("X-Request-Id"), Method: req.Method, @@ -2030,14 +2052,17 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req RequestType: "connect", StatusCode: resp.StatusCode, Duration: time.Since(reqStartFromContext(req.Context())), - RequestHeaders: req.Header.Clone(), + RequestHeaders: preHeaders, ResponseHeaders: resp.Header.Clone(), + RequestBody: reqBody, ResponseBody: respBody, RequestSize: req.ContentLength, ResponseSize: resp.ContentLength, AuthInjected: len(credResult.InjectedHeaders) > 0, InjectedHeaders: credResult.InjectedHeaders, Grants: credResult.Grants, + Denied: llmDenied, + DenyReason: llmDenyReason, }) return nil @@ -2080,7 +2105,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, + URL: "https://" + r.Host + req.URL.RequestURI(), Host: host, Path: req.URL.Path, RequestType: "connect", @@ -2109,7 +2134,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, + URL: "https://" + r.Host + req.URL.RequestURI(), Host: host, Path: req.URL.Path, RequestType: "connect", @@ -2132,7 +2157,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, + URL: "https://" + r.Host + req.URL.RequestURI(), Host: host, Path: req.URL.Path, RequestType: "connect", @@ -2169,7 +2194,7 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req p.logRequest(r, RequestLogData{ RequestID: innerReqID, Method: req.Method, - URL: "https://" + r.Host + req.URL.Path, + URL: "https://" + r.Host + req.URL.RequestURI(), Host: host, Path: req.URL.Path, RequestType: "connect", @@ -2180,10 +2205,15 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req return } - // Pass resolved credentials and start time to Rewrite via context. + // Capture request body for logging before ReverseProxy consumes it. + var reqBody []byte + reqBody, req.Body = captureBody(req.Body, req.Header.Get("Content-Type")) + + // Pass resolved credentials, start time, and captured body to Rewrite via context. ctx := req.Context() ctx = context.WithValue(ctx, interceptReqStartKey{}, reqStart) ctx = context.WithValue(ctx, interceptCredsKey{}, creds) + ctx = context.WithValue(ctx, interceptReqBodyKey{}, reqBody) reverseProxy.ServeHTTP(w, req.WithContext(ctx)) }) From 12737cfe7d301debd9c1d6544361e24374b3bfec Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:08:54 +0000 Subject: [PATCH 7/8] fix(proxy): add credential leak regression test, fix cred error log fields - Add assertion in TestIntercept_CredentialInjectionCanonicalLog that logged RequestHeaders does NOT contain the injected Authorization value - Add RequestSize/ResponseSize to credential error log path for consistency with other early-return paths --- proxy/intercept_test.go | 4 ++++ proxy/proxy.go | 20 +++++++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/proxy/intercept_test.go b/proxy/intercept_test.go index 2d0cf61..5697eae 100644 --- a/proxy/intercept_test.go +++ b/proxy/intercept_test.go @@ -129,6 +129,10 @@ func TestIntercept_CredentialInjectionCanonicalLog(t *testing.T) { if logged.RequestID == "" { t.Error("expected non-empty RequestID") } + // Verify credential value is NOT present in logged request headers. + if v := logged.RequestHeaders.Get("Authorization"); v != "" { + t.Errorf("logged RequestHeaders contains injected Authorization %q; credential values must not appear in logs", v) + } } func TestIntercept_MultiRequestKeepalive(t *testing.T) { diff --git a/proxy/proxy.go b/proxy/proxy.go index 27f9a91..8963226 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2192,15 +2192,17 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req w.WriteHeader(http.StatusBadGateway) fmt.Fprint(w, "credential resolution failed\n") p.logRequest(r, RequestLogData{ - RequestID: innerReqID, - Method: req.Method, - URL: "https://" + r.Host + req.URL.RequestURI(), - Host: host, - Path: req.URL.Path, - RequestType: "connect", - StatusCode: http.StatusBadGateway, - Duration: time.Since(reqStart), - Err: credErr, + RequestID: innerReqID, + Method: req.Method, + URL: "https://" + r.Host + req.URL.RequestURI(), + Host: host, + Path: req.URL.Path, + RequestType: "connect", + StatusCode: http.StatusBadGateway, + Duration: time.Since(reqStart), + RequestSize: req.ContentLength, + ResponseSize: -1, + Err: credErr, }) return } From 9e3db0b2cf6ca63391aab773cc85eb772a44151a Mon Sep 17 00:00:00 2001 From: Andy Bonventre <365204+andybons@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:17:32 +0000 Subject: [PATCH 8/8] fix(proxy): propagate request ID, add missing log fields in ErrorHandler - Set X-Request-Id on req.Header in wrapping handler before calling ReverseProxy so Rewrite preserves the same ID used in denial logs - Add RequestSize/ResponseSize to ErrorHandler log for consistency with other early-return paths --- proxy/proxy.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/proxy/proxy.go b/proxy/proxy.go index 8963226..17c768e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -2083,6 +2083,8 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req RequestType: "connect", StatusCode: http.StatusBadGateway, Duration: time.Since(reqStartFromContext(req.Context())), + RequestSize: req.ContentLength, + ResponseSize: -1, Err: err, AuthInjected: len(credResult.InjectedHeaders) > 0, InjectedHeaders: credResult.InjectedHeaders, @@ -2211,6 +2213,9 @@ func (p *Proxy) handleConnectWithInterception(w http.ResponseWriter, r *http.Req var reqBody []byte reqBody, req.Body = captureBody(req.Body, req.Header.Get("Content-Type")) + // Propagate request ID so Rewrite preserves it (instead of generating a new one). + req.Header.Set("X-Request-Id", innerReqID) + // Pass resolved credentials, start time, and captured body to Rewrite via context. ctx := req.Context() ctx = context.WithValue(ctx, interceptReqStartKey{}, reqStart)