Skip to content

Commit

Permalink
update some logic for struct validate
Browse files Browse the repository at this point in the history
  • Loading branch information
inhere committed Jan 20, 2021
1 parent 129d6cc commit 64f8af5
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 53 deletions.
57 changes: 30 additions & 27 deletions data_source.go
Expand Up @@ -28,9 +28,9 @@ const (
sourceStruct
)

// 0: common field
// 1: anonymous field
// 2: nonAnonymous field
// 0: top level field
// 1: field at anonymous struct
// 2: field at non-anonymous struct
const (
fieldAtTopStruct int8 = iota
fieldAtAnonymous
Expand Down Expand Up @@ -257,7 +257,7 @@ func (d *StructData) Create(err ...error) *Validation {

// parse and collect rules from struct tags.
func (d *StructData) parseRulesFromTag(v *Validation) {
var recursiveFunc func(vt reflect.Type, preStrName string)
var recursiveFunc func(vt reflect.Type, preStrName string, parentIsAnonymous bool)
if d.ValidateTag == "" {
d.ValidateTag = gOpt.ValidateTag
}
Expand All @@ -269,46 +269,44 @@ func (d *StructData) parseRulesFromTag(v *Validation) {
fMap := make(map[string]string, 0)

vt := d.valueTpy
recursiveFunc = func(vt reflect.Type, preStrName string) {
recursiveFunc = func(vt reflect.Type, preStrName string, parentIsAnonymous bool) {
for i := 0; i < vt.NumField(); i++ {
fv := vt.Field(i)
ft := vt.Field(i).Type
ft = removeTypePtr(ft)

if ft.Kind() == reflect.Struct && !strings.Contains(ft.Name(), "Time") {
name := vt.Field(i).Name
recursiveFunc(ft, name)
continue
}

name := vt.Field(i).Name

// skip don't exported field
name := fv.Name
if name[0] >= 'a' && name[0] <= 'z' {
continue
}

if preStrName != "" {
name = preStrName + "." + name
} else {
// 0:common field 1:anonymous field 2:nonAnonymous field
if preStrName == "" {
d.fieldNames[name] = fieldAtTopStruct
} else {
name = preStrName + "." + name
if parentIsAnonymous {
d.fieldNames[name] = fieldAtAnonymous
} else {
d.fieldNames[name] = fieldAtSubStruct
}
}

// validate rule
vRule := vt.Field(i).Tag.Get(d.ValidateTag)
vRule := fv.Tag.Get(d.ValidateTag)
if vRule != "" {
v.StringRule(name, vRule)
}

// filter rule
fRule := vt.Field(i).Tag.Get(d.FilterTag)
fRule := fv.Tag.Get(d.FilterTag)
if fRule != "" {
v.FilterRule(name, fRule)
}

// load field translate name. eg: `json:"user_name"`
if gOpt.FieldTag != "" {
fName := vt.Field(i).Tag.Get(gOpt.FieldTag)
fName := fv.Tag.Get(gOpt.FieldTag)
if fName != "" {
fMap[name] = fName
}
Expand All @@ -317,15 +315,21 @@ func (d *StructData) parseRulesFromTag(v *Validation) {
// load custom error messages.
// eg: `message:"required:name is required|minLen:name min len is %d"`
if gOpt.MessageTag != "" {
errMsg := vt.Field(i).Tag.Get(gOpt.MessageTag)
errMsg := fv.Tag.Get(gOpt.MessageTag)
if errMsg != "" {
d.loadMessagesFromTag(v.trans, name, vRule, errMsg)
}
}

// NEW: collect rules from sub-struct
// TODO should use ft == timeType check time.Time
if ft.Kind() == reflect.Struct && !strings.Contains(ft.Name(), "Time") {
recursiveFunc(ft, name, fv.Anonymous)
}
}
}

recursiveFunc(vt, "")
recursiveFunc(vt, "", false)

if len(fMap) > 0 {
v.trans.AddFieldMap(fMap)
Expand Down Expand Up @@ -398,12 +402,12 @@ func (d *StructData) Get(field string) (interface{}, bool) {
field = strutil.UpperFirst(field)

if strings.ContainsRune(field, '.') {
// want get sub struct field
// want get sub struct field. NOTICE: current only support two level struct
ss := strings.SplitN(field, ".", 2)
parentField, subField := ss[0], ss[1]
topField, subField := ss[0], ss[1]

// check top field is an struct
tft, ok := d.valueTpy.FieldByName(parentField)
tft, ok := d.valueTpy.FieldByName(topField)
if !ok {
return nil, false
}
Expand All @@ -419,7 +423,7 @@ func (d *StructData) Get(field string) (interface{}, bool) {
d.fieldNames[field] = fieldAtAnonymous
} else {
// get parent struct
fv = d.value.FieldByName(parentField)
fv = d.value.FieldByName(topField)
// is it a pointer?
if fv.Type().Kind() == reflect.Ptr {
fv = removeValuePtr(fv)
Expand All @@ -434,7 +438,6 @@ func (d *StructData) Get(field string) (interface{}, bool) {
if d.HasField(field) {
fv = d.value.FieldByName(field)
fv = removeValuePtr(fv)
d.fieldNames[field] = fieldAtTopStruct
} else {
// not found field
return nil, false
Expand Down
20 changes: 11 additions & 9 deletions issues_test.go
Expand Up @@ -279,20 +279,20 @@ func TestStructNested(t *testing.T) {

// anonymous struct nested
type User struct {
Name string `validate:"required|string" filter:"trim|lower"`
*Info
*Info `validate:"required"`
Org
Sex string `validate:"string"`
Name string `validate:"required|string" filter:"trim|lower"`
Sex string `validate:"string"`
}

// non-anonymous struct nested
// non-anonymous struct nested
type User2 struct {
Name string `validate:"required|string" filter:"trim|lower"`
In Info
Sex string `validate:"string"`
}

// anonymous field test
// anonymous field test
age := 3
u := &User{
Name: "fish",
Expand All @@ -303,7 +303,8 @@ func TestStructNested(t *testing.T) {
Org: Org{Company: "E"},
Sex: "male",
}
// anonymous field test

// anonymous field test
v := Struct(u)
if v.Validate() {
assert.True(t, v.Validate())
Expand All @@ -312,7 +313,8 @@ func TestStructNested(t *testing.T) {
fmt.Println(v.Errors)
assert.False(t, v.Validate())
}
// non-anonymous field test

// non-anonymous field test
age = 3
user2 := &User2{
Name: "fish",
Expand All @@ -335,13 +337,12 @@ func TestStructNested(t *testing.T) {

// https://github.com/gookit/validate/issues/78
func TestIssue78(t *testing.T) {

type UserDto struct {
Name string `validate:"required"`
Sex *bool `validate:"required"`
}

//sex := true
// sex := true
u := UserDto{
Name: "abc",
Sex: nil,
Expand All @@ -352,6 +353,7 @@ func TestIssue78(t *testing.T) {
if !v.Validate() {
fmt.Println(v.Errors)
} else {
assert.True(t, v.Validate())
fmt.Println("Success...")
}
}
2 changes: 1 addition & 1 deletion util.go
Expand Up @@ -326,7 +326,7 @@ var (

func checkValidatorFunc(name string, fn interface{}) reflect.Value {
if !goodName(name) {
panic(fmt.Errorf("validate name %s is not a valid identifier", name))
panicf("validate name %s is not a valid identifier", name)
}

fv := reflect.ValueOf(fn)
Expand Down
49 changes: 33 additions & 16 deletions validate_test.go
Expand Up @@ -7,11 +7,11 @@ import (
"github.com/stretchr/testify/assert"
)

func TestUtil_Func_valueToInt64(t *testing.T) {
func TestUtil_Func_valueToInt64(t *testing.T) {
noErrTests := []struct {
val interface{}
val interface{}
strict bool
want int64
want int64
}{
{" 12", false, 12},
{float32(12.23), false, 12},
Expand All @@ -25,22 +25,23 @@ func TestUtil_Func_valueToInt64(t *testing.T) {
}
}

func TestUtil_Func_getVariadicKind(t *testing.T) {
func TestUtil_Func_getVariadicKind(t *testing.T) {
noErrTests := []struct {
val interface{}
val interface{}
want reflect.Kind
}{
{"invalid", reflect.Invalid},
{[]int{1, 2}, reflect.Int},
{[]int8{1, 2}, reflect.Int8},
{[]int16{1, 2}, reflect.Int16},
{[]int32{1, 2}, reflect.Int32},
{[]int64{1, 2}, reflect.Int64},
{[]uint{1, 2}, reflect.Uint},
{[]uint8{1, 2}, reflect.Uint8},
{[]uint16{1, 2}, reflect.Uint16},
{[]uint32{1, 2}, reflect.Uint32},
{[]uint64{1, 2}, reflect.Uint64},
{"invalid", reflect.Invalid},
{[]int{1, 2}, reflect.Int},
{[]int8{1, 2}, reflect.Int8},
{[]int16{1, 2}, reflect.Int16},
{[]int32{1, 2}, reflect.Int32},
{[]int64{1, 2}, reflect.Int64},
{[]uint{1, 2}, reflect.Uint},
{[]uint8{1, 2}, reflect.Uint8},
{[]uint16{1, 2}, reflect.Uint16},
{[]uint32{1, 2}, reflect.Uint32},
{[]uint64{1, 2}, reflect.Uint64},
{[]string{"a", "b"}, reflect.String},
}

for _, item := range noErrTests {
Expand All @@ -50,6 +51,22 @@ func TestUtil_Func_getVariadicKind(t *testing.T) {
}
}

func TestUtil_Func_goodName(t *testing.T) {
tests := []struct {
give string
want bool
}{
{"ab", true},
{"1234", false},
{"01234", false},
{"abc123", true},
}

for _, item := range tests {
assert.Equal(t, item.want, goodName(item.give))
}
}

func TestMS_String(t *testing.T) {
ms := MS{}

Expand Down

0 comments on commit 64f8af5

Please sign in to comment.