diff --git a/audit/logging_auditor.go b/audit/logging_auditor.go index 31de3d2..f2f81d8 100644 --- a/audit/logging_auditor.go +++ b/audit/logging_auditor.go @@ -15,7 +15,7 @@ func NewLoggingAuditor(logger *slog.Logger) *LoggingAuditor { } // AuditRequest logs the request using structured logging -func (a *LoggingAuditor) AuditRequest(req *Request) { +func (a *LoggingAuditor) AuditRequest(req Request) { if req.Allowed { a.logger.Info("ALLOW", "method", req.Method, diff --git a/audit/logging_auditor_test.go b/audit/logging_auditor_test.go index 3dac8ca..b72651f 100644 --- a/audit/logging_auditor_test.go +++ b/audit/logging_auditor_test.go @@ -11,13 +11,13 @@ import ( func TestLoggingAuditor(t *testing.T) { tests := []struct { name string - request *Request + request Request expectedLevel string expectedFields []string }{ { name: "allow request", - request: &Request{ + request: Request{ Method: "GET", URL: "https://github.com", Allowed: true, @@ -28,7 +28,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "deny request", - request: &Request{ + request: Request{ Method: "POST", URL: "https://example.com", Allowed: false, @@ -38,7 +38,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "allow with empty rule", - request: &Request{ + request: Request{ Method: "PUT", URL: "https://api.github.com/repos", Allowed: true, @@ -49,7 +49,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "deny HTTPS request", - request: &Request{ + request: Request{ Method: "GET", URL: "https://malware.bad.com/payload", Allowed: false, @@ -59,7 +59,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "allow with wildcard rule", - request: &Request{ + request: Request{ Method: "POST", URL: "https://api.github.com/graphql", Allowed: true, @@ -70,7 +70,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "deny HTTP request", - request: &Request{ + request: Request{ Method: "GET", URL: "http://insecure.example.com", Allowed: false, @@ -80,7 +80,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "allow HEAD request", - request: &Request{ + request: Request{ Method: "HEAD", URL: "https://cdn.jsdelivr.net/health", Allowed: true, @@ -91,7 +91,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "deny OPTIONS request", - request: &Request{ + request: Request{ Method: "OPTIONS", URL: "https://restricted.api.com/cors", Allowed: false, @@ -101,7 +101,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "allow with port number", - request: &Request{ + request: Request{ Method: "GET", URL: "https://localhost:3000/api/health", Allowed: true, @@ -112,7 +112,7 @@ func TestLoggingAuditor(t *testing.T) { }, { name: "deny DELETE request", - request: &Request{ + request: Request{ Method: "DELETE", URL: "https://api.production.com/users/admin", Allowed: false, @@ -153,13 +153,13 @@ func TestLoggingAuditor(t *testing.T) { func TestLoggingAuditor_EdgeCases(t *testing.T) { tests := []struct { name string - request *Request + request Request expectedLevel string expectedFields []string }{ { name: "empty fields", - request: &Request{ + request: Request{ Method: "", URL: "", Allowed: true, @@ -170,7 +170,7 @@ func TestLoggingAuditor_EdgeCases(t *testing.T) { }, { name: "special characters in URL", - request: &Request{ + request: Request{ Method: "POST", URL: "https://api.example.com/users?name=John%20Doe&id=123", Allowed: true, @@ -181,7 +181,7 @@ func TestLoggingAuditor_EdgeCases(t *testing.T) { }, { name: "very long URL", - request: &Request{ + request: Request{ Method: "GET", URL: "https://example.com/" + strings.Repeat("a", 1000), Allowed: false, @@ -191,7 +191,7 @@ func TestLoggingAuditor_EdgeCases(t *testing.T) { }, { name: "deny with custom URL", - request: &Request{ + request: Request{ Method: "DELETE", URL: "https://malicious.com", Allowed: false, @@ -233,13 +233,13 @@ func TestLoggingAuditor_DifferentLogLevels(t *testing.T) { tests := []struct { name string logLevel slog.Level - request *Request + request Request expectOutput bool }{ { name: "info level allows info logs", logLevel: slog.LevelInfo, - request: &Request{ + request: Request{ Method: "GET", URL: "https://github.com", Allowed: true, @@ -250,7 +250,7 @@ func TestLoggingAuditor_DifferentLogLevels(t *testing.T) { { name: "warn level blocks info logs", logLevel: slog.LevelWarn, - request: &Request{ + request: Request{ Method: "GET", URL: "https://github.com", Allowed: true, @@ -261,7 +261,7 @@ func TestLoggingAuditor_DifferentLogLevels(t *testing.T) { { name: "warn level allows warn logs", logLevel: slog.LevelWarn, - request: &Request{ + request: Request{ Method: "POST", URL: "https://example.com", Allowed: false, @@ -271,7 +271,7 @@ func TestLoggingAuditor_DifferentLogLevels(t *testing.T) { { name: "error level blocks warn logs", logLevel: slog.LevelError, - request: &Request{ + request: Request{ Method: "POST", URL: "https://example.com", Allowed: false, @@ -312,7 +312,7 @@ func TestLoggingAuditor_NilLogger(t *testing.T) { }() auditor := &LoggingAuditor{logger: nil} - req := &Request{ + req := Request{ Method: "GET", URL: "https://example.com", Allowed: true, @@ -331,7 +331,7 @@ func TestLoggingAuditor_JSONHandler(t *testing.T) { })) auditor := NewLoggingAuditor(logger) - req := &Request{ + req := Request{ Method: "GET", URL: "https://github.com", Allowed: true, @@ -364,7 +364,7 @@ func TestLoggingAuditor_DiscardHandler(t *testing.T) { logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) auditor := NewLoggingAuditor(logger) - req := &Request{ + req := Request{ Method: "GET", URL: "https://example.com", Allowed: true, diff --git a/audit/request.go b/audit/request.go index b0e7bae..6183c0d 100644 --- a/audit/request.go +++ b/audit/request.go @@ -1,8 +1,6 @@ package audit -import ( - "net/http" -) +import "net/http" // Request represents information about an HTTP request for auditing type Request struct { diff --git a/cli/cli.go b/cli/cli.go index 9ca1f4e..37afae7 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -11,8 +11,6 @@ import ( "github.com/coder/jail" "github.com/coder/jail/audit" - "github.com/coder/jail/namespace" - "github.com/coder/jail/proxy" "github.com/coder/jail/rules" "github.com/coder/jail/tls" "github.com/coder/serpent" @@ -115,20 +113,7 @@ func Run(config Config, args []string) error { ruleEngine := rules.NewRuleEngine(allowRules, logger) // Create auditor - auditor := audit.NewLoggingAuditor(logger) - - // Create network namespace configuration - nsConfig := namespace.Config{ - HTTPPort: 8040, - HTTPSPort: 8043, - } - - // Create commander - commander, err := namespace.New(nsConfig, logger) - if err != nil { - logger.Error("Failed to create network namespace", "error", err) - return fmt.Errorf("failed to create network namespace: %v", err) - } + // auditor := audit.NewLoggingAuditor(logger) // Create certificate manager certManager, err := tls.NewCertificateManager(logger) @@ -137,39 +122,17 @@ func Run(config Config, args []string) error { return fmt.Errorf("failed to create certificate manager: %v", err) } - // Setup TLS config and write CA certificate to file - var caCertPath, configDir string - tlsConfig, caCertPath, configDir, err := certManager.SetupTLSAndWriteCACert() - if err != nil { - logger.Error("Failed to setup TLS and CA certificate", "error", err) - return fmt.Errorf("failed to setup TLS and CA certificate: %v", err) - } - - // Set standard CA certificate environment variables for common tools - // This makes tools like curl, git, etc. trust our dynamically generated CA - commander.SetEnv("SSL_CERT_FILE", caCertPath) // OpenSSL/LibreSSL-based tools - commander.SetEnv("SSL_CERT_DIR", configDir) // OpenSSL certificate directory - commander.SetEnv("CURL_CA_BUNDLE", caCertPath) // curl - commander.SetEnv("GIT_SSL_CAINFO", caCertPath) // Git - commander.SetEnv("REQUESTS_CA_BUNDLE", caCertPath) // Python requests - commander.SetEnv("NODE_EXTRA_CA_CERTS", caCertPath) // Node.js - - // Create proxy server - proxyServer := proxy.NewProxyServer(proxy.Config{ - HTTPPort: 8040, - HTTPSPort: 8043, - RuleEngine: ruleEngine, - Auditor: auditor, - Logger: logger, - TLSConfig: tlsConfig, - }) - // Create jail instance - jailInstance := jail.New(jail.Config{ - Commander: commander, - ProxyServer: proxyServer, + jailInstance, err := jail.New(context.Background(), jail.Config{ + RuleEngine: ruleEngine, + Auditor: audit.NewLoggingAuditor(logger), Logger: logger, + CertManager: certManager, }) + if err != nil { + logger.Error("Failed to create jail instance", "error", err) + return fmt.Errorf("failed to create jail instance: %v", err) + } // Setup signal handling BEFORE any setup sigChan := make(chan os.Signal, 1) @@ -198,7 +161,7 @@ func Run(config Config, args []string) error { }() // Open jail (starts network namespace and proxy server) - err = jailInstance.Open() + err = jailInstance.Start() if err != nil { logger.Error("Failed to open jail", "error", err) return fmt.Errorf("failed to open jail: %v", err) diff --git a/jail.go b/jail.go index f0e5a84..6f27ee4 100644 --- a/jail.go +++ b/jail.go @@ -5,50 +5,87 @@ import ( "fmt" "log/slog" "os/exec" + "runtime" "time" + + "github.com/coder/jail/namespace" + "github.com/coder/jail/proxy" + "github.com/coder/jail/tls" ) type Commander interface { - Open() error - SetEnv(key string, value string) + Start() error Command(command []string) *exec.Cmd Close() error } -type ProxyServer interface { - Start(ctx context.Context) error - Stop() error -} - type Config struct { - Commander Commander - ProxyServer ProxyServer + RuleEngine proxy.RuleEvaluator + Auditor proxy.Auditor + CertManager *tls.CertificateManager Logger *slog.Logger } type Jail struct { commandExecutor Commander - proxyServer ProxyServer + proxyServer *proxy.ProxyServer logger *slog.Logger - cancel context.CancelFunc ctx context.Context + cancel context.CancelFunc } -func New(config Config) *Jail { - ctx, cancel := context.WithCancel(context.Background()) +func New(ctx context.Context, config Config) (*Jail, error) { + // Setup TLS config and write CA certificate to file + tlsConfig, caCertPath, configDir, err := config.CertManager.SetupTLSAndWriteCACert() + if err != nil { + return nil, fmt.Errorf("failed to setup TLS and CA certificate: %v", err) + } + + // Create proxy server + proxyServer := proxy.NewProxyServer(proxy.Config{ + HTTPPort: 8080, + HTTPSPort: 8443, + Auditor: config.Auditor, + RuleEngine: config.RuleEngine, + Logger: config.Logger, + TLSConfig: tlsConfig, + }) + + // Create commander + commander, err := newCommander(namespace.Config{ + Logger: config.Logger, + HttpProxyPort: 8080, + HttpsProxyPort: 8443, + Env: map[string]string{ + // Set standard CA certificate environment variables for common tools + // This makes tools like curl, git, etc. trust our dynamically generated CA + "SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools + "SSL_CERT_DIR": configDir, // OpenSSL certificate directory + "CURL_CA_BUNDLE": caCertPath, // curl + "GIT_SSL_CAINFO": caCertPath, // Git + "REQUESTS_CA_BUNDLE": caCertPath, // Python requests + "NODE_EXTRA_CA_CERTS": caCertPath, // Node.js + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to create commander: %v", err) + } + + // Create cancellable context for jail + ctx, cancel := context.WithCancel(ctx) return &Jail{ - commandExecutor: config.Commander, - proxyServer: config.ProxyServer, + commandExecutor: commander, + proxyServer: proxyServer, logger: config.Logger, ctx: ctx, cancel: cancel, - } + }, nil } -func (j *Jail) Open() error { +func (j *Jail) Start() error { // Open the command executor (network namespace) - err := j.commandExecutor.Open() + err := j.commandExecutor.Start() if err != nil { return fmt.Errorf("failed to open command executor: %v", err) } @@ -72,11 +109,6 @@ func (j *Jail) Command(command []string) *exec.Cmd { } func (j *Jail) Close() error { - // Cancel context to stop proxy server - if j.cancel != nil { - j.cancel() - } - // Stop proxy server if j.proxyServer != nil { err := j.proxyServer.Stop() @@ -88,3 +120,15 @@ func (j *Jail) Close() error { // Close command executor return j.commandExecutor.Close() } + +// newCommander creates a new NetJail instance for the current platform +func newCommander(config namespace.Config) (Commander, error) { + switch runtime.GOOS { + case "darwin": + return namespace.NewMacOS(config) + case "linux": + return namespace.NewLinux(config) + default: + return nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/namespace/linux.go b/namespace/linux.go index 92be549..accc168 100644 --- a/namespace/linux.go +++ b/namespace/linux.go @@ -16,26 +16,34 @@ import ( // Linux implements jail.Commander using Linux network namespaces type Linux struct { - config Config - namespace string - vethHost string // Host-side veth interface name for iptables rules - logger *slog.Logger - preparedEnv map[string]string - procAttr *syscall.SysProcAttr + namespace string + vethHost string // Host-side veth interface name for iptables rules + logger *slog.Logger + preparedEnv map[string]string + procAttr *syscall.SysProcAttr + httpProxyPort int + httpsProxyPort int } -// newLinux creates a new Linux network jail instance -func newLinux(config Config, logger *slog.Logger) (*Linux, error) { +// NewLinux creates a new Linux network jail instance +func NewLinux(config Config) (*Linux, error) { + // Initialize preparedEnv with config environment variables + preparedEnv := make(map[string]string) + for key, value := range config.Env { + preparedEnv[key] = value + } + return &Linux{ - config: config, - namespace: newNamespaceName(), - logger: logger, - preparedEnv: make(map[string]string), + namespace: newNamespaceName(), + logger: config.Logger, + preparedEnv: preparedEnv, + httpProxyPort: config.HttpProxyPort, + httpsProxyPort: config.HttpsProxyPort, }, nil } // Setup creates network namespace and configures iptables rules -func (l *Linux) Open() error { +func (l *Linux) Start() error { l.logger.Debug("Setup called") // Setup DNS configuration BEFORE creating namespace @@ -69,7 +77,7 @@ func (l *Linux) Open() error { // Start with current environment for _, envVar := range os.Environ() { if parts := strings.SplitN(envVar, "=", 2); len(parts) == 2 { - // Only set if not already set by SetEnv + // Only set if not already set by config if _, exists := l.preparedEnv[parts[0]]; !exists { l.preparedEnv[parts[0]] = parts[1] } @@ -119,11 +127,6 @@ func (l *Linux) Open() error { return nil } -// SetEnv sets an environment variable for commands run in the namespace -func (l *Linux) SetEnv(key string, value string) { - l.preparedEnv[key] = value -} - // Command returns an exec.Cmd configured to run within the network namespace func (l *Linux) Command(command []string) *exec.Cmd { l.logger.Debug("Command called", "command", command) @@ -269,20 +272,20 @@ func (l *Linux) setupIptables() error { // COMPREHENSIVE APPROACH: Intercept ALL TCP traffic from namespace // Use PREROUTING on host to catch traffic after it exits namespace but before routing // This ensures NO TCP traffic can bypass the proxy - cmd = exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.config.HTTPSPort)) + cmd = exec.Command("iptables", "-t", "nat", "-A", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpsProxyPort)) err = cmd.Run() if err != nil { return fmt.Errorf("failed to add comprehensive TCP redirect rule: %v", err) } - l.logger.Debug("Comprehensive TCP jailing enabled", "interface", l.vethHost, "proxy_port", l.config.HTTPSPort) + l.logger.Debug("Comprehensive TCP jailing enabled", "interface", l.vethHost, "proxy_port", l.httpsProxyPort) return nil } // removeIptables removes iptables rules func (l *Linux) removeIptables() error { // Remove comprehensive TCP redirect rule - cmd := exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.config.HTTPSPort)) + cmd := exec.Command("iptables", "-t", "nat", "-D", "PREROUTING", "-i", l.vethHost, "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", l.httpsProxyPort)) cmd.Run() // Ignore errors during cleanup // Remove NAT rule @@ -300,4 +303,4 @@ func (l *Linux) removeNamespace() error { return fmt.Errorf("failed to remove namespace: %v", err) } return nil -} \ No newline at end of file +} diff --git a/namespace/linux_stub.go b/namespace/linux_stub.go index 63d0d7d..29a304b 100644 --- a/namespace/linux_stub.go +++ b/namespace/linux_stub.go @@ -4,12 +4,9 @@ package namespace import ( "fmt" - "log/slog" - - "github.com/coder/jail" ) -// newLinux is not available on non-Linux platforms -func newLinux(_ Config, _ *slog.Logger) (jail.Commander, error) { +// NewLinux is not available on non-Linux platforms +func NewLinux(_ Config) (*noop, error) { return nil, fmt.Errorf("linux network jail not supported on this platform") } diff --git a/namespace/macos.go b/namespace/macos.go index 56b904e..a9db968 100644 --- a/namespace/macos.go +++ b/namespace/macos.go @@ -20,32 +20,40 @@ const ( // MacOSNetJail implements network jail using macOS PF (Packet Filter) and group-based isolation type MacOSNetJail struct { - config Config - groupID int - pfRulesPath string - mainRulesPath string - logger *slog.Logger - preparedEnv map[string]string - procAttr *syscall.SysProcAttr + groupID int + pfRulesPath string + mainRulesPath string + logger *slog.Logger + preparedEnv map[string]string + procAttr *syscall.SysProcAttr + httpProxyPort int + httpsProxyPort int } -// newMacOSJail creates a new macOS network jail instance -func newMacOSJail(config Config, logger *slog.Logger) (*MacOSNetJail, error) { +// NewMacOS creates a new macOS network jail instance +func NewMacOS(config Config) (*MacOSNetJail, error) { ns := newNamespaceName() pfRulesPath := fmt.Sprintf("/tmp/%s.pf", ns) mainRulesPath := fmt.Sprintf("/tmp/%s_main.pf", ns) + // Initialize preparedEnv with config environment variables + preparedEnv := make(map[string]string) + for key, value := range config.Env { + preparedEnv[key] = value + } + return &MacOSNetJail{ - config: config, - pfRulesPath: pfRulesPath, - mainRulesPath: mainRulesPath, - logger: logger, - preparedEnv: make(map[string]string), + pfRulesPath: pfRulesPath, + mainRulesPath: mainRulesPath, + logger: config.Logger, + preparedEnv: preparedEnv, + httpProxyPort: config.HttpProxyPort, + httpsProxyPort: config.HttpsProxyPort, }, nil } // Setup creates the network jail group and configures PF rules -func (m *MacOSNetJail) Open() error { +func (m *MacOSNetJail) Start() error { m.logger.Debug("Setup called") // Create or get network jail group @@ -68,7 +76,7 @@ func (m *MacOSNetJail) Open() error { // Start with current environment for _, envVar := range os.Environ() { if parts := strings.SplitN(envVar, "=", 2); len(parts) == 2 { - // Only set if not already set by SetEnv + // Only set if not already set by config if _, exists := m.preparedEnv[parts[0]]; !exists { m.preparedEnv[parts[0]] = parts[1] } @@ -123,12 +131,7 @@ func (m *MacOSNetJail) Open() error { return nil } -// SetEnv sets an environment variable for commands run in the namespace -func (m *MacOSNetJail) SetEnv(key string, value string) { - m.preparedEnv[key] = value -} - -// Execute runs the command with the network jail group membership +// Command runs the command with the network jail group membership func (m *MacOSNetJail) Command(command []string) *exec.Cmd { m.logger.Debug("Command called", "command", command) @@ -275,13 +278,13 @@ pass on lo0 all `, m.groupID, iface, - m.config.HTTPSPort, // Use HTTPS proxy port for all TCP traffic + m.httpsProxyPort, // Use HTTPS proxy port for all TCP traffic m.groupID, iface, m.groupID, ) - m.logger.Debug("Comprehensive TCP jailing enabled for macOS", "group_id", m.groupID, "proxy_port", m.config.HTTPSPort) + m.logger.Debug("Comprehensive TCP jailing enabled for macOS", "group_id", m.groupID, "proxy_port", m.httpsProxyPort) return rules, nil } @@ -367,4 +370,4 @@ func (m *MacOSNetJail) cleanupTempFiles() { if m.mainRulesPath != "" { os.Remove(m.mainRulesPath) } -} \ No newline at end of file +} diff --git a/namespace/macos_stub.go b/namespace/macos_stub.go index 9368ae3..224a9f8 100644 --- a/namespace/macos_stub.go +++ b/namespace/macos_stub.go @@ -2,13 +2,7 @@ package namespace -import ( - "log/slog" - - "github.com/coder/jail" -) - -// newMacOSJail is not available on non-macOS platforms -func newMacOSJail(config Config, logger *slog.Logger) (jail.Commander, error) { +// NewMacOS is not available on non-macOS platforms +func NewMacOS(_ Config) (*noop, error) { panic("macOS network jail not available on this platform") } diff --git a/namespace/namespace.go b/namespace/namespace.go index a71ca22..1cf01ef 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -3,10 +3,7 @@ package namespace import ( "fmt" "log/slog" - "runtime" "time" - - "github.com/coder/jail" ) const ( @@ -15,22 +12,24 @@ const ( // JailConfig holds configuration for network jail type Config struct { - HTTPPort int - HTTPSPort int + Logger *slog.Logger + HttpProxyPort int + HttpsProxyPort int + Env map[string]string } -// NewJail creates a new NetJail instance for the current platform -func New(config Config, logger *slog.Logger) (jail.Commander, error) { - switch runtime.GOOS { - case "darwin": - return newMacOSJail(config, logger) - case "linux": - return newLinux(config, logger) - default: - return nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) - } -} +// // NewJail creates a new NetJail instance for the current platform +// func New(config Config) (jail.Commander, error) { +// switch runtime.GOOS { +// case "darwin": +// return NewMacOS(config) +// case "linux": +// return NewLinux(config) +// default: +// return nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) +// } +// } func newNamespaceName() string { return fmt.Sprintf("%s_%d", namespacePrefix, time.Now().UnixNano()%10000000) -} \ No newline at end of file +} diff --git a/namespace/noop.go b/namespace/noop.go new file mode 100644 index 0000000..aed145c --- /dev/null +++ b/namespace/noop.go @@ -0,0 +1,23 @@ +package namespace + +import ( + "os/exec" +) + +type noop struct{} + +func newNoop(_ Config) (*noop, error) { + return &noop{}, nil +} + +func (n *noop) Command(_ []string) *exec.Cmd { + return exec.Command("true") +} + +func (n *noop) Start() error { + return nil +} + +func (n *noop) Close() error { + return nil +} diff --git a/proxy/proxy.go b/proxy/proxy.go index cedc223..5fce72c 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -14,12 +14,20 @@ import ( "github.com/coder/jail/rules" ) +type RuleEvaluator interface { + Evaluate(method, url string) rules.EvaluationResult +} + +type Auditor interface { + AuditRequest(req audit.Request) +} + // ProxyServer handles HTTP and HTTPS requests with rule-based filtering type ProxyServer struct { httpServer *http.Server httpsServer *http.Server - ruleEngine *rules.RuleEngine - auditor *audit.LoggingAuditor + ruleEngine RuleEvaluator + auditor Auditor logger *slog.Logger tlsConfig *tls.Config httpPort int @@ -30,8 +38,8 @@ type ProxyServer struct { type Config struct { HTTPPort int HTTPSPort int - RuleEngine *rules.RuleEngine - Auditor *audit.LoggingAuditor + RuleEngine RuleEvaluator + Auditor Auditor Logger *slog.Logger TLSConfig *tls.Config } @@ -111,10 +119,12 @@ func (p *ProxyServer) handleHTTP(w http.ResponseWriter, r *http.Request) { result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) // Audit the request - auditReq := audit.HTTPRequestToAuditRequest(r) - auditReq.Allowed = result.Allowed - auditReq.Rule = result.Rule - p.auditRequest(auditReq) + p.auditor.AuditRequest(audit.Request{ + Method: r.Method, + URL: r.URL.String(), + Allowed: result.Allowed, + Rule: result.Rule, + }) if !result.Allowed { p.writeBlockedResponse(w, r) @@ -127,23 +137,16 @@ func (p *ProxyServer) handleHTTP(w http.ResponseWriter, r *http.Request) { // handleHTTPS handles HTTPS requests (after TLS termination) func (p *ProxyServer) handleHTTPS(w http.ResponseWriter, r *http.Request) { - // Reconstruct the full URL for HTTPS requests - fullURL := fmt.Sprintf("https://%s%s", r.Host, r.URL.Path) - if r.URL.RawQuery != "" { - fullURL += "?" + r.URL.RawQuery - } - // Check if request should be allowed - result := p.ruleEngine.Evaluate(r.Method, fullURL) + result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) // Audit the request - auditReq := &audit.Request{ + p.auditor.AuditRequest(audit.Request{ Method: r.Method, - URL: fullURL, + URL: r.URL.String(), Allowed: result.Allowed, Rule: result.Rule, - } - p.auditRequest(auditReq) + }) if !result.Allowed { p.writeBlockedResponse(w, r) @@ -291,8 +294,3 @@ For more help: https://github.com/coder/jail `, r.Method, r.URL.Path, host, host, r.Method, host, r.Method) } - -// auditRequest handles auditing of requests -func (p *ProxyServer) auditRequest(req *audit.Request) { - p.auditor.AuditRequest(req) -}