diff --git a/cmd/cli/commands/install-runner.go b/cmd/cli/commands/install-runner.go index c830333c..a3640c1b 100644 --- a/cmd/cli/commands/install-runner.go +++ b/cmd/cli/commands/install-runner.go @@ -143,7 +143,9 @@ func ensureStandaloneRunnerAvailable(ctx context.Context, printer standalone.Sta port = standalone.DefaultControllerPortCloud environment = "cloud" } - if err := standalone.CreateControllerContainer(ctx, dockerClient, port, host, environment, false, gpu, "", modelStorageVolume, printer, engineKind, debug, false, ""); err != nil { + // TLS is disabled by default for auto-installation + tlsOpts := standalone.TLSOptions{Enabled: false} + if err := standalone.CreateControllerContainer(ctx, dockerClient, port, host, environment, false, gpu, "", modelStorageVolume, printer, engineKind, debug, false, "", tlsOpts); err != nil { return nil, fmt.Errorf("unable to initialize standalone model runner container: %w", err) } @@ -225,6 +227,10 @@ type runnerOptions struct { pullImage bool pruneContainers bool proxyCert string + tls bool + tlsPort uint16 + tlsCert string + tlsKey string } // runInstallOrStart is shared logic for install-runner and start-runner commands @@ -337,8 +343,17 @@ func runInstallOrStart(cmd *cobra.Command, opts runnerOptions, debug bool) error if err != nil { return fmt.Errorf("unable to initialize standalone model storage: %w", err) } + + // Build TLS options + tlsOpts := standalone.TLSOptions{ + Enabled: opts.tls, + Port: opts.tlsPort, + CertPath: opts.tlsCert, + KeyPath: opts.tlsKey, + } + // Create the model runner container. - if err := standalone.CreateControllerContainer(cmd.Context(), dockerClient, port, opts.host, environment, opts.doNotTrack, gpu, opts.backend, modelStorageVolume, asPrinter(cmd), engineKind, debug, vllmOnWSL, opts.proxyCert); err != nil { + if err := standalone.CreateControllerContainer(cmd.Context(), dockerClient, port, opts.host, environment, opts.doNotTrack, gpu, opts.backend, modelStorageVolume, asPrinter(cmd), engineKind, debug, vllmOnWSL, opts.proxyCert, tlsOpts); err != nil { return fmt.Errorf("unable to initialize standalone model runner container: %w", err) } @@ -354,6 +369,10 @@ func newInstallRunner() *cobra.Command { var doNotTrack bool var debug bool var proxyCert string + var tlsEnabled bool + var tlsPort uint16 + var tlsCert string + var tlsKey string c := &cobra.Command{ Use: "install-runner", Short: "Install Docker Model Runner (Docker Engine only)", @@ -367,6 +386,10 @@ func newInstallRunner() *cobra.Command { pullImage: true, pruneContainers: false, proxyCert: proxyCert, + tls: tlsEnabled, + tlsPort: tlsPort, + tlsCert: tlsCert, + tlsKey: tlsKey, }, debug) }, ValidArgsFunction: completion.NoComplete, @@ -379,6 +402,10 @@ func newInstallRunner() *cobra.Command { DoNotTrack: &doNotTrack, Debug: &debug, ProxyCert: &proxyCert, + TLS: &tlsEnabled, + TLSPort: &tlsPort, + TLSCert: &tlsCert, + TLSKey: &tlsKey, }) return c } diff --git a/cmd/cli/commands/reinstall-runner.go b/cmd/cli/commands/reinstall-runner.go index 01798f5a..cec28b38 100644 --- a/cmd/cli/commands/reinstall-runner.go +++ b/cmd/cli/commands/reinstall-runner.go @@ -13,6 +13,10 @@ func newReinstallRunner() *cobra.Command { var doNotTrack bool var debug bool var proxyCert string + var tlsEnabled bool + var tlsPort uint16 + var tlsCert string + var tlsKey string c := &cobra.Command{ Use: "reinstall-runner", Short: "Reinstall Docker Model Runner (Docker Engine only)", @@ -26,6 +30,10 @@ func newReinstallRunner() *cobra.Command { pullImage: true, pruneContainers: true, proxyCert: proxyCert, + tls: tlsEnabled, + tlsPort: tlsPort, + tlsCert: tlsCert, + tlsKey: tlsKey, }, debug) }, ValidArgsFunction: completion.NoComplete, @@ -38,6 +46,10 @@ func newReinstallRunner() *cobra.Command { DoNotTrack: &doNotTrack, Debug: &debug, ProxyCert: &proxyCert, + TLS: &tlsEnabled, + TLSPort: &tlsPort, + TLSCert: &tlsCert, + TLSKey: &tlsKey, }) return c } diff --git a/cmd/cli/commands/start-runner.go b/cmd/cli/commands/start-runner.go index 8cab9749..61bb4b0a 100644 --- a/cmd/cli/commands/start-runner.go +++ b/cmd/cli/commands/start-runner.go @@ -7,33 +7,48 @@ import ( func newStartRunner() *cobra.Command { var port uint16 + var host string var gpuMode string var backend string var doNotTrack bool var debug bool var proxyCert string + var tlsEnabled bool + var tlsPort uint16 + var tlsCert string + var tlsKey string c := &cobra.Command{ Use: "start-runner", Short: "Start Docker Model Runner (Docker Engine only)", RunE: func(cmd *cobra.Command, args []string) error { return runInstallOrStart(cmd, runnerOptions{ port: port, + host: host, gpuMode: gpuMode, backend: backend, doNotTrack: doNotTrack, pullImage: false, proxyCert: proxyCert, + tls: tlsEnabled, + tlsPort: tlsPort, + tlsCert: tlsCert, + tlsKey: tlsKey, }, debug) }, ValidArgsFunction: completion.NoComplete, } addRunnerFlags(c, runnerFlagOptions{ Port: &port, + Host: &host, GpuMode: &gpuMode, Backend: &backend, DoNotTrack: &doNotTrack, Debug: &debug, ProxyCert: &proxyCert, + TLS: &tlsEnabled, + TLSPort: &tlsPort, + TLSCert: &tlsCert, + TLSKey: &tlsKey, }) return c } diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index c136d673..6844887a 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -195,6 +195,10 @@ type runnerFlagOptions struct { DoNotTrack *bool Debug *bool ProxyCert *string + TLS *bool + TLSPort *uint16 + TLSCert *string + TLSKey *string } // addRunnerFlags adds common runner flags to a command @@ -221,6 +225,19 @@ func addRunnerFlags(cmd *cobra.Command, opts runnerFlagOptions) { if opts.ProxyCert != nil { cmd.Flags().StringVar(opts.ProxyCert, "proxy-cert", "", "Path to a CA certificate file for proxy SSL inspection") } + if opts.TLS != nil { + cmd.Flags().BoolVar(opts.TLS, "tls", false, "Enable TLS/HTTPS for Docker Model Runner API") + } + if opts.TLSPort != nil { + cmd.Flags().Uint16Var(opts.TLSPort, "tls-port", 0, + "TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode)") + } + if opts.TLSCert != nil { + cmd.Flags().StringVar(opts.TLSCert, "tls-cert", "", "Path to TLS certificate file (auto-generated if not provided)") + } + if opts.TLSKey != nil { + cmd.Flags().StringVar(opts.TLSKey, "tls-key", "", "Path to TLS private key file (auto-generated if not provided)") + } } // newTable creates a new table with Docker CLI-style formatting: diff --git a/cmd/cli/desktop/context.go b/cmd/cli/desktop/context.go index c92121a8..badb40f7 100644 --- a/cmd/cli/desktop/context.go +++ b/cmd/cli/desktop/context.go @@ -19,6 +19,7 @@ import ( "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/cmd/cli/pkg/types" "github.com/docker/model-runner/pkg/inference" + modeltls "github.com/docker/model-runner/pkg/tls" "github.com/moby/moby/client" ) @@ -102,6 +103,12 @@ type ModelRunnerContext struct { // For internal Docker Model Runner, this is "/engines/v1". // For external OpenAI-compatible endpoints, this is empty (the URL already includes the version path). openaiPathPrefix string + // useTLS indicates whether TLS is being used for connections. + useTLS bool + // tlsURLPrefix is the TLS URL prefix (if TLS is enabled). + tlsURLPrefix *url.URL + // tlsClient is the TLS-enabled HTTP client (if TLS is enabled). + tlsClient DockerHttpClient } // NewContextForMock is a ModelRunnerContext constructor exposed only for the @@ -216,6 +223,11 @@ func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalo // testing purposes. treatDesktopAsMoby := os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") == "1" + // Check if TLS should be used + useTLS := os.Getenv("MODEL_RUNNER_TLS") == "true" + tlsSkipVerify := os.Getenv("MODEL_RUNNER_TLS_SKIP_VERIFY") == "true" + tlsCACert := os.Getenv("MODEL_RUNNER_TLS_CA_CERT") + // Detect the associated engine type. kind := types.ModelRunnerEngineKindMoby if modelRunnerHost != "" { @@ -238,15 +250,42 @@ func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalo // Compute the URL prefix based on the associated engine kind. var rawURLPrefix string + var rawTLSURLPrefix string switch kind { case types.ModelRunnerEngineKindMoby: rawURLPrefix = "http://localhost:" + strconv.Itoa(standalone.DefaultControllerPortMoby) + rawTLSURLPrefix = "https://localhost:" + strconv.Itoa(standalone.DefaultTLSPortMoby) case types.ModelRunnerEngineKindCloud: rawURLPrefix = "http://localhost:" + strconv.Itoa(standalone.DefaultControllerPortCloud) + rawTLSURLPrefix = "https://localhost:" + strconv.Itoa(standalone.DefaultTLSPortCloud) case types.ModelRunnerEngineKindMobyManual: - rawURLPrefix = modelRunnerHost + normalizedHost := modelRunnerHost + + // Ensure the manual host has a scheme. + // Default to https when TLS is requested, otherwise http. + if !strings.HasPrefix(normalizedHost, "http://") && !strings.HasPrefix(normalizedHost, "https://") { + if useTLS { + normalizedHost = "https://" + normalizedHost + } else { + normalizedHost = "http://" + normalizedHost + } + } + + rawURLPrefix = normalizedHost + + // Derive TLS URL from the normalized host, ensuring https when TLS is enabled. + if useTLS { + if strings.HasPrefix(normalizedHost, "http://") { + rawTLSURLPrefix = "https://" + strings.TrimPrefix(normalizedHost, "http://") + } else { + rawTLSURLPrefix = normalizedHost + } + } else { + rawTLSURLPrefix = normalizedHost + } case types.ModelRunnerEngineKindDesktop: rawURLPrefix = "http://localhost" + inference.ExperimentalEndpointsPrefix + rawTLSURLPrefix = rawURLPrefix // TLS not typically used with Desktop if IsDesktopWSLContext(ctx, cli) { dockerClient, err := DockerClientForContext(cli, cli.CurrentContext()) if err != nil { @@ -257,6 +296,7 @@ func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalo containerID, _, _, err := standalone.FindControllerContainer(ctx, dockerClient) if err == nil && containerID != "" { rawURLPrefix = "http://localhost:" + strconv.Itoa(standalone.DefaultControllerPortMoby) + rawTLSURLPrefix = "https://localhost:" + strconv.Itoa(standalone.DefaultTLSPortMoby) kind = types.ModelRunnerEngineKindMoby } } @@ -266,28 +306,72 @@ func DetectContext(ctx context.Context, cli *command.DockerCli, printer standalo return nil, fmt.Errorf("invalid model runner URL (%s): %w", rawURLPrefix, err) } + var tlsURLPrefix *url.URL + if useTLS { + tlsURLPrefix, err = url.Parse(rawTLSURLPrefix) + if err != nil { + return nil, fmt.Errorf("invalid model runner TLS URL (%s): %w", rawTLSURLPrefix, err) + } + + // Validate that TLS URL uses HTTPS when TLS is enabled + if tlsURLPrefix.Scheme != "https" { + return nil, fmt.Errorf("TLS requested but URL scheme is not HTTPS: %s", rawTLSURLPrefix) + } + } + // Construct the HTTP client. - var client DockerHttpClient + var httpClient DockerHttpClient if kind == types.ModelRunnerEngineKindDesktop { dockerClient, err := DockerClientForContext(cli, cli.CurrentContext()) if err != nil { return nil, fmt.Errorf("unable to create model runner client: %w", err) } - client = dockerClient.HTTPClient() + httpClient = dockerClient.HTTPClient() } else { - client = http.DefaultClient + httpClient = http.DefaultClient } if userAgent := os.Getenv("USER_AGENT"); userAgent != "" { - setUserAgent(client, userAgent) + setUserAgent(httpClient, userAgent) + } + + // Construct TLS client if TLS is enabled + var tlsClient DockerHttpClient + if useTLS { + if kind == types.ModelRunnerEngineKindDesktop { + // For Desktop context, if TLS is enabled, we should either fully support it or fail fast + // Since Desktop context uses Docker client, we need to handle TLS differently + // For now, we'll fail fast to make the behavior clear + return nil, fmt.Errorf("TLS is not supported for Desktop contexts") + } + + tlsConfig, err := modeltls.LoadClientTLSConfig(tlsCACert, tlsSkipVerify) + if err != nil { + return nil, fmt.Errorf("unable to load TLS configuration: %w", err) + } + + tlsTransport := &http.Transport{ + TLSClientConfig: tlsConfig, + Proxy: http.ProxyFromEnvironment, + } + tlsClient = &http.Client{ + Transport: tlsTransport, + } + + if userAgent := os.Getenv("USER_AGENT"); userAgent != "" { + setUserAgent(tlsClient, userAgent) + } } // Success. return &ModelRunnerContext{ kind: kind, urlPrefix: urlPrefix, - client: client, + client: httpClient, openaiPathPrefix: inference.InferencePrefix + "/v1", + useTLS: useTLS, + tlsURLPrefix: tlsURLPrefix, + tlsClient: tlsClient, }, nil } @@ -297,9 +381,45 @@ func (c *ModelRunnerContext) EngineKind() types.ModelRunnerEngineKind { } // URL constructs a URL string appropriate for the model runner. +// If TLS is enabled, returns the TLS URL. func (c *ModelRunnerContext) URL(path string) string { + prefix := c.urlPrefix + if c.useTLS && c.tlsURLPrefix != nil { + prefix = c.tlsURLPrefix + } + return c.buildURL(prefix, path) +} + +// Client returns an HTTP client appropriate for accessing the model runner. +// If TLS is enabled, returns the TLS client. +func (c *ModelRunnerContext) Client() DockerHttpClient { + if c.useTLS && c.tlsClient != nil { + return c.tlsClient + } + return c.client +} + +// UseTLS returns whether TLS is enabled for this context. +func (c *ModelRunnerContext) UseTLS() bool { + return c.useTLS +} + +// TLSURL constructs a TLS URL string for the model runner. +// Returns an empty string if TLS is not enabled. +func (c *ModelRunnerContext) TLSURL(path string) string { + if c.tlsURLPrefix == nil { + return "" + } + return c.buildURL(c.tlsURLPrefix, path) +} + +// buildURL constructs a URL string from a prefix and path, handling query parameters. +func (c *ModelRunnerContext) buildURL(prefix *url.URL, path string) string { + if prefix == nil { + return "" + } components := strings.Split(path, "?") - result := c.urlPrefix.JoinPath(components[0]).String() + result := prefix.JoinPath(components[0]).String() if len(components) > 1 { components[0] = result result = strings.Join(components, "?") @@ -307,9 +427,9 @@ func (c *ModelRunnerContext) URL(path string) string { return result } -// Client returns an HTTP client appropriate for accessing the model runner. -func (c *ModelRunnerContext) Client() DockerHttpClient { - return c.client +// TLSClient returns the TLS HTTP client, or nil if TLS is not enabled. +func (c *ModelRunnerContext) TLSClient() DockerHttpClient { + return c.tlsClient } // OpenAIPathPrefix returns the path prefix for OpenAI-compatible endpoints. diff --git a/cmd/cli/docs/reference/docker_model_install-runner.yaml b/cmd/cli/docs/reference/docker_model_install-runner.yaml index 0da2321c..d8640fd8 100644 --- a/cmd/cli/docs/reference/docker_model_install-runner.yaml +++ b/cmd/cli/docs/reference/docker_model_install-runner.yaml @@ -75,6 +75,45 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: tls + value_type: bool + default_value: "false" + description: Enable TLS/HTTPS for Docker Model Runner API + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-cert + value_type: string + description: Path to TLS certificate file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-key + value_type: string + description: Path to TLS private key file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-port + value_type: uint16 + default_value: "0" + description: | + TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false deprecated: false hidden: false experimental: false diff --git a/cmd/cli/docs/reference/docker_model_reinstall-runner.yaml b/cmd/cli/docs/reference/docker_model_reinstall-runner.yaml index 213bf27a..28b56666 100644 --- a/cmd/cli/docs/reference/docker_model_reinstall-runner.yaml +++ b/cmd/cli/docs/reference/docker_model_reinstall-runner.yaml @@ -75,6 +75,45 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: tls + value_type: bool + default_value: "false" + description: Enable TLS/HTTPS for Docker Model Runner API + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-cert + value_type: string + description: Path to TLS certificate file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-key + value_type: string + description: Path to TLS private key file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-port + value_type: uint16 + default_value: "0" + description: | + TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false deprecated: false hidden: false experimental: false diff --git a/cmd/cli/docs/reference/docker_model_start-runner.yaml b/cmd/cli/docs/reference/docker_model_start-runner.yaml index 646e055f..5fa9df42 100644 --- a/cmd/cli/docs/reference/docker_model_start-runner.yaml +++ b/cmd/cli/docs/reference/docker_model_start-runner.yaml @@ -47,6 +47,16 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: host + value_type: string + default_value: 127.0.0.1 + description: Host address to bind Docker Model Runner + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: port value_type: uint16 default_value: "0" @@ -67,6 +77,45 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: tls + value_type: bool + default_value: "false" + description: Enable TLS/HTTPS for Docker Model Runner API + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-cert + value_type: string + description: Path to TLS certificate file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-key + value_type: string + description: Path to TLS private key file (auto-generated if not provided) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false + - option: tls-port + value_type: uint16 + default_value: "0" + description: | + TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false deprecated: false hidden: false experimental: false diff --git a/cmd/cli/docs/reference/model_install-runner.md b/cmd/cli/docs/reference/model_install-runner.md index eb4ee1b6..2d9e1c01 100644 --- a/cmd/cli/docs/reference/model_install-runner.md +++ b/cmd/cli/docs/reference/model_install-runner.md @@ -14,6 +14,10 @@ Install Docker Model Runner (Docker Engine only) | `--host` | `string` | `127.0.0.1` | Host address to bind Docker Model Runner | | `--port` | `uint16` | `0` | Docker container port for Docker Model Runner (default: 12434 for Docker Engine, 12435 for Cloud mode) | | `--proxy-cert` | `string` | | Path to a CA certificate file for proxy SSL inspection | +| `--tls` | `bool` | | Enable TLS/HTTPS for Docker Model Runner API | +| `--tls-cert` | `string` | | Path to TLS certificate file (auto-generated if not provided) | +| `--tls-key` | `string` | | Path to TLS private key file (auto-generated if not provided) | +| `--tls-port` | `uint16` | `0` | TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) | diff --git a/cmd/cli/docs/reference/model_reinstall-runner.md b/cmd/cli/docs/reference/model_reinstall-runner.md index e711f710..2ec74431 100644 --- a/cmd/cli/docs/reference/model_reinstall-runner.md +++ b/cmd/cli/docs/reference/model_reinstall-runner.md @@ -14,6 +14,10 @@ Reinstall Docker Model Runner (Docker Engine only) | `--host` | `string` | `127.0.0.1` | Host address to bind Docker Model Runner | | `--port` | `uint16` | `0` | Docker container port for Docker Model Runner (default: 12434 for Docker Engine, 12435 for Cloud mode) | | `--proxy-cert` | `string` | | Path to a CA certificate file for proxy SSL inspection | +| `--tls` | `bool` | | Enable TLS/HTTPS for Docker Model Runner API | +| `--tls-cert` | `string` | | Path to TLS certificate file (auto-generated if not provided) | +| `--tls-key` | `string` | | Path to TLS private key file (auto-generated if not provided) | +| `--tls-port` | `uint16` | `0` | TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) | diff --git a/cmd/cli/docs/reference/model_start-runner.md b/cmd/cli/docs/reference/model_start-runner.md index 2e43a92c..4ca1acc3 100644 --- a/cmd/cli/docs/reference/model_start-runner.md +++ b/cmd/cli/docs/reference/model_start-runner.md @@ -5,14 +5,19 @@ Start Docker Model Runner (Docker Engine only) ### Options -| Name | Type | Default | Description | -|:-----------------|:---------|:--------|:-------------------------------------------------------------------------------------------------------| -| `--backend` | `string` | | Specify backend (llama.cpp\|vllm). Default: llama.cpp | -| `--debug` | `bool` | | Enable debug logging | -| `--do-not-track` | `bool` | | Do not track models usage in Docker Model Runner | -| `--gpu` | `string` | `auto` | Specify GPU support (none\|auto\|cuda\|rocm\|musa\|cann) | -| `--port` | `uint16` | `0` | Docker container port for Docker Model Runner (default: 12434 for Docker Engine, 12435 for Cloud mode) | -| `--proxy-cert` | `string` | | Path to a CA certificate file for proxy SSL inspection | +| Name | Type | Default | Description | +|:-----------------|:---------|:------------|:-------------------------------------------------------------------------------------------------------| +| `--backend` | `string` | | Specify backend (llama.cpp\|vllm). Default: llama.cpp | +| `--debug` | `bool` | | Enable debug logging | +| `--do-not-track` | `bool` | | Do not track models usage in Docker Model Runner | +| `--gpu` | `string` | `auto` | Specify GPU support (none\|auto\|cuda\|rocm\|musa\|cann) | +| `--host` | `string` | `127.0.0.1` | Host address to bind Docker Model Runner | +| `--port` | `uint16` | `0` | Docker container port for Docker Model Runner (default: 12434 for Docker Engine, 12435 for Cloud mode) | +| `--proxy-cert` | `string` | | Path to a CA certificate file for proxy SSL inspection | +| `--tls` | `bool` | | Enable TLS/HTTPS for Docker Model Runner API | +| `--tls-cert` | `string` | | Path to TLS certificate file (auto-generated if not provided) | +| `--tls-key` | `string` | | Path to TLS private key file (auto-generated if not provided) | +| `--tls-port` | `uint16` | `0` | TLS port for Docker Model Runner (default: 12444 for Docker Engine, 12445 for Cloud mode) | diff --git a/cmd/cli/pkg/standalone/containers.go b/cmd/cli/pkg/standalone/containers.go index a6ed8afa..7bcddad6 100644 --- a/cmd/cli/pkg/standalone/containers.go +++ b/cmd/cli/pkg/standalone/containers.go @@ -274,8 +274,23 @@ func tryGetBindAscendMounts(printer StatusPrinter, debug bool) []mount.Mount { // This location is used by update-ca-certificates to add the cert to the system trust store. const proxyCertContainerPath = "/usr/local/share/ca-certificates/proxy-ca.crt" +// TLSOptions holds TLS configuration for the controller container. +type TLSOptions struct { + // Enabled indicates whether TLS is enabled. + Enabled bool + // Port is the TLS port (0 to use default). + Port uint16 + // CertPath is the path to the TLS certificate file. + CertPath string + // KeyPath is the path to the TLS key file. + KeyPath string +} + +// tlsCertContainerPath is the path where TLS certificates will be mounted in the container. +const tlsCertContainerPath = "/etc/model-runner/certs" + // CreateControllerContainer creates and starts a controller container. -func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, port uint16, host string, environment string, doNotTrack bool, gpu gpupkg.GPUSupport, backend string, modelStorageVolume string, printer StatusPrinter, engineKind types.ModelRunnerEngineKind, debug bool, vllmOnWSL bool, proxyCert string) error { +func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, port uint16, host string, environment string, doNotTrack bool, gpu gpupkg.GPUSupport, backend string, modelStorageVolume string, printer StatusPrinter, engineKind types.ModelRunnerEngineKind, debug bool, vllmOnWSL bool, proxyCert string, tlsOpts TLSOptions) error { imageName := controllerImageName(gpu, backend) // Set up the container configuration. @@ -296,12 +311,51 @@ func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, } } + // Determine TLS port + tlsPort := tlsOpts.Port + if tlsOpts.Enabled && tlsPort == 0 { + if engineKind == types.ModelRunnerEngineKindCloud { + tlsPort = DefaultTLSPortCloud + } else { + tlsPort = DefaultTLSPortMoby + } + } + + // Add TLS environment variables if TLS is enabled + if tlsOpts.Enabled { + env = append(env, "MODEL_RUNNER_TLS_ENABLED=true") + env = append(env, "MODEL_RUNNER_TLS_PORT="+strconv.Itoa(int(tlsPort))) + if tlsOpts.CertPath != "" && tlsOpts.KeyPath != "" { + // Determine the actual file names in the container + certContainerPath := tlsCertContainerPath + "/server.crt" + keyContainerPath := tlsCertContainerPath + "/server.key" + + // If cert and key are in the same directory, use their actual file names + certDir := filepath.Dir(tlsOpts.CertPath) + keyDir := filepath.Dir(tlsOpts.KeyPath) + if certDir == keyDir { + certContainerPath = tlsCertContainerPath + "/" + filepath.Base(tlsOpts.CertPath) + keyContainerPath = tlsCertContainerPath + "/" + filepath.Base(tlsOpts.KeyPath) + } + + // Use mounted certificates + env = append(env, "MODEL_RUNNER_TLS_CERT="+certContainerPath) + env = append(env, "MODEL_RUNNER_TLS_KEY="+keyContainerPath) + } + // If no cert paths, auto-cert will be used inside the container + } + + exposedPorts := nat.PortSet{ + nat.Port(portStr + "/tcp"): struct{}{}, + } + if tlsOpts.Enabled { + exposedPorts[nat.Port(strconv.Itoa(int(tlsPort))+"/tcp")] = struct{}{} + } + config := &container.Config{ - Image: imageName, - Env: env, - ExposedPorts: nat.PortSet{ - nat.Port(portStr + "/tcp"): struct{}{}, - }, + Image: imageName, + Env: env, + ExposedPorts: exposedPorts, Labels: map[string]string{ labelDesktopService: serviceModelRunner, labelRole: roleController, @@ -333,18 +387,74 @@ func CreateControllerContainer(ctx context.Context, dockerClient *client.Client, }) } - portBindings := []nat.PortBinding{{HostIP: host, HostPort: portStr}} - if os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") != "1" { - // Don't bind the bridge gateway IP if we're treating Docker Desktop as Moby. - // Only add bridge gateway IP binding if host is 127.0.0.1 and not in rootless mode - if host == "127.0.0.1" && !isRootless(ctx, dockerClient) && !vllmOnWSL { - if bridgeGatewayIP, err := determineBridgeGatewayIP(ctx, dockerClient); err == nil && bridgeGatewayIP != "" { - portBindings = append(portBindings, nat.PortBinding{HostIP: bridgeGatewayIP, HostPort: portStr}) + // Mount TLS certificates if custom paths are provided + if tlsOpts.Enabled && tlsOpts.CertPath != "" && tlsOpts.KeyPath != "" { + // Get the directory containing the certificates + certDir := filepath.Dir(tlsOpts.CertPath) + keyDir := filepath.Dir(tlsOpts.KeyPath) + + if certDir == keyDir { + // Both files in same directory, mount each file individually with their actual names + certFileName := filepath.Base(tlsOpts.CertPath) + keyFileName := filepath.Base(tlsOpts.KeyPath) + + hostConfig.Mounts = append(hostConfig.Mounts, + mount.Mount{ + Type: mount.TypeBind, + Source: tlsOpts.CertPath, + Target: tlsCertContainerPath + "/" + certFileName, + ReadOnly: true, + }, + mount.Mount{ + Type: mount.TypeBind, + Source: tlsOpts.KeyPath, + Target: tlsCertContainerPath + "/" + keyFileName, + ReadOnly: true, + }, + ) + } else { + // Files in different directories, mount each file individually + hostConfig.Mounts = append(hostConfig.Mounts, + mount.Mount{ + Type: mount.TypeBind, + Source: tlsOpts.CertPath, + Target: tlsCertContainerPath + "/server.crt", + ReadOnly: true, + }, + mount.Mount{ + Type: mount.TypeBind, + Source: tlsOpts.KeyPath, + Target: tlsCertContainerPath + "/server.key", + ReadOnly: true, + }, + ) + } + } + + // Helper function to create port bindings with optional bridge gateway IP + createPortBindings := func(port string) []nat.PortBinding { + portBindings := []nat.PortBinding{{HostIP: host, HostPort: port}} + if os.Getenv("_MODEL_RUNNER_TREAT_DESKTOP_AS_MOBY") != "1" { + // Don't bind the bridge gateway IP if we're treating Docker Desktop as Moby. + // Only add bridge gateway IP binding if host is 127.0.0.1 and not in rootless mode + if host == "127.0.0.1" && !isRootless(ctx, dockerClient) && !vllmOnWSL { + if bridgeGatewayIP, err := determineBridgeGatewayIP(ctx, dockerClient); err == nil && bridgeGatewayIP != "" { + portBindings = append(portBindings, nat.PortBinding{HostIP: bridgeGatewayIP, HostPort: port}) + } } } + return portBindings } + + // Create port bindings for the main port hostConfig.PortBindings = nat.PortMap{ - nat.Port(portStr + "/tcp"): portBindings, + nat.Port(portStr + "/tcp"): createPortBindings(portStr), + } + + // Add TLS port bindings if TLS is enabled + if tlsOpts.Enabled { + tlsPortStr := strconv.Itoa(int(tlsPort)) + hostConfig.PortBindings[nat.Port(tlsPortStr+"/tcp")] = createPortBindings(tlsPortStr) } switch gpu { case gpupkg.GPUSupportNone: diff --git a/cmd/cli/pkg/standalone/ports.go b/cmd/cli/pkg/standalone/ports.go index eea85732..92cd5f7a 100644 --- a/cmd/cli/pkg/standalone/ports.go +++ b/cmd/cli/pkg/standalone/ports.go @@ -7,4 +7,11 @@ const ( // DefaultControllerPortCloud is the default TCP port on which the // standalone controller will listen for requests in Cloud environments. DefaultControllerPortCloud = 12435 + + // DefaultTLSPortMoby is the default TCP port on which the standalone + // controller will listen for TLS requests in Moby environments. + DefaultTLSPortMoby = 12444 + // DefaultTLSPortCloud is the default TCP port on which the standalone + // controller will listen for TLS requests in Cloud environments. + DefaultTLSPortCloud = 12445 ) diff --git a/main.go b/main.go index 17826d41..08180b3c 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "net" "net/http" "os" @@ -26,9 +27,15 @@ import ( "github.com/docker/model-runner/pkg/ollama" "github.com/docker/model-runner/pkg/responses" "github.com/docker/model-runner/pkg/routing" + modeltls "github.com/docker/model-runner/pkg/tls" "github.com/sirupsen/logrus" ) +const ( + // DefaultTLSPort is the default TLS port for Moby + DefaultTLSPort = "12444" +) + var log = logrus.New() // Log is the logger used by the application, exported for testing purposes. @@ -256,6 +263,10 @@ func main() { } serverErrors := make(chan error, 1) + // TLS server (optional) + var tlsServer *http.Server + tlsServerErrors := make(chan error, 1) + // Check if we should use TCP port instead of Unix socket tcpPort := os.Getenv("MODEL_RUNNER_PORT") if tcpPort != "" { @@ -282,22 +293,92 @@ func main() { }() } + // Start TLS server if enabled + if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" { + tlsPort := os.Getenv("MODEL_RUNNER_TLS_PORT") + if tlsPort == "" { + tlsPort = DefaultTLSPort // Default TLS port for Moby + } + + // Get certificate paths + certPath := os.Getenv("MODEL_RUNNER_TLS_CERT") + keyPath := os.Getenv("MODEL_RUNNER_TLS_KEY") + + // Auto-generate certificates if not provided and auto-cert is not disabled + if certPath == "" || keyPath == "" { + if os.Getenv("MODEL_RUNNER_TLS_AUTO_CERT") != "false" { + log.Info("Auto-generating TLS certificates...") + var err error + certPath, keyPath, err = modeltls.EnsureCertificates("", "") + if err != nil { + log.Fatalf("Failed to ensure TLS certificates: %v", err) + } + log.Infof("Using TLS certificate: %s", certPath) + log.Infof("Using TLS key: %s", keyPath) + } else { + log.Fatal("TLS enabled but no certificate provided and auto-cert is disabled") + } + } + + // Load TLS configuration + tlsConfig, err := modeltls.LoadTLSConfig(certPath, keyPath) + if err != nil { + log.Fatalf("Failed to load TLS configuration: %v", err) + } + + tlsServer = &http.Server{ + Addr: ":" + tlsPort, + Handler: router, + TLSConfig: tlsConfig, + ReadHeaderTimeout: 10 * time.Second, + } + + log.Infof("Listening on TLS port %s", tlsPort) + go func() { + // Use ListenAndServeTLS with empty strings since TLSConfig already has the certs + ln, err := tls.Listen("tcp", tlsServer.Addr, tlsConfig) + if err != nil { + tlsServerErrors <- err + return + } + tlsServerErrors <- tlsServer.Serve(ln) + }() + } + schedulerErrors := make(chan error, 1) go func() { schedulerErrors <- scheduler.Run(ctx) }() + var tlsServerErrorsChan <-chan error + if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" { + tlsServerErrorsChan = tlsServerErrors + } else { + // Use a nil channel which will block forever when TLS is disabled + tlsServerErrorsChan = nil + } + select { case err := <-serverErrors: if err != nil { log.Errorf("Server error: %v", err) } + case err := <-tlsServerErrorsChan: + if err != nil { + log.Errorf("TLS server error: %v", err) + } case <-ctx.Done(): log.Infoln("Shutdown signal received") log.Infoln("Shutting down the server") if err := server.Close(); err != nil { log.Errorf("Server shutdown error: %v", err) } + if tlsServer != nil { + log.Infoln("Shutting down the TLS server") + if err := tlsServer.Close(); err != nil { + log.Errorf("TLS server shutdown error: %v", err) + } + } log.Infoln("Waiting for the scheduler to stop") if err := <-schedulerErrors; err != nil { log.Errorf("Scheduler error: %v", err) diff --git a/pkg/tls/certs.go b/pkg/tls/certs.go new file mode 100644 index 00000000..c8f69799 --- /dev/null +++ b/pkg/tls/certs.go @@ -0,0 +1,349 @@ +// Package tls provides TLS certificate generation and management utilities +// for the Model Runner API server. +package tls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "path/filepath" + "time" +) + +const ( + // DefaultCertsDir is the default directory for storing certificates. + DefaultCertsDir = ".docker/model-runner/certs" + + // CACertFile is the filename for the CA certificate. + CACertFile = "ca.crt" + // CAKeyFile is the filename for the CA private key. + CAKeyFile = "ca.key" + // ServerCertFile is the filename for the server certificate. + ServerCertFile = "server.crt" + // ServerKeyFile is the filename for the server private key. + ServerKeyFile = "server.key" + + // DefaultCertValidityDays is the default validity period for certificates. + DefaultCertValidityDays = 365 + // DefaultCAValidityDays is the default validity period for CA certificates. + DefaultCAValidityDays = 3650 // 10 years +) + +// CertPaths holds the paths to certificate and key files. +type CertPaths struct { + CACert string + CAKey string + ServerCert string + ServerKey string +} + +// DefaultCertPaths returns the default certificate paths in the user's home directory. +func DefaultCertPaths() (*CertPaths, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get user home directory: %w", err) + } + + certsDir := filepath.Join(homeDir, DefaultCertsDir) + return &CertPaths{ + CACert: filepath.Join(certsDir, CACertFile), + CAKey: filepath.Join(certsDir, CAKeyFile), + ServerCert: filepath.Join(certsDir, ServerCertFile), + ServerKey: filepath.Join(certsDir, ServerKeyFile), + }, nil +} + +// EnsureCertificates checks for existing certificates or generates new ones. +// If certPath and keyPath are provided, they are used directly. +// Otherwise, auto-generated certificates are checked/created in the default location. +// Returns the paths to the certificate and key files. +func EnsureCertificates(certPath, keyPath string) (cert, key string, err error) { + // If custom paths are provided, use them directly + if certPath != "" && keyPath != "" { + if _, err := os.Stat(certPath); err != nil { + return "", "", fmt.Errorf("certificate file not found: %s", certPath) + } + if _, err := os.Stat(keyPath); err != nil { + return "", "", fmt.Errorf("key file not found: %s", keyPath) + } + return certPath, keyPath, nil + } + + // Use default paths for auto-generated certificates + paths, err := DefaultCertPaths() + if err != nil { + return "", "", err + } + + // Check if certificates already exist and are valid + if certsExistAndValid(paths) { + return paths.ServerCert, paths.ServerKey, nil + } + + // Generate new certificates + if err := GenerateCertificates(paths); err != nil { + return "", "", fmt.Errorf("failed to generate certificates: %w", err) + } + + return paths.ServerCert, paths.ServerKey, nil +} + +// certsExistAndValid checks if certificate files exist and are not expired. +func certsExistAndValid(paths *CertPaths) bool { + // Check if files exist + for _, path := range []string{paths.CACert, paths.CAKey, paths.ServerCert, paths.ServerKey} { + if _, err := os.Stat(path); err != nil { + return false + } + } + + // Check if server certificate is still valid + certPEM, err := os.ReadFile(paths.ServerCert) + if err != nil { + return false + } + + block, _ := pem.Decode(certPEM) + if block == nil { + return false + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false + } + + // Check if certificate expires within 30 days + if time.Until(cert.NotAfter) < 30*24*time.Hour { + return false + } + + return true +} + +// GenerateCertificates generates a CA certificate and a server certificate signed by the CA. +func GenerateCertificates(paths *CertPaths) error { + // Ensure the certificates directory exists + certsDir := filepath.Dir(paths.CACert) + if err := os.MkdirAll(certsDir, 0700); err != nil { + return fmt.Errorf("failed to create certificates directory: %w", err) + } + + // Generate CA certificate + caKey, caCert, err := GenerateSelfSignedCA() + if err != nil { + return fmt.Errorf("failed to generate CA certificate: %w", err) + } + + // Save CA certificate and key + if err := saveCertAndKey(paths.CACert, paths.CAKey, caCert, caKey); err != nil { + return fmt.Errorf("failed to save CA certificate: %w", err) + } + + // Generate server certificate signed by CA + serverKey, serverCert, err := GenerateServerCert(caKey, caCert) + if err != nil { + return fmt.Errorf("failed to generate server certificate: %w", err) + } + + // Save server certificate and key + if err := saveCertAndKey(paths.ServerCert, paths.ServerKey, serverCert, serverKey); err != nil { + return fmt.Errorf("failed to save server certificate: %w", err) + } + + return nil +} + +// GenerateSelfSignedCA creates a self-signed CA certificate. +func GenerateSelfSignedCA() (*ecdsa.PrivateKey, *x509.Certificate, error) { + // Generate private key + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Generate serial number + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Docker Model Runner"}, + OrganizationalUnit: []string{"Self-Signed CA"}, + CommonName: "Docker Model Runner CA", + }, + NotBefore: time.Now().Add(-1 * time.Hour), // Allow for clock skew + NotAfter: time.Now().AddDate(0, 0, DefaultCAValidityDays), + IsCA: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + MaxPathLen: 1, + } + + // Create the certificate + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Parse the certificate back + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return privateKey, cert, nil +} + +// GenerateServerCert creates a server certificate signed by the given CA. +func GenerateServerCert(caKey *ecdsa.PrivateKey, caCert *x509.Certificate) (*ecdsa.PrivateKey, *x509.Certificate, error) { + // Generate private key for server + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %w", err) + } + + // Generate serial number + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + // Create certificate template + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Docker Model Runner"}, + OrganizationalUnit: []string{"Server"}, + CommonName: "localhost", + }, + NotBefore: time.Now().Add(-1 * time.Hour), // Allow for clock skew + NotAfter: time.Now().AddDate(0, 0, DefaultCertValidityDays), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost", "docker-model-runner"}, + IPAddresses: []net.IP{ + net.ParseIP("127.0.0.1"), + net.ParseIP("::1"), + net.IPv4zero, // 0.0.0.0 + net.IPv6zero, // :: + }, + } + + // Create the certificate signed by CA + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &privateKey.PublicKey, caKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate: %w", err) + } + + // Parse the certificate back + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return privateKey, cert, nil +} + +// saveCertAndKey saves a certificate and private key to PEM files. +func saveCertAndKey(certPath, keyPath string, cert *x509.Certificate, key *ecdsa.PrivateKey) error { + // Encode and save the certificate + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { + return fmt.Errorf("failed to write certificate file: %w", err) + } + + // Encode and save the private key + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("failed to marshal private key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyDER, + }) + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return fmt.Errorf("failed to write key file: %w", err) + } + + return nil +} + +// LoadTLSConfig loads certificates and returns a TLS configuration for the server. +func LoadTLSConfig(certPath, keyPath string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("failed to load certificate and key: %w", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }, nil +} + +// LoadClientTLSConfig loads CA certificates and returns a TLS configuration for clients. +// If caCertPath is empty, it uses the default CA certificate location. +// If skipVerify is true, certificate verification is skipped (for development only). +func LoadClientTLSConfig(caCertPath string, skipVerify bool) (*tls.Config, error) { + if skipVerify { + return &tls.Config{ + InsecureSkipVerify: true, //nolint:gosec // Intentional for development/testing + MinVersion: tls.VersionTLS12, + }, nil + } + + // Determine CA certificate path + if caCertPath == "" { + paths, err := DefaultCertPaths() + if err != nil { + return nil, err + } + caCertPath = paths.CACert + } + + // Load CA certificate + caCertPEM, err := os.ReadFile(caCertPath) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate: %w", err) + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCertPEM) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + + return &tls.Config{ + RootCAs: certPool, + MinVersion: tls.VersionTLS12, + }, nil +} + +// GetCACertPath returns the path to the CA certificate file. +// Returns the custom path if provided, otherwise returns the default path. +func GetCACertPath(customPath string) (string, error) { + if customPath != "" { + return customPath, nil + } + + paths, err := DefaultCertPaths() + if err != nil { + return "", err + } + return paths.CACert, nil +} diff --git a/pkg/tls/certs_test.go b/pkg/tls/certs_test.go new file mode 100644 index 00000000..f09b3163 --- /dev/null +++ b/pkg/tls/certs_test.go @@ -0,0 +1,537 @@ +package tls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "math/big" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" +) + +func TestGenerateSelfSignedCA(t *testing.T) { + key, cert, err := GenerateSelfSignedCA() + if err != nil { + t.Fatalf("GenerateSelfSignedCA() error = %v", err) + } + + if key == nil { + t.Error("GenerateSelfSignedCA() returned nil key") + } + + if cert == nil { + t.Error("GenerateSelfSignedCA() returned nil cert") + return // Return early to avoid nil dereference + } + + // Verify the certificate is a CA + if !cert.IsCA { + t.Error("Generated certificate is not a CA") + } + + // Verify the certificate has proper key usage + if cert.KeyUsage&x509.KeyUsageCertSign == 0 { + t.Error("CA certificate missing KeyUsageCertSign") + } + + // Verify validity period + if cert.NotAfter.Before(time.Now().AddDate(0, 0, DefaultCAValidityDays-1)) { + t.Error("CA certificate validity period is too short") + } +} + +func TestGenerateServerCert(t *testing.T) { + // First generate a CA + caKey, caCert, err := GenerateSelfSignedCA() + if err != nil { + t.Fatalf("GenerateSelfSignedCA() error = %v", err) + } + + // Generate server certificate + serverKey, serverCert, err := GenerateServerCert(caKey, caCert) + if err != nil { + t.Fatalf("GenerateServerCert() error = %v", err) + } + + if serverKey == nil { + t.Error("GenerateServerCert() returned nil key") + } + + if serverCert == nil { + t.Error("GenerateServerCert() returned nil cert") + return // Return early to avoid nil dereference + } + + // Verify the certificate is NOT a CA + if serverCert.IsCA { + t.Error("Server certificate should not be a CA") + } + + // Verify server authentication extended key usage + hasServerAuth := false + for _, usage := range serverCert.ExtKeyUsage { + if usage == x509.ExtKeyUsageServerAuth { + hasServerAuth = true + break + } + } + if !hasServerAuth { + t.Error("Server certificate missing ExtKeyUsageServerAuth") + } + + // Verify DNS names + expectedDNS := []string{"localhost", "docker-model-runner"} + for _, expected := range expectedDNS { + found := false + for _, dns := range serverCert.DNSNames { + if dns == expected { + found = true + break + } + } + if !found { + t.Errorf("Server certificate missing DNS name: %s", expected) + } + } + + // Verify the certificate was signed by the CA + err = serverCert.CheckSignatureFrom(caCert) + if err != nil { + t.Errorf("Server certificate not signed by CA: %v", err) + } +} + +func TestGenerateCertificates(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Verify all files exist + for _, path := range []string{paths.CACert, paths.CAKey, paths.ServerCert, paths.ServerKey} { + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file not created: %s", path) + } + } + + // Verify certificate permissions + if info, err := os.Stat(paths.CAKey); err == nil { + if info.Mode().Perm() != 0600 { + t.Errorf("CA key has wrong permissions: %v, expected 0600", info.Mode().Perm()) + } + } + + if info, err := os.Stat(paths.ServerKey); err == nil { + if info.Mode().Perm() != 0600 { + t.Errorf("Server key has wrong permissions: %v, expected 0600", info.Mode().Perm()) + } + } +} + +func TestLoadTLSConfig(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Load TLS config + tlsConfig, err := LoadTLSConfig(paths.ServerCert, paths.ServerKey) + if err != nil { + t.Fatalf("LoadTLSConfig() error = %v", err) + } + + if tlsConfig == nil { + t.Error("LoadTLSConfig() returned nil config") + return // Return early to avoid nil dereference + } + + if len(tlsConfig.Certificates) != 1 { + t.Errorf("Expected 1 certificate, got %d", len(tlsConfig.Certificates)) + } + + if tlsConfig.MinVersion != tls.VersionTLS12 { + t.Error("TLS minimum version should be TLS 1.2") + } +} + +func TestLoadClientTLSConfig(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Test loading with CA cert + tlsConfig, err := LoadClientTLSConfig(paths.CACert, false) + if err != nil { + t.Fatalf("LoadClientTLSConfig() error = %v", err) + } + + if tlsConfig == nil { + t.Error("LoadClientTLSConfig() returned nil config") + return // Return early to avoid nil dereference + } + + if tlsConfig.RootCAs == nil { + t.Error("LoadClientTLSConfig() should have RootCAs set") + } + + if tlsConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify should be false when skipVerify=false") + } + + // Test with skipVerify=true + tlsConfig, err = LoadClientTLSConfig("", true) + if err != nil { + t.Fatalf("LoadClientTLSConfig() with skipVerify error = %v", err) + } + + if !tlsConfig.InsecureSkipVerify { + t.Error("InsecureSkipVerify should be true when skipVerify=true") + } +} + +func TestLoadClientTLSConfig_CAErrors(t *testing.T) { + tmpDir := t.TempDir() + + t.Run("non-existent CA path", func(t *testing.T) { + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, "does-not-exist-ca.pem"), + } + + tlsConfig, err := LoadClientTLSConfig(paths.CACert, false) + if err == nil { + t.Fatalf("expected error when loading client TLS config with non-existent CA path, got nil") + } + if tlsConfig != nil { + t.Fatalf("expected nil TLS config on CA load error, got non-nil") + } + }) + + t.Run("invalid CA PEM", func(t *testing.T) { + caPath := filepath.Join(tmpDir, "invalid-ca.pem") + if err := os.WriteFile(caPath, []byte("this is not a valid PEM"), 0o600); err != nil { + t.Fatalf("failed to write invalid CA file: %v", err) + } + + tlsConfig, err := LoadClientTLSConfig(caPath, false) + if err == nil { + t.Fatalf("expected error when loading client TLS config with invalid CA PEM, got nil") + } + if tlsConfig != nil { + t.Fatalf("expected nil TLS config on CA load error, got non-nil") + } + }) +} + +func TestEnsureCertificates_CustomPaths(t *testing.T) { + // Create a temporary directory with existing certs + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Generate certs first + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Test with custom paths + cert, key, err := EnsureCertificates(paths.ServerCert, paths.ServerKey) + if err != nil { + t.Fatalf("EnsureCertificates() error = %v", err) + } + + if cert != paths.ServerCert { + t.Errorf("Expected cert path %s, got %s", paths.ServerCert, cert) + } + + if key != paths.ServerKey { + t.Errorf("Expected key path %s, got %s", paths.ServerKey, key) + } +} + +func TestEnsureCertificates_MissingFile(t *testing.T) { + _, _, err := EnsureCertificates("/nonexistent/cert.pem", "/nonexistent/key.pem") + if err == nil { + t.Error("EnsureCertificates() should return error for missing files") + } +} + +func TestCertsExistAndValid(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + // Should return false when files don't exist + if certsExistAndValid(paths) { + t.Error("certsExistAndValid() should return false for non-existent files") + } + + // Generate certs + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Should return true for valid certs + if !certsExistAndValid(paths) { + t.Error("certsExistAndValid() should return true for valid certs") + } + + t.Run("near-expiry certs treated as invalid", func(t *testing.T) { + // Load CA certificate + caCertPEM, err := os.ReadFile(paths.CACert) + if err != nil { + t.Fatalf("failed to read CA cert: %v", err) + } + caCertBlock, _ := pem.Decode(caCertPEM) + if caCertBlock == nil { + t.Fatal("failed to decode CA cert PEM") + } + caCert, err := x509.ParseCertificate(caCertBlock.Bytes) + if err != nil { + t.Fatalf("failed to parse CA cert: %v", err) + } + + // Load CA private key + caKeyPEM, err := os.ReadFile(paths.CAKey) + if err != nil { + t.Fatalf("failed to read CA key: %v", err) + } + caKeyBlock, _ := pem.Decode(caKeyPEM) + if caKeyBlock == nil { + t.Fatal("failed to decode CA key PEM") + } + caKey, err := x509.ParseECPrivateKey(caKeyBlock.Bytes) + if err != nil { + // Try parsing as PKCS#8 + caKeyI, err := x509.ParsePKCS8PrivateKey(caKeyBlock.Bytes) + if err != nil { + t.Fatalf("failed to parse CA key: %v", err) + } + var ok bool + caKey, ok = caKeyI.(*ecdsa.PrivateKey) + if !ok { + t.Fatal("not an ECDSA private key") + } + } + + // Create a server certificate that expires in 15 days + serverKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("failed to generate server key: %v", err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(15 * 24 * time.Hour), // within 30 days, should be treated as invalid + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + DNSNames: []string{"localhost", "docker-model-runner"}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, template, caCert, &serverKey.PublicKey, caKey) + if err != nil { + t.Fatalf("failed to create near-expiry server cert: %v", err) + } + + // Overwrite server cert and key on disk with the near-expiry ones + serverCertPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + if err := os.WriteFile(paths.ServerCert, serverCertPEM, 0o600); err != nil { + t.Fatalf("failed to write near-expiry server cert: %v", err) + } + + serverKeyBytes, err := x509.MarshalECPrivateKey(serverKey) + if err != nil { + t.Fatalf("failed to marshal server key: %v", err) + } + serverKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: serverKeyBytes}) + if err := os.WriteFile(paths.ServerKey, serverKeyPEM, 0o600); err != nil { + t.Fatalf("failed to write near-expiry server key: %v", err) + } + + // Now certsExistAndValid should treat these near-expiry certs as invalid + if certsExistAndValid(paths) { + t.Error("certsExistAndValid() should return false for near-expiry certificates") + } + }) +} + +func TestTLSServerIntegration(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Load server TLS config + serverTLSConfig, err := LoadTLSConfig(paths.ServerCert, paths.ServerKey) + if err != nil { + t.Fatalf("LoadTLSConfig() error = %v", err) + } + + // Create a test server + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + })) + server.TLS = serverTLSConfig + server.StartTLS() + defer server.Close() + + // Load client TLS config + clientTLSConfig, err := LoadClientTLSConfig(paths.CACert, false) + if err != nil { + t.Fatalf("LoadClientTLSConfig() error = %v", err) + } + + // Create HTTP client with TLS config + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: clientTLSConfig, + }, + } + + // Make a request + resp, err := client.Get(server.URL) + if err != nil { + t.Fatalf("HTTP request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } +} + +func TestCertificatePEMEncoding(t *testing.T) { + // Create a temporary directory + tmpDir, err := os.MkdirTemp("", "certs-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + paths := &CertPaths{ + CACert: filepath.Join(tmpDir, CACertFile), + CAKey: filepath.Join(tmpDir, CAKeyFile), + ServerCert: filepath.Join(tmpDir, ServerCertFile), + ServerKey: filepath.Join(tmpDir, ServerKeyFile), + } + + err = GenerateCertificates(paths) + if err != nil { + t.Fatalf("GenerateCertificates() error = %v", err) + } + + // Read and verify PEM encoding of certificate + certPEM, err := os.ReadFile(paths.ServerCert) + if err != nil { + t.Fatalf("Failed to read cert file: %v", err) + } + + block, _ := pem.Decode(certPEM) + if block == nil { + t.Error("Failed to decode certificate PEM") + return // Return early to avoid nil dereference + } + if block.Type != "CERTIFICATE" { + t.Errorf("Expected PEM type CERTIFICATE, got %s", block.Type) + } + + // Read and verify PEM encoding of key + keyPEM, err := os.ReadFile(paths.ServerKey) + if err != nil { + t.Fatalf("Failed to read key file: %v", err) + } + + block, _ = pem.Decode(keyPEM) + if block == nil { + t.Error("Failed to decode key PEM") + return // Return early to avoid nil dereference + } + if block.Type != "EC PRIVATE KEY" { + t.Errorf("Expected PEM type EC PRIVATE KEY, got %s", block.Type) + } +}