Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
feat: support converting a structure defining the configuration into …
Browse files Browse the repository at this point in the history
…flags
  • Loading branch information
mjpitz committed Sep 6, 2021
1 parent fbd32bc commit 0e0bab5
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 0 deletions.
168 changes: 168 additions & 0 deletions internal/flagset/extract.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package flagset

import (
"fmt"
"reflect"
"strconv"
"strings"
"time"

"github.com/urfave/cli/v2"
)

type ref struct {
t reflect.Type
v reflect.Value
}

func (r *ref) Set(value string) error {
if r.v.CanInterface() {
switch v := r.v.Interface().(type) {
case time.Duration:
duration, err := time.ParseDuration(value)
if err != nil {
return err
}
r.v.Set(reflect.ValueOf(duration))
return nil
case cli.Generic:
return v.Set(value)
}
}

switch r.t.Kind() {
case reflect.Bool:
v, err := strconv.ParseBool(value)
if err != nil {
return err
}
r.v.SetBool(v)

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
r.v.SetInt(v)

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
v, err := strconv.ParseUint(value, 10, 64)
if err != nil {
return err
}
r.v.SetUint(v)

case reflect.Float32, reflect.Float64:
v, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
r.v.SetFloat(v)

case reflect.String:
r.v.SetString(value)

default:
return fmt.Errorf("unsupported kind: %s", r.t.Kind().String())

}

return nil
}

func (r *ref) String() string {
if r.v.CanInterface() {
switch v := r.v.Interface().(type) {
case time.Duration:
return v.String()
case cli.Generic:
return v.String()
}
}

switch r.t.Kind() {
case reflect.Bool:
return strconv.FormatBool(r.v.Bool())

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(r.v.Int(), 10)

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return strconv.FormatUint(r.v.Uint(), 10)

case reflect.Float32, reflect.Float64:
return strconv.FormatFloat(r.v.Float(), 'f', 7, 64)

}

return r.v.String()
}

var _ cli.Generic = &ref{}

func format(prefix []string, name string) string {
envVar := name
for i := 1; i <= len(prefix); i++ {
envVar = prefix[len(prefix)-i] + "_" + envVar
}
return envVar
}

func extract(prefix []string, v interface{}) []cli.Flag {
flags := make([]cli.Flag, 0)

value := reflect.ValueOf(v)

// unbox pointers
if value.Kind() == reflect.Ptr {
value = value.Elem()
}

for i := 0; i < value.NumField(); i++ {
fieldValue := value.Field(i)
field := value.Type().Field(i)

name := strings.Split(field.Tag.Get("json"), ",")[0]
if name == "-" {
continue
}

// recursive field types
switch fieldValue.Kind() {
case reflect.Ptr, reflect.Struct:
pre := prefix
if name != "" {
pre = append(pre, name)
}

flags = append(flags, extract(pre, fieldValue.Interface())...)
continue
}

// all other data types
var aliases []string
if alias := field.Tag.Get("aliases"); alias != "" {
aliases = strings.Split(alias, ",")
}

flagName := format(prefix, name)

flags = append(flags, &cli.GenericFlag{
Name: flagName,
Aliases: aliases,
Usage: field.Tag.Get("usage"),
EnvVars: []string{strings.ToUpper(flagName)},
Value: &ref{
t: field.Type,
v: fieldValue,
},
})
}

return flags
}

// Extract parses the provided object to create a flagset.
func Extract(v interface{}) []cli.Flag {
return extract([]string{}, v)
}
127 changes: 127 additions & 0 deletions internal/flagset/extract_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package flagset_test

import (
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"

"github.com/mjpitz/aetherfs/internal/flagset"
)

