diff --git a/jail.go b/jail.go index 6f27ee4..9b77aa7 100644 --- a/jail.go +++ b/jail.go @@ -2,6 +2,7 @@ package jail import ( "context" + cryptotls "crypto/tls" "fmt" "log/slog" "os/exec" @@ -10,7 +11,6 @@ import ( "github.com/coder/jail/namespace" "github.com/coder/jail/proxy" - "github.com/coder/jail/tls" ) type Commander interface { @@ -19,19 +19,23 @@ type Commander interface { Close() error } +type CertificateManager interface { + SetupTLSAndWriteCACert() (*cryptotls.Config, string, string, error) +} + type Config struct { RuleEngine proxy.RuleEvaluator Auditor proxy.Auditor - CertManager *tls.CertificateManager + CertManager CertificateManager Logger *slog.Logger } type Jail struct { - commandExecutor Commander - proxyServer *proxy.ProxyServer - logger *slog.Logger - ctx context.Context - cancel context.CancelFunc + commander Commander + proxyServer *proxy.ProxyServer + logger *slog.Logger + ctx context.Context + cancel context.CancelFunc } func New(ctx context.Context, config Config) (*Jail, error) { @@ -75,17 +79,17 @@ func New(ctx context.Context, config Config) (*Jail, error) { ctx, cancel := context.WithCancel(ctx) return &Jail{ - commandExecutor: commander, - proxyServer: proxyServer, - logger: config.Logger, - ctx: ctx, - cancel: cancel, + commander: commander, + proxyServer: proxyServer, + logger: config.Logger, + ctx: ctx, + cancel: cancel, }, nil } func (j *Jail) Start() error { // Open the command executor (network namespace) - err := j.commandExecutor.Start() + err := j.commander.Start() if err != nil { return fmt.Errorf("failed to open command executor: %v", err) } @@ -105,7 +109,7 @@ func (j *Jail) Start() error { } func (j *Jail) Command(command []string) *exec.Cmd { - return j.commandExecutor.Command(command) + return j.commander.Command(command) } func (j *Jail) Close() error { @@ -118,7 +122,7 @@ func (j *Jail) Close() error { } // Close command executor - return j.commandExecutor.Close() + return j.commander.Close() } // newCommander creates a new NetJail instance for the current platform diff --git a/tls/tls.go b/tls/tls.go index b529a54..002c9d5 100644 --- a/tls/tls.go +++ b/tls/tls.go @@ -51,22 +51,6 @@ func NewCertificateManager(logger *slog.Logger) (*CertificateManager, error) { return cm, nil } -// GetTLSConfig returns a TLS config that generates certificates on-demand -func (cm *CertificateManager) GetTLSConfig() *tls.Config { - return &tls.Config{ - GetCertificate: cm.getCertificate, - MinVersion: tls.VersionTLS12, - } -} - -// GetCACertPEM returns the CA certificate in PEM format -func (cm *CertificateManager) GetCACertPEM() ([]byte, error) { - return pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: cm.caCert.Raw, - }), nil -} - // SetupTLSAndWriteCACert sets up TLS config and writes CA certificate to file // Returns the TLS config, CA cert path, and config directory func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, string, string, error) { @@ -77,10 +61,10 @@ func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, string, str } // Get TLS config - tlsConfig := cm.GetTLSConfig() + tlsConfig := cm.getTLSConfig() // Get CA certificate PEM - caCertPEM, err := cm.GetCACertPEM() + caCertPEM, err := cm.getCACertPEM() if err != nil { return nil, "", "", fmt.Errorf("failed to get CA certificate: %v", err) } @@ -111,6 +95,22 @@ func (cm *CertificateManager) loadOrGenerateCA() error { return cm.generateCA(caKeyPath, caCertPath) } +// getTLSConfig returns a TLS config that generates certificates on-demand +func (cm *CertificateManager) getTLSConfig() *tls.Config { + return &tls.Config{ + GetCertificate: cm.getCertificate, + MinVersion: tls.VersionTLS12, + } +} + +// getCACertPEM returns the CA certificate in PEM format +func (cm *CertificateManager) getCACertPEM() ([]byte, error) { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cm.caCert.Raw, + }), nil +} + // loadExistingCA attempts to load existing CA files func (cm *CertificateManager) loadExistingCA(keyPath, certPath string) bool { // Check if files exist