Skip to content

Commit

Permalink
decode DeviceRequest.Count using DecodeMapstructure
Browse files Browse the repository at this point in the history
Signed-off-by: Nicolas De Loof <nicolas.deloof@gmail.com>
  • Loading branch information
ndeloof committed Oct 2, 2023
1 parent 88eac1d commit 49eefaf
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 39 deletions.
30 changes: 0 additions & 30 deletions loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ func createTransformHook(additionalTransformers ...Transformer) mapstructure.Dec
reflect.TypeOf(types.BuildConfig{}): transformBuildConfig,
reflect.TypeOf(types.DependsOnConfig{}): transformDependsOnConfig,
reflect.TypeOf(types.ExtendsConfig{}): transformExtendsConfig,
reflect.TypeOf(types.DeviceRequest{}): transformServiceDeviceRequest,
reflect.TypeOf(types.SSHConfig{}): transformSSHConfig,
reflect.TypeOf(types.IncludeConfig{}): transformIncludeConfig,
}
Expand Down Expand Up @@ -1087,35 +1086,6 @@ var transformServicePort TransformerFunc = func(data interface{}) (interface{},
}
}

var transformServiceDeviceRequest TransformerFunc = func(data interface{}) (interface{}, error) {
switch value := data.(type) {
case map[string]interface{}:
count, ok := value["count"]
if ok {
switch val := count.(type) {
case int:
return value, nil
case string:
if strings.ToLower(val) == "all" {
value["count"] = -1
return value, nil
}
i, err := strconv.ParseInt(val, 10, 64)
if err == nil {
value["count"] = i
return value, nil
}
return data, errors.Errorf("invalid string value for 'count' (the only value allowed is 'all' or a number)")
default:
return data, errors.Errorf("invalid type %T for device count", val)
}
}
return data, nil
default:
return data, errors.Errorf("invalid type %T for resource reservation", value)
}
}

var transformFileReferenceConfig TransformerFunc = func(data interface{}) (interface{}, error) {
switch value := data.(type) {
case string:
Expand Down
4 changes: 2 additions & 2 deletions loader/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2129,9 +2129,9 @@ services:
devices:
- driver: nvidia
capabilities: [gpu]
count: somestring
count: some_string
`)
assert.ErrorContains(t, err, "invalid string value for 'count' (the only value allowed is 'all' or a number)")
assert.ErrorContains(t, err, `invalid value "some_string", the only value allowed is 'all' or a number`)
}

func TestServicePullPolicy(t *testing.T) {
Expand Down
53 changes: 53 additions & 0 deletions types/device.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
Copyright 2020 The Compose Specification Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"strconv"
"strings"

"github.com/pkg/errors"
)

type DeviceRequest struct {
Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"`
Driver string `yaml:"driver,omitempty" json:"driver,omitempty"`
Count DeviceCount `yaml:"count,omitempty" json:"count,omitempty"`
IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"`
}

type DeviceCount int64

func (c *DeviceCount) DecodeMapstructure(value interface{}) error {
switch v := value.(type) {
case int:
*c = DeviceCount(v)
case string:
if strings.ToLower(v) == "all" {
*c = -1
return nil
}
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return errors.Errorf("invalid value %q, the only value allowed is 'all' or a number", v)
}
*c = DeviceCount(i)
default:
return errors.Errorf("invalid type %T for device count", v)
}
return nil
}
7 changes: 0 additions & 7 deletions types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,6 @@ type Resource struct {
Extensions Extensions `yaml:"#extensions,inline" json:"-"`
}

type DeviceRequest struct {
Capabilities []string `yaml:"capabilities,omitempty" json:"capabilities,omitempty"`
Driver string `yaml:"driver,omitempty" json:"driver,omitempty"`
Count int64 `yaml:"count,omitempty" json:"count,omitempty"`
IDs []string `yaml:"device_ids,omitempty" json:"device_ids,omitempty"`
}

// GenericResource represents a "user defined" resource which can
// only be an integer (e.g: SSD=3) for a service
type GenericResource struct {
Expand Down

0 comments on commit 49eefaf

Please sign in to comment.