/
utils.go
59 lines (51 loc) · 1.17 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
// Copyright (c) 2015-2016, NVIDIA CORPORATION. All rights reserved.
package main
import (
"fmt"
"strings"
"github.com/KamranAlipour/nvidia-docker/src/docker"
)
const (
labelCUDAVersion = "com.nvidia.cuda.version"
labelVolumesNeeded = "com.nvidia.volumes.needed"
)
func VolumesNeeded(image string) ([]string, error) {
ok, err := docker.ImageExists(image)
if err != nil {
return nil, err
}
if !ok {
if err = docker.ImagePull(image); err != nil {
return nil, err
}
}
label, err := docker.Label(image, labelVolumesNeeded)
if err != nil {
return nil, err
}
if label == "" {
return nil, nil
}
return strings.Split(label, " "), nil
}
func cudaSupported(image, version string) error {
var vmaj, vmin int
var lmaj, lmin int
label, err := docker.Label(image, labelCUDAVersion)
if err != nil {
return err
}
if label == "" {
return nil
}
if _, err := fmt.Sscanf(version, "%d.%d", &vmaj, &vmin); err != nil {
return err
}
if _, err := fmt.Sscanf(label, "%d.%d", &lmaj, &lmin); err != nil {
return err
}
if lmaj > vmaj || (lmaj == vmaj && lmin > vmin) {
return fmt.Errorf("unsupported CUDA version: driver %s < image %s", version, label)
}
return nil
}