diff --git a/data_source.go b/data_source.go index 9d2e32e..33d4b77 100644 --- a/data_source.go +++ b/data_source.go @@ -28,9 +28,9 @@ const ( sourceStruct ) -// 0: common field -// 1: anonymous field -// 2: nonAnonymous field +// 0: top level field +// 1: field at anonymous struct +// 2: field at non-anonymous struct const ( fieldAtTopStruct int8 = iota fieldAtAnonymous @@ -257,7 +257,7 @@ func (d *StructData) Create(err ...error) *Validation { // parse and collect rules from struct tags. func (d *StructData) parseRulesFromTag(v *Validation) { - var recursiveFunc func(vt reflect.Type, preStrName string) + var recursiveFunc func(vt reflect.Type, preStrName string, parentIsAnonymous bool) if d.ValidateTag == "" { d.ValidateTag = gOpt.ValidateTag } @@ -269,46 +269,44 @@ func (d *StructData) parseRulesFromTag(v *Validation) { fMap := make(map[string]string, 0) vt := d.valueTpy - recursiveFunc = func(vt reflect.Type, preStrName string) { + recursiveFunc = func(vt reflect.Type, preStrName string, parentIsAnonymous bool) { for i := 0; i < vt.NumField(); i++ { + fv := vt.Field(i) ft := vt.Field(i).Type ft = removeTypePtr(ft) - if ft.Kind() == reflect.Struct && !strings.Contains(ft.Name(), "Time") { - name := vt.Field(i).Name - recursiveFunc(ft, name) - continue - } - - name := vt.Field(i).Name - // skip don't exported field + name := fv.Name if name[0] >= 'a' && name[0] <= 'z' { continue } - if preStrName != "" { - name = preStrName + "." + name - } else { - // 0:common field 1:anonymous field 2:nonAnonymous field + if preStrName == "" { d.fieldNames[name] = fieldAtTopStruct + } else { + name = preStrName + "." + name + if parentIsAnonymous { + d.fieldNames[name] = fieldAtAnonymous + } else { + d.fieldNames[name] = fieldAtSubStruct + } } // validate rule - vRule := vt.Field(i).Tag.Get(d.ValidateTag) + vRule := fv.Tag.Get(d.ValidateTag) if vRule != "" { v.StringRule(name, vRule) } // filter rule - fRule := vt.Field(i).Tag.Get(d.FilterTag) + fRule := fv.Tag.Get(d.FilterTag) if fRule != "" { v.FilterRule(name, fRule) } // load field translate name. eg: `json:"user_name"` if gOpt.FieldTag != "" { - fName := vt.Field(i).Tag.Get(gOpt.FieldTag) + fName := fv.Tag.Get(gOpt.FieldTag) if fName != "" { fMap[name] = fName } @@ -317,15 +315,21 @@ func (d *StructData) parseRulesFromTag(v *Validation) { // load custom error messages. // eg: `message:"required:name is required|minLen:name min len is %d"` if gOpt.MessageTag != "" { - errMsg := vt.Field(i).Tag.Get(gOpt.MessageTag) + errMsg := fv.Tag.Get(gOpt.MessageTag) if errMsg != "" { d.loadMessagesFromTag(v.trans, name, vRule, errMsg) } } + + // NEW: collect rules from sub-struct + // TODO should use ft == timeType check time.Time + if ft.Kind() == reflect.Struct && !strings.Contains(ft.Name(), "Time") { + recursiveFunc(ft, name, fv.Anonymous) + } } } - recursiveFunc(vt, "") + recursiveFunc(vt, "", false) if len(fMap) > 0 { v.trans.AddFieldMap(fMap) @@ -398,12 +402,12 @@ func (d *StructData) Get(field string) (interface{}, bool) { field = strutil.UpperFirst(field) if strings.ContainsRune(field, '.') { - // want get sub struct field + // want get sub struct field. NOTICE: current only support two level struct ss := strings.SplitN(field, ".", 2) - parentField, subField := ss[0], ss[1] + topField, subField := ss[0], ss[1] // check top field is an struct - tft, ok := d.valueTpy.FieldByName(parentField) + tft, ok := d.valueTpy.FieldByName(topField) if !ok { return nil, false } @@ -419,7 +423,7 @@ func (d *StructData) Get(field string) (interface{}, bool) { d.fieldNames[field] = fieldAtAnonymous } else { // get parent struct - fv = d.value.FieldByName(parentField) + fv = d.value.FieldByName(topField) // is it a pointer? if fv.Type().Kind() == reflect.Ptr { fv = removeValuePtr(fv) @@ -434,7 +438,6 @@ func (d *StructData) Get(field string) (interface{}, bool) { if d.HasField(field) { fv = d.value.FieldByName(field) fv = removeValuePtr(fv) - d.fieldNames[field] = fieldAtTopStruct } else { // not found field return nil, false diff --git a/issues_test.go b/issues_test.go index 4348cd1..df6e2e7 100644 --- a/issues_test.go +++ b/issues_test.go @@ -279,20 +279,20 @@ func TestStructNested(t *testing.T) { // anonymous struct nested type User struct { - Name string `validate:"required|string" filter:"trim|lower"` - *Info + *Info `validate:"required"` Org - Sex string `validate:"string"` + Name string `validate:"required|string" filter:"trim|lower"` + Sex string `validate:"string"` } - // non-anonymous struct nested + // non-anonymous struct nested type User2 struct { Name string `validate:"required|string" filter:"trim|lower"` In Info Sex string `validate:"string"` } - // anonymous field test + // anonymous field test age := 3 u := &User{ Name: "fish", @@ -303,7 +303,8 @@ func TestStructNested(t *testing.T) { Org: Org{Company: "E"}, Sex: "male", } - // anonymous field test + + // anonymous field test v := Struct(u) if v.Validate() { assert.True(t, v.Validate()) @@ -312,7 +313,8 @@ func TestStructNested(t *testing.T) { fmt.Println(v.Errors) assert.False(t, v.Validate()) } - // non-anonymous field test + + // non-anonymous field test age = 3 user2 := &User2{ Name: "fish", @@ -335,13 +337,12 @@ func TestStructNested(t *testing.T) { // https://github.com/gookit/validate/issues/78 func TestIssue78(t *testing.T) { - type UserDto struct { Name string `validate:"required"` Sex *bool `validate:"required"` } - //sex := true + // sex := true u := UserDto{ Name: "abc", Sex: nil, @@ -352,6 +353,7 @@ func TestIssue78(t *testing.T) { if !v.Validate() { fmt.Println(v.Errors) } else { + assert.True(t, v.Validate()) fmt.Println("Success...") } } diff --git a/util.go b/util.go index c16907e..fc77320 100644 --- a/util.go +++ b/util.go @@ -326,7 +326,7 @@ var ( func checkValidatorFunc(name string, fn interface{}) reflect.Value { if !goodName(name) { - panic(fmt.Errorf("validate name %s is not a valid identifier", name)) + panicf("validate name %s is not a valid identifier", name) } fv := reflect.ValueOf(fn) diff --git a/validate_test.go b/validate_test.go index 7d6319c..b220eba 100644 --- a/validate_test.go +++ b/validate_test.go @@ -7,11 +7,11 @@ import ( "github.com/stretchr/testify/assert" ) -func TestUtil_Func_valueToInt64(t *testing.T) { +func TestUtil_Func_valueToInt64(t *testing.T) { noErrTests := []struct { - val interface{} + val interface{} strict bool - want int64 + want int64 }{ {" 12", false, 12}, {float32(12.23), false, 12}, @@ -25,22 +25,23 @@ func TestUtil_Func_valueToInt64(t *testing.T) { } } -func TestUtil_Func_getVariadicKind(t *testing.T) { +func TestUtil_Func_getVariadicKind(t *testing.T) { noErrTests := []struct { - val interface{} + val interface{} want reflect.Kind }{ - {"invalid", reflect.Invalid}, - {[]int{1, 2}, reflect.Int}, - {[]int8{1, 2}, reflect.Int8}, - {[]int16{1, 2}, reflect.Int16}, - {[]int32{1, 2}, reflect.Int32}, - {[]int64{1, 2}, reflect.Int64}, - {[]uint{1, 2}, reflect.Uint}, - {[]uint8{1, 2}, reflect.Uint8}, - {[]uint16{1, 2}, reflect.Uint16}, - {[]uint32{1, 2}, reflect.Uint32}, - {[]uint64{1, 2}, reflect.Uint64}, + {"invalid", reflect.Invalid}, + {[]int{1, 2}, reflect.Int}, + {[]int8{1, 2}, reflect.Int8}, + {[]int16{1, 2}, reflect.Int16}, + {[]int32{1, 2}, reflect.Int32}, + {[]int64{1, 2}, reflect.Int64}, + {[]uint{1, 2}, reflect.Uint}, + {[]uint8{1, 2}, reflect.Uint8}, + {[]uint16{1, 2}, reflect.Uint16}, + {[]uint32{1, 2}, reflect.Uint32}, + {[]uint64{1, 2}, reflect.Uint64}, + {[]string{"a", "b"}, reflect.String}, } for _, item := range noErrTests { @@ -50,6 +51,22 @@ func TestUtil_Func_getVariadicKind(t *testing.T) { } } +func TestUtil_Func_goodName(t *testing.T) { + tests := []struct { + give string + want bool + }{ + {"ab", true}, + {"1234", false}, + {"01234", false}, + {"abc123", true}, + } + + for _, item := range tests { + assert.Equal(t, item.want, goodName(item.give)) + } +} + func TestMS_String(t *testing.T) { ms := MS{}