type Options struct {
Endpoint string `json:"endpoint" aliases:"e" usage:"the endpoint of the server we're speaking to"`
EnableSSL bool `json:"enable_ssl" aliases:"s" usage:"enable encryption between processes"`
ValidFor time.Duration `json:"valid_for" aliases:"v" usage:"how long tokens are good for before expiring"`
Temperature int `json:"temperature" aliases:"t"`
BlockSize uint `json:"block_size"`
}

type Nested struct {
Options *Options `json:"options"`
Repeated []string `json:"repeated"`
}

func TestExtract(t *testing.T) {
opts := &Options{
Endpoint: "default-endpoint",
ValidFor: time.Minute,
}
flags := flagset.Extract(opts)

require.Len(t, flags, 5)

{
flag := flags[0].(*cli.GenericFlag)
require.Equal(t, "endpoint", flag.Name)
require.Equal(t, "e", flag.Aliases[0])
require.Equal(t, "ENDPOINT", flag.EnvVars[0])
require.Equal(t, "the endpoint of the server we're speaking to", flag.Usage)
require.Equal(t, "default-endpoint", flag.GetValue())
}

{
flag := flags[1].(*cli.GenericFlag)
require.Equal(t, "enable_ssl", flag.Name)
require.Equal(t, "s", flag.Aliases[0])
require.Equal(t, "ENABLE_SSL", flag.EnvVars[0])
require.Equal(t, "enable encryption between processes", flag.Usage)
require.Equal(t, "false", flag.GetValue())
}

{
flag := flags[2].(*cli.GenericFlag)
require.Equal(t, "valid_for", flag.Name)
require.Equal(t, "v", flag.Aliases[0])
require.Equal(t, "VALID_FOR", flag.EnvVars[0])
require.Equal(t, "how long tokens are good for before expiring", flag.Usage)
require.Equal(t, "1m0s", flag.GetValue())
}

{
flag := flags[3].(*cli.GenericFlag)
require.Equal(t, "temperature", flag.Name)
require.Equal(t, "t", flag.Aliases[0])
require.Equal(t, "TEMPERATURE", flag.EnvVars[0])
require.Equal(t, "", flag.Usage)
require.Equal(t, "0", flag.GetValue())
}

{
flag := flags[4].(*cli.GenericFlag)
require.Equal(t, "block_size", flag.Name)
require.Equal(t, "BLOCK_SIZE", flag.EnvVars[0])
require.Equal(t, "", flag.Usage)
require.Equal(t, "0", flag.GetValue())
}
}

func TestExtract_Nested(t *testing.T) {
nested := &Nested{
Options: &Options{
Endpoint: "default-endpoint",
ValidFor: time.Minute,
},
}

flags := flagset.Extract(nested)

require.Len(t, flags, 6)

{
flag := flags[0].(*cli.GenericFlag)
require.Equal(t, "options_endpoint", flag.Name)
require.Equal(t, "OPTIONS_ENDPOINT", flag.EnvVars[0])
}

{
flag := flags[1].(*cli.GenericFlag)
require.Equal(t, "options_enable_ssl", flag.Name)
require.Equal(t, "OPTIONS_ENABLE_SSL", flag.EnvVars[0])
}

{
flag := flags[2].(*cli.GenericFlag)
require.Equal(t, "options_valid_for", flag.Name)
require.Equal(t, "OPTIONS_VALID_FOR", flag.EnvVars[0])
}

{
flag := flags[3].(*cli.GenericFlag)
require.Equal(t, "options_temperature", flag.Name)
require.Equal(t, "OPTIONS_TEMPERATURE", flag.EnvVars[0])
}

{
flag := flags[4].(*cli.GenericFlag)
require.Equal(t, "options_block_size", flag.Name)
require.Equal(t, "OPTIONS_BLOCK_SIZE", flag.EnvVars[0])
}

{
flag := flags[5].(*cli.GenericFlag)
require.Equal(t, "repeated", flag.Name)
require.Equal(t, "REPEATED", flag.EnvVars[0])
}
}

0 comments on commit 0e0bab5

Please sign in to comment.