Skip to content

Commit

Permalink
Merge pull request #46115 from thaJeztah/24.0_backport_fix_filter_errors
Browse files Browse the repository at this point in the history
[24.0 backport] api/types/filters: fix errors not being matched by errors.Is()
  • Loading branch information
thaJeztah committed Aug 1, 2023
2 parents e426ae0 + b6568d2 commit 2ef88a3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
10 changes: 5 additions & 5 deletions api/types/filters/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func FromJSON(p string) (Args, error) {
// Fallback to parsing arguments in the legacy slice format
deprecated := map[string][]string{}
if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil {
return args, invalidFilter{}
return args, &invalidFilter{}
}

args.fields = deprecatedArgs(deprecated)
Expand Down Expand Up @@ -206,7 +206,7 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
}

if len(fieldValues) == 0 {
return defaultValue, invalidFilter{key, nil}
return defaultValue, &invalidFilter{key, nil}
}

isFalse := fieldValues["0"] || fieldValues["false"]
Expand All @@ -216,15 +216,15 @@ func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) {
invalid := !isFalse && !isTrue

if conflicting || invalid {
return defaultValue, invalidFilter{key, args.Get(key)}
return defaultValue, &invalidFilter{key, args.Get(key)}
} else if isFalse {
return false, nil
} else if isTrue {
return true, nil
}

// This code shouldn't be reached.
return defaultValue, unreachableCode{Filter: key, Value: args.Get(key)}
return defaultValue, &unreachableCode{Filter: key, Value: args.Get(key)}
}

// ExactMatch returns true if the source matches exactly one of the values.
Expand Down Expand Up @@ -282,7 +282,7 @@ func (args Args) Contains(field string) bool {
func (args Args) Validate(accepted map[string]bool) error {
for name := range args.fields {
if !accepted[name] {
return invalidFilter{name, nil}
return &invalidFilter{name, nil}
}
}
return nil
Expand Down
50 changes: 31 additions & 19 deletions api/types/filters/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package filters // import "github.com/docker/docker/api/types/filters"
import (
"encoding/json"
"errors"
"fmt"
"sort"
"testing"

Expand Down Expand Up @@ -95,15 +96,19 @@ func TestFromJSON(t *testing.T) {
if err == nil {
t.Fatalf("Expected an error with %v, got nothing", invalid)
}
var invalidFilterError invalidFilter
var invalidFilterError *invalidFilter
if !errors.As(err, &invalidFilterError) {
t.Fatalf("Expected an invalidFilter error, got %T", err)
}
wrappedErr := fmt.Errorf("something went wrong: %w", err)
if !errors.Is(wrappedErr, err) {
t.Errorf("Expected a wrapped error to be detected as invalidFilter")
}
}

for expectedArgs, matchers := range valid {
for _, json := range matchers {
args, err := FromJSON(json)
for _, jsonString := range matchers {
args, err := FromJSON(jsonString)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -358,9 +363,13 @@ func TestValidate(t *testing.T) {
if err == nil {
t.Fatal("Expected to return an error, got nil")
}
var invalidFilterError invalidFilter
var invalidFilterError *invalidFilter
if !errors.As(err, &invalidFilterError) {
t.Fatalf("Expected an invalidFilter error, got %T", err)
t.Errorf("Expected an invalidFilter error, got %T", err)
}
wrappedErr := fmt.Errorf("something went wrong: %w", err)
if !errors.Is(wrappedErr, err) {
t.Errorf("Expected a wrapped error to be detected as invalidFilter")
}
}

Expand Down Expand Up @@ -421,7 +430,7 @@ func TestClone(t *testing.T) {
}

func TestGetBoolOrDefault(t *testing.T) {
for _, tC := range []struct {
for _, tc := range []struct {
name string
args map[string][]string
defValue bool
Expand Down Expand Up @@ -452,7 +461,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"potato"},
},
defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"potato"}},
expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"potato"}},
expectedValue: true,
},
{
Expand All @@ -461,7 +470,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"banana", "potato"},
},
defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}},
expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"banana", "potato"}},
expectedValue: true,
},
{
Expand All @@ -470,7 +479,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"false", "true"},
},
defValue: false,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"false", "true"}},
expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"false", "true"}},
expectedValue: false,
},
{
Expand All @@ -479,7 +488,7 @@ func TestGetBoolOrDefault(t *testing.T) {
"dangling": {"false", "true", "1"},
},
defValue: true,
expectedErr: invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}},
expectedErr: &invalidFilter{Filter: "dangling", Value: []string{"false", "true", "1"}},
expectedValue: true,
},
{
Expand All @@ -501,35 +510,38 @@ func TestGetBoolOrDefault(t *testing.T) {
expectedValue: false,
},
} {
tC := tC
t.Run(tC.name, func(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
a := NewArgs()

for key, values := range tC.args {
for key, values := range tc.args {
for _, value := range values {
a.Add(key, value)
}
}

value, err := a.GetBoolOrDefault("dangling", tC.defValue)
value, err := a.GetBoolOrDefault("dangling", tc.defValue)

if tC.expectedErr == nil {
if tc.expectedErr == nil {
assert.Check(t, is.Nil(err))
} else {
assert.Check(t, is.ErrorType(err, tC.expectedErr))
assert.Check(t, is.ErrorType(err, tc.expectedErr))

// Check if error is the same.
expected := tC.expectedErr.(invalidFilter)
actual := err.(invalidFilter)
expected := tc.expectedErr.(*invalidFilter)
actual := err.(*invalidFilter)

assert.Check(t, is.Equal(expected.Filter, actual.Filter))

sort.Strings(expected.Value)
sort.Strings(actual.Value)
assert.Check(t, is.DeepEqual(expected.Value, actual.Value))

wrappedErr := fmt.Errorf("something went wrong: %w", err)
assert.Check(t, errors.Is(wrappedErr, err), "Expected a wrapped error to be detected as invalidFilter")
}

assert.Check(t, is.Equal(tC.expectedValue, value))
assert.Check(t, is.Equal(tc.expectedValue, value))
})
}

Expand Down

0 comments on commit 2ef88a3

Please sign in to comment.