Skip to content

Commit

Permalink
added support for aliased types. (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
wesdean committed Jan 6, 2022
1 parent 270ada0 commit a7abbb6
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 13 deletions.
187 changes: 174 additions & 13 deletions flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}

tagDefault, hasDefaultTag := tag.Lookup("default")

fieldType, _ := tag.Lookup("type")

var renamed string
if override, exists := tag.Lookup("flag"); exists {
if override == "" {
Expand All @@ -142,7 +144,7 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
err = f.processFloat64(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage)

// NOTE check time.Duration before int64 since it is aliased from int64
case t == durationType:
case t == durationType, fieldType == "duration":
err = f.processDuration(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage)

case t.Kind() == reflect.Int64:
Expand All @@ -157,7 +159,7 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
case t.Kind() == reflect.Uint:
err = f.processUint(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage)

case t == stringSliceType:
case t == stringSliceType, fieldType == "stringSlice":
var override bool
if overrideValue, exists := tag.Lookup("override-value"); exists {
if value, err := strconv.ParseBool(overrideValue); err == nil {
Expand All @@ -166,7 +168,7 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
}
f.processStringSlice(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, override)

case t == stringToStringMapType:
case t == stringToStringMapType, fieldType == "stringMap":
f.processStringToStringMap(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage)

// ignore any other types
Expand All @@ -190,7 +192,21 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
}

func (f *FlagSetFiller) processStringToStringMap(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) {
casted := fieldRef.(*map[string]string)
casted, ok := fieldRef.(*map[string]string)
if !ok {
_ = f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
return parseStringToStringMap(s), nil
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
return
}
var val map[string]string
if hasDefaultTag {
val = parseStringToStringMap(tagDefault)
Expand All @@ -205,15 +221,43 @@ func (f *FlagSetFiller) processStringToStringMap(fieldRef interface{}, hasDefaul
}

func (f *FlagSetFiller) processStringSlice(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, override bool) {
casted := fieldRef.(*[]string)
casted, ok := fieldRef.(*[]string)
if !ok {
_ = f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
return parseStringSlice(s), nil
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
return
}
if hasDefaultTag {
*casted = parseStringSlice(tagDefault)
}
flagSet.Var(&strSliceVar{ref: casted, override: override}, renamed, usage)
}

func (f *FlagSetFiller) processUint(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*uint)
casted, ok := fieldRef.(*uint)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.Atoi(s)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal uint
if hasDefaultTag {
var asInt int
Expand All @@ -230,7 +274,21 @@ func (f *FlagSetFiller) processUint(fieldRef interface{}, hasDefaultTag bool, ta
}

func (f *FlagSetFiller) processUint64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*uint64)
casted, ok := fieldRef.(*uint64)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.ParseUint(s, 10, 64)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal uint64
if hasDefaultTag {
defaultVal, err = strconv.ParseUint(tagDefault, 10, 64)
Expand All @@ -245,7 +303,21 @@ func (f *FlagSetFiller) processUint64(fieldRef interface{}, hasDefaultTag bool,
}

func (f *FlagSetFiller) processInt(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*int)
casted, ok := fieldRef.(*int)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.Atoi(s)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal int
if hasDefaultTag {
defaultVal, err = strconv.Atoi(tagDefault)
Expand All @@ -260,7 +332,21 @@ func (f *FlagSetFiller) processInt(fieldRef interface{}, hasDefaultTag bool, tag
}

func (f *FlagSetFiller) processInt64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*int64)
casted, ok := fieldRef.(*int64)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.ParseInt(s, 10, 64)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal int64
if hasDefaultTag {
defaultVal, err = strconv.ParseInt(tagDefault, 10, 64)
Expand All @@ -275,7 +361,21 @@ func (f *FlagSetFiller) processInt64(fieldRef interface{}, hasDefaultTag bool, t
}

func (f *FlagSetFiller) processDuration(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*time.Duration)
casted, ok := fieldRef.(*time.Duration)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := time.ParseDuration(s)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal time.Duration
if hasDefaultTag {
defaultVal, err = time.ParseDuration(tagDefault)
Expand All @@ -290,7 +390,21 @@ func (f *FlagSetFiller) processDuration(fieldRef interface{}, hasDefaultTag bool
}

func (f *FlagSetFiller) processFloat64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*float64)
casted, ok := fieldRef.(*float64)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.ParseFloat(s, 64)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal float64
if hasDefaultTag {
defaultVal, err = strconv.ParseFloat(tagDefault, 64)
Expand All @@ -305,7 +419,21 @@ func (f *FlagSetFiller) processFloat64(fieldRef interface{}, hasDefaultTag bool,
}

func (f *FlagSetFiller) processBool(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) (err error) {
casted := fieldRef.(*bool)
casted, ok := fieldRef.(*bool)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value, err := strconv.ParseBool(s)
return value, err
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
}
var defaultVal bool
if hasDefaultTag {
defaultVal, err = strconv.ParseBool(tagDefault)
Expand All @@ -320,7 +448,21 @@ func (f *FlagSetFiller) processBool(fieldRef interface{}, hasDefaultTag bool, ta
}

func (f *FlagSetFiller) processString(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) {
casted := fieldRef.(*string)
casted, ok := fieldRef.(*string)
if !ok {
_ = f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
return s, nil
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
)
return
}
var defaultVal string
if hasDefaultTag {
defaultVal = tagDefault
Expand All @@ -330,6 +472,25 @@ func (f *FlagSetFiller) processString(fieldRef interface{}, hasDefaultTag bool,
flagSet.StringVar(casted, renamed, defaultVal, usage)
}

func (f *FlagSetFiller) processCustom(fieldRef interface{}, converter func(string) (interface{}, error), hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string) error {
if hasDefaultTag {
value, err := converter(tagDefault)
if err != nil {
return fmt.Errorf("failed to parse default into custom type: %w", err)
}
reflect.ValueOf(fieldRef).Elem().Set(reflect.ValueOf(value).Convert(reflect.TypeOf(fieldRef).Elem()))
}
flagSet.Func(renamed, usage, func(s string) error {
value, err := converter(s)
if err != nil {
return err
}
reflect.ValueOf(fieldRef).Elem().Set(reflect.ValueOf(value).Convert(reflect.TypeOf(fieldRef).Elem()))
return nil
})
return nil
}

type strSliceVar struct {
ref *[]string
override bool
Expand Down
98 changes: 98 additions & 0 deletions flagset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,104 @@ func TestStringFields(t *testing.T) {
assert.Equal(t, "val1", config.MultiWordName)
}

func TestCustomFields(t *testing.T) {
type CustomStringType string
type CustomBoolType bool
type CustomFloat64 float64
type CustomDuration time.Duration
type CustomInt64 int64
type CustomInt int
type CustomUint64 uint64
type CustomUint uint
type CustomStringSlice []string
type CustomStringMap map[string]string

t.Run("Default values", func(t *testing.T) {
type Config struct {
String CustomStringType `default:"stringValue"`
Bool CustomBoolType `default:"true"`
Float64 CustomFloat64 `default:"1.234"`
Duration CustomDuration `type:"duration" default:"2s"`
Int64 CustomInt64 `default:"-1"`
Int CustomInt `default:"-2"`
Uint64 CustomUint64 `default:"1"`
Uint CustomUint `default:"2"`
StringSlice CustomStringSlice `type:"stringSlice" default:"one,two"`
StringMap CustomStringMap `type:"stringMap" default:"one=value1,two=value2"`
}

var config Config

filler := flagsfiller.New()

var flagset flag.FlagSet
err := filler.Fill(&flagset, &config)
require.NoError(t, err)

err = flagset.Parse([]string{})
require.NoError(t, err)

assert.Equal(t, CustomStringType("stringValue"), config.String)
assert.Equal(t, CustomBoolType(true), config.Bool)
assert.Equal(t, CustomFloat64(1.234), config.Float64)
assert.Equal(t, CustomDuration(2*time.Second), config.Duration)
assert.Equal(t, CustomInt64(-1), config.Int64)
assert.Equal(t, CustomInt(-2), config.Int)
assert.Equal(t, CustomUint64(1), config.Uint64)
assert.Equal(t, CustomUint(2), config.Uint)
assert.Equal(t, CustomStringSlice{"one", "two"}, config.StringSlice)
assert.Equal(t, CustomStringMap{"one": "value1", "two": "value2"}, config.StringMap)
})

t.Run("Values set from arguments", func(t *testing.T) {
type Config struct {
String CustomStringType
Bool CustomBoolType
Float64 CustomFloat64
Duration CustomDuration `type:"duration"`
Int64 CustomInt64
Int CustomInt
Uint64 CustomUint64
Uint CustomUint
StringSlice CustomStringSlice `type:"stringSlice"`
StringMap CustomStringMap `type:"stringMap"`
}

var config Config

filler := flagsfiller.New()

var flagset flag.FlagSet
err := filler.Fill(&flagset, &config)
require.NoError(t, err)

err = flagset.Parse([]string{
"--string", "stringValue",
"--bool", "true",
"--float-64", "1.234",
"--duration", "2s",
"--int-64", "-1",
"--int", "-2",
"--uint-64", "1",
"--uint", "2",
"--string-slice", "one,two",
"--string-map", "one=value1,two=value2",
})
require.NoError(t, err)

assert.Equal(t, CustomStringType("stringValue"), config.String)
assert.Equal(t, CustomBoolType(true), config.Bool)
assert.Equal(t, CustomFloat64(1.234), config.Float64)
assert.Equal(t, CustomDuration(2*time.Second), config.Duration)
assert.Equal(t, CustomInt64(-1), config.Int64)
assert.Equal(t, CustomInt(-2), config.Int)
assert.Equal(t, CustomUint64(1), config.Uint64)
assert.Equal(t, CustomUint(2), config.Uint)
assert.Equal(t, CustomStringSlice{"one", "two"}, config.StringSlice)
assert.Equal(t, CustomStringMap{"one": "value1", "two": "value2"}, config.StringMap)
})
}

func TestUsage(t *testing.T) {
type Config struct {
MultiWordName string `usage:"usage goes here"`
Expand Down

0 comments on commit a7abbb6

Please sign in to comment.