diff --git a/audit/log_auditor.go b/audit/log_auditor.go index 562fe2f..100a58b 100644 --- a/audit/log_auditor.go +++ b/audit/log_auditor.go @@ -20,10 +20,13 @@ func (a *LogAuditor) AuditRequest(req Request) { a.logger.Info("ALLOW", "method", req.Method, "url", req.URL, + "host", req.Host, "rule", req.Rule) } else { a.logger.Warn("DENY", "method", req.Method, - "url", req.URL) + "url", req.URL, + "host", req.Host, + ) } } diff --git a/audit/request.go b/audit/request.go index 54f4c4e..b0d2e06 100644 --- a/audit/request.go +++ b/audit/request.go @@ -8,6 +8,7 @@ type Auditor interface { type Request struct { Method string URL string + Host string Allowed bool Rule string // The rule that matched (if any) } diff --git a/boundary.go b/boundary.go index 4993f32..c60aa61 100644 --- a/boundary.go +++ b/boundary.go @@ -20,6 +20,7 @@ type Config struct { TLSConfig *tls.Config Logger *slog.Logger Jailer jail.Jailer + ProxyPort int } type Boundary struct { @@ -34,7 +35,7 @@ type Boundary struct { func New(ctx context.Context, config Config) (*Boundary, error) { // Create proxy server proxyServer := proxy.NewProxyServer(proxy.Config{ - HTTPPort: 8080, + HTTPPort: config.ProxyPort, RuleEngine: config.RuleEngine, Auditor: config.Auditor, Logger: config.Logger, diff --git a/cli/cli.go b/cli/cli.go index 3935595..bfd6894 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -28,6 +28,7 @@ type Config struct { LogLevel string LogDir string Unprivileged bool + ProxyPort int64 } // NewCommand creates and returns the root serpent command @@ -86,6 +87,13 @@ func BaseCommand() *serpent.Command { Description: "Run in unprivileged mode (no network isolation, uses proxy environment variables).", Value: serpent.BoolOf(&config.Unprivileged), }, + { + Flag: "proxy-port", + Env: "PROXY_PORT", + Description: "Set a port for HTTP proxy.", + Default: "8080", + Value: serpent.Int64Of(&config.ProxyPort), + }, }, Handler: func(inv *serpent.Invocation) error { args := inv.Args @@ -100,15 +108,20 @@ func isChild() bool { // Run executes the boundary command with the given configuration and arguments func Run(ctx context.Context, config Config, args []string) error { + logger, err := setupLogging(config) + if err != nil { + return fmt.Errorf("could not set up logging: %v", err) + } + if isChild() { - log.Printf("boundary CHILD process is started") + logger.Info("boundary CHILD process is started") vethNetJail := os.Getenv("VETH_JAIL_NAME") err := jail.SetupChildNetworking(vethNetJail) if err != nil { return fmt.Errorf("failed to setup child networking: %v", err) } - log.Printf("child networking is successfully configured") + logger.Info("child networking is successfully configured") // Program to run bin := args[0] @@ -130,10 +143,6 @@ func Run(ctx context.Context, config Config, args []string) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - logger, err := setupLogging(config) - if err != nil { - return fmt.Errorf("could not set up logging: %v", err) - } username, uid, gid, homeDir, configDir := util.GetUserInfo() // Get command arguments @@ -180,7 +189,7 @@ func Run(ctx context.Context, config Config, args []string) error { // Create jailer with cert path from TLS setup jailer, err := createJailer(jail.Config{ Logger: logger, - HttpProxyPort: 8080, + HttpProxyPort: int(config.ProxyPort), Username: username, Uid: uid, Gid: gid, @@ -199,6 +208,7 @@ func Run(ctx context.Context, config Config, args []string) error { TLSConfig: tlsConfig, Logger: logger, Jailer: jailer, + ProxyPort: int(config.ProxyPort), }) if err != nil { return fmt.Errorf("failed to create boundary instance: %v", err) diff --git a/proxy/proxy.go b/proxy/proxy.go index ea0bde9..10c6e59 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -187,6 +187,7 @@ func (p *Server) handleHTTPConnection(conn net.Conn) { p.auditor.AuditRequest(audit.Request{ Method: req.Method, URL: req.URL.String(), + Host: req.Host, Allowed: result.Allowed, Rule: result.Rule, }) @@ -237,6 +238,7 @@ func (p *Server) handleTLSConnection(conn net.Conn) { p.auditor.AuditRequest(audit.Request{ Method: req.Method, URL: req.URL.String(), + Host: req.Host, Allowed: result.Allowed, Rule: result.Rule, }) @@ -270,6 +272,23 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { Path: req.URL.Path, RawQuery: req.URL.RawQuery, } + + //var requestBodyBytes []byte + //{ + // var err error + // requestBodyBytes, err = io.ReadAll(req.Body) + // if err != nil { + // p.logger.Error("can't read response body", "error", err) + // return + // } + // err = req.Body.Close() + // if err != nil { + // p.logger.Error("Failed to close HTTP response body", "error", err) + // return + // } + // req.Body = io.NopCloser(bytes.NewBuffer(requestBodyBytes)) + //} + var body = req.Body if req.Method == http.MethodGet || req.Method == http.MethodHead { body = nil @@ -300,6 +319,16 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { p.logger.Debug("🔒 HTTPS Response", "status code", resp.StatusCode, "status", resp.Status) + p.logger.Debug("Forwarded Request", + "method", newReq.Method, + "host", newReq.Host, + //"requestBodyBytes", string(requestBodyBytes), + "URL", newReq.URL, + ) + //for hKey, hVal := range newReq.Header { + // p.logger.Debug("Forwarded Request Header", hKey, hVal) + //} + // Read the body and explicitly set Content-Length header, otherwise client can hung up on the request. bodyBytes, err := io.ReadAll(resp.Body) if err != nil { @@ -315,10 +344,26 @@ func (p *Server) forwardRequest(conn net.Conn, req *http.Request, https bool) { } resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + // The downstream client (Claude) always communicates over HTTP/1.1. + // However, Go's default HTTP client may negotiate an HTTP/2 connection + // with the upstream server via ALPN during TLS handshake. + // This can cause the response's Proto field to be set to "HTTP/2.0", + // which would produce an invalid response for an HTTP/1.1 client. + // To prevent this mismatch, we explicitly normalize the response + // to HTTP/1.1 before writing it back to the client. + resp.Proto = "HTTP/1.1" + resp.ProtoMajor = 1 + resp.ProtoMinor = 1 + // Copy response back to client err = resp.Write(conn) if err != nil { - p.logger.Error("Failed to forward HTTP request", "error", err) + p.logger.Error("Failed to forward back HTTP response", + "error", err, + "host", req.Host, + "method", req.Method, + //"bodyBytes", string(bodyBytes), + ) return }