From efb5f16a4f3a90f78028c011d45bff1142b08b92 Mon Sep 17 00:00:00 2001 From: Garrett Delfosse Date: Thu, 11 Sep 2025 13:04:31 -0400 Subject: [PATCH 1/2] style --- audit/logging_auditor.go | 2 +- audit/request.go | 10 -------- audit/request_test.go | 5 +++- cli/cli.go | 49 ++++++++++++---------------------------- namespace/namespace.go | 12 ---------- 5 files changed, 20 insertions(+), 58 deletions(-) diff --git a/audit/logging_auditor.go b/audit/logging_auditor.go index f2f81d8..28d4612 100644 --- a/audit/logging_auditor.go +++ b/audit/logging_auditor.go @@ -2,7 +2,7 @@ package audit import "log/slog" -// LoggingAuditor implements Auditor by logging to slog +// LoggingAuditor implements proxy.Auditor by logging to slog type LoggingAuditor struct { logger *slog.Logger } diff --git a/audit/request.go b/audit/request.go index 6183c0d..1ffd146 100644 --- a/audit/request.go +++ b/audit/request.go @@ -1,7 +1,5 @@ package audit -import "net/http" - // Request represents information about an HTTP request for auditing type Request struct { Method string @@ -9,11 +7,3 @@ type Request struct { Allowed bool Rule string // The rule that matched (if any) } - -// HTTPRequestToAuditRequest converts an http.Request to an audit.Request -func HTTPRequestToAuditRequest(httpReq *http.Request) *Request { - return &Request{ - Method: httpReq.Method, - URL: httpReq.URL.String(), - } -} diff --git a/audit/request_test.go b/audit/request_test.go index b8b6a5a..5a6ced1 100644 --- a/audit/request_test.go +++ b/audit/request_test.go @@ -94,7 +94,10 @@ func TestHTTPRequestToAuditRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - auditReq := HTTPRequestToAuditRequest(tt.request) + auditReq := Request{ + Method: tt.request.Method, + URL: tt.request.URL.String(), + } if auditReq.Method != tt.expectedMethod { t.Errorf("expected method %s, got %s", tt.expectedMethod, auditReq.Method) diff --git a/cli/cli.go b/cli/cli.go index 37afae7..0554f6d 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -59,7 +59,7 @@ Examples: }, }, Handler: func(inv *serpent.Invocation) error { - return Run(config, inv.Args) + return Run(inv.Context(), config, inv.Args) }, } } @@ -89,7 +89,9 @@ func setupLogging(logLevel string) *slog.Logger { } // Run executes the jail command with the given configuration and arguments -func Run(config Config, args []string) error { +func Run(ctx context.Context, config Config, args []string) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() logger := setupLogging(config.LogLevel) // Get command arguments @@ -113,7 +115,7 @@ func Run(config Config, args []string) error { ruleEngine := rules.NewRuleEngine(allowRules, logger) // Create auditor - // auditor := audit.NewLoggingAuditor(logger) + auditor := audit.NewLoggingAuditor(logger) // Create certificate manager certManager, err := tls.NewCertificateManager(logger) @@ -123,14 +125,13 @@ func Run(config Config, args []string) error { } // Create jail instance - jailInstance, err := jail.New(context.Background(), jail.Config{ + jailInstance, err := jail.New(ctx, jail.Config{ RuleEngine: ruleEngine, - Auditor: audit.NewLoggingAuditor(logger), - Logger: logger, + Auditor: auditor, CertManager: certManager, + Logger: logger, }) if err != nil { - logger.Error("Failed to create jail instance", "error", err) return fmt.Errorf("failed to create jail instance: %v", err) } @@ -138,38 +139,18 @@ func Run(config Config, args []string) error { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - // Handle signals immediately in background - go func() { - sig := <-sigChan - logger.Info("Received signal during setup, cleaning up...", "signal", sig) - err := jailInstance.Close() - if err != nil { - logger.Error("Emergency cleanup failed", "error", err) - } - os.Exit(1) - }() - - // Ensure cleanup happens no matter what - defer func() { - logger.Debug("Starting cleanup process") - err := jailInstance.Close() - if err != nil { - logger.Error("Failed to cleanup jail", "error", err) - } else { - logger.Debug("Cleanup completed successfully") - } - }() - // Open jail (starts network namespace and proxy server) err = jailInstance.Start() if err != nil { - logger.Error("Failed to open jail", "error", err) return fmt.Errorf("failed to open jail: %v", err) } - - // Create context for graceful shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + defer func() { + logger.Info("Closing jail...") + err := jailInstance.Close() + if err != nil { + logger.Error("Failed to close jail", "error", err) + } + }() // Execute command in jail go func() { diff --git a/namespace/namespace.go b/namespace/namespace.go index 1cf01ef..8fd365c 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -18,18 +18,6 @@ type Config struct { Env map[string]string } -// // NewJail creates a new NetJail instance for the current platform -// func New(config Config) (jail.Commander, error) { -// switch runtime.GOOS { -// case "darwin": -// return NewMacOS(config) -// case "linux": -// return NewLinux(config) -// default: -// return nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS) -// } -// } - func newNamespaceName() string { return fmt.Sprintf("%s_%d", namespacePrefix, time.Now().UnixNano()%10000000) } From 209bd68519d9e50112a229ca3232164548b47305 Mon Sep 17 00:00:00 2001 From: Garrett Delfosse Date: Thu, 11 Sep 2025 13:17:06 -0400 Subject: [PATCH 2/2] style --- audit/request.go | 4 ++ audit/request_test.go | 120 ----------------------------------------- jail.go | 30 ++++------- namespace/namespace.go | 7 +++ namespace/noop.go | 4 -- proxy/proxy.go | 45 +++++++--------- rules/rules.go | 14 +++-- tls/tls.go | 4 ++ 8 files changed, 54 insertions(+), 174 deletions(-) delete mode 100644 audit/request_test.go diff --git a/audit/request.go b/audit/request.go index 1ffd146..54f4c4e 100644 --- a/audit/request.go +++ b/audit/request.go @@ -1,5 +1,9 @@ package audit +type Auditor interface { + AuditRequest(req Request) +} + // Request represents information about an HTTP request for auditing type Request struct { Method string diff --git a/audit/request_test.go b/audit/request_test.go deleted file mode 100644 index 5a6ced1..0000000 --- a/audit/request_test.go +++ /dev/null @@ -1,120 +0,0 @@ -package audit - -import ( - "net/http" - "net/url" - "strings" - "testing" -) - -func TestHTTPRequestToAuditRequest(t *testing.T) { - tests := []struct { - name string - request *http.Request - expectedMethod string - expectedURL string - }{ - { - name: "basic GET request", - request: func() *http.Request { - req, _ := http.NewRequest("GET", "https://example.com/path?query=value", nil) - return req - }(), - expectedMethod: "GET", - expectedURL: "https://example.com/path?query=value", - }, - { - name: "POST request with body", - request: func() *http.Request { - req, _ := http.NewRequest("POST", "https://api.example.com/users", strings.NewReader("data")) - return req - }(), - expectedMethod: "POST", - expectedURL: "https://api.example.com/users", - }, - { - name: "request with port", - request: func() *http.Request { - req, _ := http.NewRequest("GET", "https://example.com:8443/api", nil) - return req - }(), - expectedMethod: "GET", - expectedURL: "https://example.com:8443/api", - }, - { - name: "request with complex query parameters", - request: func() *http.Request { - req, _ := http.NewRequest("GET", "https://search.example.com/api?q=hello%20world&limit=10&offset=0", nil) - return req - }(), - expectedMethod: "GET", - expectedURL: "https://search.example.com/api?q=hello%20world&limit=10&offset=0", - }, - { - name: "request with fragment (should be ignored)", - request: func() *http.Request { - u, _ := url.Parse("https://example.com/page#section") - req := &http.Request{ - Method: "GET", - URL: u, - } - return req - }(), - expectedMethod: "GET", - expectedURL: "https://example.com/page#section", - }, - { - name: "HTTP request (not HTTPS)", - request: func() *http.Request { - req, _ := http.NewRequest("GET", "http://insecure.example.com/data", nil) - return req - }(), - expectedMethod: "GET", - expectedURL: "http://insecure.example.com/data", - }, - { - name: "PUT request", - request: func() *http.Request { - req, _ := http.NewRequest("PUT", "https://api.example.com/users/123", strings.NewReader("updated data")) - return req - }(), - expectedMethod: "PUT", - expectedURL: "https://api.example.com/users/123", - }, - { - name: "DELETE request", - request: func() *http.Request { - req, _ := http.NewRequest("DELETE", "https://api.example.com/users/123", nil) - return req - }(), - expectedMethod: "DELETE", - expectedURL: "https://api.example.com/users/123", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - auditReq := Request{ - Method: tt.request.Method, - URL: tt.request.URL.String(), - } - - if auditReq.Method != tt.expectedMethod { - t.Errorf("expected method %s, got %s", tt.expectedMethod, auditReq.Method) - } - - if auditReq.URL != tt.expectedURL { - t.Errorf("expected URL %s, got %s", tt.expectedURL, auditReq.URL) - } - - // Verify that fields not set by HTTPRequestToAuditRequest have zero values - if auditReq.Allowed != false { - t.Errorf("expected Allowed to be false (zero value), got %v", auditReq.Allowed) - } - - if auditReq.Rule != "" { - t.Errorf("expected Rule to be empty (zero value), got %q", auditReq.Rule) - } - }) - } -} diff --git a/jail.go b/jail.go index 9b77aa7..0b2a4b5 100644 --- a/jail.go +++ b/jail.go @@ -2,37 +2,29 @@ package jail import ( "context" - cryptotls "crypto/tls" "fmt" "log/slog" "os/exec" "runtime" "time" + "github.com/coder/jail/audit" "github.com/coder/jail/namespace" "github.com/coder/jail/proxy" + "github.com/coder/jail/rules" + "github.com/coder/jail/tls" ) -type Commander interface { - Start() error - Command(command []string) *exec.Cmd - Close() error -} - -type CertificateManager interface { - SetupTLSAndWriteCACert() (*cryptotls.Config, string, string, error) -} - type Config struct { - RuleEngine proxy.RuleEvaluator - Auditor proxy.Auditor - CertManager CertificateManager + RuleEngine rules.Evaluator + Auditor audit.Auditor + CertManager tls.Manager Logger *slog.Logger } type Jail struct { - commander Commander - proxyServer *proxy.ProxyServer + commander namespace.Commander + proxyServer *proxy.Server logger *slog.Logger ctx context.Context cancel context.CancelFunc @@ -56,7 +48,7 @@ func New(ctx context.Context, config Config) (*Jail, error) { }) // Create commander - commander, err := newCommander(namespace.Config{ + commander, err := newNamespaceCommander(namespace.Config{ Logger: config.Logger, HttpProxyPort: 8080, HttpsProxyPort: 8443, @@ -125,8 +117,8 @@ func (j *Jail) Close() error { return j.commander.Close() } -// newCommander creates a new NetJail instance for the current platform -func newCommander(config namespace.Config) (Commander, error) { +// newNamespaceCommander creates a new namespace instance for the current platform +func newNamespaceCommander(config namespace.Config) (namespace.Commander, error) { switch runtime.GOOS { case "darwin": return namespace.NewMacOS(config) diff --git a/namespace/namespace.go b/namespace/namespace.go index 8fd365c..6436782 100644 --- a/namespace/namespace.go +++ b/namespace/namespace.go @@ -3,6 +3,7 @@ package namespace import ( "fmt" "log/slog" + "os/exec" "time" ) @@ -10,6 +11,12 @@ const ( namespacePrefix = "coder_jail" ) +type Commander interface { + Start() error + Command(command []string) *exec.Cmd + Close() error +} + // JailConfig holds configuration for network jail type Config struct { Logger *slog.Logger diff --git a/namespace/noop.go b/namespace/noop.go index aed145c..64445eb 100644 --- a/namespace/noop.go +++ b/namespace/noop.go @@ -6,10 +6,6 @@ import ( type noop struct{} -func newNoop(_ Config) (*noop, error) { - return &noop{}, nil -} - func (n *noop) Command(_ []string) *exec.Cmd { return exec.Command("true") } diff --git a/proxy/proxy.go b/proxy/proxy.go index 5fce72c..944b8aa 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -14,39 +14,32 @@ import ( "github.com/coder/jail/rules" ) -type RuleEvaluator interface { - Evaluate(method, url string) rules.EvaluationResult -} - -type Auditor interface { - AuditRequest(req audit.Request) -} +// Server handles HTTP and HTTPS requests with rule-based filtering +type Server struct { + ruleEngine rules.Evaluator + auditor audit.Auditor + logger *slog.Logger + tlsConfig *tls.Config + httpPort int + httpsPort int -// ProxyServer handles HTTP and HTTPS requests with rule-based filtering -type ProxyServer struct { httpServer *http.Server httpsServer *http.Server - ruleEngine RuleEvaluator - auditor Auditor - logger *slog.Logger - tlsConfig *tls.Config - httpPort int - httpsPort int } // Config holds configuration for the proxy server type Config struct { HTTPPort int HTTPSPort int - RuleEngine RuleEvaluator - Auditor Auditor + RuleEngine rules.Evaluator + Auditor audit.Auditor Logger *slog.Logger TLSConfig *tls.Config } // NewProxyServer creates a new proxy server instance -func NewProxyServer(config Config) *ProxyServer { - return &ProxyServer{ +func NewProxyServer(config Config) *Server { + return &Server{ ruleEngine: config.RuleEngine, auditor: config.Auditor, logger: config.Logger, @@ -57,7 +50,7 @@ func NewProxyServer(config Config) *ProxyServer { } // Start starts both HTTP and HTTPS proxy servers -func (p *ProxyServer) Start(ctx context.Context) error { +func (p *Server) Start(ctx context.Context) error { // Create HTTP server p.httpServer = &http.Server{ Addr: fmt.Sprintf(":%d", p.httpPort), @@ -95,7 +88,7 @@ func (p *ProxyServer) Start(ctx context.Context) error { } // Stop stops both proxy servers -func (p *ProxyServer) Stop() error { +func (p *Server) Stop() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -114,7 +107,7 @@ func (p *ProxyServer) Stop() error { } // handleHTTP handles regular HTTP requests -func (p *ProxyServer) handleHTTP(w http.ResponseWriter, r *http.Request) { +func (p *Server) handleHTTP(w http.ResponseWriter, r *http.Request) { // Check if request should be allowed result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) @@ -136,7 +129,7 @@ func (p *ProxyServer) handleHTTP(w http.ResponseWriter, r *http.Request) { } // handleHTTPS handles HTTPS requests (after TLS termination) -func (p *ProxyServer) handleHTTPS(w http.ResponseWriter, r *http.Request) { +func (p *Server) handleHTTPS(w http.ResponseWriter, r *http.Request) { // Check if request should be allowed result := p.ruleEngine.Evaluate(r.Method, r.URL.String()) @@ -158,7 +151,7 @@ func (p *ProxyServer) handleHTTPS(w http.ResponseWriter, r *http.Request) { } // forwardHTTPRequest forwards a regular HTTP request -func (p *ProxyServer) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) { +func (p *Server) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) { // Create a new request to the target server targetURL := r.URL if targetURL.Scheme == "" { @@ -212,7 +205,7 @@ func (p *ProxyServer) forwardHTTPRequest(w http.ResponseWriter, r *http.Request) } // forwardHTTPSRequest forwards an HTTPS request -func (p *ProxyServer) forwardHTTPSRequest(w http.ResponseWriter, r *http.Request) { +func (p *Server) forwardHTTPSRequest(w http.ResponseWriter, r *http.Request) { // Create target URL targetURL := &url.URL{ Scheme: "https", @@ -271,7 +264,7 @@ func (p *ProxyServer) forwardHTTPSRequest(w http.ResponseWriter, r *http.Request } // writeBlockedResponse writes a blocked response -func (p *ProxyServer) writeBlockedResponse(w http.ResponseWriter, r *http.Request) { +func (p *Server) writeBlockedResponse(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusForbidden) diff --git a/rules/rules.go b/rules/rules.go index 6ad64fe..ba099a2 100644 --- a/rules/rules.go +++ b/rules/rules.go @@ -6,6 +6,10 @@ import ( "strings" ) +type Evaluator interface { + Evaluate(method, url string) Result +} + // Rule represents an allow rule with optional HTTP method restrictions type Rule struct { Pattern string // wildcard pattern for matching @@ -119,18 +123,18 @@ func NewRuleEngine(rules []*Rule, logger *slog.Logger) *RuleEngine { } } -// EvaluationResult contains the result of rule evaluation -type EvaluationResult struct { +// Result contains the result of rule evaluation +type Result struct { Allowed bool Rule string // The rule that matched (if any) } // Evaluate evaluates a request and returns both result and matching rule -func (re *RuleEngine) Evaluate(method, url string) EvaluationResult { +func (re *RuleEngine) Evaluate(method, url string) Result { // Check if any allow rule matches for _, rule := range re.rules { if rule.Matches(method, url) { - return EvaluationResult{ + return Result{ Allowed: true, Rule: rule.Raw, } @@ -138,7 +142,7 @@ func (re *RuleEngine) Evaluate(method, url string) EvaluationResult { } // Default deny if no allow rules match - return EvaluationResult{ + return Result{ Allowed: false, Rule: "", } diff --git a/tls/tls.go b/tls/tls.go index 002c9d5..75f34aa 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -19,6 +19,10 @@ import ( "time" ) +type Manager interface { + SetupTLSAndWriteCACert() (*tls.Config, string, string, error) +} + // CertificateManager manages TLS certificates for the proxy type CertificateManager struct { caKey *rsa.PrivateKey