Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (s *tagCacheMap) Set(key string, value *cachedTag) {
// Validate contains the validator settings passed in using the Config struct
type Validate struct {
tagName string
fieldNameTag string
validationFuncs map[string]Func
customTypeFuncs map[reflect.Type]CustomTypeFunc
aliasValidators map[string]string
Expand All @@ -96,7 +97,8 @@ func (v *Validate) initCheck() {
// Config contains the options that a Validator instance will use.
// It is passed to the New() function
type Config struct {
TagName string
TagName string
FieldNameTag string
}

// CustomTypeFunc allows for overriding or adding custom field type handler functions
Expand Down Expand Up @@ -137,6 +139,7 @@ func (ve ValidationErrors) Error() string {
// with other properties that may be needed for error message creation
type FieldError struct {
Field string
Name string
Tag string
ActualTag string
Kind reflect.Kind
Expand All @@ -149,8 +152,9 @@ type FieldError struct {
func New(config *Config) *Validate {

v := &Validate{
tagName: config.TagName,
tagsCache: &tagCacheMap{m: map[string]*cachedTag{}},
tagName: config.TagName,
fieldNameTag: config.FieldNameTag,
tagsCache: &tagCacheMap{m: map[string]*cachedTag{}},
errsPool: &sync.Pool{New: func() interface{} {
return ValidationErrors{}
}}}
Expand Down Expand Up @@ -245,7 +249,7 @@ func (v *Validate) Field(field interface{}, tag string) error {
errs := v.errsPool.Get().(ValidationErrors)
fieldVal := reflect.ValueOf(field)

v.traverseField(fieldVal, fieldVal, fieldVal, blank, errs, false, tag, blank, false, false, nil)
v.traverseField(fieldVal, fieldVal, fieldVal, blank, errs, false, tag, blank, blank, false, false, nil)

if len(errs) == 0 {
v.errsPool.Put(errs)
Expand All @@ -265,7 +269,7 @@ func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string
errs := v.errsPool.Get().(ValidationErrors)
topVal := reflect.ValueOf(val)

v.traverseField(topVal, topVal, reflect.ValueOf(field), blank, errs, false, tag, blank, false, false, nil)
v.traverseField(topVal, topVal, reflect.ValueOf(field), blank, errs, false, tag, blank, blank, false, false, nil)

if len(errs) == 0 {
v.errsPool.Put(errs)
Expand Down Expand Up @@ -417,12 +421,20 @@ func (v *Validate) tranverseStruct(topStruct reflect.Value, currentStruct reflec
}
}

v.traverseField(topStruct, currentStruct, current.Field(i), errPrefix, errs, true, fld.Tag.Get(v.tagName), fld.Name, partial, exclude, includeExclude)
customName := fld.Name
if v.fieldNameTag != "" {
name := strings.Split(fld.Tag.Get(v.fieldNameTag), ",")[0]
if name != "" {
customName = name
}
}

v.traverseField(topStruct, currentStruct, current.Field(i), errPrefix, errs, true, fld.Tag.Get(v.tagName), fld.Name, customName, partial, exclude, includeExclude)
}
}

// traverseField validates any field, be it a struct or single field, ensures it's validity and passes it along to be validated via it's tag options
func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, isStructField bool, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, isStructField bool, tag, name, customName string, partial bool, exclude bool, includeExclude map[string]*struct{}) {

if tag == skipValidationTag {
return
Expand All @@ -448,6 +460,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.

if kind == reflect.Invalid {
errs[errPrefix+name] = &FieldError{
Name: customName,
Field: name,
Tag: cTag.tags[0].tag,
ActualTag: cTag.tags[0].tagVals[0][0],
Expand All @@ -458,6 +471,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
}

errs[errPrefix+name] = &FieldError{
Name: customName,
Field: name,
Tag: cTag.tags[0].tag,
ActualTag: cTag.tags[0].tagVals[0][0],
Expand Down Expand Up @@ -520,7 +534,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
continue
}

if v.validateField(topStruct, currentStruct, current, typ, kind, errPrefix, errs, valTag, name) {
if v.validateField(topStruct, currentStruct, current, typ, kind, errPrefix, errs, valTag, name, customName) {
return
}
}
Expand All @@ -530,9 +544,9 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
// or panic ;)
switch kind {
case reflect.Slice, reflect.Array:
v.traverseSlice(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
v.traverseSlice(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, customName, partial, exclude, includeExclude)
case reflect.Map:
v.traverseMap(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, partial, exclude, includeExclude)
v.traverseMap(topStruct, currentStruct, current, errPrefix, errs, diveSubTag, name, customName, partial, exclude, includeExclude)
default:
// throw error, if not a slice or map then should not have gotten here
// bad dive tag
Expand All @@ -542,23 +556,23 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
}

// traverseSlice traverses a Slice or Array's elements and passes them to traverseField for validation
func (v *Validate) traverseSlice(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
func (v *Validate) traverseSlice(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag, name, customName string, partial bool, exclude bool, includeExclude map[string]*struct{}) {

for i := 0; i < current.Len(); i++ {
v.traverseField(topStruct, currentStruct, current.Index(i), errPrefix, errs, false, tag, fmt.Sprintf(arrayIndexFieldName, name, i), partial, exclude, includeExclude)
v.traverseField(topStruct, currentStruct, current.Index(i), errPrefix, errs, false, tag, fmt.Sprintf(arrayIndexFieldName, name, i), fmt.Sprintf(arrayIndexFieldName, customName, i), partial, exclude, includeExclude)
}
}

// traverseMap traverses a map's elements and passes them to traverseField for validation
func (v *Validate) traverseMap(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag string, name string, partial bool, exclude bool, includeExclude map[string]*struct{}) {
func (v *Validate) traverseMap(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, errPrefix string, errs ValidationErrors, tag, name, customName string, partial bool, exclude bool, includeExclude map[string]*struct{}) {

for _, key := range current.MapKeys() {
v.traverseField(topStruct, currentStruct, current.MapIndex(key), errPrefix, errs, false, tag, fmt.Sprintf(mapIndexFieldName, name, key.Interface()), partial, exclude, includeExclude)
v.traverseField(topStruct, currentStruct, current.MapIndex(key), errPrefix, errs, false, tag, fmt.Sprintf(mapIndexFieldName, name, key.Interface()), fmt.Sprintf(mapIndexFieldName, customName, key.Interface()), partial, exclude, includeExclude)
}
}

// validateField validates a field based on the provided tag's key and param values and returns true if there is an error or false if all ok
func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, currentType reflect.Type, currentKind reflect.Kind, errPrefix string, errs ValidationErrors, valTag *tagVals, name string) bool {
func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.Value, current reflect.Value, currentType reflect.Type, currentKind reflect.Kind, errPrefix string, errs ValidationErrors, valTag *tagVals, name, customName string) bool {

var valFunc Func
var ok bool
Expand All @@ -583,6 +597,7 @@ func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.

if valTag.isAlias {
errs[errPrefix+name] = &FieldError{
Name: customName,
Field: name,
Tag: valTag.tag,
ActualTag: errTag[1:],
Expand All @@ -592,6 +607,7 @@ func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.
}
} else {
errs[errPrefix+name] = &FieldError{
Name: customName,
Field: name,
Tag: errTag[1:],
ActualTag: errTag[1:],
Expand All @@ -614,6 +630,7 @@ func (v *Validate) validateField(topStruct reflect.Value, currentStruct reflect.
}

errs[errPrefix+name] = &FieldError{
Name: customName,
Field: name,
Tag: valTag.tag,
ActualTag: valTag.tagVals[0][0],
Expand Down
18 changes: 18 additions & 0 deletions validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4930,3 +4930,21 @@ func TestInvalidValidatorFunction(t *testing.T) {

PanicMatches(t, func() { validate.Field(s.Test, "zzxxBadFunction") }, "Undefined validation function on field")
}

func TestCustomFieldName(t *testing.T) {
type A struct {
B string `schema:"b" validate:"required"`
C string `schema:"c" validate:"required"`
D []bool `schema:"d" validate:"required"`
}

a := &A{}

errs := New(&Config{TagName: "validate", FieldNameTag: "schema"}).Struct(a).(ValidationErrors)
Equal(t, errs["A.B"].Name, "b")
Equal(t, errs["A.C"].Name, "c")
Equal(t, errs["A.D"].Name, "d")

errs = New(&Config{TagName: "validate"}).Struct(a).(ValidationErrors)
Equal(t, errs["A.B"].Name, "B")
}