Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions pkg/inference/backends/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ type RunnerConfig struct {
// ErrorTransformer is an optional function to transform error output
// into a more user-friendly message. If nil, the raw output is used.
ErrorTransformer ErrorTransformer
// Env is an optional list of extra environment variables for the backend
// process, each in "KEY=VALUE" form. These are appended to the current
// process environment. If nil or empty, the backend inherits the parent
// env as-is (an empty non-nil slice is treated the same as nil).
Env []string
}

// Logger interface for backend logging
Expand All @@ -50,6 +55,18 @@ type Logger interface {
Warn(msg string, args ...any)
}

// ValidateEnv checks that every entry in env is in "KEY=VALUE" form with a
// non-empty key. It returns an error describing the first malformed entry.
func ValidateEnv(env []string) error {
for _, e := range env {
k, _, ok := strings.Cut(e, "=")
if !ok || k == "" {
return fmt.Errorf("invalid env var format (expected KEY=VALUE): %q", e)
}
}
return nil
}

// RunBackend runs a backend process with common error handling and logging.
// It handles:
// - Socket cleanup
Expand All @@ -58,6 +75,11 @@ type Logger interface {
// - Error channel handling
// - Context cancellation
func RunBackend(ctx context.Context, config RunnerConfig) error {
// Validate env var format early to catch misconfiguration before spawning.
if err := ValidateEnv(config.Env); err != nil {
return fmt.Errorf("invalid %s configuration: %w", config.BackendName, err)
}

// Remove old socket file
if err := os.RemoveAll(config.Socket); err != nil && !errors.Is(err, fs.ErrNotExist) {
config.Logger.Warn("failed to remove socket file", "socket", config.Socket, "error", err)
Expand Down Expand Up @@ -88,6 +110,9 @@ func RunBackend(ctx context.Context, config RunnerConfig) error {
}
command.Stdout = config.ServerLogWriter
command.Stderr = out
if len(config.Env) > 0 {
command.Env = append(os.Environ(), config.Env...)
}
},
config.SandboxPath,
config.BinaryPath,
Expand Down
75 changes: 75 additions & 0 deletions pkg/inference/backends/runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package backends_test

import (
"testing"

"github.com/docker/model-runner/pkg/inference/backends"
)

func TestValidateEnv(t *testing.T) {
tests := []struct {
name string
env []string
wantErr bool
}{
{
name: "nil slice is valid",
env: nil,
wantErr: false,
},
{
name: "empty slice is valid",
env: []string{},
wantErr: false,
},
{
name: "single valid entry",
env: []string{"KEY=value"},
wantErr: false,
},
{
name: "multiple valid entries",
env: []string{"A=1", "B=2", "FOO=bar"},
wantErr: false,
},
{
name: "value with equals sign",
env: []string{"KEY=val=ue"},
wantErr: false,
},
{
name: "empty value is valid",
env: []string{"KEY="},
wantErr: false,
},
{
name: "missing equals sign",
env: []string{"NOEQUALS"},
wantErr: true,
},
{
name: "empty key",
env: []string{"=value"},
wantErr: true,
},
{
name: "empty string",
env: []string{""},
wantErr: true,
},
{
name: "valid then invalid",
env: []string{"GOOD=ok", "BAD"},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := backends.ValidateEnv(tt.env)
if (err != nil) != tt.wantErr {
t.Errorf("ValidateEnv(%v) error = %v, wantErr %v", tt.env, err, tt.wantErr)
}
})
}
}
12 changes: 12 additions & 0 deletions pkg/inference/backends/testmain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package backends_test

import (
"testing"

"go.uber.org/goleak"
)

// TestMain runs goleak after the test suite to detect goroutine leaks.
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
1 change: 1 addition & 0 deletions pkg/inference/backends/vllm/vllm_metal.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (v *vllmMetal) Run(ctx context.Context, socket, model string, modelRef stri
Args: args,
Logger: v.log,
ServerLogWriter: logging.NewWriter(v.serverLog),
Env: []string{"VLLM_HOST_IP=127.0.0.1"},
})
}

Expand Down
12 changes: 8 additions & 4 deletions pkg/sandbox/sandbox_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ const ConfigurationPython = `(version 1)
;;; Also allow Unix domain socket binding in the system temp directory
;;; (/private/var/folders) which vllm-metal uses for internal ZMQ IPC sockets.
(deny network*)
(allow network-bind network-inbound
(allow network-bind network-inbound network-outbound
(regex #"inference.*-[0-9]+\.sock$")
(local tcp "localhost:*"))
(allow network-bind
(allow network-bind network-inbound network-outbound
(regex #"^/private/var/folders/"))

;;; Deny access to the camera and microphone.
Expand Down Expand Up @@ -76,12 +76,16 @@ const ConfigurationPython = `(version 1)
(allow file-write*
(literal "/dev/null")
(subpath "/private/var")
(subpath "/private/tmp")
(subpath "[HOMEDIR]/Library/Containers/com.docker.docker/Data")
(subpath "[WORKDIR]"))
(subpath "[WORKDIR]")
(subpath "[HOMEDIR]/.cache/vllm"))
(allow file-read*
(subpath "[HOMEDIR]/.docker/models")
(subpath "[HOMEDIR]/Library/Containers/com.docker.docker/Data")
(subpath "[WORKDIR]"))
(subpath "[WORKDIR]")
(subpath "[HOMEDIR]/.cache/vllm")
(subpath "/private/tmp"))
`

// ConfigurationLlamaCpp is the sandbox configuration for llama.cpp processes.
Expand Down
Loading