Skip to content

Commit

Permalink
[FAB-17176] Make decode hooks consistent in viperutil (#424)
Browse files Browse the repository at this point in the history
Signed-off-by: Tiffany Harris <tiffany.harris@ibm.com>
  • Loading branch information
stephyee authored and Jason Yellick committed Dec 17, 2019
1 parent 28c6efd commit 2fbd83d
Showing 1 changed file with 138 additions and 148 deletions.
286 changes: 138 additions & 148 deletions common/viperutil/config_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,163 +109,155 @@ func unmarshalJSON(val interface{}) (map[string]string, bool) {
// customDecodeHook adds the additional functions of parsing durations from strings
// as well as parsing strings of the format "[thing1, thing2, thing3]" into string slices
// Note that whitespace around slice elements is removed
func customDecodeHook() mapstructure.DecodeHookFunc {
func customDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
durationHook := mapstructure.StringToTimeDurationHookFunc()
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data)
if err == nil {
if _, ok := dur.(time.Duration); ok {
return dur, nil
}
dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data)
if err == nil {
if _, ok := dur.(time.Duration); ok {
return dur, nil
}
}

if f.Kind() != reflect.String {
return data, nil
}
if f.Kind() != reflect.String {
return data, nil
}

raw := data.(string)
l := len(raw)
if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
slice := strings.Split(raw[1:l-1], ",")
for i, v := range slice {
slice[i] = strings.TrimSpace(v)
}
return slice, nil
raw := data.(string)
l := len(raw)
if l > 1 && raw[0] == '[' && raw[l-1] == ']' {
slice := strings.Split(raw[1:l-1], ",")
for i, v := range slice {
slice[i] = strings.TrimSpace(v)
}

return data, nil
return slice, nil
}

return data, nil
}

func byteSizeDecodeHook() mapstructure.DecodeHookFunc {
return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
if f != reflect.String || t != reflect.Uint32 {
func byteSizeDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
if f != reflect.String || t != reflect.Uint32 {
return data, nil
}
raw := data.(string)
if raw == "" {
return data, nil
}
var re = regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
if re.MatchString(raw) {
size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
if err != nil {
return data, nil
}
raw := data.(string)
if raw == "" {
return data, nil
unit := re.ReplaceAllString(raw, "${unit}")
switch strings.ToLower(unit) {
case "g":
size = size << 10
fallthrough
case "m":
size = size << 10
fallthrough
case "k":
size = size << 10
}
var re = regexp.MustCompile(`^(?P<size>[0-9]+)\s*(?i)(?P<unit>(k|m|g))b?$`)
if re.MatchString(raw) {
size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64)
if err != nil {
return data, nil
}
unit := re.ReplaceAllString(raw, "${unit}")
switch strings.ToLower(unit) {
case "g":
size = size << 10
fallthrough
case "m":
size = size << 10
fallthrough
case "k":
size = size << 10
}
if size > math.MaxUint32 {
return size, fmt.Errorf("value '%s' overflows uint32", raw)
}
return size, nil
if size > math.MaxUint32 {
return size, fmt.Errorf("value '%s' overflows uint32", raw)
}
return data, nil
return size, nil
}
return data, nil
}

func stringFromFileDecodeHook() mapstructure.DecodeHookFunc {
return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
// "to" type should be string
if t != reflect.String {
return data, nil
func stringFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
// "to" type should be string
if t != reflect.String {
return data, nil
}
// "from" type should be map
if f != reflect.Map {
return data, nil
}
v := reflect.ValueOf(data)
switch v.Kind() {
case reflect.String:
return data, nil
case reflect.Map:
d := data.(map[string]interface{})
fileName, ok := d["File"]
if !ok {
fileName, ok = d["file"]
}
// "from" type should be map
if f != reflect.Map {
return data, nil
switch {
case ok && fileName != nil:
bytes, err := ioutil.ReadFile(fileName.(string))
if err != nil {
return data, err
}
return string(bytes), nil
case ok:
// fileName was nil
return nil, fmt.Errorf("Value of File: was nil")
}
v := reflect.ValueOf(data)
switch v.Kind() {
case reflect.String:
return data, nil
case reflect.Map:
d := data.(map[string]interface{})
fileName, ok := d["File"]
}
return data, nil
}

func pemBlocksFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
// "to" type should be string
if t != reflect.Slice {
return data, nil
}
// "from" type should be map
if f != reflect.Map {
return data, nil
}
v := reflect.ValueOf(data)
switch v.Kind() {
case reflect.String:
return data, nil
case reflect.Map:
var fileName string
var ok bool
switch d := data.(type) {
case map[string]string:
fileName, ok = d["File"]
if !ok {
fileName, ok = d["file"]
}
switch {
case ok && fileName != nil:
bytes, err := ioutil.ReadFile(fileName.(string))
if err != nil {
return data, err
}
return string(bytes), nil
case ok:
// fileName was nil
return nil, fmt.Errorf("Value of File: was nil")
case map[string]interface{}:
var fileI interface{}
fileI, ok = d["File"]
if !ok {
fileI = d["file"]
}
fileName, ok = fileI.(string)
}
return data, nil
}
}

func pemBlocksFromFileDecodeHook() mapstructure.DecodeHookFunc {
return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
// "to" type should be string
if t != reflect.Slice {
return data, nil
}
// "from" type should be map
if f != reflect.Map {
return data, nil
}
v := reflect.ValueOf(data)
switch v.Kind() {
case reflect.String:
return data, nil
case reflect.Map:
var fileName string
var ok bool
switch d := data.(type) {
case map[string]string:
fileName, ok = d["File"]
if !ok {
fileName, ok = d["file"]
}
case map[string]interface{}:
var fileI interface{}
fileI, ok = d["File"]
if !ok {
fileI = d["file"]
}
fileName, ok = fileI.(string)
switch {
case ok && fileName != "":
var result []string
bytes, err := ioutil.ReadFile(fileName)
if err != nil {
return data, err
}

switch {
case ok && fileName != "":
var result []string
bytes, err := ioutil.ReadFile(fileName)
if err != nil {
return data, err
for len(bytes) > 0 {
var block *pem.Block
block, bytes = pem.Decode(bytes)
if block == nil {
break
}
for len(bytes) > 0 {
var block *pem.Block
block, bytes = pem.Decode(bytes)
if block == nil {
break
}
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
result = append(result, string(pem.EncodeToMemory(block)))
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
continue
}
return result, nil
case ok:
// fileName was nil
return nil, fmt.Errorf("Value of File: was nil")
result = append(result, string(pem.EncodeToMemory(block)))
}
return result, nil
case ok:
// fileName was nil
return nil, fmt.Errorf("Value of File: was nil")
}
return data, nil
}
return data, nil
}

var kafkaVersionConstraints map[sarama.KafkaVersion]version.Constraints
Expand All @@ -285,25 +277,23 @@ func init() {
kafkaVersionConstraints[sarama.V1_0_0_0], _ = version.NewConstraint(">=1.0.0")
}

func kafkaVersionDecodeHook() mapstructure.DecodeHookFunc {
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) {
return data, nil
}
func kafkaVersionDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) {
return data, nil
}

v, err := version.NewVersion(data.(string))
if err != nil {
return nil, fmt.Errorf("Unable to parse Kafka version: %s", err)
}
v, err := version.NewVersion(data.(string))
if err != nil {
return nil, fmt.Errorf("Unable to parse Kafka version: %s", err)
}

for kafkaVersion, constraints := range kafkaVersionConstraints {
if constraints.Check(v) {
return kafkaVersion, nil
}
for kafkaVersion, constraints := range kafkaVersionConstraints {
if constraints.Check(v) {
return kafkaVersion, nil
}

return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data)
}

return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data)
}

func bccspHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
Expand Down Expand Up @@ -347,11 +337,11 @@ func EnhancedExactUnmarshal(v *viper.Viper, output interface{}) error {
WeaklyTypedInput: true,
DecodeHook: mapstructure.ComposeDecodeHookFunc(
bccspHook,
customDecodeHook(),
byteSizeDecodeHook(),
stringFromFileDecodeHook(),
pemBlocksFromFileDecodeHook(),
kafkaVersionDecodeHook(),
customDecodeHook,
byteSizeDecodeHook,
stringFromFileDecodeHook,
pemBlocksFromFileDecodeHook,
kafkaVersionDecodeHook,
),
}

Expand Down

0 comments on commit 2fbd83d

Please sign in to comment.