diff --git a/boundary.go b/boundary.go index c60aa61..f8c4dee 100644 --- a/boundary.go +++ b/boundary.go @@ -21,6 +21,8 @@ type Config struct { Logger *slog.Logger Jailer jail.Jailer ProxyPort int + PprofEnabled bool + PprofPort int } type Boundary struct { @@ -35,11 +37,13 @@ type Boundary struct { func New(ctx context.Context, config Config) (*Boundary, error) { // Create proxy server proxyServer := proxy.NewProxyServer(proxy.Config{ - HTTPPort: config.ProxyPort, - RuleEngine: config.RuleEngine, - Auditor: config.Auditor, - Logger: config.Logger, - TLSConfig: config.TLSConfig, + HTTPPort: config.ProxyPort, + RuleEngine: config.RuleEngine, + Auditor: config.Auditor, + Logger: config.Logger, + TLSConfig: config.TLSConfig, + PprofEnabled: config.PprofEnabled, + PprofPort: config.PprofPort, }) // Create cancellable context for boundary diff --git a/cli/cli.go b/cli/cli.go index bfd6894..e5c4a74 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -29,6 +29,8 @@ type Config struct { LogDir string Unprivileged bool ProxyPort int64 + PprofEnabled bool + PprofPort int64 } // NewCommand creates and returns the root serpent command @@ -94,6 +96,19 @@ func BaseCommand() *serpent.Command { Default: "8080", Value: serpent.Int64Of(&config.ProxyPort), }, + { + Flag: "pprof", + Env: "BOUNDARY_PPROF", + Description: "Enable pprof profiling server.", + Value: serpent.BoolOf(&config.PprofEnabled), + }, + { + Flag: "pprof-port", + Env: "BOUNDARY_PPROF_PORT", + Description: "Set port for pprof profiling server.", + Default: "6060", + Value: serpent.Int64Of(&config.PprofPort), + }, }, Handler: func(inv *serpent.Invocation) error { args := inv.Args @@ -203,12 +218,14 @@ func Run(ctx context.Context, config Config, args []string) error { // Create boundary instance boundaryInstance, err := boundary.New(ctx, boundary.Config{ - RuleEngine: ruleEngine, - Auditor: auditor, - TLSConfig: tlsConfig, - Logger: logger, - Jailer: jailer, - ProxyPort: int(config.ProxyPort), + RuleEngine: ruleEngine, + Auditor: auditor, + TLSConfig: tlsConfig, + Logger: logger, + Jailer: jailer, + ProxyPort: int(config.ProxyPort), + PprofEnabled: config.PprofEnabled, + PprofPort: int(config.PprofPort), }) if err != nil { return fmt.Errorf("failed to create boundary instance: %v", err) diff --git a/proxy/proxy.go b/proxy/proxy.go index 10c6e59..709f3a9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -3,6 +3,7 @@ package proxy import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -10,10 +11,12 @@ import ( "log/slog" "net" "net/http" + _ "net/http/pprof" "net/url" "strconv" "strings" "sync/atomic" + "time" "github.com/coder/boundary/audit" "github.com/coder/boundary/rules" @@ -28,26 +31,33 @@ type Server struct { httpPort int started atomic.Bool - listener net.Listener + listener net.Listener + pprofServer *http.Server + pprofEnabled bool + pprofPort int } // Config holds configuration for the proxy server type Config struct { - HTTPPort int - RuleEngine rules.Evaluator - Auditor audit.Auditor - Logger *slog.Logger - TLSConfig *tls.Config + HTTPPort int + RuleEngine rules.Evaluator + Auditor audit.Auditor + Logger *slog.Logger + TLSConfig *tls.Config + PprofEnabled bool + PprofPort int } // NewProxyServer creates a new proxy server instance func NewProxyServer(config Config) *Server { return &Server{ - ruleEngine: config.RuleEngine, - auditor: config.Auditor, - logger: config.Logger, - tlsConfig: config.TLSConfig, - httpPort: config.HTTPPort, + ruleEngine: config.RuleEngine, + auditor: config.Auditor, + logger: config.Logger, + tlsConfig: config.TLSConfig, + httpPort: config.HTTPPort, + pprofEnabled: config.PprofEnabled, + pprofPort: config.PprofPort, } } @@ -58,6 +68,29 @@ func (p *Server) Start() error { } p.logger.Info("Starting HTTP proxy with TLS termination", "port", p.httpPort) + + // Start pprof server if enabled + if p.pprofEnabled { + p.pprofServer = &http.Server{ + Addr: fmt.Sprintf(":%d", p.pprofPort), + Handler: http.DefaultServeMux, + } + + ln, err := net.Listen("tcp", fmt.Sprintf(":%d", p.pprofPort)) + if err != nil { + p.logger.Error("failed to listen on port for pprof server", "port", p.pprofPort, "error", err) + return fmt.Errorf("failed to listen on port %v for pprof server: %v", p.pprofPort, err) + } + + go func() { + p.logger.Info("Serving pprof on existing listener", "port", p.pprofPort) + if err := p.pprofServer.Serve(ln); err != nil && errors.Is(err, http.ErrServerClosed) { + p.logger.Error("pprof server error", "error", err) + } + }() + + } + var err error p.listener, err = net.Listen("tcp", fmt.Sprintf(":%d", p.httpPort)) if err != nil { @@ -105,6 +138,15 @@ func (p *Server) Stop() error { return err } + // Close pprof server + if p.pprofServer != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := p.pprofServer.Shutdown(ctx); err != nil { + p.logger.Error("Failed to shutdown pprof server", "error", err) + } + } + return nil }