Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix boolean support for required_if, required_unless and eqfield #754

Merged
merged 2 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions baked_in.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,9 @@ func isNeCrossStructField(fl FieldLevel) bool {
case reflect.Slice, reflect.Map, reflect.Array:
return int64(topField.Len()) != int64(field.Len())

case reflect.Bool:
return topField.Bool() != field.Bool()

case reflect.Struct:

fieldType := field.Type()
Expand Down Expand Up @@ -1085,6 +1088,9 @@ func isEqCrossStructField(fl FieldLevel) bool {
case reflect.Slice, reflect.Map, reflect.Array:
return int64(topField.Len()) == int64(field.Len())

case reflect.Bool:
return topField.Bool() == field.Bool()

case reflect.Struct:

fieldType := field.Type()
Expand Down Expand Up @@ -1132,6 +1138,9 @@ func isEqField(fl FieldLevel) bool {
case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) == int64(currentField.Len())

case reflect.Bool:
return field.Bool() == currentField.Bool()

case reflect.Struct:

fieldType := field.Type()
Expand Down Expand Up @@ -1446,6 +1455,9 @@ func requireCheckFieldValue(fl FieldLevel, param string, value string, defaultNo

case reflect.Slice, reflect.Map, reflect.Array:
return int64(field.Len()) == asInt(value)

case reflect.Bool:
return field.Bool() == asBool(value)
}

// default reflect.String:
Expand Down
24 changes: 23 additions & 1 deletion validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1670,12 +1670,14 @@ func TestCrossStructNeFieldValidation(t *testing.T) {
i := 1
j = 1
k = 1.543
b := true
arr := []string{"test"}

s2 := "abcd"
i2 := 1
j2 = 1
k2 = 1.543
b2 := true
arr2 := []string{"test"}
arr3 := []string{"test", "test2"}
now2 := now
Expand All @@ -1696,6 +1698,10 @@ func TestCrossStructNeFieldValidation(t *testing.T) {
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "necsfield")

errs = validate.VarWithValue(b2, b, "necsfield")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "necsfield")

errs = validate.VarWithValue(arr2, arr, "necsfield")
NotEqual(t, errs, nil)
AssertError(t, errs, "", "", "", "", "necsfield")
Expand Down Expand Up @@ -1834,6 +1840,7 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
i := 1
j = 1
k = 1.543
b := true
arr := []string{"test"}

var j2 uint64
Expand All @@ -1842,6 +1849,7 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
i2 := 1
j2 = 1
k2 = 1.543
b2 := true
arr2 := []string{"test"}
arr3 := []string{"test", "test2"}
now2 := now
Expand All @@ -1858,6 +1866,9 @@ func TestCrossStructEqFieldValidation(t *testing.T) {
errs = validate.VarWithValue(k2, k, "eqcsfield")
Equal(t, errs, nil)

errs = validate.VarWithValue(b2, b, "eqcsfield")
Equal(t, errs, nil)

errs = validate.VarWithValue(arr2, arr, "eqcsfield")
Equal(t, errs, nil)

Expand Down Expand Up @@ -4829,6 +4840,7 @@ func TestIsEqFieldValidation(t *testing.T) {
i := 1
j = 1
k = 1.543
b := true
arr := []string{"test"}
now := time.Now().UTC()

Expand All @@ -4838,6 +4850,7 @@ func TestIsEqFieldValidation(t *testing.T) {
i2 := 1
j2 = 1
k2 = 1.543
b2 := true
arr2 := []string{"test"}
arr3 := []string{"test", "test2"}
now2 := now
Expand All @@ -4854,6 +4867,9 @@ func TestIsEqFieldValidation(t *testing.T) {
errs = validate.VarWithValue(k2, k, "eqfield")
Equal(t, errs, nil)

errs = validate.VarWithValue(b2, b, "eqfield")
Equal(t, errs, nil)

errs = validate.VarWithValue(arr2, arr, "eqfield")
Equal(t, errs, nil)

Expand Down Expand Up @@ -10065,12 +10081,15 @@ func TestRequiredUnless(t *testing.T) {
Field6 uint `validate:"required_unless=Field5 2" json:"field_6"`
Field7 float32 `validate:"required_unless=Field6 0" json:"field_7"`
Field8 float64 `validate:"required_unless=Field7 0.0" json:"field_8"`
Field9 bool `validate:"omitempty" json:"field_9"`
Field10 string `validate:"required_unless=Field9 true" json:"field_10"`
}{
FieldE: "test",
Field2: &fieldVal,
Field3: map[string]string{"key": "val"},
Field4: "test",
Field5: 2,
Field9: true,
}

validate := New()
Expand All @@ -10090,6 +10109,8 @@ func TestRequiredUnless(t *testing.T) {
Field5 string `validate:"required_unless=Field3 0" json:"field_5"`
Field6 string `validate:"required_unless=Inner.Field test" json:"field_6"`
Field7 string `validate:"required_unless=Inner2.Field test" json:"field_7"`
Field8 bool `validate:"omitempty" json:"field_8"`
Field9 string `validate:"required_unless=Field8 true" json:"field_9"`
}{
Inner: &Inner{Field: &fieldVal},
FieldE: "test",
Expand All @@ -10100,10 +10121,11 @@ func TestRequiredUnless(t *testing.T) {
NotEqual(t, errs, nil)

ve := errs.(ValidationErrors)
Equal(t, len(ve), 3)
Equal(t, len(ve), 4)
AssertError(t, errs, "Field3", "Field3", "Field3", "Field3", "required_unless")
AssertError(t, errs, "Field4", "Field4", "Field4", "Field4", "required_unless")
AssertError(t, errs, "Field7", "Field7", "Field7", "Field7", "required_unless")
AssertError(t, errs, "Field9", "Field9", "Field9", "Field9", "required_unless")

defer func() {
if r := recover(); r == nil {
Expand Down