Skip to content

Commit

Permalink
fix: parsing config input and convert to match the type
Browse files Browse the repository at this point in the history
  • Loading branch information
akoserwal committed May 30, 2024
1 parent fbac5fa commit d2baffa
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 15 deletions.
87 changes: 73 additions & 14 deletions config/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"fmt"
"regexp"
"strconv"
"strings"

"github.com/go-kratos/kratos/v2/encoding"
Expand Down Expand Up @@ -45,6 +46,14 @@ func WithDecoder(d Decoder) Option {
}
}

// WithNewResolver with config resolver.
// bool input will enable conversion of config to data types
func WithNewResolver(enableConvertToType bool) Option {
return func(o *options) {
o.resolver = newDefaultResolver(enableConvertToType)
}
}

// WithResolver with config resolver.
func WithResolver(r Resolver) Option {
return func(o *options) {
Expand Down Expand Up @@ -82,26 +91,27 @@ func defaultDecoder(src *KeyValue, target map[string]interface{}) error {
return fmt.Errorf("unsupported key: %s format: %s", src.Key, src.Format)
}

func newDefaultResolver(enableConvertToType bool) func(map[string]interface{}) error {
return func(input map[string]interface{}) error {
mapper := mapper(input)
return resolver(input, mapper, enableConvertToType)
}
}

// defaultResolver resolve placeholder in map value,
// placeholder format in ${key:default}.
func defaultResolver(input map[string]interface{}) error {
mapper := func(name string) string {
args := strings.SplitN(strings.TrimSpace(name), ":", 2) //nolint:gomnd
if v, has := readValue(input, args[0]); has {
s, _ := v.String()
return s
} else if len(args) > 1 { // default value
return args[1]
}
return ""
}
mapper := mapper(input)
return resolver(input, mapper, false)
}

func resolver(input map[string]interface{}, mapper func(name string) string, toType bool) error {
var resolve func(map[string]interface{}) error
resolve = func(sub map[string]interface{}) error {
for k, v := range sub {
switch vt := v.(type) {
case string:
sub[k] = expand(vt, mapper)
sub[k] = expand(vt, mapper, toType)
case map[string]interface{}:
if err := resolve(vt); err != nil {
return err
Expand All @@ -110,7 +120,7 @@ func defaultResolver(input map[string]interface{}) error {
for i, iface := range vt {
switch it := iface.(type) {
case string:
vt[i] = expand(it, mapper)
vt[i] = expand(it, mapper, toType)
case map[string]interface{}:
if err := resolve(it); err != nil {
return err
Expand All @@ -125,12 +135,61 @@ func defaultResolver(input map[string]interface{}) error {
return resolve(input)
}

func expand(s string, mapping func(string) string) string {
func mapper(input map[string]interface{}) func(name string) string {
mapper := func(name string) string {
args := strings.SplitN(strings.TrimSpace(name), ":", 2) //nolint:gomnd
if v, has := readValue(input, args[0]); has {
s, _ := v.String()
return s
} else if len(args) > 1 { // default value
return args[1]
}
return ""
}
return mapper
}

func convertToType(input string) interface{} {
// Check if the input is a string with quotes
if strings.HasPrefix(input, "\"") && strings.HasSuffix(input, "\"") {
// Trim the quotes and return the string value
return strings.Trim(input, "\"")
}

// Try converting to bool
if input == "true" || input == "false" {
b, _ := strconv.ParseBool(input)
return b
}

// Try converting to float64
if strings.Contains(input, ".") {
if f, err := strconv.ParseFloat(input, 64); err == nil {
return f
}
}

// Try converting to int64
if i, err := strconv.ParseInt(input, 10, 64); err == nil {
return i
}

// Default to string if no other conversion succeeds
return input
}

func expand(s string, mapping func(string) string, toType bool) interface{} {
r := regexp.MustCompile(`\${(.*?)}`)
re := r.FindAllStringSubmatch(s, -1)
var ct interface{}
for _, i := range re {
if len(i) == 2 { //nolint:gomnd
s = strings.ReplaceAll(s, i[0], mapping(i[1]))
m := mapping(i[1])
if toType {
ct = convertToType(m)
return ct
}
s = strings.ReplaceAll(s, i[0], m)
}
}
return s
Expand Down
160 changes: 159 additions & 1 deletion config/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,164 @@ func TestDefaultResolver(t *testing.T) {
}
}

func TestNewDefaultResolver(t *testing.T) {
var (
portString = "8080"
countInt = 10
rateFloat = 0.9
)

data := map[string]interface{}{
"foo": map[string]interface{}{
"bar": map[string]interface{}{
"notexist": "${NOTEXIST:100}",
"port": "${PORT:\"8081\"}",
"count": "${COUNT:\"0\"}",
"enable": "${ENABLE:false}",
"rate": "${RATE}",
"empty": "${EMPTY:foobar}",
"url": "${URL:\"http://example.com\"}",
"array": []interface{}{
"${PORT}",
map[string]interface{}{"foobar": "${NOTEXIST:\"8081\"}"},
},
"value1": "${test.value}",
"value2": "$PORT",
"value3": "abc${PORT}foo${COUNT}bar",
"value4": "${foo${bar}}",
},
},
"test": map[string]interface{}{
"value": "foobar",
},
"PORT": "\"8080\"",
"COUNT": "\"10\"",
"ENABLE": "true",
"RATE": "0.9",
"EMPTY": "",
}

tests := []struct {
name string
path string
expect interface{}
}{
{
name: "test not exist int env with default",
path: "foo.bar.notexist",
expect: 100,
},
{
name: "test string with default",
path: "foo.bar.port",
expect: portString,
},
{
name: "test int with default",
path: "foo.bar.count",
expect: countInt,
},
{
name: "test bool with default",
path: "foo.bar.enable",
expect: true,
},
{
name: "test float without default",
path: "foo.bar.rate",
expect: rateFloat,
},
{
name: "test empty value with default",
path: "foo.bar.empty",
expect: "",
},
{
name: "test url with default",
path: "foo.bar.url",
expect: "http://example.com",
},
{
name: "test array",
path: "foo.bar.array",
expect: []interface{}{portString, map[string]interface{}{"foobar": "8081"}},
},
{
name: "test ${test.value}",
path: "foo.bar.value1",
expect: "foobar",
},
{
name: "test $PORT",
path: "foo.bar.value2",
expect: "$PORT",
},
//{
// name: "test abc${PORT}foo${COUNT}bar",
// path: "foo.bar.value3",
// expect: "abc8080foo10bar",
//},
{
name: "test ${foo${bar}}",
path: "foo.bar.value4",
expect: "",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
fn := newDefaultResolver(true)
err := fn(data)
if err != nil {
t.Fatal(err)
}
rd := reader{
values: data,
}
if v, ok := rd.Value(test.path); ok {
var actual interface{}
switch test.expect.(type) {
case int:
if actual, err = v.Int(); err == nil {
if !reflect.DeepEqual(test.expect.(int), int(actual.(int64))) {
t.Fatal("expect is not equal to actual")
}
}
case string:
if actual, err = v.String(); err == nil {
if !reflect.DeepEqual(test.expect, actual) {
t.Fatal("expect is not equal to actual")
}
}
case bool:
if actual, err = v.Bool(); err == nil {
if !reflect.DeepEqual(test.expect, actual) {
t.Fatal("expect is not equal to actual")
}
}
case float64:
if actual, err = v.Float(); err == nil {
if !reflect.DeepEqual(test.expect, actual) {
t.Fatal("expect is not equal to actual")
}
}
default:
actual = v.Load()
if !reflect.DeepEqual(test.expect, actual) {
t.Logf("expect: %#v, actural: %#v", test.expect, actual)
t.Fail()
}
}
if err != nil {
t.Error(err)
}
} else {
t.Error("value path not found")
}
})
}
}

func TestExpand(t *testing.T) {
tests := []struct {
input string
Expand All @@ -221,7 +379,7 @@ func TestExpand(t *testing.T) {
},
}
for _, tt := range tests {
if got := expand(tt.input, tt.mapping); got != tt.want {
if got := expand(tt.input, tt.mapping, false); got != tt.want {
t.Errorf("expand() want: %s, got: %s", tt.want, got)
}
}
Expand Down

0 comments on commit d2baffa

Please sign in to comment.