diff --git a/cli/cli.go b/cli/cli.go index 7d1a20c..1777e60 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -8,7 +8,7 @@ import ( "github.com/coder/boundary/config" "github.com/coder/boundary/log" - "github.com/coder/boundary/nsjail_manager" + "github.com/coder/boundary/run" "github.com/coder/serpent" ) @@ -119,9 +119,25 @@ func BaseCommand() *serpent.Command { Value: &cliConfig.ConfigureDNSForLocalStubResolver, YAML: "configure_dns_for_local_stub_resolver", }, + { + Flag: "jail-type", + Env: "BOUNDARY_JAIL_TYPE", + Description: "Jail type to use for network isolation. Options: nsjail (default), landjail.", + Default: "nsjail", + Value: &cliConfig.JailType, + YAML: "jail_type", + }, }, Handler: func(inv *serpent.Invocation) error { - appConfig := config.NewAppConfigFromCliConfig(cliConfig) + appConfig, err := config.NewAppConfigFromCliConfig(cliConfig, inv.Args) + if err != nil { + return fmt.Errorf("failed to parse cli config file: %v", err) + } + + // Get command arguments + if len(appConfig.TargetCMD) == 0 { + return fmt.Errorf("no command specified") + } logger, err := log.SetupLogging(appConfig) if err != nil { @@ -134,7 +150,7 @@ func BaseCommand() *serpent.Command { } logger.Debug("Application config", "config", appConfigInJSON) - return nsjail_manager.Run(inv.Context(), logger, appConfig, inv.Args) + return run.Run(inv.Context(), logger, appConfig) }, } } diff --git a/config/config.go b/config/config.go index 76176b4..716ef05 100644 --- a/config/config.go +++ b/config/config.go @@ -1,9 +1,30 @@ package config import ( + "fmt" + "github.com/coder/serpent" ) +// JailType represents the type of jail to use for network isolation +type JailType string + +const ( + NSJailType JailType = "nsjail" + LandjailType JailType = "landjail" +) + +func NewJailTypeFromString(str string) (JailType, error) { + switch str { + case "nsjail": + return NSJailType, nil + case "landjail": + return LandjailType, nil + default: + return NSJailType, fmt.Errorf("invalid JailType: %s", str) + } +} + type CliConfig struct { Config serpent.YAMLConfigPath `yaml:"-"` AllowListStrings serpent.StringArray `yaml:"allowlist"` // From config file @@ -14,6 +35,7 @@ type CliConfig struct { PprofEnabled serpent.Bool `yaml:"pprof_enabled"` PprofPort serpent.Int64 `yaml:"pprof_port"` ConfigureDNSForLocalStubResolver serpent.Bool `yaml:"configure_dns_for_local_stub_resolver"` + JailType serpent.String `yaml:"jail_type"` } type AppConfig struct { @@ -24,9 +46,12 @@ type AppConfig struct { PprofEnabled bool PprofPort int64 ConfigureDNSForLocalStubResolver bool + JailType JailType + TargetCMD []string + UserInfo *UserInfo } -func NewAppConfigFromCliConfig(cfg CliConfig) AppConfig { +func NewAppConfigFromCliConfig(cfg CliConfig, targetCMD []string) (AppConfig, error) { // Merge allowlist from config file with allow from CLI flags allowListStrings := cfg.AllowListStrings.Value() allowStrings := cfg.AllowStrings.Value() @@ -34,6 +59,13 @@ func NewAppConfigFromCliConfig(cfg CliConfig) AppConfig { // Combine allowlist (config file) with allow (CLI flags) allAllowStrings := append(allowListStrings, allowStrings...) + jailType, err := NewJailTypeFromString(cfg.JailType.Value()) + if err != nil { + return AppConfig{}, err + } + + userInfo := GetUserInfo() + return AppConfig{ AllowRules: allAllowStrings, LogLevel: cfg.LogLevel.Value(), @@ -42,5 +74,8 @@ func NewAppConfigFromCliConfig(cfg CliConfig) AppConfig { PprofEnabled: cfg.PprofEnabled.Value(), PprofPort: cfg.PprofPort.Value(), ConfigureDNSForLocalStubResolver: cfg.ConfigureDNSForLocalStubResolver.Value(), - } + JailType: jailType, + TargetCMD: targetCMD, + UserInfo: userInfo, + }, nil } diff --git a/util/user.go b/config/user_info.go similarity index 71% rename from util/user.go rename to config/user_info.go index dcb15bb..4ce8125 100644 --- a/util/user.go +++ b/config/user_info.go @@ -1,4 +1,4 @@ -package util +package config import ( "os" @@ -7,8 +7,21 @@ import ( "strconv" ) +const ( + CAKeyName = "ca-key.pem" + CACertName = "ca-cert.pem" +) + +type UserInfo struct { + SudoUser string + Uid int + Gid int + HomeDir string + ConfigDir string +} + // GetUserInfo returns information about the current user, handling sudo scenarios -func GetUserInfo() (string, int, int, string, string) { +func GetUserInfo() *UserInfo { // Only consider SUDO_USER if we're actually running with elevated privileges // In environments like Coder workspaces, SUDO_USER may be set to 'root' // but we're not actually running under sudo @@ -36,7 +49,13 @@ func GetUserInfo() (string, int, int, string, string) { configDir := getConfigDir(user.HomeDir) - return sudoUser, uid, gid, user.HomeDir, configDir + return &UserInfo{ + SudoUser: sudoUser, + Uid: uid, + Gid: gid, + HomeDir: user.HomeDir, + ConfigDir: configDir, + } } // Not actually running under sudo, use current user @@ -44,11 +63,11 @@ func GetUserInfo() (string, int, int, string, string) { } // getCurrentUserInfo gets information for the current user -func getCurrentUserInfo() (string, int, int, string, string) { +func getCurrentUserInfo() *UserInfo { currentUser, err := user.Current() if err != nil { // Fallback with empty values if we can't get user info - return "", 0, 0, "", "" + return &UserInfo{} } uid, _ := strconv.Atoi(currentUser.Uid) @@ -56,7 +75,13 @@ func getCurrentUserInfo() (string, int, int, string, string) { configDir := getConfigDir(currentUser.HomeDir) - return currentUser.Username, uid, gid, currentUser.HomeDir, configDir + return &UserInfo{ + SudoUser: currentUser.Username, + Uid: uid, + Gid: gid, + HomeDir: currentUser.HomeDir, + ConfigDir: configDir, + } } // getConfigDir determines the config directory based on XDG_CONFIG_HOME or fallback @@ -67,3 +92,11 @@ func getConfigDir(homeDir string) string { } return filepath.Join(homeDir, ".config", "coder_boundary") } + +func (u *UserInfo) CAKeyPath() string { + return filepath.Join(u.ConfigDir, CAKeyName) +} + +func (u *UserInfo) CACertPath() string { + return filepath.Join(u.ConfigDir, CACertName) +} diff --git a/e2e_tests/boundary_test.go b/e2e_tests/boundary_test.go deleted file mode 100644 index 4f6236d..0000000 --- a/e2e_tests/boundary_test.go +++ /dev/null @@ -1,241 +0,0 @@ -package e2e_tests - -import ( - "bytes" - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strconv" - "strings" - "testing" - "time" - - "github.com/coder/boundary/util" - "github.com/stretchr/testify/require" -) - -// BoundaryTest is a high-level test framework for boundary e2e tests -type BoundaryTest struct { - t *testing.T - projectRoot string - binaryPath string - allowedDomains []string - logLevel string - cmd *exec.Cmd - pid int - startupDelay time.Duration -} - -// BoundaryTestOption is a function that configures BoundaryTest -type BoundaryTestOption func(*BoundaryTest) - -// NewBoundaryTest creates a new BoundaryTest instance -func NewBoundaryTest(t *testing.T, opts ...BoundaryTestOption) *BoundaryTest { - projectRoot := findProjectRoot(t) - binaryPath := "/tmp/boundary-test" - - bt := &BoundaryTest{ - t: t, - projectRoot: projectRoot, - binaryPath: binaryPath, - allowedDomains: []string{}, - logLevel: "warn", - startupDelay: 2 * time.Second, - } - - // Apply options - for _, opt := range opts { - opt(bt) - } - - return bt -} - -// WithAllowedDomain adds an allowed domain rule -func WithAllowedDomain(domain string) BoundaryTestOption { - return func(bt *BoundaryTest) { - bt.allowedDomains = append(bt.allowedDomains, fmt.Sprintf("domain=%s", domain)) - } -} - -// WithAllowedRule adds a full allow rule (e.g., "method=GET domain=example.com path=/api/*") -func WithAllowedRule(rule string) BoundaryTestOption { - return func(bt *BoundaryTest) { - bt.allowedDomains = append(bt.allowedDomains, rule) - } -} - -// WithLogLevel sets the log level -func WithLogLevel(level string) BoundaryTestOption { - return func(bt *BoundaryTest) { - bt.logLevel = level - } -} - -// WithStartupDelay sets how long to wait after starting boundary before making requests -func WithStartupDelay(delay time.Duration) BoundaryTestOption { - return func(bt *BoundaryTest) { - bt.startupDelay = delay - } -} - -// Build builds the boundary binary -func (bt *BoundaryTest) Build() *BoundaryTest { - buildCmd := exec.Command("go", "build", "-o", bt.binaryPath, "./cmd/...") - buildCmd.Dir = bt.projectRoot - err := buildCmd.Run() - require.NoError(bt.t, err, "Failed to build boundary binary") - return bt -} - -// Start starts the boundary process with a long-running command -func (bt *BoundaryTest) Start(command ...string) *BoundaryTest { - if len(command) == 0 { - // Default: sleep for a long time to keep the process alive - command = []string{"/bin/bash", "-c", "/usr/bin/sleep 100 && /usr/bin/echo 'Root boundary process exited'"} - } - - // Build command args - args := []string{ - "--log-level", bt.logLevel, - } - for _, domain := range bt.allowedDomains { - args = append(args, "--allow", domain) - } - args = append(args, "--") - args = append(args, command...) - - bt.cmd = exec.Command(bt.binaryPath, args...) - bt.cmd.Stdin = os.Stdin - - stdout, _ := bt.cmd.StdoutPipe() - stderr, _ := bt.cmd.StderrPipe() - go io.Copy(os.Stdout, stdout) //nolint:errcheck - go io.Copy(os.Stderr, stderr) //nolint:errcheck - - err := bt.cmd.Start() - require.NoError(bt.t, err, "Failed to start boundary process") - - // Wait for boundary to start - time.Sleep(bt.startupDelay) - - // Get the child process PID - bt.pid = getTargetProcessPID(bt.t) - - return bt -} - -// Stop gracefully stops the boundary process -func (bt *BoundaryTest) Stop() { - if bt.cmd == nil || bt.cmd.Process == nil { - return - } - - // Send interrupt signal - err := bt.cmd.Process.Signal(os.Interrupt) - if err != nil { - bt.t.Logf("Failed to interrupt boundary process: %v", err) - } - - time.Sleep(1 * time.Second) - - // Wait for process to finish - if bt.cmd != nil { - err = bt.cmd.Wait() - if err != nil { - bt.t.Logf("Boundary process finished with error: %v", err) - } - } - - // Clean up binary - err = os.Remove(bt.binaryPath) - if err != nil { - bt.t.Logf("Failed to remove boundary binary: %v", err) - } -} - -// ExpectAllowed makes an HTTP/HTTPS request and expects it to be allowed with the given response body -func (bt *BoundaryTest) ExpectAllowed(url string, expectedBody string) { - bt.t.Helper() - output := bt.makeRequest(url) - require.Equal(bt.t, expectedBody, string(output), "Expected response body does not match") -} - -// ExpectAllowedContains makes an HTTP/HTTPS request and expects it to be allowed, checking that response contains the given text -func (bt *BoundaryTest) ExpectAllowedContains(url string, containsText string) { - bt.t.Helper() - output := bt.makeRequest(url) - require.Contains(bt.t, string(output), containsText, "Response does not contain expected text") -} - -// ExpectDeny makes an HTTP/HTTPS request and expects it to be denied -func (bt *BoundaryTest) ExpectDeny(url string) { - bt.t.Helper() - output := bt.makeRequest(url) - require.Contains(bt.t, string(output), "Request Blocked by Boundary", "Expected request to be blocked") -} - -// makeRequest makes an HTTP/HTTPS request from inside the namespace -// Always sets SSL_CERT_FILE for HTTPS support (harmless for HTTP requests) -func (bt *BoundaryTest) makeRequest(url string) []byte { - bt.t.Helper() - - pid := fmt.Sprintf("%v", bt.pid) - _, _, _, _, configDir := util.GetUserInfo() - certPath := fmt.Sprintf("%v/ca-cert.pem", configDir) - - args := []string{"nsenter", "-t", pid, "-n", "--", - "env", fmt.Sprintf("SSL_CERT_FILE=%v", certPath), "curl", "-sS", url} - - curlCmd := exec.Command("sudo", args...) - - var stderr bytes.Buffer - curlCmd.Stderr = &stderr - output, err := curlCmd.Output() - - if err != nil { - bt.t.Fatalf("curl command failed: %v, stderr: %s, output: %s", err, stderr.String(), string(output)) - } - - return output -} - -// getTargetProcessPID gets the PID of the boundary target process. -// Target process is associated with a network namespace, so you can exec into it, using this PID. -// pgrep -f boundary-test -n is doing two things: -// -f = match against the full command line -// -n = return the newest (most recently started) matching process -func getTargetProcessPID(t *testing.T) int { - cmd := exec.Command("pgrep", "-f", "boundary-test", "-n") - output, err := cmd.Output() - require.NoError(t, err) - - pidStr := strings.TrimSpace(string(output)) - pid, err := strconv.Atoi(pidStr) - require.NoError(t, err) - return pid -} - -// findProjectRoot finds the project root by looking for go.mod file -func findProjectRoot(t *testing.T) string { - cwd, err := os.Getwd() - require.NoError(t, err, "Failed to get current working directory") - - // Start from current directory and walk up until we find go.mod - dir := cwd - for { - goModPath := filepath.Join(dir, "go.mod") - if _, err := os.Stat(goModPath); err == nil { - return dir - } - - parent := filepath.Dir(dir) - if parent == dir { - // Reached filesystem root - t.Fatalf("Could not find go.mod file starting from %s", cwd) - } - dir = parent - } -} diff --git a/e2e_tests/iptables_cleanup_test.go b/e2e_tests/iptables_cleanup_test.go index bb80900..ec65520 100644 --- a/e2e_tests/iptables_cleanup_test.go +++ b/e2e_tests/iptables_cleanup_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/coder/boundary/e2e_tests/util" "github.com/stretchr/testify/require" ) @@ -36,7 +37,7 @@ func TestIPTablesCleanup(t *testing.T) { // Step 2: Run Boundary // Find project root by looking for go.mod file - projectRoot := findProjectRoot(t) + projectRoot := util.FindProjectRoot(t) // Build the boundary binary buildCmd := exec.Command("go", "build", "-o", "/tmp/boundary-test", "./cmd/...") diff --git a/e2e_tests/landjail/landjail_framework_test.go b/e2e_tests/landjail/landjail_framework_test.go new file mode 100644 index 0000000..2de8e44 --- /dev/null +++ b/e2e_tests/landjail/landjail_framework_test.go @@ -0,0 +1,258 @@ +package landjail + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "testing" + "time" + + "github.com/coder/boundary/config" + "github.com/coder/boundary/e2e_tests/util" + "github.com/stretchr/testify/require" +) + +// LandjailTest is a high-level test framework for boundary e2e tests using landjail +type LandjailTest struct { + t *testing.T + projectRoot string + binaryPath string + allowedDomains []string + logLevel string + cmd *exec.Cmd + startupDelay time.Duration + // Pipes to communicate with the bash process + bashStdin io.WriteCloser + bashStdout io.ReadCloser + bashStderr io.ReadCloser +} + +// LandjailTestOption is a function that configures LandjailTest +type LandjailTestOption func(*LandjailTest) + +// NewLandjailTest creates a new LandjailTest instance +func NewLandjailTest(t *testing.T, opts ...LandjailTestOption) *LandjailTest { + projectRoot := util.FindProjectRoot(t) + binaryPath := "/tmp/boundary-landjail-test" + + lt := &LandjailTest{ + t: t, + projectRoot: projectRoot, + binaryPath: binaryPath, + allowedDomains: []string{}, + logLevel: "warn", + startupDelay: 2 * time.Second, + } + + // Apply options + for _, opt := range opts { + opt(lt) + } + + return lt +} + +// WithAllowedDomain adds an allowed domain rule +func WithLandjailAllowedDomain(domain string) LandjailTestOption { + return func(lt *LandjailTest) { + lt.allowedDomains = append(lt.allowedDomains, fmt.Sprintf("domain=%s", domain)) + } +} + +// WithAllowedRule adds a full allow rule (e.g., "method=GET domain=example.com path=/api/*") +func WithLandjailAllowedRule(rule string) LandjailTestOption { + return func(lt *LandjailTest) { + lt.allowedDomains = append(lt.allowedDomains, rule) + } +} + +// WithLogLevel sets the log level +func WithLandjailLogLevel(level string) LandjailTestOption { + return func(lt *LandjailTest) { + lt.logLevel = level + } +} + +// WithStartupDelay sets how long to wait after starting boundary before making requests +func WithLandjailStartupDelay(delay time.Duration) LandjailTestOption { + return func(lt *LandjailTest) { + lt.startupDelay = delay + } +} + +// Build builds the boundary binary +func (lt *LandjailTest) Build() *LandjailTest { + buildCmd := exec.Command("go", "build", "-o", lt.binaryPath, "./cmd/...") + buildCmd.Dir = lt.projectRoot + err := buildCmd.Run() + require.NoError(lt.t, err, "Failed to build boundary binary") + return lt +} + +// Start starts the boundary process with a bash process that reads commands from stdin +func (lt *LandjailTest) Start(command ...string) *LandjailTest { + // Build command args + args := []string{ + "--log-level", lt.logLevel, + "--jail-type", "landjail", + } + for _, domain := range lt.allowedDomains { + args = append(args, "--allow", domain) + } + args = append(args, "--") + + // Bash command that reads and executes commands from stdin + // Each command should end with a newline, and we use a marker to detect completion + // Using a unique marker to avoid conflicts with command output + if len(command) == 0 { + command = []string{"/bin/bash", "-c", "while IFS= read -r cmd; do if [ \"$cmd\" = \"exit\" ]; then exit 0; fi; eval \"$cmd\"; echo \"__BOUNDARY_CMD_DONE__\"; done"} + } + args = append(args, command...) + + lt.cmd = exec.Command(lt.binaryPath, args...) + + // Capture pipes for communication with bash + var err error + lt.bashStdin, err = lt.cmd.StdinPipe() + require.NoError(lt.t, err, "Failed to create stdin pipe for landjail") + + lt.bashStdout, err = lt.cmd.StdoutPipe() + require.NoError(lt.t, err, "Failed to create stdout pipe for landjail") + + lt.bashStderr, err = lt.cmd.StderrPipe() + require.NoError(lt.t, err, "Failed to create stderr pipe for landjail") + + // Forward stderr to os.Stderr for debugging + go io.Copy(os.Stderr, lt.bashStderr) //nolint:errcheck + + err = lt.cmd.Start() + require.NoError(lt.t, err, "Failed to start boundary process with landjail") + + // Wait for boundary to start + time.Sleep(lt.startupDelay) + + return lt +} + +// Stop gracefully stops the boundary process +func (lt *LandjailTest) Stop() { + if lt.cmd == nil || lt.cmd.Process == nil { + return + } + + // Send "exit" command to bash, then close stdin + if lt.bashStdin != nil { + _, _ = lt.bashStdin.Write([]byte("exit\n")) + lt.bashStdin.Close() //nolint:errcheck + } + + time.Sleep(1 * time.Second) + + // Wait for process to finish + if lt.cmd != nil { + err := lt.cmd.Wait() + if err != nil { + lt.t.Logf("Boundary process finished with error: %v", err) + } + } + + // Close pipes if they're still open + if lt.bashStdout != nil { + lt.bashStdout.Close() //nolint:errcheck + } + if lt.bashStderr != nil { + lt.bashStderr.Close() //nolint:errcheck + } + + // Clean up binary + err := os.Remove(lt.binaryPath) + if err != nil { + lt.t.Logf("Failed to remove boundary binary: %v", err) + } +} + +// ExpectAllowed makes an HTTP/HTTPS request and expects it to be allowed with the given response body +func (lt *LandjailTest) ExpectAllowed(url string, expectedBody string) { + lt.t.Helper() + output := lt.makeRequest(url) + require.Equal(lt.t, expectedBody, string(output), "Expected response body does not match") +} + +// ExpectAllowedContains makes an HTTP/HTTPS request and expects it to be allowed, checking that response contains the given text +func (lt *LandjailTest) ExpectAllowedContains(url string, containsText string) { + lt.t.Helper() + output := lt.makeRequest(url) + require.Contains(lt.t, string(output), containsText, "Response does not contain expected text") +} + +// ExpectDeny makes an HTTP/HTTPS request and expects it to be denied +func (lt *LandjailTest) ExpectDeny(url string) { + lt.t.Helper() + output := lt.makeRequest(url) + require.Contains(lt.t, string(output), "Request Blocked by Boundary", "Expected request to be blocked") +} + +// ExpectDenyContains makes an HTTP/HTTPS request and expects it to be denied, checking that response contains the given text +func (lt *LandjailTest) ExpectDenyContains(url string, containsText string) { + lt.t.Helper() + output := lt.makeRequest(url) + require.Contains(lt.t, string(output), containsText, "Response does not contain expected denial text") +} + +// makeRequest executes a curl command in the landjail bash process +// Always sets SSL_CERT_FILE for HTTPS support (harmless for HTTP requests) +func (lt *LandjailTest) makeRequest(url string) []byte { + lt.t.Helper() + + if lt.bashStdin == nil || lt.bashStdout == nil { + lt.t.Fatalf("landjail pipes not initialized") + } + + userInfo := config.GetUserInfo() + proxyURL := fmt.Sprintf("http://localhost:%d", 8080) // Default proxy port + + // Build curl command with SSL_CERT_FILE and proxy environment variables + curlCmd := fmt.Sprintf("env SSL_CERT_FILE=%s HTTP_PROXY=%s HTTPS_PROXY=%s http_proxy=%s https_proxy=%s curl -sS %s\n", + userInfo.CACertPath(), proxyURL, proxyURL, proxyURL, proxyURL, url) + + // Write command to stdin + _, err := lt.bashStdin.Write([]byte(curlCmd)) + require.NoError(lt.t, err, "Failed to write command to landjail stdin") + + // Read output until we see the completion marker + var output bytes.Buffer + doneMarker := []byte("__BOUNDARY_CMD_DONE__") + buf := make([]byte, 4096) + + for { + n, err := lt.bashStdout.Read(buf) + if n > 0 { + // Check if we've received the completion marker + data := buf[:n] + if idx := bytes.Index(data, doneMarker); idx != -1 { + // Found the marker, add everything before it to output + output.Write(data[:idx]) + // Remove the marker and newline + remaining := data[idx+len(doneMarker):] + if len(remaining) > 0 && remaining[0] == '\n' { + remaining = remaining[1:] + } + if len(remaining) > 0 { + output.Write(remaining) + } + break + } + output.Write(data) + } + if err == io.EOF { + break + } + if err != nil { + lt.t.Fatalf("Failed to read from landjail stdout: %v", err) + } + } + + return output.Bytes() +} diff --git a/e2e_tests/landjail/landjail_test.go b/e2e_tests/landjail/landjail_test.go new file mode 100644 index 0000000..900a85e --- /dev/null +++ b/e2e_tests/landjail/landjail_test.go @@ -0,0 +1,47 @@ +package landjail + +import ( + "testing" +) + +func TestLandjail(t *testing.T) { + // Create and configure landjail test + lt := NewLandjailTest(t, + WithLandjailAllowedDomain("dev.coder.com"), + WithLandjailAllowedDomain("jsonplaceholder.typicode.com"), + WithLandjailLogLevel("debug"), + ). + Build(). + Start() + + // Ensure cleanup + defer lt.Stop() + + // Test allowed HTTP request + t.Run("HTTPRequestThroughBoundary", func(t *testing.T) { + expectedResponse := `{ + "userId": 1, + "id": 1, + "title": "delectus aut autem", + "completed": false +}` + lt.ExpectAllowed("http://jsonplaceholder.typicode.com/todos/1", expectedResponse) + }) + + // Test allowed HTTPS request + // t.Run("HTTPSRequestThroughBoundary", func(t *testing.T) { + // expectedResponse := `{"message":"👋"} + //` + // lt.ExpectAllowed("https://dev.coder.com/api/v2", expectedResponse) + // }) + + // Test blocked HTTP request + t.Run("HTTPBlockedDomainTest", func(t *testing.T) { + lt.ExpectDeny("http://example.com") + }) + + // Test blocked HTTPS request + //t.Run("HTTPSBlockedDomainTest", func(t *testing.T) { + // lt.ExpectDeny("https://example.com") + //}) +} diff --git a/e2e_tests/nsjail/ns_jail_framework_test.go b/e2e_tests/nsjail/ns_jail_framework_test.go new file mode 100644 index 0000000..e1181a2 --- /dev/null +++ b/e2e_tests/nsjail/ns_jail_framework_test.go @@ -0,0 +1,226 @@ +package nsjail + +import ( + "bytes" + "fmt" + "io" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" + + "github.com/coder/boundary/config" + "github.com/coder/boundary/e2e_tests/util" + "github.com/stretchr/testify/require" +) + +// NSJailTest is a high-level test framework for boundary e2e tests using nsjail +type NSJailTest struct { + t *testing.T + projectRoot string + binaryPath string + allowedDomains []string + logLevel string + cmd *exec.Cmd + pid int + startupDelay time.Duration +} + +// NSJailTestOption is a function that configures NSJailTest +type NSJailTestOption func(*NSJailTest) + +// NewNSJailTest creates a new NSJailTest instance +func NewNSJailTest(t *testing.T, opts ...NSJailTestOption) *NSJailTest { + projectRoot := util.FindProjectRoot(t) + binaryPath := "/tmp/boundary-test" + + nt := &NSJailTest{ + t: t, + projectRoot: projectRoot, + binaryPath: binaryPath, + allowedDomains: []string{}, + logLevel: "warn", + startupDelay: 2 * time.Second, + } + + // Apply options + for _, opt := range opts { + opt(nt) + } + + return nt +} + +// WithNSJailAllowedDomain adds an allowed domain rule +func WithNSJailAllowedDomain(domain string) NSJailTestOption { + return func(nt *NSJailTest) { + nt.allowedDomains = append(nt.allowedDomains, fmt.Sprintf("domain=%s", domain)) + } +} + +// WithNSJailAllowedRule adds a full allow rule (e.g., "method=GET domain=example.com path=/api/*") +func WithNSJailAllowedRule(rule string) NSJailTestOption { + return func(nt *NSJailTest) { + nt.allowedDomains = append(nt.allowedDomains, rule) + } +} + +// WithNSJailLogLevel sets the log level +func WithNSJailLogLevel(level string) NSJailTestOption { + return func(nt *NSJailTest) { + nt.logLevel = level + } +} + +// WithNSJailStartupDelay sets how long to wait after starting boundary before making requests +func WithNSJailStartupDelay(delay time.Duration) NSJailTestOption { + return func(nt *NSJailTest) { + nt.startupDelay = delay + } +} + +// Build builds the boundary binary +func (nt *NSJailTest) Build() *NSJailTest { + buildCmd := exec.Command("go", "build", "-o", nt.binaryPath, "./cmd/...") + buildCmd.Dir = nt.projectRoot + err := buildCmd.Run() + require.NoError(nt.t, err, "Failed to build boundary binary") + return nt +} + +// Start starts the boundary process with a long-running command +func (nt *NSJailTest) Start(command ...string) *NSJailTest { + if len(command) == 0 { + // Default: sleep for a long time to keep the process alive + command = []string{"/bin/bash", "-c", "/usr/bin/sleep 100 && /usr/bin/echo 'Root boundary process exited'"} + } + + // Build command args + args := []string{ + "--log-level", nt.logLevel, + "--jail-type", "nsjail", + } + for _, domain := range nt.allowedDomains { + args = append(args, "--allow", domain) + } + args = append(args, "--") + args = append(args, command...) + + nt.cmd = exec.Command(nt.binaryPath, args...) + nt.cmd.Stdin = os.Stdin + + stdout, _ := nt.cmd.StdoutPipe() + stderr, _ := nt.cmd.StderrPipe() + go io.Copy(os.Stdout, stdout) //nolint:errcheck + go io.Copy(os.Stderr, stderr) //nolint:errcheck + + err := nt.cmd.Start() + require.NoError(nt.t, err, "Failed to start boundary process") + + // Wait for boundary to start + time.Sleep(nt.startupDelay) + + // Get the child process PID + nt.pid = getTargetProcessPID(nt.t) + + return nt +} + +// Stop gracefully stops the boundary process +func (nt *NSJailTest) Stop() { + if nt.cmd == nil || nt.cmd.Process == nil { + return + } + + // Send interrupt signal + err := nt.cmd.Process.Signal(os.Interrupt) + if err != nil { + nt.t.Logf("Failed to interrupt boundary process: %v", err) + } + + time.Sleep(1 * time.Second) + + // Wait for process to finish + if nt.cmd != nil { + err = nt.cmd.Wait() + if err != nil { + nt.t.Logf("Boundary process finished with error: %v", err) + } + } + + // Clean up binary + err = os.Remove(nt.binaryPath) + if err != nil { + nt.t.Logf("Failed to remove boundary binary: %v", err) + } +} + +// ExpectAllowed makes an HTTP/HTTPS request and expects it to be allowed with the given response body +func (nt *NSJailTest) ExpectAllowed(url string, expectedBody string) { + nt.t.Helper() + output := nt.makeRequest(url) + require.Equal(nt.t, expectedBody, string(output), "Expected response body does not match") +} + +// ExpectAllowedContains makes an HTTP/HTTPS request and expects it to be allowed, checking that response contains the given text +func (nt *NSJailTest) ExpectAllowedContains(url string, containsText string) { + nt.t.Helper() + output := nt.makeRequest(url) + require.Contains(nt.t, string(output), containsText, "Response does not contain expected text") +} + +// ExpectDeny makes an HTTP/HTTPS request and expects it to be denied +func (nt *NSJailTest) ExpectDeny(url string) { + nt.t.Helper() + output := nt.makeRequest(url) + require.Contains(nt.t, string(output), "Request Blocked by Boundary", "Expected request to be blocked") +} + +// makeRequest makes an HTTP/HTTPS request from inside the namespace +// Always sets SSL_CERT_FILE for HTTPS support (harmless for HTTP requests) +func (nt *NSJailTest) makeRequest(url string) []byte { + nt.t.Helper() + + pid := fmt.Sprintf("%v", nt.pid) + userInfo := config.GetUserInfo() + + args := []string{"nsenter", "-t", pid, "-n", "--", + "env", fmt.Sprintf("SSL_CERT_FILE=%v", userInfo.CACertPath()), "curl", "-sS", url} + + curlCmd := exec.Command("sudo", args...) + + var stderr bytes.Buffer + curlCmd.Stderr = &stderr + output, err := curlCmd.Output() + + if err != nil { + nt.t.Fatalf("curl command failed: %v, stderr: %s, output: %s", err, stderr.String(), string(output)) + } + + return output +} + +// ExpectDenyContains makes an HTTP/HTTPS request and expects it to be denied, checking that response contains the given text +func (nt *NSJailTest) ExpectDenyContains(url string, containsText string) { + nt.t.Helper() + output := nt.makeRequest(url) + require.Contains(nt.t, string(output), containsText, "Response does not contain expected denial text") +} + +// getTargetProcessPID gets the PID of the boundary target process. +// Target process is associated with a network namespace, so you can exec into it, using this PID. +// pgrep -f boundary-test -n is doing two things: +// -f = match against the full command line +// -n = return the newest (most recently started) matching process +func getTargetProcessPID(t *testing.T) int { + cmd := exec.Command("pgrep", "-f", "boundary-test", "-n") + output, err := cmd.Output() + require.NoError(t, err) + + pidStr := strings.TrimSpace(string(output)) + pid, err := strconv.Atoi(pidStr) + require.NoError(t, err) + return pid +} diff --git a/e2e_tests/e2e_boundary_test.go b/e2e_tests/nsjail/ns_jail_test.go similarity index 56% rename from e2e_tests/e2e_boundary_test.go rename to e2e_tests/nsjail/ns_jail_test.go index 7d2e6ac..befa293 100644 --- a/e2e_tests/e2e_boundary_test.go +++ b/e2e_tests/nsjail/ns_jail_test.go @@ -1,19 +1,19 @@ -package e2e_tests +package nsjail import "testing" -func TestE2EBoundary(t *testing.T) { - // Create and configure boundary test - bt := NewBoundaryTest(t, - WithAllowedDomain("dev.coder.com"), - WithAllowedDomain("jsonplaceholder.typicode.com"), - WithLogLevel("debug"), +func TestNamespaceJail(t *testing.T) { + // Create and configure nsjail test + nt := NewNSJailTest(t, + WithNSJailAllowedDomain("dev.coder.com"), + WithNSJailAllowedDomain("jsonplaceholder.typicode.com"), + WithNSJailLogLevel("debug"), ). Build(). Start() // Ensure cleanup - defer bt.Stop() + defer nt.Stop() // Test allowed HTTP request t.Run("HTTPRequestThroughBoundary", func(t *testing.T) { @@ -23,23 +23,23 @@ func TestE2EBoundary(t *testing.T) { "title": "delectus aut autem", "completed": false }` - bt.ExpectAllowed("http://jsonplaceholder.typicode.com/todos/1", expectedResponse) + nt.ExpectAllowed("http://jsonplaceholder.typicode.com/todos/1", expectedResponse) }) // Test allowed HTTPS request t.Run("HTTPSRequestThroughBoundary", func(t *testing.T) { expectedResponse := `{"message":"👋"} ` - bt.ExpectAllowed("https://dev.coder.com/api/v2", expectedResponse) + nt.ExpectAllowed("https://dev.coder.com/api/v2", expectedResponse) }) // Test blocked HTTP request t.Run("HTTPBlockedDomainTest", func(t *testing.T) { - bt.ExpectDeny("http://example.com") + nt.ExpectDeny("http://example.com") }) // Test blocked HTTPS request t.Run("HTTPSBlockedDomainTest", func(t *testing.T) { - bt.ExpectDeny("https://example.com") + nt.ExpectDeny("https://example.com") }) } diff --git a/e2e_tests/util/util.go b/e2e_tests/util/util.go new file mode 100644 index 0000000..c1b5639 --- /dev/null +++ b/e2e_tests/util/util.go @@ -0,0 +1,31 @@ +package util + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +// FindProjectRoot finds the project root by looking for go.mod file +func FindProjectRoot(t *testing.T) string { + cwd, err := os.Getwd() + require.NoError(t, err, "Failed to get current working directory") + + // Start from current directory and walk up until we find go.mod + dir := cwd + for { + goModPath := filepath.Join(dir, "go.mod") + if _, err := os.Stat(goModPath); err == nil { + return dir + } + + parent := filepath.Dir(dir) + if parent == dir { + // Reached filesystem root + t.Fatalf("Could not find go.mod file starting from %s", cwd) + } + dir = parent + } +} diff --git a/landjail/child.go b/landjail/child.go new file mode 100644 index 0000000..9fa547a --- /dev/null +++ b/landjail/child.go @@ -0,0 +1,101 @@ +package landjail + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + + "github.com/coder/boundary/config" + "github.com/coder/boundary/util" + "github.com/landlock-lsm/go-landlock/landlock" +) + +type LandlockConfig struct { + // TODO(yevhenii): + // - should it be able to bind to any port? + // - should it be able to connect to any port on localhost? + // BindTCPPorts []int + ConnectTCPPorts []int +} + +func ApplyLandlockRestrictions(logger *slog.Logger, cfg LandlockConfig) error { + // Get the Landlock version which works for Kernel 6.7+ + llCfg := landlock.V4 + + // Collect our rules + var netRules []landlock.Rule + + // Add rules for TCP connections + for _, port := range cfg.ConnectTCPPorts { + logger.Debug("Adding TCP connect port", "port", port) + netRules = append(netRules, landlock.ConnectTCP(uint16(port))) + } + + err := llCfg.RestrictNet(netRules...) + if err != nil { + return fmt.Errorf("failed to apply Landlock network restrictions: %w", err) + } + + return nil +} + +func RunChild(logger *slog.Logger, config config.AppConfig) error { + landjailCfg := LandlockConfig{ + ConnectTCPPorts: []int{int(config.ProxyPort)}, + } + + err := ApplyLandlockRestrictions(logger, landjailCfg) + if err != nil { + return fmt.Errorf("failed to apply Landlock network restrictions: %v", err) + } + + // Build command + cmd := exec.Command(config.TargetCMD[0], config.TargetCMD[1:]...) + cmd.Env = getEnvsForTargetProcess(config.UserInfo.ConfigDir, config.UserInfo.CACertPath(), int(config.ProxyPort)) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + logger.Info("Executing target command", "command", config.TargetCMD) + + // Run the command - this will block until it completes + err = cmd.Run() + if err != nil { + // Check if this is a normal exit with non-zero status code + if exitError, ok := err.(*exec.ExitError); ok { + exitCode := exitError.ExitCode() + logger.Debug("Command exited with non-zero status", "exit_code", exitCode) + return fmt.Errorf("command exited with code %d", exitCode) + } + // This is an unexpected error + logger.Error("Command execution failed", "error", err) + return fmt.Errorf("command execution failed: %v", err) + } + + logger.Debug("Command completed successfully") + return nil +} + +func getEnvsForTargetProcess(configDir string, caCertPath string, httpProxyPort int) []string { + e := os.Environ() + + proxyAddr := fmt.Sprintf("http://localhost:%d", httpProxyPort) + e = util.MergeEnvs(e, map[string]string{ + //Set standard CA certificate environment variables for common tools + //This makes tools like curl, git, etc. trust our dynamically generated CA + "SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools + "SSL_CERT_DIR": configDir, // OpenSSL certificate directory + "CURL_CA_BUNDLE": caCertPath, // curl + "GIT_SSL_CAINFO": caCertPath, // Git + "REQUESTS_CA_BUNDLE": caCertPath, // Python requests + "NODE_EXTRA_CA_CERTS": caCertPath, // Node.js + + "HTTP_PROXY": proxyAddr, + "HTTPS_PROXY": proxyAddr, + "http_proxy": proxyAddr, + "https_proxy": proxyAddr, + }) + + return e +} diff --git a/landjail/manager.go b/landjail/manager.go new file mode 100644 index 0000000..4a22ccf --- /dev/null +++ b/landjail/manager.go @@ -0,0 +1,163 @@ +package landjail + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/coder/boundary/proxy" + "github.com/coder/boundary/rulesengine" +) + +type LandJail struct { + proxyServer *proxy.Server + logger *slog.Logger + config config.AppConfig +} + +func NewLandJail( + ruleEngine rulesengine.Engine, + auditor audit.Auditor, + tlsConfig *tls.Config, + logger *slog.Logger, + config config.AppConfig, +) (*LandJail, error) { + // Create proxy server + proxyServer := proxy.NewProxyServer(proxy.Config{ + HTTPPort: int(config.ProxyPort), + RuleEngine: ruleEngine, + Auditor: auditor, + Logger: logger, + TLSConfig: tlsConfig, + PprofEnabled: config.PprofEnabled, + PprofPort: int(config.PprofPort), + }) + + return &LandJail{ + config: config, + proxyServer: proxyServer, + logger: logger, + }, nil +} + +func (b *LandJail) Run(ctx context.Context) error { + b.logger.Info("Start landjail manager") + err := b.startProxy() + if err != nil { + return fmt.Errorf("failed to start landjail manager: %v", err) + } + + defer func() { + b.logger.Info("Stop landjail manager") + err := b.stopProxy() + if err != nil { + b.logger.Error("Failed to stop landjail manager", "error", err) + } + }() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + defer cancel() + err := b.RunChildProcess(os.Args) + if err != nil { + b.logger.Error("Failed to run child process", "error", err) + } + }() + + // Setup signal handling BEFORE any setup + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Wait for signal or context cancellation + select { + case sig := <-sigChan: + b.logger.Info("Received signal, shutting down...", "signal", sig) + cancel() + case <-ctx.Done(): + // Context canceled by command completion + b.logger.Info("Command completed, shutting down...") + } + + return nil +} + +func (b *LandJail) RunChildProcess(command []string) error { + childCmd := b.getChildCommand(command) + + b.logger.Debug("Executing command in boundary", "command", strings.Join(os.Args, " ")) + err := childCmd.Start() + if err != nil { + b.logger.Error("Command failed to start", "error", err) + return err + } + + b.logger.Debug("waiting on a child process to finish") + err = childCmd.Wait() + if err != nil { + // Check if this is a normal exit with non-zero status code + if exitError, ok := err.(*exec.ExitError); ok { + exitCode := exitError.ExitCode() + // Log at debug level for non-zero exits (normal behavior) + b.logger.Debug("Command exited with non-zero status", "exit_code", exitCode) + return err + } + + // This is an unexpected error (not just a non-zero exit) + b.logger.Error("Command execution failed", "error", err) + return err + } + b.logger.Debug("Command completed successfully") + + return nil +} + +func (b *LandJail) getChildCommand(command []string) *exec.Cmd { + cmd := exec.Command(command[0], command[1:]...) + cmd.Env = append(cmd.Env, "CHILD=true") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cmd.Stdin = os.Stdin + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, + } + + return cmd +} + +func (b *LandJail) startProxy() error { + // Start proxy server in background + err := b.proxyServer.Start() + if err != nil { + b.logger.Error("Proxy server error", "error", err) + return err + } + + // Give proxy time to start + time.Sleep(100 * time.Millisecond) + + return nil +} + +func (b *LandJail) stopProxy() error { + // Stop proxy server + if b.proxyServer != nil { + err := b.proxyServer.Stop() + if err != nil { + b.logger.Error("Failed to stop proxy server", "error", err) + } + } + + return nil +} diff --git a/landjail/parent.go b/landjail/parent.go new file mode 100644 index 0000000..1f67e14 --- /dev/null +++ b/landjail/parent.go @@ -0,0 +1,56 @@ +package landjail + +import ( + "context" + "fmt" + "log/slog" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/coder/boundary/rulesengine" + "github.com/coder/boundary/tls" +) + +func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error { + if len(config.AllowRules) == 0 { + logger.Warn("No allow rules specified; all network traffic will be denied by default") + } + + // Parse allow rules + allowRules, err := rulesengine.ParseAllowSpecs(config.AllowRules) + if err != nil { + logger.Error("Failed to parse allow rules", "error", err) + return fmt.Errorf("failed to parse allow rules: %v", err) + } + + // Create rule engine + ruleEngine := rulesengine.NewRuleEngine(allowRules, logger) + + // Create auditor + auditor := audit.NewLogAuditor(logger) + + // Create TLS certificate manager + certManager, err := tls.NewCertificateManager(tls.Config{ + Logger: logger, + ConfigDir: config.UserInfo.ConfigDir, + Uid: config.UserInfo.Uid, + Gid: config.UserInfo.Gid, + }) + if err != nil { + logger.Error("Failed to create certificate manager", "error", err) + return fmt.Errorf("failed to create certificate manager: %v", err) + } + + // Setup TLS to get cert path for jailer + tlsConfig, err := certManager.SetupTLSAndWriteCACert() + if err != nil { + return fmt.Errorf("failed to setup TLS and CA certificate: %v", err) + } + + landjail, err := NewLandJail(ruleEngine, auditor, tlsConfig, logger, config) + if err != nil { + return fmt.Errorf("failed to create landjail: %v", err) + } + + return landjail.Run(ctx) +} diff --git a/landjail/run.go b/landjail/run.go new file mode 100644 index 0000000..c25e7f9 --- /dev/null +++ b/landjail/run.go @@ -0,0 +1,25 @@ +package landjail + +import ( + "context" + "log/slog" + "os" + + "github.com/coder/boundary/config" +) + +func isChild() bool { + return os.Getenv("CHILD") == "true" +} + +// Run is the main entry point that determines whether to execute as a parent or child process. +// If running as a child (CHILD env var is set), it applies landlock restrictions +// and executes the target command. Otherwise, it runs as the parent process, sets up the proxy server, +// and manages the child process lifecycle. +func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error { + if isChild() { + return RunChild(logger, config) + } + + return RunParent(ctx, logger, config) +} diff --git a/nsjail_manager/child.go b/nsjail_manager/child.go index 78f49a7..824816f 100644 --- a/nsjail_manager/child.go +++ b/nsjail_manager/child.go @@ -47,7 +47,7 @@ func waitForInterface(interfaceName string, timeout time.Duration) error { return nil } -func RunChild(logger *slog.Logger, args []string) error { +func RunChild(logger *slog.Logger, targetCMD []string) error { logger.Info("boundary CHILD process is started") vethNetJail := os.Getenv("VETH_JAIL_NAME") @@ -75,8 +75,8 @@ func RunChild(logger *slog.Logger, args []string) error { } // Program to run - bin := args[0] - args = args[1:] + bin := targetCMD[0] + args := targetCMD[1:] cmd := exec.Command(bin, args...) cmd.Stdin = os.Stdin diff --git a/nsjail_manager/nsjail/util.go b/nsjail_manager/nsjail/env.go similarity index 50% rename from nsjail_manager/nsjail/util.go rename to nsjail_manager/nsjail/env.go index 7c03e15..5533229 100644 --- a/nsjail_manager/nsjail/util.go +++ b/nsjail_manager/nsjail/env.go @@ -2,13 +2,14 @@ package nsjail import ( "os" - "strings" + + "github.com/coder/boundary/util" ) -func getEnvs(configDir string, caCertPath string) []string { +func getEnvsForTargetProcess(configDir string, caCertPath string) []string { e := os.Environ() - e = mergeEnvs(e, map[string]string{ + e = util.MergeEnvs(e, map[string]string{ // Set standard CA certificate environment variables for common tools // This makes tools like curl, git, etc. trust our dynamically generated CA "SSL_CERT_FILE": caCertPath, // OpenSSL/LibreSSL-based tools @@ -21,24 +22,3 @@ func getEnvs(configDir string, caCertPath string) []string { return e } - -func mergeEnvs(base []string, extra map[string]string) []string { - envMap := make(map[string]string) - for _, env := range base { - parts := strings.SplitN(env, "=", 2) - if len(parts) == 2 { - envMap[parts[0]] = parts[1] - } - } - - for key, value := range extra { - envMap[key] = value - } - - merged := make([]string, 0, len(envMap)) - for key, value := range envMap { - merged = append(merged, key+"="+value) - } - - return merged -} diff --git a/nsjail_manager/nsjail/jail.go b/nsjail_manager/nsjail/jail.go index ebb8664..3d1cf15 100644 --- a/nsjail_manager/nsjail/jail.go +++ b/nsjail_manager/nsjail/jail.go @@ -31,7 +31,6 @@ type LinuxJail struct { logger *slog.Logger vethHostName string // Host-side veth interface name for iptables rules vethJailName string // Jail-side veth interface name for iptables rules - commandEnv []string httpProxyPort int configDir string caCertPath string @@ -53,8 +52,6 @@ func NewLinuxJail(config Config) (*LinuxJail, error) { // installs iptables rules on the host. At this stage, the target PID and its netns // are not yet known. func (l *LinuxJail) ConfigureHost() error { - l.commandEnv = getEnvs(l.configDir, l.caCertPath) - if err := l.configureHostNetworkBeforeCmdExec(); err != nil { return err } @@ -70,7 +67,7 @@ func (l *LinuxJail) Command(command []string) *exec.Cmd { l.logger.Debug("Creating command with namespace") cmd := exec.Command(command[0], command[1:]...) - cmd.Env = l.commandEnv + cmd.Env = getEnvsForTargetProcess(l.configDir, l.caCertPath) cmd.Env = append(cmd.Env, "CHILD=true") cmd.Env = append(cmd.Env, fmt.Sprintf("VETH_JAIL_NAME=%v", l.vethJailName)) if l.configureDNSForLocalStubResolver { diff --git a/nsjail_manager/parent.go b/nsjail_manager/parent.go index a289fdf..78b95be 100644 --- a/nsjail_manager/parent.go +++ b/nsjail_manager/parent.go @@ -10,17 +10,9 @@ import ( "github.com/coder/boundary/nsjail_manager/nsjail" "github.com/coder/boundary/rulesengine" "github.com/coder/boundary/tls" - "github.com/coder/boundary/util" ) -func RunParent(ctx context.Context, logger *slog.Logger, args []string, config config.AppConfig) error { - _, uid, gid, homeDir, configDir := util.GetUserInfo() - - // Get command arguments - if len(args) == 0 { - return fmt.Errorf("no command specified") - } - +func RunParent(ctx context.Context, logger *slog.Logger, config config.AppConfig) error { if len(config.AllowRules) == 0 { logger.Warn("No allow rules specified; all network traffic will be denied by default") } @@ -41,9 +33,9 @@ func RunParent(ctx context.Context, logger *slog.Logger, args []string, config c // Create TLS certificate manager certManager, err := tls.NewCertificateManager(tls.Config{ Logger: logger, - ConfigDir: configDir, - Uid: uid, - Gid: gid, + ConfigDir: config.UserInfo.ConfigDir, + Uid: config.UserInfo.Uid, + Gid: config.UserInfo.Gid, }) if err != nil { logger.Error("Failed to create certificate manager", "error", err) @@ -51,7 +43,7 @@ func RunParent(ctx context.Context, logger *slog.Logger, args []string, config c } // Setup TLS to get cert path for jailer - tlsConfig, caCertPath, configDir, err := certManager.SetupTLSAndWriteCACert() + tlsConfig, err := certManager.SetupTLSAndWriteCACert() if err != nil { return fmt.Errorf("failed to setup TLS and CA certificate: %v", err) } @@ -60,9 +52,9 @@ func RunParent(ctx context.Context, logger *slog.Logger, args []string, config c jailer, err := nsjail.NewLinuxJail(nsjail.Config{ Logger: logger, HttpProxyPort: int(config.ProxyPort), - HomeDir: homeDir, - ConfigDir: configDir, - CACertPath: caCertPath, + HomeDir: config.UserInfo.HomeDir, + ConfigDir: config.UserInfo.ConfigDir, + CACertPath: config.UserInfo.CACertPath(), ConfigureDNSForLocalStubResolver: config.ConfigureDNSForLocalStubResolver, }) if err != nil { diff --git a/nsjail_manager/run.go b/nsjail_manager/run.go index f8efb8d..e38e431 100644 --- a/nsjail_manager/run.go +++ b/nsjail_manager/run.go @@ -16,10 +16,10 @@ func isChild() bool { // If running as a child (CHILD env var is set), it sets up networking in the namespace // and executes the target command. Otherwise, it runs as the parent process, setting up the jail, // proxy server, and managing the child process lifecycle. -func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig, args []string) error { +func Run(ctx context.Context, logger *slog.Logger, config config.AppConfig) error { if isChild() { - return RunChild(logger, args) + return RunChild(logger, config.TargetCMD) } - return RunParent(ctx, logger, args, config) + return RunParent(ctx, logger, config) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 10d61d4..1dab16c 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -145,9 +145,8 @@ func TestProxyServerBasicHTTPS(t *testing.T) { require.NoError(t, err) // Setup TLS to get cert path for jailer - tlsConfig, caCertPath, configDir, err := certManager.SetupTLSAndWriteCACert() + tlsConfig, err := certManager.SetupTLSAndWriteCACert() require.NoError(t, err) - _, _ = caCertPath, configDir // Create proxy server server := NewProxyServer(Config{ @@ -242,9 +241,8 @@ func TestProxyServerCONNECT(t *testing.T) { require.NoError(t, err) // Setup TLS to get cert path for proxy - tlsConfig, caCertPath, configDir, err := certManager.SetupTLSAndWriteCACert() + tlsConfig, err := certManager.SetupTLSAndWriteCACert() require.NoError(t, err) - _, _ = caCertPath, configDir // Create proxy server server := NewProxyServer(Config{ diff --git a/run/run.go b/run/run.go new file mode 100644 index 0000000..9353090 --- /dev/null +++ b/run/run.go @@ -0,0 +1,22 @@ +package run + +import ( + "context" + "fmt" + "log/slog" + + "github.com/coder/boundary/config" + "github.com/coder/boundary/landjail" + "github.com/coder/boundary/nsjail_manager" +) + +func Run(ctx context.Context, logger *slog.Logger, cfg config.AppConfig) error { + switch cfg.JailType { + case config.NSJailType: + return nsjail_manager.Run(ctx, logger, cfg) + case config.LandjailType: + return landjail.Run(ctx, logger, cfg) + default: + return fmt.Errorf("unknown jail type: %s", cfg.JailType) + } +} diff --git a/tls/tls.go b/tls/tls.go index d717dde..f43e3a2 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -15,6 +15,8 @@ import ( "path/filepath" "sync" "time" + + "github.com/coder/boundary/config" ) type Manager interface { @@ -61,30 +63,30 @@ func NewCertificateManager(config Config) (*CertificateManager, error) { // SetupTLSAndWriteCACert sets up TLS config and writes CA certificate to file // Returns the TLS config, CA cert path, and config directory -func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, string, string, error) { +func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, error) { // Get TLS config tlsConfig := cm.getTLSConfig() // Get CA certificate PEM caCertPEM, err := cm.getCACertPEM() if err != nil { - return nil, "", "", fmt.Errorf("failed to get CA certificate: %v", err) + return nil, fmt.Errorf("failed to get CA certificate: %v", err) } // Write CA certificate to file - caCertPath := filepath.Join(cm.configDir, "ca-cert.pem") + caCertPath := filepath.Join(cm.configDir, config.CACertName) err = os.WriteFile(caCertPath, caCertPEM, 0644) if err != nil { - return nil, "", "", fmt.Errorf("failed to write CA certificate file: %v", err) + return nil, fmt.Errorf("failed to write CA certificate file: %v", err) } - return tlsConfig, caCertPath, cm.configDir, nil + return tlsConfig, nil } // loadOrGenerateCA loads existing CA or generates a new one func (cm *CertificateManager) loadOrGenerateCA() error { - caKeyPath := filepath.Join(cm.configDir, "ca-key.pem") - caCertPath := filepath.Join(cm.configDir, "ca-cert.pem") + caKeyPath := filepath.Join(cm.configDir, config.CAKeyName) + caCertPath := filepath.Join(cm.configDir, config.CACertName) cm.logger.Debug("paths", "cm.configDir", cm.configDir, "caCertPath", caCertPath) diff --git a/util/env.go b/util/env.go new file mode 100644 index 0000000..15f2836 --- /dev/null +++ b/util/env.go @@ -0,0 +1,24 @@ +package util + +import "strings" + +func MergeEnvs(base []string, extra map[string]string) []string { + envMap := make(map[string]string) + for _, env := range base { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + for key, value := range extra { + envMap[key] = value + } + + merged := make([]string, 0, len(envMap)) + for key, value := range envMap { + merged = append(merged, key+"="+value) + } + + return merged +}