diff --git a/common/viperutil/config_util.go b/common/viperutil/config_util.go index 8aab468fb81..6c6c02109d0 100644 --- a/common/viperutil/config_util.go +++ b/common/viperutil/config_util.go @@ -109,163 +109,155 @@ func unmarshalJSON(val interface{}) (map[string]string, bool) { // customDecodeHook adds the additional functions of parsing durations from strings // as well as parsing strings of the format "[thing1, thing2, thing3]" into string slices // Note that whitespace around slice elements is removed -func customDecodeHook() mapstructure.DecodeHookFunc { +func customDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { durationHook := mapstructure.StringToTimeDurationHookFunc() - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { - dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data) - if err == nil { - if _, ok := dur.(time.Duration); ok { - return dur, nil - } + dur, err := mapstructure.DecodeHookExec(durationHook, f, t, data) + if err == nil { + if _, ok := dur.(time.Duration); ok { + return dur, nil } + } - if f.Kind() != reflect.String { - return data, nil - } + if f.Kind() != reflect.String { + return data, nil + } - raw := data.(string) - l := len(raw) - if l > 1 && raw[0] == '[' && raw[l-1] == ']' { - slice := strings.Split(raw[1:l-1], ",") - for i, v := range slice { - slice[i] = strings.TrimSpace(v) - } - return slice, nil + raw := data.(string) + l := len(raw) + if l > 1 && raw[0] == '[' && raw[l-1] == ']' { + slice := strings.Split(raw[1:l-1], ",") + for i, v := range slice { + slice[i] = strings.TrimSpace(v) } - - return data, nil + return slice, nil } + + return data, nil } -func byteSizeDecodeHook() mapstructure.DecodeHookFunc { - return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { - if f != reflect.String || t != reflect.Uint32 { +func byteSizeDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Uint32 { + return data, nil + } + raw := data.(string) + if raw == "" { + return data, nil + } + var re = regexp.MustCompile(`^(?P[0-9]+)\s*(?i)(?P(k|m|g))b?$`) + if re.MatchString(raw) { + size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64) + if err != nil { return data, nil } - raw := data.(string) - if raw == "" { - return data, nil + unit := re.ReplaceAllString(raw, "${unit}") + switch strings.ToLower(unit) { + case "g": + size = size << 10 + fallthrough + case "m": + size = size << 10 + fallthrough + case "k": + size = size << 10 } - var re = regexp.MustCompile(`^(?P[0-9]+)\s*(?i)(?P(k|m|g))b?$`) - if re.MatchString(raw) { - size, err := strconv.ParseUint(re.ReplaceAllString(raw, "${size}"), 0, 64) - if err != nil { - return data, nil - } - unit := re.ReplaceAllString(raw, "${unit}") - switch strings.ToLower(unit) { - case "g": - size = size << 10 - fallthrough - case "m": - size = size << 10 - fallthrough - case "k": - size = size << 10 - } - if size > math.MaxUint32 { - return size, fmt.Errorf("value '%s' overflows uint32", raw) - } - return size, nil + if size > math.MaxUint32 { + return size, fmt.Errorf("value '%s' overflows uint32", raw) } - return data, nil + return size, nil } + return data, nil } -func stringFromFileDecodeHook() mapstructure.DecodeHookFunc { - return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { - // "to" type should be string - if t != reflect.String { - return data, nil +func stringFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { + // "to" type should be string + if t != reflect.String { + return data, nil + } + // "from" type should be map + if f != reflect.Map { + return data, nil + } + v := reflect.ValueOf(data) + switch v.Kind() { + case reflect.String: + return data, nil + case reflect.Map: + d := data.(map[string]interface{}) + fileName, ok := d["File"] + if !ok { + fileName, ok = d["file"] } - // "from" type should be map - if f != reflect.Map { - return data, nil + switch { + case ok && fileName != nil: + bytes, err := ioutil.ReadFile(fileName.(string)) + if err != nil { + return data, err + } + return string(bytes), nil + case ok: + // fileName was nil + return nil, fmt.Errorf("Value of File: was nil") } - v := reflect.ValueOf(data) - switch v.Kind() { - case reflect.String: - return data, nil - case reflect.Map: - d := data.(map[string]interface{}) - fileName, ok := d["File"] + } + return data, nil +} + +func pemBlocksFromFileDecodeHook(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { + // "to" type should be string + if t != reflect.Slice { + return data, nil + } + // "from" type should be map + if f != reflect.Map { + return data, nil + } + v := reflect.ValueOf(data) + switch v.Kind() { + case reflect.String: + return data, nil + case reflect.Map: + var fileName string + var ok bool + switch d := data.(type) { + case map[string]string: + fileName, ok = d["File"] if !ok { fileName, ok = d["file"] } - switch { - case ok && fileName != nil: - bytes, err := ioutil.ReadFile(fileName.(string)) - if err != nil { - return data, err - } - return string(bytes), nil - case ok: - // fileName was nil - return nil, fmt.Errorf("Value of File: was nil") + case map[string]interface{}: + var fileI interface{} + fileI, ok = d["File"] + if !ok { + fileI = d["file"] } + fileName, ok = fileI.(string) } - return data, nil - } -} -func pemBlocksFromFileDecodeHook() mapstructure.DecodeHookFunc { - return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) { - // "to" type should be string - if t != reflect.Slice { - return data, nil - } - // "from" type should be map - if f != reflect.Map { - return data, nil - } - v := reflect.ValueOf(data) - switch v.Kind() { - case reflect.String: - return data, nil - case reflect.Map: - var fileName string - var ok bool - switch d := data.(type) { - case map[string]string: - fileName, ok = d["File"] - if !ok { - fileName, ok = d["file"] - } - case map[string]interface{}: - var fileI interface{} - fileI, ok = d["File"] - if !ok { - fileI = d["file"] - } - fileName, ok = fileI.(string) + switch { + case ok && fileName != "": + var result []string + bytes, err := ioutil.ReadFile(fileName) + if err != nil { + return data, err } - - switch { - case ok && fileName != "": - var result []string - bytes, err := ioutil.ReadFile(fileName) - if err != nil { - return data, err + for len(bytes) > 0 { + var block *pem.Block + block, bytes = pem.Decode(bytes) + if block == nil { + break } - for len(bytes) > 0 { - var block *pem.Block - block, bytes = pem.Decode(bytes) - if block == nil { - break - } - if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { - continue - } - result = append(result, string(pem.EncodeToMemory(block))) + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue } - return result, nil - case ok: - // fileName was nil - return nil, fmt.Errorf("Value of File: was nil") + result = append(result, string(pem.EncodeToMemory(block))) } + return result, nil + case ok: + // fileName was nil + return nil, fmt.Errorf("Value of File: was nil") } - return data, nil } + return data, nil } var kafkaVersionConstraints map[sarama.KafkaVersion]version.Constraints @@ -285,25 +277,23 @@ func init() { kafkaVersionConstraints[sarama.V1_0_0_0], _ = version.NewConstraint(">=1.0.0") } -func kafkaVersionDecodeHook() mapstructure.DecodeHookFunc { - return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { - if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) { - return data, nil - } +func kafkaVersionDecodeHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { + if f.Kind() != reflect.String || t != reflect.TypeOf(sarama.KafkaVersion{}) { + return data, nil + } - v, err := version.NewVersion(data.(string)) - if err != nil { - return nil, fmt.Errorf("Unable to parse Kafka version: %s", err) - } + v, err := version.NewVersion(data.(string)) + if err != nil { + return nil, fmt.Errorf("Unable to parse Kafka version: %s", err) + } - for kafkaVersion, constraints := range kafkaVersionConstraints { - if constraints.Check(v) { - return kafkaVersion, nil - } + for kafkaVersion, constraints := range kafkaVersionConstraints { + if constraints.Check(v) { + return kafkaVersion, nil } - - return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data) } + + return nil, fmt.Errorf("Unsupported Kafka version: '%s'", data) } func bccspHook(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) { @@ -347,11 +337,11 @@ func EnhancedExactUnmarshal(v *viper.Viper, output interface{}) error { WeaklyTypedInput: true, DecodeHook: mapstructure.ComposeDecodeHookFunc( bccspHook, - customDecodeHook(), - byteSizeDecodeHook(), - stringFromFileDecodeHook(), - pemBlocksFromFileDecodeHook(), - kafkaVersionDecodeHook(), + customDecodeHook, + byteSizeDecodeHook, + stringFromFileDecodeHook, + pemBlocksFromFileDecodeHook, + kafkaVersionDecodeHook, ), }