Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions baked_in.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func contains(v *Validate, topStruct reflect.Value, currentStructOrField reflect

func isNeField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)

if !ok || currentKind != fieldKind {
return true
Expand Down Expand Up @@ -307,7 +307,7 @@ func isNe(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Val

func isLteCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, topKind, ok := v.getStructFieldOK(topStruct, param)
topField, topKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || topKind != fieldKind {
return false
}
Expand Down Expand Up @@ -348,7 +348,7 @@ func isLteCrossStructField(v *Validate, topStruct reflect.Value, current reflect

func isLtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, topKind, ok := v.getStructFieldOK(topStruct, param)
topField, topKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || topKind != fieldKind {
return false
}
Expand Down Expand Up @@ -389,7 +389,7 @@ func isLtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.

func isGteCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, topKind, ok := v.getStructFieldOK(topStruct, param)
topField, topKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || topKind != fieldKind {
return false
}
Expand Down Expand Up @@ -430,7 +430,7 @@ func isGteCrossStructField(v *Validate, topStruct reflect.Value, current reflect

func isGtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, topKind, ok := v.getStructFieldOK(topStruct, param)
topField, topKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || topKind != fieldKind {
return false
}
Expand Down Expand Up @@ -471,7 +471,7 @@ func isGtCrossStructField(v *Validate, topStruct reflect.Value, current reflect.

func isNeCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, currentKind, ok := v.getStructFieldOK(topStruct, param)
topField, currentKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || currentKind != fieldKind {
return true
}
Expand Down Expand Up @@ -512,7 +512,7 @@ func isNeCrossStructField(v *Validate, topStruct reflect.Value, current reflect.

func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

topField, topKind, ok := v.getStructFieldOK(topStruct, param)
topField, topKind, ok := v.GetStructFieldOK(topStruct, param)
if !ok || topKind != fieldKind {
return false
}
Expand Down Expand Up @@ -553,7 +553,7 @@ func isEqCrossStructField(v *Validate, topStruct reflect.Value, current reflect.

func isEqField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)
if !ok || currentKind != fieldKind {
return false
}
Expand Down Expand Up @@ -718,7 +718,7 @@ func hasValue(v *Validate, topStruct reflect.Value, currentStructOrField reflect

func isGteField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)
if !ok || currentKind != fieldKind {
return false
}
Expand Down Expand Up @@ -759,7 +759,7 @@ func isGteField(v *Validate, topStruct reflect.Value, currentStructOrField refle

func isGtField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)
if !ok || currentKind != fieldKind {
return false
}
Expand Down Expand Up @@ -927,7 +927,7 @@ func hasMinOf(v *Validate, topStruct reflect.Value, currentStructOrField reflect

func isLteField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)
if !ok || currentKind != fieldKind {
return false
}
Expand Down Expand Up @@ -968,7 +968,7 @@ func isLteField(v *Validate, topStruct reflect.Value, currentStructOrField refle

func isLtField(v *Validate, topStruct reflect.Value, currentStructOrField reflect.Value, field reflect.Value, fieldType reflect.Type, fieldKind reflect.Kind, param string) bool {

currentField, currentKind, ok := v.getStructFieldOK(currentStructOrField, param)
currentField, currentKind, ok := v.GetStructFieldOK(currentStructOrField, param)
if !ok || currentKind != fieldKind {
return false
}
Expand Down
52 changes: 30 additions & 22 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ var (
}
)

func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Kind) {
// ExtractType gets the actual underlying type of field value.
// It will dive into pointers, customTypes and return you the
// underlying value and it's kind.
// it is exposed for use within you Custom Functions
func (v *Validate) ExtractType(current reflect.Value) (reflect.Value, reflect.Kind) {

switch current.Kind() {
case reflect.Ptr:
Expand All @@ -38,15 +42,15 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki
return current, reflect.Ptr
}

return v.extractType(current.Elem())
return v.ExtractType(current.Elem())

case reflect.Interface:

if current.IsNil() {
return current, reflect.Interface
}

return v.extractType(current.Elem())
return v.ExtractType(current.Elem())

case reflect.Invalid:
return current, reflect.Invalid
Expand All @@ -55,17 +59,21 @@ func (v *Validate) extractType(current reflect.Value) (reflect.Value, reflect.Ki

if v.hasCustomFuncs {
if fn, ok := v.customTypeFuncs[current.Type()]; ok {
return v.extractType(reflect.ValueOf(fn(current)))
return v.ExtractType(reflect.ValueOf(fn(current)))
}
}

return current, current.Kind()
}
}

func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) {
// GetStructFieldOK traverses a struct to retrieve a specific field denoted by the provided namespace and
// returns the field, field kind and whether is was successful in retrieving the field at all.
// NOTE: when not successful ok will be false, this can happen when a nested struct is nil and so the field
// could not be retrived because it didnt exist.
func (v *Validate) GetStructFieldOK(current reflect.Value, namespace string) (reflect.Value, reflect.Kind, bool) {

current, kind := v.extractType(current)
current, kind := v.ExtractType(current)

if kind == reflect.Invalid {
return current, kind, false
Expand Down Expand Up @@ -108,7 +116,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re

current = current.FieldByName(fld)

return v.getStructFieldOK(current, ns)
return v.GetStructFieldOK(current, ns)
}

case reflect.Array, reflect.Slice:
Expand All @@ -129,7 +137,7 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
}
}

return v.getStructFieldOK(current.Index(arrIdx), namespace[startIdx:])
return v.GetStructFieldOK(current.Index(arrIdx), namespace[startIdx:])

case reflect.Map:
idx := strings.Index(namespace, leftBracket) + 1
Expand All @@ -148,47 +156,47 @@ func (v *Validate) getStructFieldOK(current reflect.Value, namespace string) (re
switch current.Type().Key().Kind() {
case reflect.Int:
i, _ := strconv.Atoi(key)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
case reflect.Int8:
i, _ := strconv.ParseInt(key, 10, 8)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int8(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int8(i))), namespace[endIdx+1:])
case reflect.Int16:
i, _ := strconv.ParseInt(key, 10, 16)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int16(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int16(i))), namespace[endIdx+1:])
case reflect.Int32:
i, _ := strconv.ParseInt(key, 10, 32)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(int32(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(int32(i))), namespace[endIdx+1:])
case reflect.Int64:
i, _ := strconv.ParseInt(key, 10, 64)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
case reflect.Uint:
i, _ := strconv.ParseUint(key, 10, 0)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint(i))), namespace[endIdx+1:])
case reflect.Uint8:
i, _ := strconv.ParseUint(key, 10, 8)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint8(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint8(i))), namespace[endIdx+1:])
case reflect.Uint16:
i, _ := strconv.ParseUint(key, 10, 16)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint16(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint16(i))), namespace[endIdx+1:])
case reflect.Uint32:
i, _ := strconv.ParseUint(key, 10, 32)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(uint32(i))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(uint32(i))), namespace[endIdx+1:])
case reflect.Uint64:
i, _ := strconv.ParseUint(key, 10, 64)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(i)), namespace[endIdx+1:])
case reflect.Float32:
f, _ := strconv.ParseFloat(key, 32)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(float32(f))), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(float32(f))), namespace[endIdx+1:])
case reflect.Float64:
f, _ := strconv.ParseFloat(key, 64)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(f)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(f)), namespace[endIdx+1:])
case reflect.Bool:
b, _ := strconv.ParseBool(key)
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(b)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(b)), namespace[endIdx+1:])

