Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CDI devices in --device flag #4084

Merged
merged 2 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion cli/command/container/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"time"

cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/docker/cli/cli/compose/loader"
"github.com/docker/cli/opts"
"github.com/docker/docker/api/types/container"
Expand Down Expand Up @@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
// parsing flags, we haven't yet sent a _ping to the daemon to determine
// what operating system it is.
deviceMappings := []container.DeviceMapping{}
var cdiDeviceNames []string
for _, device := range copts.devices.GetAll() {
var (
validated string
deviceMapping container.DeviceMapping
err error
)
if cdi.IsQualifiedName(device) {
cdiDeviceNames = append(cdiDeviceNames, device)
continue
}
validated, err = validateDevice(device, serverOS)
if err != nil {
return nil, err
Expand Down Expand Up @@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
}
}

deviceRequests := copts.gpus.Value()
if len(cdiDeviceNames) > 0 {
cdiDeviceRequest := container.DeviceRequest{
Driver: "cdi",
DeviceIDs: cdiDeviceNames,
}
deviceRequests = append(deviceRequests, cdiDeviceRequest)
}

resources := container.Resources{
CgroupParent: copts.cgroupParent,
Memory: copts.memory.Value(),
Expand Down Expand Up @@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
Ulimits: copts.ulimits.GetList(),
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
Devices: deviceMappings,
DeviceRequests: copts.gpus.Value(),
DeviceRequests: deviceRequests,
}

config := &container.Config{
Expand Down
118 changes: 85 additions & 33 deletions cli/command/container/opts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) {

func TestParseDevice(t *testing.T) {
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
valids := map[string]container.DeviceMapping{
"/dev/snd": {
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rwm",
},
"/dev/snd:rw": {
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rw",
},
"/dev/snd:/something": {
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rwm",
},
"/dev/snd:/something:rw": {
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
}
for device, deviceMapping := range valids {
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"})
if err != nil {
t.Fatal(err)
}
if len(hostconfig.Devices) != 1 {
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices)
}
if hostconfig.Devices[0] != deviceMapping {
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices)
}
testCases := []struct {
devices []string
deviceMapping *container.DeviceMapping
deviceRequests []container.DeviceRequest
}{
{
devices: []string{"/dev/snd"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rwm",
},
},
{
devices: []string{"/dev/snd:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rw",
},
},
{
devices: []string{"/dev/snd:/something"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rwm",
},
},
{
devices: []string{"/dev/snd:/something:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
},
{
devices: []string{"vendor.com/class=name"},
deviceMapping: nil,
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
},
{
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
var args []string
for _, d := range tc.devices {
args = append(args, fmt.Sprintf("--device=%v", d))
}
args = append(args, "img", "cmd")

_, hostconfig, _, err := parseRun(args)

assert.NilError(t, err)

if tc.deviceMapping != nil {
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
}
} else {
assert.Check(t, is.Len(hostconfig.Devices, 0))
}

assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
})
}
}

Expand Down
1 change: 1 addition & 0 deletions vendor.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ go 1.18

require (
dario.cat/mergo v1.0.0
github.com/container-orchestrated-devices/container-device-interface v0.5.5-0.20230516140309-1e6752771dc5
github.com/containerd/containerd v1.6.21
github.com/creack/pty v1.1.18
github.com/docker/distribution v2.8.2+incompatible
Expand Down
2 changes: 2 additions & 0 deletions vendor.sum
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH
github.com/cockroachdb/datadriven v0.0.0-20200714090401-bf6692d28da5/go.mod h1:h6jFvWxBdQXxjopDMZyH2UVceIRfR84bdzbkoKrsWNo=
github.com/cockroachdb/errors v1.2.4/go.mod h1:rQD95gz6FARkaKkQXUksEje/d9a6wBJoCr5oaCLELYA=
github.com/cockroachdb/logtags v0.0.0-20190617123548-eb05cc24525f/go.mod h1:i/u985jwjWRlyHXQbwatDASoW0RMlZ/3i9yJHE2xLkI=
github.com/container-orchestrated-devices/container-device-interface v0.5.5-0.20230516140309-1e6752771dc5 h1:qJTKvM6AD0paTZodYyHV540e01I6+Uhq5vkKBi4byvQ=
github.com/container-orchestrated-devices/container-device-interface v0.5.5-0.20230516140309-1e6752771dc5/go.mod h1:OQlgtJtDrOxSQ1BWODC8OZK1tzi9W69wek+Jy17ndzo=
github.com/containerd/containerd v1.6.21 h1:eSTAmnvDKRPWan+MpSSfNyrtleXd86ogK9X8fMWpe/Q=
github.com/containerd/containerd v1.6.21/go.mod h1:apei1/i5Ux2FzrK6+DM/suEsGuK/MeVOfy8tR2q7Wnw=
github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading