Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 24 additions & 47 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"syscall"
"time"

"github.com/docker/model-runner/pkg/envconfig"
"github.com/docker/model-runner/pkg/inference"
"github.com/docker/model-runner/pkg/inference/backends/diffusers"
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
Expand All @@ -27,15 +28,9 @@ import (
modeltls "github.com/docker/model-runner/pkg/tls"
)

const (
// DefaultTLSPort is the default TLS port for Moby
DefaultTLSPort = "12444"
)

// initLogger creates the application logger based on LOG_LEVEL env var.
func initLogger() *slog.Logger {
level := logging.ParseLevel(os.Getenv("LOG_LEVEL"))
return logging.NewLogger(level)
return logging.NewLogger(envconfig.LogLevel())
}

var log = initLogger()
Expand All @@ -47,45 +42,29 @@ func main() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

sockName := os.Getenv("MODEL_RUNNER_SOCK")
if sockName == "" {
sockName = "model-runner.sock"
}

userHomeDir, err := os.UserHomeDir()
sockName := envconfig.SocketPath()
modelPath, err := envconfig.ModelsPath()
if err != nil {
log.Error("Failed to get user home directory", "error", err)
log.Error("Failed to get models path", "error", err)
exitFunc(1)
}

modelPath := os.Getenv("MODELS_PATH")
if modelPath == "" {
modelPath = filepath.Join(userHomeDir, ".docker", "models")
}