// reflect.Type = string
default:
return v.getStructFieldOK(current.MapIndex(reflect.ValueOf(key)), namespace[endIdx+1:])
return v.GetStructFieldOK(current.MapIndex(reflect.ValueOf(key)), namespace[endIdx+1:])
}
}

Expand Down
6 changes: 3 additions & 3 deletions validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (v *Validate) FieldWithValue(val interface{}, field interface{}, tag string
func (v *Validate) StructPartial(current interface{}, fields ...string) error {
v.initCheck()

sv, _ := v.extractType(reflect.ValueOf(current))
sv, _ := v.ExtractType(reflect.ValueOf(current))
name := sv.Type().Name()
m := map[string]*struct{}{}

Expand Down Expand Up @@ -340,7 +340,7 @@ func (v *Validate) StructPartial(current interface{}, fields ...string) error {
func (v *Validate) StructExcept(current interface{}, fields ...string) error {
v.initCheck()

sv, _ := v.extractType(reflect.ValueOf(current))
sv, _ := v.ExtractType(reflect.ValueOf(current))
name := sv.Type().Name()
m := map[string]*struct{}{}

Expand Down Expand Up @@ -435,7 +435,7 @@ func (v *Validate) traverseField(topStruct reflect.Value, currentStruct reflect.
v.tagsCache.Set(tag, cTag)
}

current, kind := v.extractType(current)
current, kind := v.ExtractType(current)
var typ reflect.Type

switch kind {
Expand Down
Loading