Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions jail.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jail

import (
"context"
cryptotls "crypto/tls"
"fmt"
"log/slog"
"os/exec"
Expand All @@ -10,7 +11,6 @@ import (

"github.com/coder/jail/namespace"
"github.com/coder/jail/proxy"
"github.com/coder/jail/tls"
)

type Commander interface {
Expand All @@ -19,19 +19,23 @@ type Commander interface {
Close() error
}

type CertificateManager interface {
SetupTLSAndWriteCACert() (*cryptotls.Config, string, string, error)
}

type Config struct {
RuleEngine proxy.RuleEvaluator
Auditor proxy.Auditor
CertManager *tls.CertificateManager
CertManager CertificateManager
Logger *slog.Logger
}

type Jail struct {
commandExecutor Commander
proxyServer *proxy.ProxyServer
logger *slog.Logger
ctx context.Context
cancel context.CancelFunc
commander Commander
proxyServer *proxy.ProxyServer
logger *slog.Logger
ctx context.Context
cancel context.CancelFunc
}

func New(ctx context.Context, config Config) (*Jail, error) {
Expand Down Expand Up @@ -75,17 +79,17 @@ func New(ctx context.Context, config Config) (*Jail, error) {
ctx, cancel := context.WithCancel(ctx)

return &Jail{
commandExecutor: commander,
proxyServer: proxyServer,
logger: config.Logger,
ctx: ctx,
cancel: cancel,
commander: commander,
proxyServer: proxyServer,
logger: config.Logger,
ctx: ctx,
cancel: cancel,
}, nil
}

func (j *Jail) Start() error {
// Open the command executor (network namespace)
err := j.commandExecutor.Start()
err := j.commander.Start()
if err != nil {
return fmt.Errorf("failed to open command executor: %v", err)
}
Expand All @@ -105,7 +109,7 @@ func (j *Jail) Start() error {
}

func (j *Jail) Command(command []string) *exec.Cmd {
return j.commandExecutor.Command(command)
return j.commander.Command(command)
}

func (j *Jail) Close() error {
Expand All @@ -118,7 +122,7 @@ func (j *Jail) Close() error {
}

// Close command executor
return j.commandExecutor.Close()
return j.commander.Close()
}

// newCommander creates a new NetJail instance for the current platform
Expand Down
36 changes: 18 additions & 18 deletions tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,6 @@ func NewCertificateManager(logger *slog.Logger) (*CertificateManager, error) {
return cm, nil
}

// GetTLSConfig returns a TLS config that generates certificates on-demand
func (cm *CertificateManager) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: cm.getCertificate,
MinVersion: tls.VersionTLS12,
}
}

// GetCACertPEM returns the CA certificate in PEM format
func (cm *CertificateManager) GetCACertPEM() ([]byte, error) {
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cm.caCert.Raw,
}), nil
}

// SetupTLSAndWriteCACert sets up TLS config and writes CA certificate to file
// Returns the TLS config, CA cert path, and config directory
func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, string, string, error) {
Expand All @@ -77,10 +61,10 @@ func (cm *CertificateManager) SetupTLSAndWriteCACert() (*tls.Config, string, str
}

// Get TLS config
tlsConfig := cm.GetTLSConfig()
tlsConfig := cm.getTLSConfig()

// Get CA certificate PEM
caCertPEM, err := cm.GetCACertPEM()
caCertPEM, err := cm.getCACertPEM()
if err != nil {
return nil, "", "", fmt.Errorf("failed to get CA certificate: %v", err)
}
Expand Down Expand Up @@ -111,6 +95,22 @@ func (cm *CertificateManager) loadOrGenerateCA() error {
return cm.generateCA(caKeyPath, caCertPath)
}

// getTLSConfig returns a TLS config that generates certificates on-demand
func (cm *CertificateManager) getTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: cm.getCertificate,
MinVersion: tls.VersionTLS12,
}
}

// getCACertPEM returns the CA certificate in PEM format
func (cm *CertificateManager) getCACertPEM() ([]byte, error) {
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cm.caCert.Raw,
}), nil
}

// loadExistingCA attempts to load existing CA files
func (cm *CertificateManager) loadExistingCA(keyPath, certPath string) bool {
// Check if files exist
Expand Down
Loading