diff --git a/config/config.go b/config/config.go index 33d7b4e..d286ad0 100644 --- a/config/config.go +++ b/config/config.go @@ -2,8 +2,14 @@ package config import ( "context" + "fmt" + "reflect" + "strings" + "time" "github.com/openmfp/golang-commons/context/keys" + "github.com/spf13/pflag" + "github.com/spf13/viper" ) func SetConfigInContext(ctx context.Context, config any) context.Context { @@ -13,3 +19,119 @@ func SetConfigInContext(ctx context.Context, config any) context.Context { func LoadConfigFromContext(ctx context.Context) any { return ctx.Value(keys.ConfigCtxKey) } + +type CommonServiceConfig struct { + DebugLabelValue string `mapstructure:"debug-label-value"` + MaxConcurrentReconciles int `mapstructure:"max-concurrent-reconciles"` + Environment string `mapstructure:"environment"` + Region string `mapstructure:"region"` + Kubeconfig string `mapstructure:"kubeconfig"` + Image struct { + Name string `mapstructure:"image-name"` + Tag string `mapstructure:"image-tag"` + } `mapstructure:",squash"` + Log struct { + Level string `mapstructure:"log-level"` + + NoJson bool `mapstructure:"no-json"` + } `mapstructure:",squash"` + ShutdownTimeout time.Duration `mapstructure:"shutdown-timeout"` + Probes struct { + BindAddress string `mapstructure:"probes-bind-address"` + } `mapstructure:",squash"` + LeaderElection struct { + Enabled bool `mapstructure:"leader-election-enabled"` + } `mapstructure:",squash"` + Sentry struct { + Dsn string `mapstructure:"sentry-dsn"` + } `mapstructure:",squash"` +} + +func CommonFlags() *pflag.FlagSet { + flagSet := pflag.NewFlagSet("common", pflag.ContinueOnError) + + flagSet.String("debug-label-value", "", "Debug label value") + flagSet.Int("max-concurrent-reconciles", 1, "Max concurrent reconciles") + flagSet.String("environment", "local", "Environment") + flagSet.String("region", "local", "Region") + flagSet.String("image-name", "", "Image name") + flagSet.String("image-tag", "latest", "Image tag") + flagSet.String("log-level", "info", "Log level") + flagSet.Bool("log-no-json", false, "Log in JSON format") + flagSet.Duration("shutdown-timeout", 1, "Shutdown timeout") + flagSet.String("probes-bind-address", ":8090", "Probes bind address") + flagSet.Bool("leader-election-enabled", false, "Enable leader election") + flagSet.String("sentry-dsn", "", "Sentry DSN") + + return flagSet +} + +// generateFlagSet generates a pflag.FlagSet from a struct based on its `mapstructure` tags. +func generateFlagSet(config any) *pflag.FlagSet { + flagSet := pflag.NewFlagSet("generated", pflag.ContinueOnError) + traverseStruct(reflect.ValueOf(config), flagSet, "") + return flagSet +} + +// traverseStruct recursively traverses a struct and adds flags to the FlagSet. +func traverseStruct(value reflect.Value, flagSet *pflag.FlagSet, prefix string) { + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + + if value.Kind() != reflect.Struct { + return + } + + typ := value.Type() + for i := 0; i < value.NumField(); i++ { + field := typ.Field(i) + fieldValue := value.Field(i) + + // Get the `mapstructure` tag + tag := field.Tag.Get("mapstructure") + if tag == "" || tag == "-" { + continue + } + + // Handle nested structs + if fieldValue.Kind() == reflect.Struct { + if tag == ",squash" { + traverseStruct(fieldValue, flagSet, "") + } else { + traverseStruct(fieldValue, flagSet, prefix+tag+".") + } + continue + } + + // Add flags based on the field type + switch fieldValue.Kind() { + case reflect.String: + flagSet.String(prefix+tag, "", fmt.Sprintf("Set the %s", tag)) + case reflect.Int, reflect.Int64: + if fieldValue.Type() == reflect.TypeOf(time.Duration(0)) { + flagSet.Duration(prefix+tag, 0, fmt.Sprintf("Set the %s", tag)) + } else { + flagSet.Int(prefix+tag, 0, fmt.Sprintf("Set the %s", tag)) + } + case reflect.Bool: + flagSet.Bool(prefix+tag, false, fmt.Sprintf("Set the %s", tag)) + } + } +} + +func NewConfigFor(serviceConfig any) (*viper.Viper, error) { + v := viper.NewWithOptions( + viper.EnvKeyReplacer(strings.NewReplacer("-", "_")), + ) + + v.AutomaticEnv() + + err := v.BindPFlags(CommonFlags()) + if err != nil { + return nil, err + } + err = v.BindPFlags(generateFlagSet(serviceConfig)) + + return v, err +} diff --git a/config/config_test.go b/config/config_test.go index c8b6248..c3bd84f 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,17 +1,44 @@ -package config +package config_test import ( "context" "testing" + "time" + "github.com/openmfp/golang-commons/config" "github.com/stretchr/testify/assert" ) func TestSetConfigInContext(t *testing.T) { ctx := context.Background() - config := "test" - ctx = SetConfigInContext(ctx, config) + configStr := "test" + ctx = config.SetConfigInContext(ctx, configStr) + + retrievedConfig := config.LoadConfigFromContext(ctx) + assert.Equal(t, configStr, retrievedConfig) +} + +func TestNewConfigFor(t *testing.T) { + + type test struct { + config.CommonServiceConfig + CustomFlag string `mapstructure:"custom-flag"` + CustomFlagInt int `mapstructure:"custom-flag-int"` + CustomFlagBool bool `mapstructure:"custom-flag-bool"` + CustomFlagStruct struct { + CustomFlagDuration time.Duration `mapstructure:"custom-flag-duration"` + } `mapstructure:",squash"` + CustomFlagStruct2 struct { + CustomFlagDuration time.Duration `mapstructure:"custom-flag-duration-2"` + } `mapstructure:"le-flag"` + } + + testStruct := test{} + + v, err := config.NewConfigFor(&testStruct) + assert.NoError(t, err) + + err = v.Unmarshal(&testStruct) + assert.NoError(t, err) - retrievedConfig := LoadConfigFromContext(ctx) - assert.Equal(t, config, retrievedConfig) } diff --git a/go.mod b/go.mod index 2a07dc9..f8bd828 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( github.com/openfga/openfga v1.7.0 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.34.0 + github.com/spf13/pflag v1.0.5 + github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.10.0 github.com/vektah/gqlparser/v2 v2.5.23 go.opentelemetry.io/otel v1.35.0 @@ -89,8 +91,6 @@ require ( github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - github.com/spf13/viper v1.19.0 // indirect github.com/stoewer/go-strcase v1.3.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect