From dad3575408197fce7bfa2a0387030e01cc885833 Mon Sep 17 00:00:00 2001 From: Valient Gough Date: Mon, 13 Oct 2025 11:36:20 -0700 Subject: [PATCH 1/2] add option for required fields --- README.md | 31 ++++ example_test.go | 40 ++++- flagset.go | 126 +++++++++++----- flagset_test.go | 386 +++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 539 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 2711f87..234acb5 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ import "github.com/itzg/go-flagsfiller" - Allows defaults to be given via struct tag `default` - Falls back to using instance field values as declared default - Declare flag usage via struct tag `usage` +- Mark flags as required via struct tag `required` and validate with `Verify()` method (cannot be combined with `default` tag) - Can be combined with other modules, such as [google/subcommands](https://github.com/google/subcommands) for sub-command processing. Can also be integrated with [spf13/cobra](https://github.com/spf13/cobra) by using pflag's [AddGoFlagSet](https://godoc.org/github.com/spf13/pflag#FlagSet.AddGoFlagSet) - Beyond the standard types supported by flag.FlagSet also includes support for: - `[]string` where repetition of the argument appends to the slice and/or an argument value can contain a comma or newline-separated list of values. For example: `--arg one --arg two,three` @@ -93,6 +94,36 @@ The following shows an example of the usage provided when passing `--help`: How long to wait (default 5s) ``` +## Required flags + +Flags can be marked as required using the `required:"true"` struct tag. After parsing command-line arguments, call the `Verify()` method to ensure all required flags have been provided: + +```go +type Config struct { + Host string `required:"true" usage:"The remote host"` + Port int `default:"8080" usage:"The port"` + Username string `required:"true" usage:"Username for authentication"` +} + +var config Config + +filler := flagsfiller.New() +err := filler.Fill(flag.CommandLine, &config) +if err != nil { + log.Fatal(err) +} + +flag.Parse() + +// Verify all required fields are set +err = filler.Verify() +if err != nil { + log.Fatal(err) // Will fail if Host or Username not provided +} +``` + +**Note:** A field cannot be both required and have a default value. Attempting to use both tags will result in an error during `Fill()`. + ## Real world example [saml-auth-proxy](https://github.com/itzg/saml-auth-proxy) shows an end-to-end usage of flagsfiller where the main function fills the flags, maps those to environment variables with [envy](https://github.com/jamiealquiza/envy), and parses the command line: diff --git a/example_test.go b/example_test.go index 28cdd4a..2b1fed1 100644 --- a/example_test.go +++ b/example_test.go @@ -3,9 +3,10 @@ package flagsfiller_test import ( "flag" "fmt" - "github.com/itzg/go-flagsfiller" "log" "time" + + "github.com/itzg/go-flagsfiller" ) func Example() { @@ -13,7 +14,7 @@ func Example() { Host string `default:"localhost" usage:"The remote host"` Enabled bool `default:"true" usage:"Turn it on"` Automatic bool `default:"false" usage:"Make it automatic" aliases:"a"` - Retries int `default:"1" usage:"Retry" aliases:"r,t"` + Retries int `default:"1" usage:"Retry" aliases:"r,t"` Timeout time.Duration `default:"5s" usage:"How long to wait"` } @@ -36,3 +37,38 @@ func Example() { // Output: // {Host:external.svc Enabled:true Automatic:true Retries:2 Timeout:10m0s} } + +func ExampleFlagSetFiller_Verify() { + type Config struct { + Host string `required:"true" usage:"The remote host"` + Port int `default:"8080" usage:"The port"` + Username string `required:"true" usage:"Username for authentication"` + } + + var config Config + + flagset := flag.NewFlagSet("ExampleVerify", flag.ContinueOnError) + + filler := flagsfiller.New() + err := filler.Fill(flagset, &config) + if err != nil { + log.Fatal(err) + } + + // Parse with required fields provided + err = flagset.Parse([]string{"--host", "example.com", "--username", "admin"}) + if err != nil { + log.Fatal(err) + } + + // Verify all required fields are set + err = filler.Verify() + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + fmt.Printf("Config validated: Host=%s, Port=%d, Username=%s\n", config.Host, config.Port, config.Username) + // Output: + // Config validated: Host=example.com, Port=8080, Username=admin +} diff --git a/flagset.go b/flagset.go index 5d4723b..0831804 100644 --- a/flagset.go +++ b/flagset.go @@ -24,19 +24,21 @@ const ( TagFlag = "flag" TagFlatten = "flatten" TagOverrideValue = "override-value" + TagRequired = "required" TagType = "type" ) // FlagSetFiller is used to map the fields of a struct into flags of a flag.FlagSet type FlagSetFiller struct { - options *fillerOptions + options *fillerOptions + requiredFields map[string]any // tracks required field references by flag name } // Parse is a convenience function that creates a FlagSetFiller with the given options, // fills and maps the flags from the given struct reference into flag.CommandLine, and uses // flag.Parse to parse the os.Args. // Returns an error if the given struct could not be used for filling flags. -func Parse(from interface{}, options ...FillerOption) error { +func Parse(from any, options ...FillerOption) error { filler := New(options...) err := filler.Fill(flag.CommandLine, from) if err != nil { @@ -49,23 +51,61 @@ func Parse(from interface{}, options ...FillerOption) error { // New creates a new FlagSetFiller with zero or more of the given FillerOption's func New(options ...FillerOption) *FlagSetFiller { - return &FlagSetFiller{options: newFillerOptions(options...)} + return &FlagSetFiller{ + options: newFillerOptions(options...), + requiredFields: make(map[string]any), + } } // Fill populates the flagSet with a flag for each field in given struct passed in the 'from' // argument which must be a struct reference. // Fill returns an error when a non-struct reference is passed as 'from' or a field has a // default tag which could not converted to the field's type. -func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from interface{}) error { +func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from any) error { v := reflect.ValueOf(from) t := v.Type() - if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + if t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct { return f.walkFields(flagSet, "", v.Elem(), t.Elem()) } else { return fmt.Errorf("can only fill from struct pointer, but it was %s", t.Kind()) } } +// Verify checks that all required fields have been set on the struct. +// Returns an error listing any required fields that were not provided. +func (f *FlagSetFiller) Verify() error { + var missingFlags []string + + for flagName, fieldRef := range f.requiredFields { + v := reflect.ValueOf(fieldRef) + if v.Kind() == reflect.Pointer { + v = v.Elem() + } + + // Check if the field is unset + isUnset := false + if v.IsZero() { + isUnset = true + } else if v.Kind() == reflect.Slice && v.Len() == 0 { + // Empty slices are considered unset + isUnset = true + } else if v.Kind() == reflect.Map && v.Len() == 0 { + // Empty maps are considered unset + isUnset = true + } + + if isUnset { + missingFlags = append(missingFlags, flagName) + } + } + + if len(missingFlags) > 0 { + return fmt.Errorf("required flags not set: %s", strings.Join(missingFlags, ", ")) + } + + return nil +} + // isSupportedStruct checks if the given field reference is a registered extended type or implements // encoding.TextUnmarshaler func isSupportedStruct(in any) bool { @@ -93,8 +133,8 @@ func getTypeName(t reflect.Type) string { } func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, - structVal reflect.Value, structType reflect.Type) error { - + structVal reflect.Value, structType reflect.Type, +) error { if prefix != "" { prefix += "-" } @@ -102,7 +142,7 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, addr := fieldValue.Addr() // make sure it is exported/public ftype := field.Type - if field.Type.Kind() == reflect.Ptr { + if field.Type.Kind() == reflect.Pointer { ftype = field.Type.Elem() } if addr.CanInterface() { @@ -143,7 +183,7 @@ func (f *FlagSetFiller) walkFields(flagSet *flag.FlagSet, prefix string, return fmt.Errorf("failed to process %s of %s: %w", field.Name, structType.String(), err) } - case reflect.Ptr: + case reflect.Pointer: if fieldValue.CanSet() && field.Type.Elem().Kind() == reflect.Struct { // fieldTypeName := getTypeName(field.Type.Elem()) // fill the pointer with a new struct of their type if it is nil @@ -196,9 +236,9 @@ func shouldFlatten(field reflect.StructField) bool { return value == "true" } -func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}, - name string, t reflect.Type, tag reflect.StructTag) (err error) { - +func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef any, + name string, t reflect.Type, tag reflect.StructTag, +) (err error) { var envName string if override, exists := tag.Lookup(TagEnv); exists { envName = override @@ -219,6 +259,12 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{} fieldType, _ := tag.Lookup(TagType) + // Check for required tag and validate it doesn't conflict with default + _, hasRequiredTag := tag.Lookup(TagRequired) + if hasRequiredTag && hasDefaultTag { + return fmt.Errorf("field cannot be both required and have a default value") + } + var renamed string if override, exists := tag.Lookup(TagFlag); exists { if override == "" { @@ -233,7 +279,14 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{} if isSupportedStruct(fieldRef) { handler := extendedTypes[getTypeName(t)] err = handler(tag, fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases) - return err + if err != nil { + return err + } + // Record required fields for extended types + if hasRequiredTag { + f.requiredFields[renamed] = fieldRef + } + return nil } switch { @@ -281,6 +334,11 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{} return err } + // Record required fields + if hasRequiredTag { + f.requiredFields[renamed] = fieldRef + } + if !f.options.noSetFromEnv && envName != "" { if val, exists := os.LookupEnv(envName); exists { err := flagSet.Lookup(renamed).Value.Set(val) @@ -294,12 +352,12 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{} return nil } -func (f *FlagSetFiller) processStringToStringMap(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) { +func (f *FlagSetFiller) processStringToStringMap(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) { casted, ok := fieldRef.(*map[string]string) if !ok { _ = f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { return parseStringToStringMap(s), nil }, hasDefaultTag, @@ -329,12 +387,12 @@ func (f *FlagSetFiller) processStringToStringMap(fieldRef interface{}, hasDefaul } } -func (f *FlagSetFiller) processStringSlice(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, override bool, aliases string) { +func (f *FlagSetFiller) processStringSlice(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, override bool, aliases string) { casted, ok := fieldRef.(*[]string) if !ok { _ = f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { return parseStringSlice(s, f.options.valueSplitPattern), nil }, hasDefaultTag, @@ -365,12 +423,12 @@ func (f *FlagSetFiller) processStringSlice(fieldRef interface{}, hasDefaultTag b } } -func (f *FlagSetFiller) processUint(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processUint(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*uint) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.Atoi(s) return value, err }, @@ -402,12 +460,12 @@ func (f *FlagSetFiller) processUint(fieldRef interface{}, hasDefaultTag bool, ta return err } -func (f *FlagSetFiller) processUint64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processUint64(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*uint64) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.ParseUint(s, 10, 64) return value, err }, @@ -437,12 +495,12 @@ func (f *FlagSetFiller) processUint64(fieldRef interface{}, hasDefaultTag bool, return err } -func (f *FlagSetFiller) processInt(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processInt(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*int) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.Atoi(s) return value, err }, @@ -472,12 +530,12 @@ func (f *FlagSetFiller) processInt(fieldRef interface{}, hasDefaultTag bool, tag return err } -func (f *FlagSetFiller) processInt64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processInt64(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*int64) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.ParseInt(s, 10, 64) return value, err }, @@ -507,12 +565,12 @@ func (f *FlagSetFiller) processInt64(fieldRef interface{}, hasDefaultTag bool, t return nil } -func (f *FlagSetFiller) processDuration(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processDuration(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*time.Duration) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := time.ParseDuration(s) return value, err }, @@ -542,12 +600,12 @@ func (f *FlagSetFiller) processDuration(fieldRef interface{}, hasDefaultTag bool return nil } -func (f *FlagSetFiller) processFloat64(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processFloat64(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*float64) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.ParseFloat(s, 64) return value, err }, @@ -577,12 +635,12 @@ func (f *FlagSetFiller) processFloat64(fieldRef interface{}, hasDefaultTag bool, return nil } -func (f *FlagSetFiller) processBool(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { +func (f *FlagSetFiller) processBool(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) { casted, ok := fieldRef.(*bool) if !ok { return f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { value, err := strconv.ParseBool(s) return value, err }, @@ -612,12 +670,12 @@ func (f *FlagSetFiller) processBool(fieldRef interface{}, hasDefaultTag bool, ta return nil } -func (f *FlagSetFiller) processString(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) { +func (f *FlagSetFiller) processString(fieldRef any, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) { casted, ok := fieldRef.(*string) if !ok { _ = f.processCustom( fieldRef, - func(s string) (interface{}, error) { + func(s string) (any, error) { return s, nil }, hasDefaultTag, @@ -643,7 +701,7 @@ func (f *FlagSetFiller) processString(fieldRef interface{}, hasDefaultTag bool, } } -func (f *FlagSetFiller) processCustom(fieldRef interface{}, converter func(string) (interface{}, error), hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) error { +func (f *FlagSetFiller) processCustom(fieldRef any, converter func(string) (any, error), hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) error { if hasDefaultTag { value, err := converter(tagDefault) if err != nil { diff --git a/flagset_test.go b/flagset_test.go index 2d79b8c..c5fcbef 100644 --- a/flagset_test.go +++ b/flagset_test.go @@ -291,9 +291,11 @@ func TestNestedStructPtr(t *testing.T) { err := filler.Fill(&flagset, &config) require.NoError(t, err) - err = flagset.Parse([]string{"--host", "h1", + err = flagset.Parse([]string{ + "--host", "h1", "--some-grouping-some-field", "val1", - "--inner-deeper-some-field", "val2"}) + "--inner-deeper-some-field", "val2", + }) require.NoError(t, err) require.NoError(t, err) @@ -414,7 +416,7 @@ func TestDefaultsViaLiteral(t *testing.T) { Nested *Nested } - var config = Config{ + config := Config{ Host: "h1", Enabled: true, Timeout: 5 * time.Second, @@ -539,7 +541,6 @@ func TestBadFieldErrorMessage(t *testing.T) { err := filler.Fill(&flagset, &config) require.Error(t, err) assert.Equal(t, "failed to process Enabled of flagsfiller_test.BadBoolConfig: failed to parse default into bool: strconv.ParseBool: parsing \"wrong\": invalid syntax", err.Error()) - } func TestHiddenFields(t *testing.T) { @@ -660,7 +661,8 @@ func TestStringToStringMap(t *testing.T) { \(default (veggie=carrot,fruit=apple|fruit=apple,veggie=carrot)\) `, buf.String()) - err = flagset.Parse([]string{"--no-default", + err = flagset.Parse([]string{ + "--no-default", "k1=v1", "--no-default", "k2=v2,k3=v3\nk4=v4\n", @@ -884,7 +886,6 @@ func TestFlagNameOverride(t *testing.T) { -server_address string address of server `, buf.String()) - } func TestFlatten(t *testing.T) { @@ -905,8 +906,10 @@ func TestFlatten(t *testing.T) { err := filler.Fill(&flagset, &config) require.NoError(t, err) - err = flagset.Parse([]string{"--flattened-field", "val1", - "--ptr-flattened-field", "val2"}) + err = flagset.Parse([]string{ + "--flattened-field", "val1", + "--ptr-flattened-field", "val2", + }) require.NoError(t, err) require.NoError(t, err) @@ -986,3 +989,370 @@ func TestTypeNamesWithPFlag(t *testing.T) { --timeout duration (default 0s) `, buf.String()) } + +func TestRequiredFieldSet(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + Optional string `usage:"optional field"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{"--host", "example.com"}) + require.NoError(t, err) + + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, "example.com", config.Host) +} + +func TestRequiredFieldNotSet(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + Optional string `usage:"optional field"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{"--optional", "value"}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "required flags not set") + assert.Contains(t, err.Error(), "host") +} + +func TestRequiredFieldWithDefault(t *testing.T) { + type Config struct { + Host string `required:"true" default:"localhost" usage:"the host"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cannot be both required and have a default") +} + +func TestMultipleRequiredFields(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + Port int `required:"true" usage:"the port"` + Optional string `usage:"optional field"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + // Set only host, not port + err = flagset.Parse([]string{"--host", "example.com"}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "required flags not set") + assert.Contains(t, err.Error(), "port") + assert.NotContains(t, err.Error(), "host") +} + +func TestMultipleRequiredFieldsAllMissing(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + Port int `required:"true" usage:"the port"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "required flags not set") + assert.Contains(t, err.Error(), "host") + assert.Contains(t, err.Error(), "port") +} + +func TestRequiredNestedField(t *testing.T) { + type Config struct { + Database struct { + Host string `required:"true" usage:"database host"` + Port int `usage:"database port"` + } + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{"--database-port", "5432"}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "required flags not set") + assert.Contains(t, err.Error(), "database-host") +} + +func TestRequiredFieldWithEnv(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + } + + var config Config + + assert.NoError(t, os.Setenv("APP_HOST", "env-host")) + defer os.Unsetenv("APP_HOST") + + filler := flagsfiller.New(flagsfiller.WithEnv("App")) + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + // Don't pass --host flag, should be set from env + err = flagset.Parse([]string{}) + require.NoError(t, err) + + // Should pass because env var set the value + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, "env-host", config.Host) +} + +func TestRequiredBoolField(t *testing.T) { + type Config struct { + EnableFeature bool `required:"true" usage:"enable the feature"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + // Bool flag not set - should be false (zero value) + err = flagset.Parse([]string{}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "enable-feature") + + // Now set it + config.EnableFeature = false // reset + err = flagset.Parse([]string{"--enable-feature"}) + require.NoError(t, err) + + err = filler.Verify() + assert.NoError(t, err) + assert.True(t, config.EnableFeature) +} + +func TestRequiredDurationField(t *testing.T) { + type Config struct { + Timeout time.Duration `required:"true" usage:"timeout duration"` + } + + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + // Duration not set - should be zero + err = flagset.Parse([]string{}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "timeout") + + // Now set it + config.Timeout = 0 // reset + err = flagset.Parse([]string{"--timeout", "5s"}) + require.NoError(t, err) + + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, 5*time.Second, config.Timeout) +} + +func TestRequiredFieldWithInstanceDefault(t *testing.T) { + type Config struct { + Host string `required:"true" usage:"the host"` + Port int `usage:"the port"` + } + + // Initialize with non-zero value + config := Config{ + Host: "preset-host", + Port: 8080, + } + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + // Parse without providing --host flag + err = flagset.Parse([]string{"--port", "9000"}) + require.NoError(t, err) + + // Should pass because Host has instance default (non-zero value) + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, "preset-host", config.Host) + assert.Equal(t, 9000, config.Port) +} + +func TestRequiredSliceAndMapFields(t *testing.T) { + type Config struct { + Hosts []string `required:"true" usage:"list of hosts"` + Settings map[string]string `required:"true" usage:"settings map"` + } + + t.Run("not set", func(t *testing.T) { + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{}) + require.NoError(t, err) + + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "hosts") + assert.Contains(t, err.Error(), "settings") + }) + + t.Run("set", func(t *testing.T) { + var config Config + + filler := flagsfiller.New() + + var flagset flag.FlagSet + err := filler.Fill(&flagset, &config) + require.NoError(t, err) + + err = flagset.Parse([]string{"--hosts", "host1,host2", "--settings", "key1=val1"}) + require.NoError(t, err) + + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, []string{"host1", "host2"}, config.Hosts) + assert.Equal(t, map[string]string{"key1": "val1"}, config.Settings) + }) +} + +func TestMultipleFillCallsWithDifferentStructs(t *testing.T) { + type Config1 struct { + Host string `required:"true" usage:"the host"` + } + type Config2 struct { + Port int `required:"true" usage:"the port"` + } + + var config1 Config1 + var config2 Config2 + + filler := flagsfiller.New() + + // Fill with first struct + var flagset1 flag.FlagSet + err := filler.Fill(&flagset1, &config1) + require.NoError(t, err) + + // Fill with second struct (should not interfere with first) + var flagset2 flag.FlagSet + err = filler.Fill(&flagset2, &config2) + require.NoError(t, err) + + // Parse both + err = flagset1.Parse([]string{"--host", "example.com"}) + require.NoError(t, err) + + err = flagset2.Parse([]string{"--port", "8080"}) + require.NoError(t, err) + + // Verify should check all required fields from both Fill calls + err = filler.Verify() + assert.NoError(t, err) + assert.Equal(t, "example.com", config1.Host) + assert.Equal(t, 8080, config2.Port) +} + +func TestMultipleFillCallsWithMissingRequired(t *testing.T) { + type Config1 struct { + Host string `required:"true" usage:"the host"` + } + type Config2 struct { + Port int `required:"true" usage:"the port"` + } + + var config1 Config1 + var config2 Config2 + + filler := flagsfiller.New() + + var flagset1 flag.FlagSet + err := filler.Fill(&flagset1, &config1) + require.NoError(t, err) + + var flagset2 flag.FlagSet + err = filler.Fill(&flagset2, &config2) + require.NoError(t, err) + + // Parse only config1, not config2 + err = flagset1.Parse([]string{"--host", "example.com"}) + require.NoError(t, err) + + err = flagset2.Parse([]string{}) + require.NoError(t, err) + + // Verify should fail because config2.Port is missing + err = filler.Verify() + assert.Error(t, err) + assert.Contains(t, err.Error(), "port") + assert.NotContains(t, err.Error(), "host") +} From 48e8b207fec90203e8e2fb14e3e1170691ccc9fb Mon Sep 17 00:00:00 2001 From: Valient Gough Date: Tue, 14 Oct 2025 11:35:32 -0700 Subject: [PATCH 2/2] show field name for required and default check --- flagset.go | 2 +- flagset_test.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/flagset.go b/flagset.go index 0831804..b3145e3 100644 --- a/flagset.go +++ b/flagset.go @@ -262,7 +262,7 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef any, // Check for required tag and validate it doesn't conflict with default _, hasRequiredTag := tag.Lookup(TagRequired) if hasRequiredTag && hasDefaultTag { - return fmt.Errorf("field cannot be both required and have a default value") + return fmt.Errorf("field %q cannot be both required and have a default value", name) } var renamed string diff --git a/flagset_test.go b/flagset_test.go index c5fcbef..c6c3249 100644 --- a/flagset_test.go +++ b/flagset_test.go @@ -1048,6 +1048,7 @@ func TestRequiredFieldWithDefault(t *testing.T) { err := filler.Fill(&flagset, &config) assert.Error(t, err) assert.Contains(t, err.Error(), "cannot be both required and have a default") + assert.Contains(t, err.Error(), "\"Host\"") } func TestMultipleRequiredFields(t *testing.T) {