_, disableServerUpdate := os.LookupEnv("DISABLE_SERVER_UPDATE")
if disableServerUpdate {
if envconfig.DisableServerUpdate() {
llamacpp.ShouldUpdateServerLock.Lock()
llamacpp.ShouldUpdateServer = false
llamacpp.ShouldUpdateServerLock.Unlock()
}

desiredServerVersion, ok := os.LookupEnv("LLAMA_SERVER_VERSION")
if ok {
llamacpp.SetDesiredServerVersion(desiredServerVersion)
}

llamaServerPath := os.Getenv("LLAMA_SERVER_PATH")
if llamaServerPath == "" {
llamaServerPath = "/Applications/Docker.app/Contents/Resources/model-runner/bin"
if v := envconfig.LlamaServerVersion(); v != "" {
llamacpp.SetDesiredServerVersion(v)
}

// Get optional custom paths for other backends
vllmServerPath := os.Getenv("VLLM_SERVER_PATH")
sglangServerPath := os.Getenv("SGLANG_SERVER_PATH")
mlxServerPath := os.Getenv("MLX_SERVER_PATH")
diffusersServerPath := os.Getenv("DIFFUSERS_SERVER_PATH")
vllmMetalServerPath := os.Getenv("VLLM_METAL_SERVER_PATH")
llamaServerPath := envconfig.LlamaServerPath()
vllmServerPath := envconfig.VLLMServerPath()
sglangServerPath := envconfig.SGLangServerPath()
mlxServerPath := envconfig.MLXServerPath()
diffusersServerPath := envconfig.DiffusersServerPath()
vllmMetalServerPath := envconfig.VLLMMetalServerPath()

// Create a proxy-aware HTTP transport
// Use a safe type assertion with fallback, and explicitly set Proxy to http.ProxyFromEnvironment
Expand Down Expand Up @@ -169,6 +148,7 @@ func main() {
"",
false,
),
AllowedOrigins: envconfig.AllowedOrigins(),
IncludeResponsesAPI: true,
ExtraRoutes: func(r *routing.NormalizedServeMux, s *routing.Service) {
// Root handler – only catches exact "/" requests
Expand All @@ -190,7 +170,7 @@ func main() {
})

// Metrics endpoint
if os.Getenv("DISABLE_METRICS") != "1" {
if !envconfig.DisableMetrics() {
metricsHandler := metrics.NewAggregatedMetricsHandler(
log.With("component", "metrics"),
s.SchedulerHTTP,
Expand Down Expand Up @@ -218,7 +198,7 @@ func main() {
tlsServerErrors := make(chan error, 1)

// Check if we should use TCP port instead of Unix socket
tcpPort := os.Getenv("MODEL_RUNNER_PORT")
tcpPort := envconfig.TCPPort()
if tcpPort != "" {
// Use TCP port
addr := ":" + tcpPort
Expand Down Expand Up @@ -246,19 +226,16 @@ func main() {
}

// Start TLS server if enabled
if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" {
tlsPort := os.Getenv("MODEL_RUNNER_TLS_PORT")
if tlsPort == "" {
tlsPort = DefaultTLSPort // Default TLS port for Moby
}
if envconfig.TLSEnabled() {
tlsPort := envconfig.TLSPort()

// Get certificate paths
certPath := os.Getenv("MODEL_RUNNER_TLS_CERT")
keyPath := os.Getenv("MODEL_RUNNER_TLS_KEY")
certPath := envconfig.TLSCert()
keyPath := envconfig.TLSKey()

// Auto-generate certificates if not provided and auto-cert is not disabled
if certPath == "" || keyPath == "" {
if os.Getenv("MODEL_RUNNER_TLS_AUTO_CERT") != "false" {
if envconfig.TLSAutoCert(true) {
log.Info("Auto-generating TLS certificates...")
var err error
certPath, keyPath, err = modeltls.EnsureCertificates("", "")
Expand Down Expand Up @@ -306,7 +283,7 @@ func main() {
}()

var tlsServerErrorsChan <-chan error
if os.Getenv("MODEL_RUNNER_TLS_ENABLED") == "true" {
if envconfig.TLSEnabled() {
tlsServerErrorsChan = tlsServerErrors
} else {
// Use a nil channel which will block forever when TLS is disabled
Expand Down Expand Up @@ -346,7 +323,7 @@ func main() {
// Returns nil config (use defaults) when LLAMA_ARGS is unset, or an error if
// the args contain disallowed flags.
func createLlamaCppConfigFromEnv() (config.BackendConfig, error) {
argsStr := os.Getenv("LLAMA_ARGS")
argsStr := envconfig.LlamaArgs()
if argsStr == "" {
return nil, nil
}
Expand Down
227 changes: 227 additions & 0 deletions pkg/envconfig/envconfig.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package envconfig

import (
"fmt"
"log/slog"
"net"
"os"
"path/filepath"
"strconv"
"strings"

"github.com/docker/model-runner/pkg/logging"
)

// Var returns an environment variable stripped of leading/trailing quotes and spaces.
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}

// String returns a lazy string accessor for the given environment variable.
func String(key string) func() string {
return func() string {
return Var(key)
}
}

// BoolWithDefault returns a lazy bool accessor for the given environment variable,
// allowing a caller-specified default. If the variable is set but cannot be parsed
// as a bool, the defaultValue is returned.
func BoolWithDefault(key string) func(defaultValue bool) bool {
return func(defaultValue bool) bool {
if s := Var(key); s != "" {
b, err := strconv.ParseBool(s)
if err != nil {
return defaultValue
}
return b
}
return defaultValue
}
}

// Bool returns a lazy bool accessor that defaults to false when the variable is unset.
func Bool(key string) func() bool {
withDefault := BoolWithDefault(key)
return func() bool {
return withDefault(false)
}
}

// LogLevel reads LOG_LEVEL and returns the corresponding slog.Level.
func LogLevel() slog.Level {
return logging.ParseLevel(Var("LOG_LEVEL"))
}

// AllowedOrigins returns a list of CORS-allowed origins. It reads DMR_ORIGINS
// and always appends default localhost/127.0.0.1/0.0.0.0 entries on http and
// https with wildcard ports.
func AllowedOrigins() (origins []string) {
if s := Var("DMR_ORIGINS"); s != "" {
for _, o := range strings.Split(s, ",") {
if trimmed := strings.TrimSpace(o); trimmed != "" {
origins = append(origins, trimmed)
}
}
}

for _, host := range []string{"localhost", "127.0.0.1", "0.0.0.0"} {
origins = append(origins,
fmt.Sprintf("http://%s", host),
fmt.Sprintf("https://%s", host),
fmt.Sprintf("http://%s", net.JoinHostPort(host, "*")),
fmt.Sprintf("https://%s", net.JoinHostPort(host, "*")),
)
}

return origins
}

// SocketPath returns the Unix socket path for the model runner.
// Configured via MODEL_RUNNER_SOCK; defaults to "model-runner.sock".
func SocketPath() string {
if s := Var("MODEL_RUNNER_SOCK"); s != "" {
return s
}
return "model-runner.sock"
}

// ModelsPath returns the directory where models are stored.
// Configured via MODELS_PATH; defaults to ~/.docker/models.
func ModelsPath() (string, error) {
if s := Var("MODELS_PATH"); s != "" {
return s, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".docker", "models"), nil
}

// TCPPort returns the optional TCP port for the model runner HTTP server.
// Configured via MODEL_RUNNER_PORT; empty string means use Unix socket.
func TCPPort() string {
return Var("MODEL_RUNNER_PORT")
}

// LlamaServerPath returns the path to the llama.cpp server binary.
// Configured via LLAMA_SERVER_PATH; defaults to the Docker Desktop bundle location.
func LlamaServerPath() string {
if s := Var("LLAMA_SERVER_PATH"); s != "" {
return s
}
return "/Applications/Docker.app/Contents/Resources/model-runner/bin"
}

// LlamaArgs returns custom arguments to pass to the llama.cpp server.
// Configured via LLAMA_ARGS.
func LlamaArgs() string {
return Var("LLAMA_ARGS")
}

// DisableServerUpdate is true when DISABLE_SERVER_UPDATE is set to a truthy value.
var DisableServerUpdate = Bool("DISABLE_SERVER_UPDATE")

// LlamaServerVersion returns a specific llama.cpp server version to pin.
// Configured via LLAMA_SERVER_VERSION; empty string means use the bundled version.
func LlamaServerVersion() string {
return Var("LLAMA_SERVER_VERSION")
}

// VLLMServerPath returns the optional path to the vLLM server binary.
// Configured via VLLM_SERVER_PATH.
func VLLMServerPath() string {
return Var("VLLM_SERVER_PATH")
}

// SGLangServerPath returns the optional path to the SGLang server binary.
// Configured via SGLANG_SERVER_PATH.
func SGLangServerPath() string {
return Var("SGLANG_SERVER_PATH")
}

// MLXServerPath returns the optional path to the MLX server binary.
// Configured via MLX_SERVER_PATH.
func MLXServerPath() string {
return Var("MLX_SERVER_PATH")
}

// DiffusersServerPath returns the optional path to the Diffusers server binary.
// Configured via DIFFUSERS_SERVER_PATH.
func DiffusersServerPath() string {
return Var("DIFFUSERS_SERVER_PATH")
}

// VLLMMetalServerPath returns the optional path to the vLLM Metal server binary.
// Configured via VLLM_METAL_SERVER_PATH.
func VLLMMetalServerPath() string {
return Var("VLLM_METAL_SERVER_PATH")
}

// DisableMetrics is true when DISABLE_METRICS is set to a truthy value (e.g. "1").
var DisableMetrics = Bool("DISABLE_METRICS")

// TLSEnabled is true when MODEL_RUNNER_TLS_ENABLED is set to a truthy value.
var TLSEnabled = Bool("MODEL_RUNNER_TLS_ENABLED")

// TLSPort returns the TLS listener port.
// Configured via MODEL_RUNNER_TLS_PORT; defaults to "12444".
func TLSPort() string {
if s := Var("MODEL_RUNNER_TLS_PORT"); s != "" {
return s
}
return "12444"
}

// TLSCert returns the path to the TLS certificate file.
// Configured via MODEL_RUNNER_TLS_CERT.
func TLSCert() string {
return Var("MODEL_RUNNER_TLS_CERT")
}

// TLSKey returns the path to the TLS private key file.
// Configured via MODEL_RUNNER_TLS_KEY.
func TLSKey() string {
return Var("MODEL_RUNNER_TLS_KEY")
}

// TLSAutoCert is true (default) unless MODEL_RUNNER_TLS_AUTO_CERT is set to a falsy value.
// Call as TLSAutoCert(true) to get the default-true behaviour.
var TLSAutoCert = BoolWithDefault("MODEL_RUNNER_TLS_AUTO_CERT")

// EnvVar describes a single environment variable with its current value
// and a human-readable description.
type EnvVar struct {
Name string
Value any
Description string
}

// AsMap returns a map of all model-runner environment variables with their
// current values and descriptions. Useful for introspection and documentation.
func AsMap() map[string]EnvVar {
modelsPath, _ := ModelsPath()
return map[string]EnvVar{
"MODEL_RUNNER_SOCK": {"MODEL_RUNNER_SOCK", SocketPath(), "Unix socket path (default: model-runner.sock)"},
"MODELS_PATH": {"MODELS_PATH", modelsPath, "Directory for model storage (default: ~/.docker/models)"},
"MODEL_RUNNER_PORT": {"MODEL_RUNNER_PORT", TCPPort(), "TCP port; overrides Unix socket when set"},
"LLAMA_SERVER_PATH": {"LLAMA_SERVER_PATH", LlamaServerPath(), "Path to llama.cpp server binary"},
"LLAMA_ARGS": {"LLAMA_ARGS", LlamaArgs(), "Extra arguments passed to the llama.cpp server"},
"DISABLE_SERVER_UPDATE": {"DISABLE_SERVER_UPDATE", DisableServerUpdate(), "Skip automatic llama.cpp server updates (any truthy value)"},
"LLAMA_SERVER_VERSION": {"LLAMA_SERVER_VERSION", LlamaServerVersion(), "Pin a specific llama.cpp server version"},
"VLLM_SERVER_PATH": {"VLLM_SERVER_PATH", VLLMServerPath(), "Path to vLLM server binary"},
"SGLANG_SERVER_PATH": {"SGLANG_SERVER_PATH", SGLangServerPath(), "Path to SGLang server binary"},
"MLX_SERVER_PATH": {"MLX_SERVER_PATH", MLXServerPath(), "Path to MLX server binary"},
"DIFFUSERS_SERVER_PATH": {"DIFFUSERS_SERVER_PATH", DiffusersServerPath(), "Path to Diffusers server binary"},
"VLLM_METAL_SERVER_PATH": {"VLLM_METAL_SERVER_PATH", VLLMMetalServerPath(), "Path to vLLM Metal server binary"},
"DISABLE_METRICS": {"DISABLE_METRICS", DisableMetrics(), "Disable Prometheus metrics endpoint (any truthy value, e.g. 1)"},
"LOG_LEVEL": {"LOG_LEVEL", LogLevel(), "Log verbosity: debug, info, warn, error (default: info)"},
"DMR_ORIGINS": {"DMR_ORIGINS", AllowedOrigins(), "Comma-separated CORS allowed origins (defaults plus any env-provided origins)"},
"MODEL_RUNNER_TLS_ENABLED": {"MODEL_RUNNER_TLS_ENABLED", TLSEnabled(), "Enable TLS listener"},
"MODEL_RUNNER_TLS_PORT": {"MODEL_RUNNER_TLS_PORT", TLSPort(), "TLS listener port (default: 12444)"},
"MODEL_RUNNER_TLS_CERT": {"MODEL_RUNNER_TLS_CERT", TLSCert(), "Path to TLS certificate file"},
"MODEL_RUNNER_TLS_KEY": {"MODEL_RUNNER_TLS_KEY", TLSKey(), "Path to TLS private key file"},
"MODEL_RUNNER_TLS_AUTO_CERT": {"MODEL_RUNNER_TLS_AUTO_CERT", TLSAutoCert(true), "Auto-generate TLS certificates (default: true)"},
}
}
Loading