diff --git a/internal/compose/compose.go b/internal/compose/compose.go index 3cd329662c..07e61b0948 100644 --- a/internal/compose/compose.go +++ b/internal/compose/compose.go @@ -54,6 +54,20 @@ type portMapping struct { Protocol string } +type intOrStringYaml int + +func (p *intOrStringYaml) UnmarshalYAML(node *yaml.Node) error { + var s string + err := node.Decode(&s) + if err == nil { + i, err := strconv.Atoi(s) + *p = intOrStringYaml(i) + return err + } + + return node.Decode(p) +} + // UnmarshalYAML unmarshals a Docker Compose port mapping in YAML to // a portMapping. func (p *portMapping) UnmarshalYAML(node *yaml.Node) error { @@ -67,9 +81,9 @@ func (p *portMapping) UnmarshalYAML(node *yaml.Node) error { } var s struct { - HostIP string `yaml:"host_ip"` - Target int - Published int + HostIP string `yaml:"host_ip"` + Target intOrStringYaml // Docker compose v2 can define ports as strings. + Published intOrStringYaml // Docker compose v2 can define ports as strings. Protocol string } @@ -77,8 +91,8 @@ func (p *portMapping) UnmarshalYAML(node *yaml.Node) error { return errors.Wrap(err, "could not unmarshal YAML map node") } - p.InternalPort = s.Target - p.ExternalPort = s.Published + p.InternalPort = int(s.Target) + p.ExternalPort = int(s.Published) p.Protocol = s.Protocol p.ExternalIP = s.HostIP return nil diff --git a/internal/compose/compose_test.go b/internal/compose/compose_test.go new file mode 100644 index 0000000000..4d893a1d67 --- /dev/null +++ b/internal/compose/compose_test.go @@ -0,0 +1,33 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package compose + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestIntOrStringYaml(t *testing.T) { + cases := []struct { + yaml string + expected int + }{ + {`"9200"`, 9200}, + {`'9200'`, 9200}, + {`9200`, 9200}, + } + + for _, c := range cases { + t.Run(c.yaml, func(t *testing.T) { + var n intOrStringYaml + err := yaml.Unmarshal([]byte(c.yaml), &n) + require.NoError(t, err) + assert.Equal(t, c.expected, int(n)) + }) + } +}