diff --git a/enterprise/server/auth/BUILD b/enterprise/server/auth/BUILD index 97421fe2080..568f03021bf 100644 --- a/enterprise/server/auth/BUILD +++ b/enterprise/server/auth/BUILD @@ -18,7 +18,7 @@ go_library( "//server/tables", "//server/util/alert", "//server/util/capabilities", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/lru", "//server/util/random", diff --git a/enterprise/server/auth/auth.go b/enterprise/server/auth/auth.go index 953d9d4d16e..91383be7677 100644 --- a/enterprise/server/auth/auth.go +++ b/enterprise/server/auth/auth.go @@ -19,7 +19,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/tables" "github.com/buildbuddy-io/buildbuddy/server/util/alert" "github.com/buildbuddy-io/buildbuddy/server/util/capabilities" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/lru" "github.com/buildbuddy-io/buildbuddy/server/util/random" @@ -34,6 +33,7 @@ import ( "google.golang.org/grpc/peer" akpb "github.com/buildbuddy-io/buildbuddy/proto/api_key" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" burl "github.com/buildbuddy-io/buildbuddy/server/util/url" oidc "github.com/coreos/go-oidc" ) @@ -41,7 +41,7 @@ import ( var ( adminGroupID = flag.String("auth.admin_group_id", "", "ID of a group whose members can perform actions only accessible to server admins.") enableAnonymousUsage = flag.Bool("auth.enable_anonymous_usage", false, "If true, unauthenticated build uploads will still be allowed but won't be associated with your organization.") - oauthProviders = flagutil.Slice("auth.oauth_providers", []OauthProvider{}, "The list of oauth providers to use to authenticate.") + oauthProviders = flagtypes.Slice("auth.oauth_providers", []OauthProvider{}, "The list of oauth providers to use to authenticate.") jwtKey = flag.String("auth.jwt_key", "set_the_jwt_in_config", "The key to use when signing JWT tokens.") apiKeyGroupCacheTTL = flag.Duration("auth.api_key_group_cache_ttl", 5*time.Minute, "TTL for API Key to Group caching. Set to '0' to disable cache.") httpsOnlyCookies = flag.Bool("auth.https_only_cookies", false, "If true, cookies will only be set over https connections.") diff --git a/enterprise/server/backends/distributed/BUILD b/enterprise/server/backends/distributed/BUILD index 9225839f718..91ff0797401 100644 --- a/enterprise/server/backends/distributed/BUILD +++ b/enterprise/server/backends/distributed/BUILD @@ -20,7 +20,7 @@ go_library( "//server/resources", "//server/util/background", "//server/util/consistent_hash", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/peerset", "//server/util/status", diff --git a/enterprise/server/backends/distributed/distributed.go b/enterprise/server/backends/distributed/distributed.go index 1632f722685..1a835f013cf 100644 --- a/enterprise/server/backends/distributed/distributed.go +++ b/enterprise/server/backends/distributed/distributed.go @@ -20,7 +20,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/resources" "github.com/buildbuddy-io/buildbuddy/server/util/background" "github.com/buildbuddy-io/buildbuddy/server/util/consistent_hash" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/peerset" "github.com/buildbuddy-io/buildbuddy/server/util/status" @@ -29,13 +28,14 @@ import ( dcpb "github.com/buildbuddy-io/buildbuddy/proto/distributed_cache" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( listenAddr = flag.String("cache.distributed_cache.listen_addr", "", "The address to listen for local BuildBuddy distributed cache traffic on.") redisTarget = flag.String("cache.distributed_cache.redis_target", "", "A redis target for improved Caching/RBE performance. Target can be provided as either a redis connection URI or a host:port pair. URI schemas supported: redis[s]://[[USER][:PASSWORD]@][HOST][:PORT][/DATABASE] or unix://[[USER][:PASSWORD]@]SOCKET_PATH[?db=DATABASE] ** Enterprise only **") groupName = flag.String("cache.distributed_cache.group_name", "", "A unique name for this distributed cache group. ** Enterprise only **") - nodes = flagutil.Slice("cache.distributed_cache.nodes", []string{}, "The hardcoded list of peer distributed cache nodes. If this is set, redis_target will be ignored. ** Enterprise only **") + nodes = flagtypes.Slice("cache.distributed_cache.nodes", []string{}, "The hardcoded list of peer distributed cache nodes. If this is set, redis_target will be ignored. ** Enterprise only **") replicationFactor = flag.Int("cache.distributed_cache.replication_factor", 0, "How many total servers the data should be replicated to. Must be >= 1. ** Enterprise only **") clusterSize = flag.Int("cache.distributed_cache.cluster_size", 0, "The total number of nodes in this cluster. Required for health checking. ** Enterprise only **") enableLocalWrites = flag.Bool("cache.distributed_cache.enable_local_writes", false, "If enabled, shortcuts distributed writes that belong to the local shard to local cache instead of making an RPC.") diff --git a/enterprise/server/backends/memcache/BUILD b/enterprise/server/backends/memcache/BUILD index f0c2c2679a1..db4049bcd86 100644 --- a/enterprise/server/backends/memcache/BUILD +++ b/enterprise/server/backends/memcache/BUILD @@ -14,7 +14,7 @@ go_library( "//server/environment", "//server/interfaces", "//server/remote_cache/digest", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/prefix", "//server/util/status", diff --git a/enterprise/server/backends/memcache/memcache.go b/enterprise/server/backends/memcache/memcache.go index 52031bba179..338b23b3ce2 100644 --- a/enterprise/server/backends/memcache/memcache.go +++ b/enterprise/server/backends/memcache/memcache.go @@ -12,16 +12,16 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/environment" "github.com/buildbuddy-io/buildbuddy/server/interfaces" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/prefix" "github.com/buildbuddy-io/buildbuddy/server/util/status" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" "golang.org/x/sync/errgroup" ) -var memcacheTargets = flagutil.Slice("cache.memcache_targets", []string{}, "Deprecated. Use Redis Target instead.") +var memcacheTargets = flagtypes.Slice("cache.memcache_targets", []string{}, "Deprecated. Use Redis Target instead.") const ( mcCutoffSizeBytes = 134217728 - 1 // 128 MB diff --git a/enterprise/server/backends/pebble_cache/BUILD b/enterprise/server/backends/pebble_cache/BUILD index f5ac9116d3c..4fec29aea3e 100644 --- a/enterprise/server/backends/pebble_cache/BUILD +++ b/enterprise/server/backends/pebble_cache/BUILD @@ -15,7 +15,7 @@ go_library( "//server/interfaces", "//server/remote_cache/digest", "//server/util/disk", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/status", "@com_github_cespare_xxhash_v2//:xxhash", diff --git a/enterprise/server/backends/pebble_cache/pebble_cache.go b/enterprise/server/backends/pebble_cache/pebble_cache.go index c006b0122ce..e04507f96f6 100644 --- a/enterprise/server/backends/pebble_cache/pebble_cache.go +++ b/enterprise/server/backends/pebble_cache/pebble_cache.go @@ -21,7 +21,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/interfaces" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/util/disk" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/status" "github.com/cespare/xxhash/v2" @@ -32,13 +31,14 @@ import ( rfpb "github.com/buildbuddy-io/buildbuddy/proto/raft" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" cache_config "github.com/buildbuddy-io/buildbuddy/server/cache/config" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( rootDirectory = flag.String("cache.pebble.root_directory", "", "The root directory to store the database in.") blockCacheSizeBytes = flag.Int64("cache.pebble.block_cache_size_bytes", 1000*megabyte, "How much ram to give the block cache") - partitions = flagutil.Slice("cache.pebble.partitions", []disk.Partition{}, "") - partitionMappings = flagutil.Slice("cache.pebble.partition_mappings", []disk.PartitionMapping{}, "") + partitions = flagtypes.Slice("cache.pebble.partitions", []disk.Partition{}, "") + partitionMappings = flagtypes.Slice("cache.pebble.partition_mappings", []disk.PartitionMapping{}, "") ) // TODO: diff --git a/enterprise/server/backends/redis_client/BUILD b/enterprise/server/backends/redis_client/BUILD index cbb2997a1b4..d3bdae2481b 100644 --- a/enterprise/server/backends/redis_client/BUILD +++ b/enterprise/server/backends/redis_client/BUILD @@ -9,7 +9,7 @@ go_library( "//enterprise/server/remote_execution/config", "//enterprise/server/util/redisutil", "//server/environment", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/status", ], ) diff --git a/enterprise/server/backends/redis_client/redis_client.go b/enterprise/server/backends/redis_client/redis_client.go index ac28b1b6ec5..db0756385a7 100644 --- a/enterprise/server/backends/redis_client/redis_client.go +++ b/enterprise/server/backends/redis_client/redis_client.go @@ -6,15 +6,15 @@ import ( "github.com/buildbuddy-io/buildbuddy/enterprise/server/util/redisutil" "github.com/buildbuddy-io/buildbuddy/server/environment" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/status" remote_execution_config "github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/config" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( defaultRedisTarget = flag.String("app.default_redis_target", "", "A Redis target for storing remote shared state. To ease migration, the redis target from the remote execution config will be used if this value is not specified.") - defaultRedisShards = flagutil.Slice("app.default_sharded_redis.shards", []string{}, "Ordered list of Redis shard addresses.") + defaultRedisShards = flagtypes.Slice("app.default_sharded_redis.shards", []string{}, "Ordered list of Redis shard addresses.") defaultShardedRedisUsername = flag.String("app.default_sharded_redis.username", "", "Redis username") defaultShardedRedisPassword = flag.String("app.default_sharded_redis.password", "", "Redis password") @@ -22,13 +22,13 @@ var ( // TODO: We need to deprecate one of the redis targets here or distinguish them cacheRedisTargetFallback = flag.String("cache.redis_target", "", "A redis target for improved Caching/RBE performance. Target can be provided as either a redis connection URI or a host:port pair. URI schemas supported: redis[s]://[[USER][:PASSWORD]@][HOST][:PORT][/DATABASE] or unix://[[USER][:PASSWORD]@]SOCKET_PATH[?db=DATABASE] ** Enterprise only **") cacheRedisTarget = flag.String("cache.redis.redis_target", "", "A redis target for improved Caching/RBE performance. Target can be provided as either a redis connection URI or a host:port pair. URI schemas supported: redis[s]://[[USER][:PASSWORD]@][HOST][:PORT][/DATABASE] or unix://[[USER][:PASSWORD]@]SOCKET_PATH[?db=DATABASE] ** Enterprise only **") - cacheRedisShards = flagutil.Slice("cache.redis.sharded.shards", []string{}, "Ordered list of Redis shard addresses.") + cacheRedisShards = flagtypes.Slice("cache.redis.sharded.shards", []string{}, "Ordered list of Redis shard addresses.") cacheShardedRedisUsername = flag.String("cache.redis.sharded.username", "", "Redis username") cacheShardedRedisPassword = flag.String("cache.redis.sharded.password", "", "Redis password") // Remote Execution Redis remoteExecRedisTarget = flag.String("remote_execution.redis_target", "", "A Redis target for storing remote execution state. Falls back to app.default_redis_target if unspecified. Required for remote execution. To ease migration, the redis target from the cache config will be used if neither this value nor app.default_redis_target are specified.") - remoteExecRedisShards = flagutil.Slice("remote_execution.sharded_redis.shards", []string{}, "Ordered list of Redis shard addresses.") + remoteExecRedisShards = flagtypes.Slice("remote_execution.sharded_redis.shards", []string{}, "Ordered list of Redis shard addresses.") remoteExecShardedRedisUsername = flag.String("remote_execution.sharded_redis.username", "", "Redis username") remoteExecShardedRedisPassword = flag.String("remote_execution.sharded_redis.password", "", "Redis password") ) diff --git a/enterprise/server/cmd/ci_runner/BUILD b/enterprise/server/cmd/ci_runner/BUILD index 236c94c9916..425333ed6aa 100644 --- a/enterprise/server/cmd/ci_runner/BUILD +++ b/enterprise/server/cmd/ci_runner/BUILD @@ -22,7 +22,7 @@ go_library( "//server/remote_cache/digest", "//server/util/bazel", "//server/util/bazelisk", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/git", "//server/util/grpc_client", "//server/util/lockingbuffer", diff --git a/enterprise/server/cmd/ci_runner/main.go b/enterprise/server/cmd/ci_runner/main.go index 1792b96d872..185636d9808 100644 --- a/enterprise/server/cmd/ci_runner/main.go +++ b/enterprise/server/cmd/ci_runner/main.go @@ -25,7 +25,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/util/bazel" "github.com/buildbuddy-io/buildbuddy/server/util/bazelisk" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_client" "github.com/buildbuddy-io/buildbuddy/server/util/lockingbuffer" "github.com/buildbuddy-io/buildbuddy/server/util/log" @@ -42,6 +41,7 @@ import ( bespb "github.com/buildbuddy-io/buildbuddy/proto/build_event_stream" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" gitutil "github.com/buildbuddy-io/buildbuddy/server/util/git" bspb "google.golang.org/genproto/googleapis/bytestream" gstatus "google.golang.org/grpc/status" @@ -113,7 +113,7 @@ var ( invocationID = flag.String("invocation_id", "", "If set, use the specified invocation ID for the workflow action. Ignored if action_name is not set.") visibility = flag.String("visibility", "", "If set, use the specified value for VISIBILITY build metadata for the workflow invocation.") bazelSubCommand = flag.String("bazel_sub_command", "", "If set, run the bazel command specified by these args and ignore all triggering and configured actions.") - patchDigests = flagutil.Slice("patch_digest", []string{}, "Digests of patches to apply to the repo after checkout. Can be specified multiple times to apply multiple patches.") + patchDigests = flagtypes.Slice("patch_digest", []string{}, "Digests of patches to apply to the repo after checkout. Can be specified multiple times to apply multiple patches.") recordRunMetadata = flag.Bool("record_run_metadata", false, "Instead of running a target, extract metadata about it and report it in the build event stream.") shutdownAndExit = flag.Bool("shutdown_and_exit", false, "If set, runs bazel shutdown with the configured bazel_command, and exits. No other commands are run.") diff --git a/enterprise/server/cmd/executor/BUILD b/enterprise/server/cmd/executor/BUILD index d3bab28337d..990f04d2685 100644 --- a/enterprise/server/cmd/executor/BUILD +++ b/enterprise/server/cmd/executor/BUILD @@ -31,7 +31,7 @@ go_library( "//server/remote_cache/content_addressable_storage_server", "//server/resources", "//server/util/fileresolver", - "//server/util/flagutil", + "//server/util/flagutil/yaml", "//server/util/grpc_client", "//server/util/grpc_server", "//server/util/healthcheck", diff --git a/enterprise/server/cmd/executor/executor.go b/enterprise/server/cmd/executor/executor.go index 2659a151fe5..2e624086944 100644 --- a/enterprise/server/cmd/executor/executor.go +++ b/enterprise/server/cmd/executor/executor.go @@ -29,7 +29,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/remote_cache/content_addressable_storage_server" "github.com/buildbuddy-io/buildbuddy/server/resources" "github.com/buildbuddy-io/buildbuddy/server/util/fileresolver" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_client" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_server" "github.com/buildbuddy-io/buildbuddy/server/util/healthcheck" @@ -48,6 +47,7 @@ import ( remote_executor "github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/executor" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" scpb "github.com/buildbuddy-io/buildbuddy/proto/scheduler" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" bspb "google.golang.org/genproto/googleapis/bytestream" _ "google.golang.org/grpc/encoding/gzip" // imported for side effects; DO NOT REMOVE. ) @@ -200,7 +200,7 @@ func main() { rootContext := context.Background() flag.Parse() - if err := flagutil.PopulateFlagsFromFile(config.Path()); err != nil { + if err := flagyaml.PopulateFlagsFromFile(config.Path()); err != nil { log.Fatalf("Error loading config from file: %s", err) } diff --git a/enterprise/server/cmd/executor/yaml_doc/BUILD b/enterprise/server/cmd/executor/yaml_doc/BUILD index 2fd075635f4..510335c5e58 100644 --- a/enterprise/server/cmd/executor/yaml_doc/BUILD +++ b/enterprise/server/cmd/executor/yaml_doc/BUILD @@ -13,7 +13,7 @@ go_library( visibility = ["//visibility:private"], deps = [ "//enterprise/server/cmd/executor:executor_lib", - "//server/util/flagutil", + "//server/util/flagutil/yaml", ], ) diff --git a/enterprise/server/cmd/executor/yaml_doc/main.go b/enterprise/server/cmd/executor/yaml_doc/main.go index a9ee7da85d3..b78e0c3c8d7 100644 --- a/enterprise/server/cmd/executor/yaml_doc/main.go +++ b/enterprise/server/cmd/executor/yaml_doc/main.go @@ -6,7 +6,7 @@ import ( "os" _ "github.com/buildbuddy-io/buildbuddy/enterprise/server/cmd/executor" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) const flagName = "yaml_documented_defaults_out_file" @@ -14,13 +14,13 @@ const flagName = "yaml_documented_defaults_out_file" var yamlDefaultsOutFile = flag.String(flagName, "buildbuddy_executor_documented_defaults.yaml", "Path to a file to write the default YAML config (with docs) to.") func init() { - flagutil.IgnoreFlagForYAML(flagName) + flagyaml.IgnoreFlagForYAML(flagName) } func main() { flag.Parse() - b, err := flagutil.SplitDocumentedYAMLFromFlags() + b, err := flagyaml.SplitDocumentedYAMLFromFlags() if err != nil { log.Fatalf("Encountered error generating documented default YAML file: %s", err) } diff --git a/enterprise/server/cmd/server/BUILD b/enterprise/server/cmd/server/BUILD index 9fa99c8d4e2..a2fcfee73e5 100644 --- a/enterprise/server/cmd/server/BUILD +++ b/enterprise/server/cmd/server/BUILD @@ -49,6 +49,7 @@ go_library( "//server/telemetry", "//server/util/fileresolver", "//server/util/flagutil", + "//server/util/flagutil/yaml", "//server/util/healthcheck", "//server/util/log", "//server/util/tracing", diff --git a/enterprise/server/cmd/server/main.go b/enterprise/server/cmd/server/main.go index 6d1bcda2303..69c87f928ec 100644 --- a/enterprise/server/cmd/server/main.go +++ b/enterprise/server/cmd/server/main.go @@ -51,6 +51,7 @@ import ( remote_execution_redis_client "github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/redis_client" telserver "github.com/buildbuddy-io/buildbuddy/enterprise/server/telemetry" workflow "github.com/buildbuddy-io/buildbuddy/enterprise/server/workflow/service" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) var serverType = flag.String("server_type", "buildbuddy-server", "The server type to match on health checks") @@ -151,7 +152,7 @@ func main() { version.Print() flag.Parse() - if err := flagutil.PopulateFlagsFromFile(config.Path()); err != nil { + if err := flagyaml.PopulateFlagsFromFile(config.Path()); err != nil { log.Fatalf("Error loading config from file: %s", err) } healthChecker := healthcheck.NewHealthChecker(*serverType) diff --git a/enterprise/server/cmd/server/yaml_doc/BUILD b/enterprise/server/cmd/server/yaml_doc/BUILD index 6b68cf742db..5cc9fb79653 100644 --- a/enterprise/server/cmd/server/yaml_doc/BUILD +++ b/enterprise/server/cmd/server/yaml_doc/BUILD @@ -13,7 +13,7 @@ go_library( visibility = ["//visibility:private"], deps = [ "//enterprise/server/cmd/server:server_lib", - "//server/util/flagutil", + "//server/util/flagutil/yaml", ], ) diff --git a/enterprise/server/cmd/server/yaml_doc/main.go b/enterprise/server/cmd/server/yaml_doc/main.go index e3e3b4ea109..c8f8e46ef43 100644 --- a/enterprise/server/cmd/server/yaml_doc/main.go +++ b/enterprise/server/cmd/server/yaml_doc/main.go @@ -6,7 +6,7 @@ import ( "os" _ "github.com/buildbuddy-io/buildbuddy/enterprise/server/cmd/server" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) const flagName = "yaml_documented_defaults_out_file" @@ -14,13 +14,13 @@ const flagName = "yaml_documented_defaults_out_file" var yamlDefaultsOutFile = flag.String(flagName, "buildbuddy_enterprise_server_documented_defaults.yaml", "Path to a file to write the default YAML config (with docs) to.") func init() { - flagutil.IgnoreFlagForYAML(flagName) + flagyaml.IgnoreFlagForYAML(flagName) } func main() { flag.Parse() - b, err := flagutil.SplitDocumentedYAMLFromFlags() + b, err := flagyaml.SplitDocumentedYAMLFromFlags() if err != nil { log.Fatalf("Encountered error generating documented default YAML file: %s", err) } diff --git a/enterprise/server/raft/cache/BUILD b/enterprise/server/raft/cache/BUILD index 702a5ed8b00..31428f23403 100644 --- a/enterprise/server/raft/cache/BUILD +++ b/enterprise/server/raft/cache/BUILD @@ -25,7 +25,7 @@ go_library( "//server/interfaces", "//server/remote_cache/digest", "//server/util/disk", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/network", "//server/util/status", diff --git a/enterprise/server/raft/cache/cache.go b/enterprise/server/raft/cache/cache.go index 392c1ed1ca9..47dd3c3c1d4 100644 --- a/enterprise/server/raft/cache/cache.go +++ b/enterprise/server/raft/cache/cache.go @@ -26,7 +26,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/interfaces" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/util/disk" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/network" "github.com/buildbuddy-io/buildbuddy/server/util/status" @@ -37,13 +36,14 @@ import ( rfpb "github.com/buildbuddy-io/buildbuddy/proto/raft" rfspb "github.com/buildbuddy-io/buildbuddy/proto/raft_service" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" dbConfig "github.com/lni/dragonboat/v3/config" ) var ( rootDirectory = flag.String("cache.raft.root_directory", "", "The root directory to use for storing cached data.") listenAddr = flag.String("cache.raft.listen_addr", "", "The address to listen for local gossip traffic on. Ex. 'localhost:1991") - join = flagutil.Slice("cache.raft.join", []string{}, "The list of nodes to use when joining clusters Ex. '1.2.3.4:1991,2.3.4.5:1991...'") + join = flagtypes.Slice("cache.raft.join", []string{}, "The list of nodes to use when joining clusters Ex. '1.2.3.4:1991,2.3.4.5:1991...'") httpPort = flag.Int("cache.raft.http_port", 0, "The address to listen for HTTP raft traffic. Ex. '1992'") gRPCPort = flag.Int("cache.raft.grpc_port", 0, "The address to listen for internal API traffic on. Ex. '1993'") ) diff --git a/enterprise/server/registry/BUILD b/enterprise/server/registry/BUILD index 922ac9fe5f4..0ac14200804 100644 --- a/enterprise/server/registry/BUILD +++ b/enterprise/server/registry/BUILD @@ -19,7 +19,7 @@ go_library( "//server/remote_cache/cachetools", "//server/remote_cache/digest", "//server/ssl", - "//server/util/flagutil", + "//server/util/flagutil/yaml", "//server/util/grpc_client", "//server/util/grpc_server", "//server/util/healthcheck", diff --git a/enterprise/server/registry/registry.go b/enterprise/server/registry/registry.go index 33bcfa9e368..d61c23a302a 100644 --- a/enterprise/server/registry/registry.go +++ b/enterprise/server/registry/registry.go @@ -26,7 +26,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/remote_cache/cachetools" "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/ssl" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_client" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_server" "github.com/buildbuddy-io/buildbuddy/server/util/healthcheck" @@ -50,6 +49,7 @@ import ( regpb "github.com/buildbuddy-io/buildbuddy/proto/registry" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ctrname "github.com/google/go-containerregistry/pkg/name" bspb "google.golang.org/genproto/googleapis/bytestream" ) @@ -916,7 +916,7 @@ func (r *registry) Start(ctx context.Context, hc interfaces.HealthChecker, env e func main() { flag.Parse() - if err := flagutil.PopulateFlagsFromFile(config.Path()); err != nil { + if err := flagyaml.PopulateFlagsFromFile(config.Path()); err != nil { log.Fatalf("Error loading config from file: %s", err) } diff --git a/enterprise/server/remote_execution/container/BUILD b/enterprise/server/remote_execution/container/BUILD index bc1e02c0c14..40be9c6d3d2 100644 --- a/enterprise/server/remote_execution/container/BUILD +++ b/enterprise/server/remote_execution/container/BUILD @@ -10,7 +10,7 @@ go_library( "//proto:remote_execution_go_proto", "//server/environment", "//server/interfaces", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/perms", "//server/util/status", diff --git a/enterprise/server/remote_execution/container/container.go b/enterprise/server/remote_execution/container/container.go index 183a8802cda..2c19c3889d9 100644 --- a/enterprise/server/remote_execution/container/container.go +++ b/enterprise/server/remote_execution/container/container.go @@ -11,7 +11,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/enterprise/server/remote_execution/platform" "github.com/buildbuddy-io/buildbuddy/server/environment" "github.com/buildbuddy-io/buildbuddy/server/interfaces" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/perms" "github.com/buildbuddy-io/buildbuddy/server/util/status" @@ -21,6 +20,7 @@ import ( "go.opentelemetry.io/otel/trace" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) const ( @@ -30,7 +30,7 @@ const ( ) var ( - containerRegistries = flagutil.Slice("executor.container_registries", []ContainerRegistry{}, "") + containerRegistries = flagtypes.Slice("executor.container_registries", []ContainerRegistry{}, "") debugUseLocalImagesOnly = flag.Bool("debug_use_local_images_only", false, "Do not pull OCI images and only used locally cached images. This can be set to test local image builds during development without needing to push to a container registry. Not intended for production use.") ) diff --git a/enterprise/server/remote_execution/platform/BUILD b/enterprise/server/remote_execution/platform/BUILD index ea93d25abb4..081214f216a 100644 --- a/enterprise/server/remote_execution/platform/BUILD +++ b/enterprise/server/remote_execution/platform/BUILD @@ -8,7 +8,7 @@ go_library( deps = [ "//proto:remote_execution_go_proto", "//server/environment", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/status", "@org_golang_google_grpc//metadata", diff --git a/enterprise/server/remote_execution/platform/platform.go b/enterprise/server/remote_execution/platform/platform.go index 0808c43ee66..e2ca6962ef4 100644 --- a/enterprise/server/remote_execution/platform/platform.go +++ b/enterprise/server/remote_execution/platform/platform.go @@ -9,12 +9,12 @@ import ( "strings" "github.com/buildbuddy-io/buildbuddy/server/environment" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/status" "google.golang.org/grpc/metadata" repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( @@ -27,7 +27,7 @@ var ( enableFirecracker = flag.Bool("executor.enable_firecracker", false, "Enables running execution commands inside of firecracker VMs") defaultImage = flag.String("executor.default_image", "gcr.io/flame-public/executor-docker-default:enterprise-v1.6.0", "The default docker image to use to warm up executors or if no platform property is set. Ex: gcr.io/flame-public/executor-docker-default:enterprise-v1.5.4") enableVFS = flag.Bool("executor.enable_vfs", false, "Whether FUSE based filesystem is enabled.") - extraEnvVars = flagutil.Slice("executor.extra_env_vars", []string{}, "Additional environment variables to pass to remotely executed actions. i.e. MY_ENV_VAR=foo") + extraEnvVars = flagtypes.Slice("executor.extra_env_vars", []string{}, "Additional environment variables to pass to remotely executed actions. i.e. MY_ENV_VAR=foo") ) const ( diff --git a/enterprise/server/remote_execution/runner/BUILD b/enterprise/server/remote_execution/runner/BUILD index 60b543e92b8..60a9777276c 100644 --- a/enterprise/server/remote_execution/runner/BUILD +++ b/enterprise/server/remote_execution/runner/BUILD @@ -34,7 +34,7 @@ go_library( "//server/resources", "//server/util/alert", "//server/util/background", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/lockingbuffer", "//server/util/log", "//server/util/perms", diff --git a/enterprise/server/remote_execution/runner/runner.go b/enterprise/server/remote_execution/runner/runner.go index 89851339e18..c50d55a7151 100644 --- a/enterprise/server/remote_execution/runner/runner.go +++ b/enterprise/server/remote_execution/runner/runner.go @@ -39,7 +39,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/resources" "github.com/buildbuddy-io/buildbuddy/server/util/alert" "github.com/buildbuddy-io/buildbuddy/server/util/background" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/lockingbuffer" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/perms" @@ -57,6 +56,7 @@ import ( uidpb "github.com/buildbuddy-io/buildbuddy/proto/user_id" vfspb "github.com/buildbuddy-io/buildbuddy/proto/vfs" wkpb "github.com/buildbuddy-io/buildbuddy/proto/worker" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" dockerclient "github.com/docker/docker/client" ) @@ -67,8 +67,8 @@ var ( dockerNetHost = flag.Bool("executor.docker_net_host", false, "Sets --net=host on the docker command. Intended for local development only.") dockerCapAdd = flag.String("docker_cap_add", "", "Sets --cap-add= on the docker command. Comma separated.") dockerSiblingContainers = flag.Bool("executor.docker_sibling_containers", false, "If set, mount the configured Docker socket to containers spawned for each action, to enable Docker-out-of-Docker (DooD). Takes effect only if docker_socket is also set. Should not be set by executors that can run untrusted code.") - dockerDevices = flagutil.Slice("executor.docker_devices", []container.DockerDeviceMapping{}, `Configure (docker) devices that will be available inside the sandbox container. Format is --executor.docker_devices='[{"PathOnHost":"/dev/foo","PathInContainer":"/some/dest","CgroupPermissions":"see,docker,docs"}]'`) - dockerVolumes = flagutil.Slice("executor.docker_volumes", []string{}, "Additional --volume arguments to be passed to docker or podman.") + dockerDevices = flagtypes.Slice("executor.docker_devices", []container.DockerDeviceMapping{}, `Configure (docker) devices that will be available inside the sandbox container. Format is --executor.docker_devices='[{"PathOnHost":"/dev/foo","PathInContainer":"/some/dest","CgroupPermissions":"see,docker,docs"}]'`) + dockerVolumes = flagtypes.Slice("executor.docker_volumes", []string{}, "Additional --volume arguments to be passed to docker or podman.") dockerInheritUserIDs = flag.Bool("executor.docker_inherit_user_ids", false, "If set, run docker containers using the same uid and gid as the user running the executor process.") podmanRuntime = flag.String("podman_runtime", "", "Enables running podman with other runtimes, like gVisor (runsc).") warmupTimeoutSecs = flag.Int64("executor.warmup_timeout_secs", 120, "The default time (in seconds) to wait for an executor to warm up i.e. download the default docker image. Default is 120s") diff --git a/server/backends/disk_cache/BUILD b/server/backends/disk_cache/BUILD index d6f5f4b80dc..1af495978e1 100644 --- a/server/backends/disk_cache/BUILD +++ b/server/backends/disk_cache/BUILD @@ -14,7 +14,7 @@ go_library( "//server/remote_cache/digest", "//server/util/alert", "//server/util/disk", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/lru", "//server/util/prefix", diff --git a/server/backends/disk_cache/disk_cache.go b/server/backends/disk_cache/disk_cache.go index 13fe6af61ac..33b345ce130 100644 --- a/server/backends/disk_cache/disk_cache.go +++ b/server/backends/disk_cache/disk_cache.go @@ -22,7 +22,6 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest" "github.com/buildbuddy-io/buildbuddy/server/util/alert" "github.com/buildbuddy-io/buildbuddy/server/util/disk" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/lru" "github.com/buildbuddy-io/buildbuddy/server/util/prefix" @@ -33,6 +32,7 @@ import ( repb "github.com/buildbuddy-io/buildbuddy/proto/remote_execution" cache_config "github.com/buildbuddy-io/buildbuddy/server/cache/config" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) const ( @@ -52,8 +52,8 @@ const ( var ( rootDirectory = flag.String("cache.disk.root_directory", "/tmp/buildbuddy_cache", "The root directory to store all blobs in, if using disk based storage.") - partitions = flagutil.Slice("cache.disk.partitions", []disk.Partition{}, "") - partitionMappings = flagutil.Slice("cache.disk.partition_mappings", []disk.PartitionMapping{}, "") + partitions = flagtypes.Slice("cache.disk.partitions", []disk.Partition{}, "") + partitionMappings = flagtypes.Slice("cache.disk.partition_mappings", []disk.PartitionMapping{}, "") useV2Layout = flag.Bool("cache.disk.use_v2_layout", false, "If enabled, files will be stored using the v2 layout. See disk_cache.MigrateToV2Layout for a description.") migrateDiskCacheToV2AndExit = flag.Bool("migrate_disk_cache_to_v2_and_exit", false, "If true, attempt to migrate disk cache to v2 layout.") diff --git a/server/build_event_protocol/build_event_proxy/BUILD b/server/build_event_protocol/build_event_proxy/BUILD index 216d2572c0b..1a0d6a71ec9 100644 --- a/server/build_event_protocol/build_event_proxy/BUILD +++ b/server/build_event_protocol/build_event_proxy/BUILD @@ -8,7 +8,7 @@ go_library( deps = [ "//proto:publish_build_event_go_proto", "//server/environment", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/grpc_client", "//server/util/log", "@org_golang_google_grpc//:go_default_library", diff --git a/server/build_event_protocol/build_event_proxy/build_event_proxy.go b/server/build_event_protocol/build_event_proxy/build_event_proxy.go index dfa3e7f40a7..d85c5d0af9c 100644 --- a/server/build_event_protocol/build_event_proxy/build_event_proxy.go +++ b/server/build_event_protocol/build_event_proxy/build_event_proxy.go @@ -7,17 +7,17 @@ import ( "sync" "github.com/buildbuddy-io/buildbuddy/server/environment" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/grpc_client" "github.com/buildbuddy-io/buildbuddy/server/util/log" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" pepb "github.com/buildbuddy-io/buildbuddy/proto/publish_build_event" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( - hosts = flagutil.Slice("build_event_proxy.hosts", []string{}, "The list of hosts to pass build events onto.") + hosts = flagtypes.Slice("build_event_proxy.hosts", []string{}, "The list of hosts to pass build events onto.") bufferSize = flag.Int("build_event_proxy.buffer_size", 100, "The number of build events to buffer locally when proxying build events.") ) diff --git a/server/cmd/buildbuddy/BUILD b/server/cmd/buildbuddy/BUILD index 1bd512ae91b..92b6a1aadaf 100644 --- a/server/cmd/buildbuddy/BUILD +++ b/server/cmd/buildbuddy/BUILD @@ -53,7 +53,7 @@ go_library( "//server/janitor", "//server/libmain", "//server/telemetry", - "//server/util/flagutil", + "//server/util/flagutil/yaml", "//server/util/healthcheck", "//server/util/log", "//server/version", diff --git a/server/cmd/buildbuddy/main.go b/server/cmd/buildbuddy/main.go index e94a8021ea9..33685aa400a 100644 --- a/server/cmd/buildbuddy/main.go +++ b/server/cmd/buildbuddy/main.go @@ -7,10 +7,11 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/janitor" "github.com/buildbuddy-io/buildbuddy/server/libmain" "github.com/buildbuddy-io/buildbuddy/server/telemetry" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/healthcheck" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/version" + + flag_yaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) var ( @@ -26,7 +27,7 @@ func main() { version.Print() flag.Parse() - if err := flagutil.PopulateFlagsFromFile(config.Path()); err != nil { + if err := flag_yaml.PopulateFlagsFromFile(config.Path()); err != nil { log.Fatalf("Error loading config from file: %s", err) } healthChecker := healthcheck.NewHealthChecker(*serverType) diff --git a/server/cmd/buildbuddy/yaml_doc/BUILD b/server/cmd/buildbuddy/yaml_doc/BUILD index bd61a37cb00..b7a7adb5eaf 100644 --- a/server/cmd/buildbuddy/yaml_doc/BUILD +++ b/server/cmd/buildbuddy/yaml_doc/BUILD @@ -13,7 +13,7 @@ go_library( visibility = ["//visibility:private"], deps = [ "//server/cmd/buildbuddy:buildbuddy_lib", - "//server/util/flagutil", + "//server/util/flagutil/yaml", ], ) diff --git a/server/cmd/buildbuddy/yaml_doc/main.go b/server/cmd/buildbuddy/yaml_doc/main.go index 17c56ee0605..23dfb449ffa 100644 --- a/server/cmd/buildbuddy/yaml_doc/main.go +++ b/server/cmd/buildbuddy/yaml_doc/main.go @@ -6,7 +6,7 @@ import ( "os" _ "github.com/buildbuddy-io/buildbuddy/server/cmd/buildbuddy" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) const flagName = "yaml_documented_defaults_out_file" @@ -14,13 +14,13 @@ const flagName = "yaml_documented_defaults_out_file" var yamlDefaultsOutFile = flag.String(flagName, "buildbuddy_server_documented_defaults.yaml", "Path to a file to write the default YAML config (with docs) to.") func init() { - flagutil.IgnoreFlagForYAML(flagName) + flagyaml.IgnoreFlagForYAML(flagName) } func main() { flag.Parse() - b, err := flagutil.SplitDocumentedYAMLFromFlags() + b, err := flagyaml.SplitDocumentedYAMLFromFlags() if err != nil { log.Fatalf("Encountered error generating documented default YAML file: %s", err) } diff --git a/server/config/BUILD b/server/config/BUILD index 707ece22b34..9f4bb8a8bcb 100644 --- a/server/config/BUILD +++ b/server/config/BUILD @@ -5,5 +5,5 @@ go_library( srcs = ["config.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/config", visibility = ["//visibility:public"], - deps = ["//server/util/flagutil"], + deps = ["//server/util/flagutil/yaml"], ) diff --git a/server/config/config.go b/server/config/config.go index 08f2bc96190..a8ac5d26d72 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -3,7 +3,7 @@ package config import ( "flag" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) const pathFlagName = "config_file" @@ -13,7 +13,7 @@ var configPath = flag.String(pathFlagName, "/config.yaml", "The path to a buildb func init() { // As this flag determines the YAML file we read the config from, it can't // meaningfully be specified in the YAML config file. - flagutil.IgnoreFlagForYAML(pathFlagName) + flagyaml.IgnoreFlagForYAML(pathFlagName) } func Path() string { diff --git a/server/endpoint_urls/build_buddy_url/BUILD b/server/endpoint_urls/build_buddy_url/BUILD index f316ccc3ca3..9136f18ef75 100644 --- a/server/endpoint_urls/build_buddy_url/BUILD +++ b/server/endpoint_urls/build_buddy_url/BUILD @@ -5,5 +5,5 @@ go_library( srcs = ["build_buddy_url.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/endpoint_urls/build_buddy_url", visibility = ["//visibility:public"], - deps = ["//server/util/flagutil"], + deps = ["//server/util/flagutil/types"], ) diff --git a/server/endpoint_urls/build_buddy_url/build_buddy_url.go b/server/endpoint_urls/build_buddy_url/build_buddy_url.go index a245272a080..32c24defba6 100644 --- a/server/endpoint_urls/build_buddy_url/build_buddy_url.go +++ b/server/endpoint_urls/build_buddy_url/build_buddy_url.go @@ -3,10 +3,10 @@ package build_buddy_url import ( "net/url" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) -var buildBuddyURL = flagutil.URLFromString("app.build_buddy_url", "http://localhost:8080", "The external URL where your BuildBuddy instance can be found.") +var buildBuddyURL = flagtypes.URLFromString("app.build_buddy_url", "http://localhost:8080", "The external URL where your BuildBuddy instance can be found.") func WithPath(path string) *url.URL { return buildBuddyURL.ResolveReference(&url.URL{Path: path}) diff --git a/server/endpoint_urls/cache_api_url/BUILD b/server/endpoint_urls/cache_api_url/BUILD index 79b7ad5317b..00f995bf855 100644 --- a/server/endpoint_urls/cache_api_url/BUILD +++ b/server/endpoint_urls/cache_api_url/BUILD @@ -5,5 +5,5 @@ go_library( srcs = ["cache_api_url.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/endpoint_urls/cache_api_url", visibility = ["//visibility:public"], - deps = ["//server/util/flagutil"], + deps = ["//server/util/flagutil/types"], ) diff --git a/server/endpoint_urls/cache_api_url/cache_api_url.go b/server/endpoint_urls/cache_api_url/cache_api_url.go index a862ea6cf16..a233b9a1896 100644 --- a/server/endpoint_urls/cache_api_url/cache_api_url.go +++ b/server/endpoint_urls/cache_api_url/cache_api_url.go @@ -3,10 +3,10 @@ package cache_api_url import ( "net/url" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) -var cacheAPIURL = flagutil.URLFromString("app.cache_api_url", "", "Overrides the default remote cache protocol gRPC address shown by BuildBuddy on the configuration screen.") +var cacheAPIURL = flagtypes.URLFromString("app.cache_api_url", "", "Overrides the default remote cache protocol gRPC address shown by BuildBuddy on the configuration screen.") func WithPath(path string) *url.URL { return cacheAPIURL.ResolveReference(&url.URL{Path: path}) diff --git a/server/endpoint_urls/events_api_url/BUILD b/server/endpoint_urls/events_api_url/BUILD index 34fdaa69829..00ba705c2fb 100644 --- a/server/endpoint_urls/events_api_url/BUILD +++ b/server/endpoint_urls/events_api_url/BUILD @@ -5,5 +5,5 @@ go_library( srcs = ["events_api_url.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/endpoint_urls/events_api_url", visibility = ["//visibility:public"], - deps = ["//server/util/flagutil"], + deps = ["//server/util/flagutil/types"], ) diff --git a/server/endpoint_urls/events_api_url/events_api_url.go b/server/endpoint_urls/events_api_url/events_api_url.go index 660b96cc0da..a95bc0c0270 100644 --- a/server/endpoint_urls/events_api_url/events_api_url.go +++ b/server/endpoint_urls/events_api_url/events_api_url.go @@ -3,10 +3,10 @@ package events_api_url import ( "net/url" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) -var eventsAPIURL = flagutil.URLFromString("app.events_api_url", "", "Overrides the default build event protocol gRPC address shown by BuildBuddy on the configuration screen.") +var eventsAPIURL = flagtypes.URLFromString("app.events_api_url", "", "Overrides the default build event protocol gRPC address shown by BuildBuddy on the configuration screen.") func WithPath(path string) *url.URL { return eventsAPIURL.ResolveReference(&url.URL{Path: path}) diff --git a/server/endpoint_urls/remote_exec_api_url/BUILD b/server/endpoint_urls/remote_exec_api_url/BUILD index 8d642e88e74..b9ae226b3b8 100644 --- a/server/endpoint_urls/remote_exec_api_url/BUILD +++ b/server/endpoint_urls/remote_exec_api_url/BUILD @@ -5,5 +5,5 @@ go_library( srcs = ["remote_exec_api_url.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/endpoint_urls/remote_exec_api_url", visibility = ["//visibility:public"], - deps = ["//server/util/flagutil"], + deps = ["//server/util/flagutil/types"], ) diff --git a/server/endpoint_urls/remote_exec_api_url/remote_exec_api_url.go b/server/endpoint_urls/remote_exec_api_url/remote_exec_api_url.go index 50128f367fe..bb69ceb5078 100644 --- a/server/endpoint_urls/remote_exec_api_url/remote_exec_api_url.go +++ b/server/endpoint_urls/remote_exec_api_url/remote_exec_api_url.go @@ -3,10 +3,10 @@ package remote_exec_api_url import ( "net/url" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) -var remoteExecAPIURL = flagutil.URLFromString("app.remote_execution_api_url", "", "Overrides the default remote execution protocol gRPC address shown by BuildBuddy on the configuration screen.") +var remoteExecAPIURL = flagtypes.URLFromString("app.remote_execution_api_url", "", "Overrides the default remote execution protocol gRPC address shown by BuildBuddy on the configuration screen.") func WithPath(path string) *url.URL { return remoteExecAPIURL.ResolveReference(&url.URL{Path: path}) diff --git a/server/ssl/BUILD b/server/ssl/BUILD index 6665128feae..c748d5a4cb3 100644 --- a/server/ssl/BUILD +++ b/server/ssl/BUILD @@ -11,7 +11,7 @@ go_library( "//server/endpoint_urls/events_api_url", "//server/environment", "//server/interfaces", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/status", "@org_golang_google_grpc//credentials", "@org_golang_x_crypto//acme", diff --git a/server/ssl/ssl.go b/server/ssl/ssl.go index 81f949547fd..99953f14802 100644 --- a/server/ssl/ssl.go +++ b/server/ssl/ssl.go @@ -20,11 +20,12 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/endpoint_urls/events_api_url" "github.com/buildbuddy-io/buildbuddy/server/environment" "github.com/buildbuddy-io/buildbuddy/server/interfaces" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/status" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" "google.golang.org/grpc/credentials" + + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" ) var ( @@ -33,7 +34,7 @@ var ( selfSigned = flag.Bool("ssl.self_signed", false, "If true, a self-signed cert will be generated for TLS termination.") clientCACertFile = flag.String("ssl.client_ca_cert_file", "", "Path to a PEM encoded certificate authority file used to issue client certificates for mTLS auth.") clientCAKeyFile = flag.String("ssl.client_ca_key_file", "", "Path to a PEM encoded certificate authority key file used to issue client certificates for mTLS auth.") - hostWhitelist = flagutil.Slice("ssl.host_whitelist", []string{}, "Cloud-Only") + hostWhitelist = flagtypes.Slice("ssl.host_whitelist", []string{}, "Cloud-Only") enableSSL = flag.Bool("ssl.enable_ssl", false, "Whether or not to enable SSL/TLS on gRPC connections (gRPCS).") useACME = flag.Bool("ssl.use_acme", false, "Whether or not to automatically configure SSL certs using ACME. If ACME is enabled, cert_file and key_file should not be set.") defaultHost = flag.String("ssl.default_host", "", "Host name to use for ACME generated cert if TLS request does not contain SNI.") diff --git a/server/util/flagutil/BUILD b/server/util/flagutil/BUILD index 7614c66a039..824d80c3f7e 100644 --- a/server/util/flagutil/BUILD +++ b/server/util/flagutil/BUILD @@ -1,27 +1,9 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "flagutil", srcs = ["flagutil.go"], importpath = "github.com/buildbuddy-io/buildbuddy/server/util/flagutil", visibility = ["//visibility:public"], - deps = [ - "//server/util/alert", - "//server/util/log", - "//server/util/status", - "@in_gopkg_yaml_v3//:yaml_v3", - ], -) - -go_test( - name = "flagutil_test", - srcs = ["flagutil_test.go"], - embed = [":flagutil"], - deps = [ - "@com_github_google_go_cmp//cmp", - "@com_github_stretchr_testify//assert", - "@com_github_stretchr_testify//require", - "@in_gopkg_yaml_v3//:yaml_v3", - "@org_golang_google_protobuf//types/known/timestamppb", - ], + deps = ["//server/util/flagutil/common"], ) diff --git a/server/util/flagutil/common/BUILD b/server/util/flagutil/common/BUILD new file mode 100644 index 00000000000..5d5628b17e9 --- /dev/null +++ b/server/util/flagutil/common/BUILD @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "common", + srcs = ["common.go"], + importpath = "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common", + visibility = [ + "//server/util/flagutil:__subpackages__", + "//server/util/testing/flags:__subpackages__", + ], + deps = [ + "//server/util/status", + "@in_gopkg_yaml_v3//:yaml_v3", + ], +) + +go_test( + name = "common_test", + srcs = ["common_test.go"], + deps = [ + ":common", + "//server/util/flagutil/types", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/server/util/flagutil/common/common.go b/server/util/flagutil/common/common.go new file mode 100644 index 00000000000..d999e7c9b4f --- /dev/null +++ b/server/util/flagutil/common/common.go @@ -0,0 +1,140 @@ +package common + +import ( + "flag" + "reflect" + "time" + + "github.com/buildbuddy-io/buildbuddy/server/util/status" + "gopkg.in/yaml.v3" +) + +var ( + // Used for type conversions between flags and normal go types + flagTypeMap = map[reflect.Type]reflect.Type{ + flagTypeFromFlagFuncName("Bool"): reflect.TypeOf((*bool)(nil)), + flagTypeFromFlagFuncName("Duration"): reflect.TypeOf((*time.Duration)(nil)), + flagTypeFromFlagFuncName("Float64"): reflect.TypeOf((*float64)(nil)), + flagTypeFromFlagFuncName("Int"): reflect.TypeOf((*int)(nil)), + flagTypeFromFlagFuncName("Int64"): reflect.TypeOf((*int64)(nil)), + flagTypeFromFlagFuncName("Uint"): reflect.TypeOf((*uint)(nil)), + flagTypeFromFlagFuncName("Uint64"): reflect.TypeOf((*uint64)(nil)), + flagTypeFromFlagFuncName("String"): reflect.TypeOf((*string)(nil)), + } + + // Change only for testing purposes + DefaultFlagSet = flag.CommandLine +) + +func flagTypeFromFlagFuncName(name string) reflect.Type { + fs := flag.NewFlagSet("", flag.ContinueOnError) + ff := reflect.ValueOf(fs).MethodByName(name) + in := make([]reflect.Value, ff.Type().NumIn()) + for i := range in { + in[i] = reflect.New(ff.Type().In(i)).Elem() + } + ff.Call(in) + return reflect.TypeOf(fs.Lookup("").Value) +} + +type TypeAliased interface { + AliasedType() reflect.Type +} + +type IsNameAliasing interface { + AliasedName() string +} + +type Appendable interface { + AppendSlice(any) error +} + +type DocumentNodeOption interface { + Transform(in any, n *yaml.Node) + Passthrough() bool +} + +// GetTypeForFlag returns the (pointer) Type this flag aliases; this is the same +// type returned when defining the flag initially. +func GetTypeForFlag(flg *flag.Flag) (reflect.Type, error) { + if t, ok := flagTypeMap[reflect.TypeOf(flg.Value)]; ok { + return t, nil + } else if v, ok := flg.Value.(TypeAliased); ok { + return v.AliasedType(), nil + } + return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) +} + +// SetValueForFlagName sets the value for a flag by name. +func SetValueForFlagName(name string, i any, setFlags map[string]struct{}, appendSlice bool, strict bool) error { + flg := DefaultFlagSet.Lookup(name) + if flg == nil { + if strict { + return status.NotFoundErrorf("Undefined flag: %s", name) + } + return nil + } + // For slice flags, append the YAML values to the existing values if appendSlice is true + if v, ok := flg.Value.(Appendable); ok && appendSlice { + if err := v.AppendSlice(i); err != nil { + return status.InternalErrorf("Error encountered appending to flag %s: %s", flg.Name, err) + } + return nil + } + if v, ok := flg.Value.(IsNameAliasing); ok { + return SetValueForFlagName(v.AliasedName(), i, setFlags, appendSlice, strict) + } + // For non-append flags, skip the YAML values if it was set on the command line + if _, ok := setFlags[name]; ok { + return nil + } + t, err := GetTypeForFlag(flg) + if err != nil { + return status.UnimplementedErrorf("Error encountered setting flag: %s", err) + } + if !reflect.ValueOf(i).CanConvert(t.Elem()) { + return status.FailedPreconditionErrorf("Cannot convert value %v of type %T into type %v for flag %s.", i, i, t.Elem(), flg.Name) + } + reflect.ValueOf(flg.Value).Convert(t).Elem().Set(reflect.ValueOf(i).Convert(t.Elem())) + return nil +} + +// GetDereferencedValue retypes and returns the dereferenced Value for +// a given flag name. +func GetDereferencedValue[T any](name string) (T, error) { + flg := DefaultFlagSet.Lookup(name) + zeroT := reflect.New(reflect.TypeOf((*T)(nil)).Elem()).Interface().(*T) + if flg == nil { + return *zeroT, status.NotFoundErrorf("Undefined flag: %s", name) + } + if v, ok := flg.Value.(IsNameAliasing); ok { + return GetDereferencedValue[T](v.AliasedName()) + } + t := reflect.TypeOf((*T)(nil)) + addr := reflect.ValueOf(flg.Value) + if t == reflect.TypeOf((*any)(nil)) { + var err error + t, err = GetTypeForFlag(flg) + if err != nil { + return *zeroT, status.InternalErrorf("Error dereferencing flag to unspecified type: %s.", err) + } + if !addr.CanConvert(t) { + return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) + } + return addr.Convert(t).Elem().Interface().(T), nil + } + if !addr.CanConvert(t) { + return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) + } + v, ok := addr.Convert(t).Interface().(*T) + if !ok { + return *zeroT, status.InternalErrorf("Failed to assert flag %s of type %T as type %s.", name, flg.Value, t) + } + return *v, nil +} + +// AddTestFlagTypeForTesting adds a type correspondence to the internal +// flagTypeMap. +func AddTestFlagTypeForTesting(flagValue, value any) { + flagTypeMap[reflect.TypeOf(flagValue)] = reflect.TypeOf(value) +} diff --git a/server/util/flagutil/common/common_test.go b/server/util/flagutil/common/common_test.go new file mode 100644 index 00000000000..337c0ec92d7 --- /dev/null +++ b/server/util/flagutil/common/common_test.go @@ -0,0 +1,260 @@ +package common_test + +import ( + "flag" + "net/url" + "testing" + + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" +) + +type unsupportedFlagValue struct{} + +func (f *unsupportedFlagValue) Set(string) error { return nil } +func (f *unsupportedFlagValue) String() string { return "" } + +type testStruct struct { + Field int `json:"field"` + Meadow string `json:"meadow"` +} + +func replaceFlagsForTesting(t *testing.T) *flag.FlagSet { + flags := flag.NewFlagSet("test", flag.ContinueOnError) + common.DefaultFlagSet = flags + + t.Cleanup(func() { + common.DefaultFlagSet = flag.CommandLine + }) + + return flags +} + +func TestSetValueForFlagName(t *testing.T) { + flags := replaceFlagsForTesting(t) + flagBool := flags.Bool("bool", false, "") + err := common.SetValueForFlagName("bool", true, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, true, *flagBool) + + flags = replaceFlagsForTesting(t) + flagBool = flags.Bool("bool", false, "") + err = common.SetValueForFlagName("bool", true, map[string]struct{}{"bool": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, false, *flagBool) + + flags = replaceFlagsForTesting(t) + err = common.SetValueForFlagName("bool", true, map[string]struct{}{}, true, false) + require.NoError(t, err) + + flags = replaceFlagsForTesting(t) + flagInt := flags.Int("int", 2, "") + err = common.SetValueForFlagName("int", 1, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, 1, *flagInt) + + flags = replaceFlagsForTesting(t) + flagInt = flags.Int("int", 2, "") + err = common.SetValueForFlagName("int", 1, map[string]struct{}{"int": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, 2, *flagInt) + + flags = replaceFlagsForTesting(t) + flagInt64 := flags.Int64("int64", 2, "") + err = common.SetValueForFlagName("int64", 1, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, int64(1), *flagInt64) + + flags = replaceFlagsForTesting(t) + flagInt64 = flags.Int64("int64", 2, "") + err = common.SetValueForFlagName("int64", 1, map[string]struct{}{"int64": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, int64(2), *flagInt64) + + flags = replaceFlagsForTesting(t) + flagUint := flags.Uint("uint", 2, "") + err = common.SetValueForFlagName("uint", 1, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, uint(1), *flagUint) + + flags = replaceFlagsForTesting(t) + flagUint = flags.Uint("uint", 2, "") + err = common.SetValueForFlagName("uint", 1, map[string]struct{}{"uint": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, uint(2), *flagUint) + + flags = replaceFlagsForTesting(t) + flagUint64 := flags.Uint64("uint64", 2, "") + err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, uint64(1), *flagUint64) + + flags = replaceFlagsForTesting(t) + flagUint64 = flags.Uint64("uint64", 2, "") + err = common.SetValueForFlagName("uint64", 1, map[string]struct{}{"uint64": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, uint64(2), *flagUint64) + + flags = replaceFlagsForTesting(t) + flagFloat64 := flags.Float64("float64", 2, "") + err = common.SetValueForFlagName("float64", 1, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, float64(1), *flagFloat64) + + flags = replaceFlagsForTesting(t) + flagFloat64 = flags.Float64("float64", 2, "") + err = common.SetValueForFlagName("float64", 1, map[string]struct{}{"float64": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, float64(2), *flagFloat64) + + flags = replaceFlagsForTesting(t) + flagString := flags.String("string", "2", "") + err = common.SetValueForFlagName("string", "1", map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, "1", *flagString) + + flags = replaceFlagsForTesting(t) + flagString = flags.String("string", "2", "") + err = common.SetValueForFlagName("string", "1", map[string]struct{}{"string": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, "2", *flagString) + + flags = replaceFlagsForTesting(t) + flagURL := flagtypes.URLFromString("url", "https://www.example.com", "") + u, err := url.Parse("https://www.example.com:8080") + require.NoError(t, err) + err = common.SetValueForFlagName("url", *u, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com:8080"}, *flagURL) + + flags = replaceFlagsForTesting(t) + flagURL = flagtypes.URLFromString("url", "https://www.example.com", "") + u, err = url.Parse("https://www.example.com:8080") + require.NoError(t, err) + err = common.SetValueForFlagName("url", *u, map[string]struct{}{"url": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com"}, *flagURL) + + flags = replaceFlagsForTesting(t) + string_slice := make([]string, 2) + string_slice[0] = "1" + string_slice[1] = "2" + flagtypes.SliceVar(&string_slice, "string_slice", "") + err = common.SetValueForFlagName("string_slice", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) + + flags = replaceFlagsForTesting(t) + flagStringSlice := flagtypes.Slice("string_slice", []string{"1", "2"}, "") + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = flagtypes.Slice("string_slice", []string{"1", "2"}, "") + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = flagtypes.Slice("string_slice", []string{"1", "2"}, "") + err = common.SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) + + flags = replaceFlagsForTesting(t) + flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) + + flags = replaceFlagsForTesting(t) + flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, false, true) + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 3}}, flagStructSlice) + + flags = replaceFlagsForTesting(t) + flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": struct{}{}}, false, true) + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, flagStructSlice) +} + +func TestBadSetValueForFlagName(t *testing.T) { + flags := replaceFlagsForTesting(t) + _ = flags.Bool("bool", false, "") + err := common.SetValueForFlagName("bool", 0, map[string]struct{}{}, true, true) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + err = common.SetValueForFlagName("bool", false, map[string]struct{}{}, true, true) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + _ = flagtypes.Slice("string_slice", []string{"1", "2"}, "") + err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, true, true) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + _ = flagtypes.Slice("string_slice", []string{"1", "2"}, "") + err = common.SetValueForFlagName("string_slice", "3", map[string]struct{}{}, false, true) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, true, true) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} + flagtypes.SliceVar(&flagStructSlice, "struct_slice", "") + err = common.SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, false, true) + require.Error(t, err) +} + +func TestDereferencedValueFromFlagName(t *testing.T) { + flags := replaceFlagsForTesting(t) + _ = flags.Bool("bool", false, "") + v, err := common.GetDereferencedValue[bool]("bool") + require.NoError(t, err) + assert.Equal(t, false, v) + + flags = replaceFlagsForTesting(t) + _ = flags.Bool("bool", true, "") + v, err = common.GetDereferencedValue[bool]("bool") + require.NoError(t, err) + assert.Equal(t, true, v) + + flags = replaceFlagsForTesting(t) + _ = flagtypes.Slice("string_slice", []string{"1", "2"}, "") + stringSlice, err := common.GetDereferencedValue[[]string]("string_slice") + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, stringSlice) + + flags = replaceFlagsForTesting(t) + flagtypes.SliceVar(&[]testStruct{{Field: 1}, {Field: 2}}, "struct_slice", "") + structSlice, err := common.GetDereferencedValue[[]testStruct]("struct_slice") + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, structSlice) +} + +func TestBadDereferencedValueFromFlagName(t *testing.T) { + _ = replaceFlagsForTesting(t) + _, err := common.GetDereferencedValue[any]("unknown_flag") + require.Error(t, err) +} diff --git a/server/util/flagutil/flagutil.go b/server/util/flagutil/flagutil.go index 93cab8e0818..c7e39ed02fe 100644 --- a/server/util/flagutil/flagutil.go +++ b/server/util/flagutil/flagutil.go @@ -1,834 +1,14 @@ package flagutil import ( - "bytes" - "encoding/json" - "flag" - "fmt" - "net/url" - "os" - "reflect" - "strings" - "time" - - "github.com/buildbuddy-io/buildbuddy/server/util/alert" - "github.com/buildbuddy-io/buildbuddy/server/util/log" - "github.com/buildbuddy-io/buildbuddy/server/util/status" - "gopkg.in/yaml.v3" + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" ) -var ( - // Change only for testing purposes - defaultFlagSet = flag.CommandLine - - // Used for type conversions between flags and YAML - flagTypeMap = map[reflect.Type]reflect.Type{ - flagTypeFromFlagFuncName("Bool"): reflect.TypeOf((*bool)(nil)), - flagTypeFromFlagFuncName("Duration"): reflect.TypeOf((*time.Duration)(nil)), - flagTypeFromFlagFuncName("Float64"): reflect.TypeOf((*float64)(nil)), - flagTypeFromFlagFuncName("Int"): reflect.TypeOf((*int)(nil)), - flagTypeFromFlagFuncName("Int64"): reflect.TypeOf((*int64)(nil)), - flagTypeFromFlagFuncName("Uint"): reflect.TypeOf((*uint)(nil)), - flagTypeFromFlagFuncName("Uint64"): reflect.TypeOf((*uint64)(nil)), - flagTypeFromFlagFuncName("String"): reflect.TypeOf((*string)(nil)), - } - - // Flag names to ignore when generating a YAML map or populating flags (e. g., - // the flag specifying the path to the config file) - ignoreSet = make(map[string]struct{}) - - nilableKinds = map[reflect.Kind]struct{}{ - reflect.Chan: {}, - reflect.Func: {}, - reflect.Interface: {}, - reflect.Map: {}, - reflect.Ptr: {}, - reflect.Slice: {}, - } -) - -func flagTypeFromFlagFuncName(name string) reflect.Type { - fs := flag.NewFlagSet("", flag.ContinueOnError) - ff := reflect.ValueOf(fs).MethodByName(name) - in := make([]reflect.Value, ff.Type().NumIn()) - for i := range in { - in[i] = reflect.New(ff.Type().In(i)).Elem() - } - ff.Call(in) - return reflect.TypeOf(fs.Lookup("").Value) -} - -func IgnoreFilter(flg *flag.Flag) bool { - keys := strings.Split(flg.Name, ".") - for i := range keys { - if _, ok := ignoreSet[strings.Join(keys[:i+1], ".")]; ok { - return false - } - } - return true -} - -type YAMLTypeAliasable interface { - YAMLTypeAlias() reflect.Type -} - -type TypeAliased interface { - AliasedType() reflect.Type -} - -type isNameAliasing interface { - AliasedName() string -} - -type Appendable interface { - AppendSlice(any) error -} - -type DocumentedMarshaler interface { - DocumentNode(n *yaml.Node, opts ...DocumentNodeOption) error -} - -type SliceFlag[T any] []T - -func NewSliceFlag[T any](slice *[]T) *SliceFlag[T] { - return (*SliceFlag[T])(slice) -} - -func Slice[T any](name string, defaultValue []T, usage string) *[]T { - slice := make([]T, len(defaultValue)) - copy(slice, defaultValue) - defaultFlagSet.Var(NewSliceFlag(&slice), name, usage) - return &slice -} - -func SliceVar[T any](slice *[]T, name, usage string) { - defaultFlagSet.Var(NewSliceFlag(slice), name, usage) -} - -func (f *SliceFlag[T]) String() string { - switch v := any((*[]T)(f)).(type) { - case *[]string: - return strings.Join(*v, ",") - default: - b, err := json.Marshal(f) - if err != nil { - alert.UnexpectedEvent("config_cannot_marshal_struct", "err: %s", err) - return "[]" - } - return string(b) - } -} - -func (f *SliceFlag[T]) Set(values string) error { - if v, ok := any((*[]T)(f)).(*[]string); ok { - for _, val := range strings.Split(values, ",") { - *v = append(*v, val) - } - return nil - } - v := (*[]T)(f) - var a any - if err := json.Unmarshal([]byte(values), &a); err != nil { - return err - } - if _, ok := a.([]any); ok { - var dst []T - if err := json.Unmarshal([]byte(values), &dst); err != nil { - return err - } - *v = append(*v, dst...) - return nil - } - if _, ok := a.(map[string]any); ok { - var dst T - if err := json.Unmarshal([]byte(values), &dst); err != nil { - return err - } - *v = append(*v, dst) - return nil - } - return fmt.Errorf("Default Set for SliceFlag can only accept JSON objects or arrays, but type was %T", a) -} - -func (f *SliceFlag[T]) AppendSlice(slice any) error { - s, ok := slice.([]T) - if !ok { - return status.FailedPreconditionErrorf("Cannot append value %v of type %T to flag of type %T.", slice, slice, ([]T)(nil)) - } - v := (*[]T)(f) - *v = append(*v, s...) - return nil -} - -func (f *SliceFlag[T]) AliasedType() reflect.Type { - return reflect.TypeOf((*[]T)(nil)) -} - -func (f *SliceFlag[T]) YAMLTypeAlias() reflect.Type { - return f.AliasedType() -} - -type URLFlag url.URL - -func URL(name string, value url.URL, usage string) *url.URL { - u := &value - defaultFlagSet.Var((*URLFlag)(u), name, usage) - return u -} - -func URLVar(value *url.URL, name string, usage string) { - defaultFlagSet.Var((*URLFlag)(value), name, usage) -} - -func URLFromString(name, value, usage string) *url.URL { - u, err := url.Parse(value) - if err != nil { - log.Fatalf("Error parsing default URL value '%s' for flag: %v", value, err) - return nil - } - return URL(name, *u, usage) -} - -func (f *URLFlag) Set(value string) error { - u, err := url.Parse(value) - if err != nil { - return err - } - *(*url.URL)(f) = *u - return nil -} - -func (f *URLFlag) String() string { - return (*url.URL)(f).String() -} - -func (f *URLFlag) UnmarshalYAML(value *yaml.Node) error { - u, err := url.Parse(value.Value) - if err != nil { - return &yaml.TypeError{Errors: []string{err.Error()}} - } - *(*url.URL)(f) = *u - return nil -} - -func (f *URLFlag) MarshalYAML() (any, error) { - return f.String(), nil -} - -func (f *URLFlag) AliasedType() reflect.Type { - return reflect.TypeOf((*url.URL)(nil)) -} - -func (f *URLFlag) YAMLTypeAlias() reflect.Type { - return reflect.TypeOf((*URLFlag)(nil)) -} - -func (f *URLFlag) DocumentNode(n *yaml.Node, opts ...DocumentNodeOption) error { - for _, opt := range opts { - if _, ok := opt.(*addTypeToLineComment); ok { - if n.LineComment != "" { - n.LineComment += " " - } - n.LineComment += "type: URL" - continue - } - opt.Transform(f, n) - } - return nil -} - -type FlagAlias struct { - name string -} - -func Alias[T any](newName, name string) *T { - f := &FlagAlias{name: name} - var flg *flag.Flag - for aliaser, ok := isNameAliasing(f), true; ok; aliaser, ok = flg.Value.(isNameAliasing) { - if flg = defaultFlagSet.Lookup(aliaser.AliasedName()); flg == nil { - log.Fatalf("Error aliasing flag %s as %s: flag %s does not exist.", name, newName, aliaser.AliasedName()) - } - } - addr := reflect.ValueOf(flg.Value) - if t, err := getTypeForFlag(flg); err == nil { - if !addr.CanConvert(t) { - log.Fatalf("Error aliasing flag %s as %s: Flag %s of type %T could not be converted to %s.", name, newName, flg.Name, flg.Value, t) - } - addr = addr.Convert(t) - } - value, ok := addr.Interface().(*T) - if !ok { - log.Fatalf("Error aliasing flag %s as %s: Failed to assert flag %s of type %T as type %T.", name, newName, flg.Name, flg.Value, (*T)(nil)) - } - defaultFlagSet.Var(f, newName, "Alias for "+name) - return value -} - -func (f *FlagAlias) Set(value string) error { - return defaultFlagSet.Set(f.name, value) -} - -func (f *FlagAlias) String() string { - return defaultFlagSet.Lookup(f.name).Value.String() -} - -func (f *FlagAlias) AliasedName() string { - return f.name -} - -func (f *FlagAlias) AliasedType() reflect.Type { - flg := defaultFlagSet.Lookup(f.name) - t, err := getTypeForFlag(flg) - if err != nil { - return reflect.TypeOf(flg.Value) - } - return t -} - -func (f *FlagAlias) YAMLTypeAlias() reflect.Type { - flg := defaultFlagSet.Lookup(f.name) - t, err := getYAMLTypeForFlag(flg) - if err != nil { - return reflect.TypeOf(flg.Value) - } - return t -} - -// IgnoreFlagForYAML ignores the flag with this name when generating YAML and when -// populating flags from YAML input. -func IgnoreFlagForYAML(name string) { - ignoreSet[name] = struct{}{} -} - -func getTypeForFlag(flg *flag.Flag) (reflect.Type, error) { - if t, ok := flagTypeMap[reflect.TypeOf(flg.Value)]; ok { - return t, nil - } else if v, ok := flg.Value.(TypeAliased); ok { - return v.AliasedType(), nil - } - return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) -} - -func getYAMLTypeForFlag(flg *flag.Flag) (reflect.Type, error) { - if t, ok := flagTypeMap[reflect.TypeOf(flg.Value)]; ok { - return t, nil - } else if v, ok := flg.Value.(YAMLTypeAliasable); ok { - return v.YAMLTypeAlias(), nil - } - return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) -} - -type DocumentNodeOption interface { - Transform(in any, n *yaml.Node) - Passthrough() bool -} - -type headComment string - -func (h *headComment) Transform(in any, n *yaml.Node) { n.HeadComment = string(*h) } -func (h *headComment) Passthrough() bool { return false } - -// HeadComment sets the HeadComment of a yaml.Node to the specified string. -func HeadComment(s string) *headComment { return (*headComment)(&s) } - -type lineComment string - -func (l *lineComment) Transform(in any, n *yaml.Node) { n.LineComment = string(*l) } -func (l *lineComment) Passthrough() bool { return false } - -// LineComment sets the LineComment of a yaml.Node to the specified string. -func LineComment(s string) *lineComment { return (*lineComment)(&s) } - -type footComment string - -func (f *footComment) Transform(in any, n *yaml.Node) { n.FootComment = string(*f) } -func (f *footComment) Passthrough() bool { return false } - -// FootComment sets the FootComment of a yaml.Node to the specified string. -func FootComment(s string) *footComment { return (*footComment)(&s) } - -type addTypeToLineComment struct{} - -func (f *addTypeToLineComment) Transform(in any, n *yaml.Node) { - if n.LineComment != "" { - n.LineComment += " " - } - n.LineComment = fmt.Sprintf("%stype: %T", n.LineComment, in) -} - -func (f *addTypeToLineComment) Passthrough() bool { return true } - -// AddTypeToLineComment appends the type specification to the LineComment of the yaml.Node. -func AddTypeToLineComment() *addTypeToLineComment { return (*addTypeToLineComment)(&struct{}{}) } - -func FilterPassthrough(opts []DocumentNodeOption) []DocumentNodeOption { - ptOpts := []DocumentNodeOption{} - for _, opt := range opts { - if opt.Passthrough() { - ptOpts = append(ptOpts, opt) - } - } - return ptOpts -} - -// DocumentedNode returns a yaml.Node representing the input value with -// documentation in the comments. -func DocumentedNode(in any, opts ...DocumentNodeOption) (*yaml.Node, error) { - n := &yaml.Node{} - if err := n.Encode(in); err != nil { - return nil, err - } - if err := DocumentNode(in, n, opts...); err != nil { - return nil, err - } - return n, nil -} - -// DocumentNode fills the comments of a yaml.Node with documentation. -func DocumentNode(in any, n *yaml.Node, opts ...DocumentNodeOption) error { - switch m := in.(type) { - case DocumentedMarshaler: - return m.DocumentNode(n, opts...) - case yaml.Marshaler: - // pass - default: - v := reflect.ValueOf(in) - t := v.Type() - switch t.Kind() { - case reflect.Ptr: - // document based on the value pointed to - if !v.IsNil() { - return DocumentNode(v.Elem().Interface(), n, opts...) - } else { - return DocumentNode(reflect.New(reflect.TypeOf(t).Elem()).Elem().Interface(), n, opts...) - } - case reflect.Struct: - // yaml.Node stores mappings in Content as [key1, value1, key2, value2...] - contentIndex := make(map[string]int, len(n.Content)/2) - for i := 0; i < len(n.Content)/2; i++ { - contentIndex[n.Content[2*i].Value] = 2*i + 1 - } - for i := 0; i < t.NumField(); i++ { - ft := t.FieldByIndex([]int{i}) - name := strings.Split(ft.Tag.Get("yaml"), ",")[0] - if name == "" { - name = strings.ToLower(ft.Name) - } - idx, ok := contentIndex[name] - if !ok { - // field is not encoded by yaml - continue - } - if err := DocumentNode( - v.FieldByIndex([]int{i}).Interface(), - n.Content[idx], - append( - []DocumentNodeOption{LineComment(ft.Tag.Get("usage"))}, - FilterPassthrough(opts)..., - )..., - ); err != nil { - return err - } - } - case reflect.Slice: - // yaml.Node stores sequences in Content as [element1, element2...] - for i := range n.Content { - var err error - if err = DocumentNode(v.Index(i).Interface(), n.Content[i], FilterPassthrough(opts)...); err != nil { - return err - } - } - if len(n.Content) == 0 { - exampleNode, err := DocumentedNode(reflect.MakeSlice(t, 1, 1).Interface(), FilterPassthrough(opts)...) - if err != nil { - return err - } - if exampleNode.Content[0].Kind != yaml.ScalarNode { - example, err := yaml.Marshal(exampleNode) - if err != nil { - return err - } - n.FootComment = fmt.Sprintf("e.g.,\n%s", string(example)) - } - } - case reflect.Map: - // yaml.Node stores mappings in Content as [key1, value1, key2, value2...] - for i := 0; i < len(n.Content)/2; i++ { - k := reflect.ValueOf(n.Content[2*i].Value) - if err := DocumentNode( - v.MapIndex(k).Interface(), - n.Content[2*i+1], - FilterPassthrough(opts)..., - ); err != nil { - return err - } - } - } - } - for _, opt := range opts { - opt.Transform(in, n) - } - return nil -} - -// GenerateDocumentedYAMLNodeFromFlag produces a documented yaml.Node which -// represents the value contained in the flag. -func GenerateDocumentedYAMLNodeFromFlag(flg *flag.Flag) (*yaml.Node, error) { - t, err := getYAMLTypeForFlag(flg) - if err != nil { - return nil, status.InternalErrorf("Error encountered generating default YAML from flags: %s", err) - } - v, err := GetDereferencedValue[any](flg.Name) - if err != nil { - return nil, status.InternalErrorf("Error encountered generating default YAML from flags: %s", err) - } - value := reflect.New(reflect.TypeOf(v)) - value.Elem().Set(reflect.ValueOf(v)) - if !value.CanConvert(t) { - return nil, status.FailedPreconditionErrorf("Cannot convert value %v of type %T into type %v for flag %s.", value.Interface(), value.Type(), t, flg.Name) - } - return DocumentedNode(value.Convert(t).Interface(), LineComment(flg.Usage), AddTypeToLineComment()) -} - -// SplitDocumentedYAMLFromFlags produces marshaled YAML representing the flags, -// partitioned into two groups: structured (flags containing dots), and -// unstructured (flags not containing dots). -func SplitDocumentedYAMLFromFlags() ([]byte, error) { - b := bytes.NewBuffer([]byte{}) - - if _, err := b.Write([]byte("# Unstructured settings\n\n")); err != nil { - return nil, err - } - um, err := GenerateYAMLMapWithValuesFromFlags( - GenerateDocumentedYAMLNodeFromFlag, - func(flg *flag.Flag) bool { return !strings.Contains(flg.Name, ".") }, - IgnoreFilter, - ) - if err != nil { - return nil, err - } - ub, err := yaml.Marshal(um) - if err != nil { - return nil, err - } - if _, err := b.Write(ub); err != nil { - return nil, err - } - - if _, err := b.Write([]byte("\n# Structured settings\n\n")); err != nil { - return nil, err - } - sm, err := GenerateYAMLMapWithValuesFromFlags( - GenerateDocumentedYAMLNodeFromFlag, - func(flg *flag.Flag) bool { return strings.Contains(flg.Name, ".") }, - IgnoreFilter, - ) - if err != nil { - return nil, err - } - sb, err := yaml.Marshal(sm) - if err != nil { - return nil, err - } - if _, err := b.Write(sb); err != nil { - return nil, err - } - - return b.Bytes(), nil -} - -// GenerateYAMLMapWithValuesFromFlags generates a YAML map structure -// representing the flags, with values generated from the flags as per the -// generateValue function that has been passed in, and filtering out any flags -// for which any of the passed filter functions return false. Any nil generated -// values are not added to the map, and any empty maps are recursively removed -// such that the final map returned contains no empty maps at any point in its -// structure. -func GenerateYAMLMapWithValuesFromFlags[T any](generateValue func(*flag.Flag) (T, error), filters ...func(*flag.Flag) bool) (map[string]any, error) { - yamlMap := make(map[string]any) - var errors []error - defaultFlagSet.VisitAll(func(flg *flag.Flag) { - for _, f := range filters { - if !f(flg) { - return - } - } - keys := strings.Split(flg.Name, ".") - m := yamlMap - for i, k := range keys[:len(keys)-1] { - v, ok := m[k] - if !ok { - v := make(map[string]any) - m[k], m = v, v - continue - } - m, ok = v.(map[string]any) - if !ok { - errors = append(errors, status.FailedPreconditionErrorf("When trying to create YAML map hierarchy for %s, encountered non-map value %s of type %T at %s", flg.Name, v, v, strings.Join(keys[:i+1], "."))) - return - } - } - k := keys[len(keys)-1] - if v, ok := m[k]; ok { - errors = append(errors, status.FailedPreconditionErrorf("When generating value for %s for YAML map, encountered pre-existing value %s of type %T.", flg.Name, v, v)) - return - } - v, err := generateValue(flg) - if err != nil { - errors = append(errors, err) - return - } - value := reflect.ValueOf(v) - if _, ok := nilableKinds[value.Kind()]; ok && value.IsNil() { - return - } - m[k] = v - }) - if errors != nil { - return nil, status.InternalErrorf("Errors encountered when generating YAML map from flags: %v", errors) - } - - return RemoveEmptyMapsFromYAMLMap(yamlMap), nil -} - -// RemoveEmptyMapsFromYAMLMap recursively removes all empty maps, such that the -// returned map contains no empty maps at any point in its structure. The -// original map is returned unless it is empty after removal, in which case nil -// is returned. -func RemoveEmptyMapsFromYAMLMap(m map[string]any) map[string]any { - for k, v := range m { - mv, ok := v.(map[string]any) - if !ok { - continue - } - if m[k] = RemoveEmptyMapsFromYAMLMap(mv); m[k] == nil { - delete(m, k) - } - } - if len(m) == 0 { - return nil - } - return m -} - -// RetypeAndFilterYAMLMap un-marshals yaml from the input yamlMap and then -// re-marshals it into the types specified by the type map, replacing the -// original value in the input map. Filters out any values not specified by the -// flags. -func RetypeAndFilterYAMLMap(yamlMap map[string]any, typeMap map[string]any, prefix []string) error { - for k := range yamlMap { - label := append(prefix, k) - if _, ok := typeMap[k]; !ok { - // No flag corresponds to this, warn and delete. - log.Warningf("No flags correspond to YAML input at '%s'.", strings.Join(label, ".")) - delete(yamlMap, k) - continue - } - switch t := typeMap[k].(type) { - case reflect.Type: - // this is a value, populate it from the YAML - yamlData, err := yaml.Marshal(yamlMap[k]) - if err != nil { - return status.InternalErrorf("Encountered error marshaling %v to YAML at %s: %s", yamlMap[k], strings.Join(label, "."), err) - } - v := reflect.New(t.Elem()).Elem() - err = yaml.Unmarshal(yamlData, v.Addr().Interface()) - if err != nil { - return status.InternalErrorf("Encountered error marshaling %s to YAML for type %v at %s: %s", string(yamlData), v.Type(), strings.Join(label, "."), err) - } - if v.Type() != t.Elem() { - return status.InternalErrorf("Failed to unmarshal YAML to the specified type at %s: wanted %v, got %T", strings.Join(label, "."), t.Elem(), v.Type()) - } - yamlMap[k] = v.Interface() - case map[string]any: - yamlSubmap, ok := yamlMap[k].(map[string]any) - if !ok { - // this is a value, not a map, and there is no corresponding type - alert.UnexpectedEvent("Input YAML contained non-map value %v of type %T at label %s", yamlMap[k], yamlMap[k], strings.Join(label, ".")) - delete(yamlMap, k) - } - err := RetypeAndFilterYAMLMap(yamlSubmap, t, label) - if err != nil { - return err - } - default: - return status.InvalidArgumentErrorf("typeMap contained invalid type %T at %s.", typeMap[k], strings.Join(label, ".")) - } - } - return nil -} - -// PopulateFlagsFromData takes some YAML input and unmarshals it, then uses the -// umnarshaled data to populate the unset flags with names corresponding to the -// keys. -func PopulateFlagsFromData(data []byte) error { - // expand environment variables - expandedData := []byte(os.ExpandEnv(string(data))) - - yamlMap := make(map[string]any) - if err := yaml.Unmarshal([]byte(expandedData), yamlMap); err != nil { - return status.InternalErrorf("Error parsing config file: %s", err) - } - node := &yaml.Node{} - if err := yaml.Unmarshal([]byte(expandedData), node); err != nil { - return status.InternalErrorf("Error parsing config file: %s", err) - } - if len(node.Content) > 0 { - node = node.Content[0] - } else { - node = nil - } - typeMap, err := GenerateYAMLMapWithValuesFromFlags(getYAMLTypeForFlag, IgnoreFilter) - if err != nil { - return err - } - if err := RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}); err != nil { - return status.InternalErrorf("Error encountered retyping YAML map: %s", err) - } - - return PopulateFlagsFromYAMLMap(yamlMap, node) -} - -// PopulateFlagsFromData takes the path to some YAML file, reads it, and -// unmarshals it, then uses the umnarshaled data to populate the unset flags -// with names corresponding to the keys. -func PopulateFlagsFromFile(configFile string) error { - log.Infof("Reading buildbuddy config from '%s'", configFile) - - _, err := os.Stat(configFile) - - // If the file does not exist then skip it. - if os.IsNotExist(err) { - log.Warningf("No config file found at %s.", configFile) - return nil - } - - fileBytes, err := os.ReadFile(configFile) - if err != nil { - return fmt.Errorf("Error reading config file: %s", err) - } - - return PopulateFlagsFromData(fileBytes) -} - -// PopulateFlagsFromYAMLMap takes a map populated by YAML from some YAML input -// and a yaml.Node populated by YAML from the same input and iterates over it, -// finding flags with names corresponding to the keys and setting the flag to -// the YAML value if the flag was not set on the command line. The yaml.Node -// preserves order when setting the flag values, which is important for aliases. -// If Node is nil, the order values will be set in is random, as per go's -// implementation of map traversal. -func PopulateFlagsFromYAMLMap(m map[string]any, node *yaml.Node) error { - setFlags := make(map[string]struct{}) - defaultFlagSet.Visit(func(flg *flag.Flag) { - setFlags[flg.Name] = struct{}{} - }) - - return populateFlagsFromYAML(m, []string{}, node, setFlags) -} - -func populateFlagsFromYAML(a any, prefix []string, node *yaml.Node, setFlags map[string]struct{}) error { - if m, ok := a.(map[string]any); ok { - i := 0 - for k, v := range m { - var n *yaml.Node - if node != nil { - // Ensure that we populate flags in the order they are specified in the - // YAML data if the node structure data was provided. - for ok := false; node != nil && !ok; i++ { - k = node.Content[2*i].Value - n = node.Content[2*i+1] - v, ok = m[k] - } - } - p := append(prefix, k) - if _, ok := ignoreSet[strings.Join(p, ".")]; ok { - return nil - } - if err := populateFlagsFromYAML(v, p, n, setFlags); err != nil { - return err - } - } - return nil - } - name := strings.Join(prefix, ".") - if _, ok := ignoreSet[name]; ok { - return nil - } - return SetValueForFlagName(name, a, setFlags, true, false) -} - // SetValueForFlagName sets the value for a flag by name. -func SetValueForFlagName(name string, i any, setFlags map[string]struct{}, appendSlice bool, strict bool) error { - flg := defaultFlagSet.Lookup(name) - if flg == nil { - if strict { - return status.NotFoundErrorf("Undefined flag: %s", name) - } - return nil - } - // For slice flags, append the YAML values to the existing values if appendSlice is true - if v, ok := flg.Value.(Appendable); ok && appendSlice { - if err := v.AppendSlice(i); err != nil { - return status.InternalErrorf("Error encountered appending to flag %s: %s", flg.Name, err) - } - return nil - } - // For non-append flags, skip the YAML values if it was set on the command line - if _, ok := setFlags[name]; ok { - return nil - } - if v, ok := flg.Value.(isNameAliasing); ok { - return SetValueForFlagName(v.AliasedName(), i, setFlags, appendSlice, strict) - } - t, err := getTypeForFlag(flg) - if err != nil { - return status.UnimplementedErrorf("Error encountered setting flag: %s", err) - } - if !reflect.ValueOf(i).CanConvert(t.Elem()) { - return status.FailedPreconditionErrorf("Cannot convert value %v of type %T into type %v for flag %s.", i, i, t.Elem(), flg.Name) - } - reflect.ValueOf(flg.Value).Convert(t).Elem().Set(reflect.ValueOf(i).Convert(t.Elem())) - return nil -} +var SetValueForFlagName = common.SetValueForFlagName // GetDereferencedValue retypes and returns the dereferenced Value for // a given flag name. func GetDereferencedValue[T any](name string) (T, error) { - flg := defaultFlagSet.Lookup(name) - zeroT := reflect.New(reflect.TypeOf((*T)(nil)).Elem()).Interface().(*T) - if flg == nil { - return *zeroT, status.NotFoundErrorf("Undefined flag: %s", name) - } - if v, ok := flg.Value.(isNameAliasing); ok { - return GetDereferencedValue[T](v.AliasedName()) - } - t := reflect.TypeOf((*T)(nil)) - addr := reflect.ValueOf(flg.Value) - if t == reflect.TypeOf((*any)(nil)) { - var err error - t, err = getTypeForFlag(flg) - if err != nil { - return *zeroT, status.InternalErrorf("Error dereferencing flag to unspecified type: %s.", err) - } - if !addr.CanConvert(t) { - return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) - } - return addr.Convert(t).Elem().Interface().(T), nil - } - if !addr.CanConvert(t) { - return *zeroT, status.InvalidArgumentErrorf("Flag %s of type %T could not be converted to %s.", name, flg.Value, t) - } - v, ok := addr.Convert(t).Interface().(*T) - if !ok { - return *zeroT, status.InternalErrorf("Failed to assert flag %s of type %T as type %s.", name, flg.Value, t) - } - return *v, nil -} - -// FOR TESTING PURPOSES ONLY -// AddTestFlagTypeForTesting adds a type correspondence to the internal -// flagTypeMap. -func AddTestFlagTypeForTesting(flagValue, value any) { - flagTypeMap[reflect.TypeOf(flagValue)] = reflect.TypeOf(value) + return common.GetDereferencedValue[T](name) } diff --git a/server/util/flagutil/flagutil_test.go b/server/util/flagutil/flagutil_test.go deleted file mode 100644 index 1519e6d7a87..00000000000 --- a/server/util/flagutil/flagutil_test.go +++ /dev/null @@ -1,911 +0,0 @@ -package flagutil - -import ( - "flag" - "net/url" - "reflect" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/timestamppb" - "gopkg.in/yaml.v3" -) - -type unsupportedFlagValue struct{} - -func (f *unsupportedFlagValue) Set(string) error { return nil } -func (f *unsupportedFlagValue) String() string { return "" } - -type testStruct struct { - Field int `json:"field"` - Meadow string `json:"meadow"` -} - -func replaceFlagsForTesting(t *testing.T) *flag.FlagSet { - flags := flag.NewFlagSet("test", flag.ContinueOnError) - defaultFlagSet = flags - - t.Cleanup(func() { - defaultFlagSet = flag.CommandLine - }) - - return flags -} - -func TestStringSliceFlag(t *testing.T) { - var err error - - flags := replaceFlagsForTesting(t) - - foo := Slice("foo", []string{}, "A list of foos") - assert.Equal(t, []string{}, *foo) - assert.Equal(t, []string{}, *(*[]string)(flags.Lookup("foo").Value.(*SliceFlag[string]))) - err = flags.Set("foo", "foo0,foo1") - assert.NoError(t, err) - err = flags.Set("foo", "foo2") - assert.NoError(t, err) - err = flags.Set("foo", "foo3,foo4,foo5") - assert.NoError(t, err) - assert.Equal(t, []string{"foo0", "foo1", "foo2", "foo3", "foo4", "foo5"}, *foo) - assert.Equal(t, []string{"foo0", "foo1", "foo2", "foo3", "foo4", "foo5"}, *(*[]string)(flags.Lookup("foo").Value.(*SliceFlag[string]))) - - bar := Slice("bar", []string{"bar0", "bar1"}, "A list of bars") - assert.Equal(t, []string{"bar0", "bar1"}, *bar) - assert.Equal(t, []string{"bar0", "bar1"}, *(*[]string)(flags.Lookup("bar").Value.(*SliceFlag[string]))) - err = flags.Set("bar", "bar2") - assert.NoError(t, err) - err = flags.Set("bar", "bar3,bar4,bar5") - assert.NoError(t, err) - assert.Equal(t, []string{"bar0", "bar1", "bar2", "bar3", "bar4", "bar5"}, *bar) - assert.Equal(t, []string{"bar0", "bar1", "bar2", "bar3", "bar4", "bar5"}, *(*[]string)(flags.Lookup("bar").Value.(*SliceFlag[string]))) - - baz := Slice("baz", []string{}, "A list of bazs") - err = flags.Set("baz", flags.Lookup("bar").Value.String()) - assert.NoError(t, err) - assert.Equal(t, *bar, *baz) - - testSlice := []string{"yes", "si", "hai"} - testFlag := NewSliceFlag(&testSlice) - testFlag.AppendSlice(*(*[]string)(testFlag)) - assert.Equal(t, []string{"yes", "si", "hai", "yes", "si", "hai"}, testSlice) -} - -func TestStructSliceFlag(t *testing.T) { - var err error - - flags := replaceFlagsForTesting(t) - - fooFlag := Slice("foo", []testStruct{}, "A list of foos") - assert.Equal(t, []testStruct{}, *fooFlag) - assert.Equal(t, []testStruct{}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) - err = flags.Set("foo", `[{"field":3,"meadow":"watership down"}]`) - assert.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}}, *fooFlag) - assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) - err = flags.Set("foo", `{"field":5,"meadow":"runnymede"}`) - assert.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}, {Field: 5, Meadow: "runnymede"}}, *fooFlag) - assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}, {Field: 5, Meadow: "runnymede"}}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) - - barFlag := []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}} - SliceVar(&barFlag, "bar", "A list of bars") - assert.Equal(t, []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}}, barFlag) - assert.Equal(t, []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}}, *(*[]testStruct)(flags.Lookup("bar").Value.(*SliceFlag[testStruct]))) - - fooxFlag := Slice("foox", []testStruct{}, "A list of fooxes") - assert.Equal(t, []testStruct{}, *fooxFlag) - assert.Equal(t, []testStruct{}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) - err = flags.Set("foox", `[{"field":13,"meadow":"cors y llyn"},{},{"field":15}]`) - assert.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}}, *fooxFlag) - assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) - err = flags.Set("foox", `[{"field":17,"meadow":"red hill"},{},{"field":19}]`) - assert.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}, {Field: 17, Meadow: "red hill"}, {}, {Field: 19}}, *fooxFlag) - assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}, {Field: 17, Meadow: "red hill"}, {}, {Field: 19}}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) - - bazFlag := []testStruct{} - SliceVar(&bazFlag, "baz", "A list of bazs") - err = flags.Set("baz", flags.Lookup("bar").Value.String()) - assert.NoError(t, err) - assert.Equal(t, barFlag, bazFlag) - - testSlice := []testStruct{{}, {Field: 1}, {Meadow: "Paradise"}} - testFlag := NewSliceFlag(&testSlice) - testFlag.AppendSlice(*(*[]testStruct)(testFlag)) - assert.Equal(t, []testStruct{{}, {Field: 1}, {Meadow: "Paradise"}, {}, {Field: 1}, {Meadow: "Paradise"}}, testSlice) -} - -func TestProtoSliceFlag(t *testing.T) { - var err error - - flags := replaceFlagsForTesting(t) - - fooFlag := Slice("foo", []*timestamppb.Timestamp{}, "A list of foos") - assert.Equal(t, []*timestamppb.Timestamp{}, *fooFlag) - assert.Equal(t, []*timestamppb.Timestamp{}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) - err = flags.Set("foo", `[{"seconds":3,"nanos":5}]`) - assert.NoError(t, err) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}}, *fooFlag) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) - err = flags.Set("foo", `{"seconds":5,"nanos":9}`) - assert.NoError(t, err) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}, {Seconds: 5, Nanos: 9}}, *fooFlag) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}, {Seconds: 5, Nanos: 9}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) - - barFlag := []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}} - SliceVar(&barFlag, "bar", "A list of bars") - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}}, barFlag) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("bar").Value.(*SliceFlag[*timestamppb.Timestamp]))) - - fooxFlag := Slice("foox", []*timestamppb.Timestamp{}, "A list of fooxes") - assert.Equal(t, []*timestamppb.Timestamp{}, *fooxFlag) - assert.Equal(t, []*timestamppb.Timestamp{}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) - err = flags.Set("foox", `[{"seconds":13,"nanos":64},{},{"seconds":15}]`) - assert.NoError(t, err) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}}, *fooxFlag) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) - err = flags.Set("foox", `[{"seconds":17,"nanos":9001},{},{"seconds":19}]`) - assert.NoError(t, err) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}, {Seconds: 17, Nanos: 9001}, {}, {Seconds: 19}}, *fooxFlag) - assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}, {Seconds: 17, Nanos: 9001}, {}, {Seconds: 19}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) - - bazFlag := []*timestamppb.Timestamp{} - SliceVar(&bazFlag, "baz", "A list of bazs") - err = flags.Set("baz", flags.Lookup("bar").Value.String()) - assert.NoError(t, err) - assert.Equal(t, barFlag, bazFlag) - - testSlice := []*timestamppb.Timestamp{{}, {Seconds: 1}, {Nanos: 99}} - testFlag := NewSliceFlag(&testSlice) - testFlag.AppendSlice(*(*[]*timestamppb.Timestamp)(testFlag)) - assert.Equal(t, []*timestamppb.Timestamp{{}, {Seconds: 1}, {Nanos: 99}, {}, {Seconds: 1}, {Nanos: 99}}, testSlice) - -} - -func TestGenerateYAMLTypeMapFromFlags(t *testing.T) { - flags := replaceFlagsForTesting(t) - - flags.Bool("bool", true, "") - flags.Int("one.two.int", 10, "") - Slice("one.two.string_slice", []string{"hi", "hello"}, "") - flags.Float64("one.two.two_and_a_half.float64", 5.2, "") - Slice("one.two.three.struct_slice", []testStruct{{Field: 4, Meadow: "Great"}}, "") - flags.String("a.b.string", "xxx", "") - URLFromString("a.b.url", "https://www.example.com", "") - actual, err := GenerateYAMLMapWithValuesFromFlags(getYAMLTypeForFlag, IgnoreFilter) - require.NoError(t, err) - expected := map[string]any{ - "bool": reflect.TypeOf((*bool)(nil)), - "one": map[string]any{ - "two": map[string]any{ - "int": reflect.TypeOf((*int)(nil)), - "string_slice": reflect.TypeOf((*[]string)(nil)), - "two_and_a_half": map[string]any{ - "float64": reflect.TypeOf((*float64)(nil)), - }, - "three": map[string]any{ - "struct_slice": reflect.TypeOf((*[]testStruct)(nil)), - }, - }, - }, - "a": map[string]any{ - "b": map[string]any{ - "string": reflect.TypeOf((*string)(nil)), - "url": reflect.TypeOf((*URLFlag)(nil)), - }, - }, - } - if diff := cmp.Diff(expected, actual, cmp.Comparer(func(x, y reflect.Type) bool { return x == y })); diff != "" { - t.Error(diff) - } -} - -func TestBadGenerateYAMLTypeMapFromFlags(t *testing.T) { - flags := replaceFlagsForTesting(t) - - flags.Int("one.two.int", 10, "") - flags.Int("one.two", 10, "") - _, err := GenerateYAMLMapWithValuesFromFlags(getYAMLTypeForFlag, IgnoreFilter) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - - flags.Int("one.two", 10, "") - flags.Int("one.two.int", 10, "") - _, err = GenerateYAMLMapWithValuesFromFlags(getYAMLTypeForFlag, IgnoreFilter) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - - flags.Var(&unsupportedFlagValue{}, "unsupported", "") - _, err = GenerateYAMLMapWithValuesFromFlags(getYAMLTypeForFlag, IgnoreFilter) - require.Error(t, err) - -} - -func TestRetypeAndFilterYAMLMap(t *testing.T) { - typeMap := map[string]any{ - "bool": reflect.TypeOf((*bool)(nil)), - "one": map[string]any{ - "two": map[string]any{ - "int": reflect.TypeOf((*int)(nil)), - "string_slice": reflect.TypeOf((*[]string)(nil)), - "two_and_a_half": map[string]any{ - "float64": reflect.TypeOf((*float64)(nil)), - }, - "three": map[string]any{ - "struct_slice": reflect.TypeOf((*[]testStruct)(nil)), - }, - }, - }, - "a": map[string]any{ - "b": map[string]any{ - "string": reflect.TypeOf((*string)(nil)), - "url": reflect.TypeOf((*URLFlag)(nil)), - }, - }, - "foo": map[string]any{ - "bar": reflect.TypeOf((*int64)(nil)), - }, - } - yamlData := ` -bool: true -one: - two: - int: 1 - string_slice: - - "string1" - - "string2" - two_and_a_half: - float64: 9.4 - three: - struct_slice: - - field: 9 - meadow: "Eternal" - - field: 5 -a: - b: - url: "http://www.example.com" -foo: 7 -first: - second: - unknown: 9009 - no: "definitely not" -` - yamlMap := make(map[string]any) - err := yaml.Unmarshal([]byte(yamlData), yamlMap) - require.NoError(t, err) - err = RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) - require.NoError(t, err) - expected := map[string]any{ - "bool": true, - "one": map[string]any{ - "two": map[string]any{ - "int": int(1), - "string_slice": []string{"string1", "string2"}, - "two_and_a_half": map[string]any{ - "float64": float64(9.4), - }, - "three": map[string]any{ - "struct_slice": []testStruct{{Field: 9, Meadow: "Eternal"}, {Field: 5}}, - }, - }, - }, - "a": map[string]any{ - "b": map[string]any{ - "url": URLFlag(url.URL{Scheme: "http", Host: "www.example.com"}), - }, - }, - } - if diff := cmp.Diff(expected, yamlMap); diff != "" { - t.Error(diff) - } -} - -func TestBadRetypeAndFilterYAMLMap(t *testing.T) { - typeMap := map[string]any{ - "bool": reflect.TypeOf((*bool)(nil)), - } - yamlData := ` -bool: 7 -` - yamlMap := make(map[string]any) - err := yaml.Unmarshal([]byte(yamlData), yamlMap) - require.NoError(t, err) - err = RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) - require.Error(t, err) - - typeMap = map[string]any{ - "bool": false, - } - yamlData = ` -bool: true -` - yamlMap = make(map[string]any) - err = yaml.Unmarshal([]byte(yamlData), yamlMap) - require.NoError(t, err) - err = RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) - require.Error(t, err) -} - -func TestPopulateFlagsFromData(t *testing.T) { - flags := replaceFlagsForTesting(t) - - flagBool := flags.Bool("bool", true, "") - flagOneTwoInt := flags.Int("one.two.int", 10, "") - flagOneTwoStringSlice := Slice("one.two.string_slice", []string{"hi", "hello"}, "") - flagOneTwoTwoAndAHalfFloat := flags.Float64("one.two.two_and_a_half.float64", 5.2, "") - flagOneTwoThreeStructSlice := []testStruct{{Field: 4, Meadow: "Great"}} - SliceVar(&flagOneTwoThreeStructSlice, "one.two.three.struct_slice", "") - flagABString := flags.String("a.b.string", "xxx", "") - flagABStructSlice := []testStruct{{Field: 7, Meadow: "Chimney"}} - SliceVar(&flagABStructSlice, "a.b.struct_slice", "") - flagABURL := URLFromString("a.b.url", "https://www.example.com", "") - yamlData := ` -bool: true -one: - two: - int: 1 - string_slice: - - "string1" - - "string2" - two_and_a_half: - float64: 9.4 - three: - struct_slice: - - field: 9 - meadow: "Eternal" - - field: 5 -a: - b: - url: "http://www.example.com:8080" -foo: 7 -first: - second: - unknown: 9009 - no: "definitely not" -` - err := PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, true, *flagBool) - assert.Equal(t, int(1), *flagOneTwoInt) - assert.Equal(t, []string{"hi", "hello", "string1", "string2"}, *flagOneTwoStringSlice) - assert.Equal(t, float64(9.4), *flagOneTwoTwoAndAHalfFloat) - assert.Equal(t, []testStruct{{Field: 4, Meadow: "Great"}, {Field: 9, Meadow: "Eternal"}, {Field: 5}}, flagOneTwoThreeStructSlice) - assert.Equal(t, "xxx", *flagABString) - assert.Equal(t, []testStruct{{Field: 7, Meadow: "Chimney"}}, flagABStructSlice) - assert.Equal(t, url.URL{Scheme: "http", Host: "www.example.com:8080"}, *flagABURL) -} - -func TestBadPopulateFlagsFromData(t *testing.T) { - _ = replaceFlagsForTesting(t) - - yamlData := ` - bool: true -` - err := PopulateFlagsFromData([]byte(yamlData)) - require.Error(t, err) - - flags := replaceFlagsForTesting(t) - - flags.Var(&unsupportedFlagValue{}, "bad", "") - err = PopulateFlagsFromData([]byte{}) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - - _ = flags.Bool("bool", false, "") - yamlData = ` -bool: 7 -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.Error(t, err) -} - -func TestPopulateFlagsFromYAML(t *testing.T) { - flags := replaceFlagsForTesting(t) - - flagBool := flags.Bool("bool", true, "") - flagOneTwoInt := flags.Int("one.two.int", 10, "") - flagOneTwoStringSlice := Slice("one.two.string_slice", []string{"hi", "hello"}, "") - flagOneTwoTwoAndAHalfFloat := flags.Float64("one.two.two_and_a_half.float64", 5.2, "") - flagOneTwoThreeStructSlice := []testStruct{{Field: 4, Meadow: "Great"}} - SliceVar(&flagOneTwoThreeStructSlice, "one.two.three.struct_slice", "") - flagABString := flags.String("a.b.string", "xxx", "") - flagABStructSlice := []testStruct{{Field: 7, Meadow: "Chimney"}} - SliceVar(&flagABStructSlice, "a.b.struct_slice", "") - flagABURL := URLFromString("a.b.url", "https://www.example.com", "") - input := map[string]any{ - "bool": false, - "one": map[string]any{ - "two": map[string]any{ - "string_slice": []string{"meow", "woof"}, - "two_and_a_half": map[string]any{ - "float64": float64(7), - }, - "three": map[string]any{ - "struct_slice": ([]testStruct)(nil), - }, - }, - }, - "a": map[string]any{ - "b": map[string]any{ - "string": "", - "struct_slice": []testStruct{{Field: 9}}, - "url": URLFlag(url.URL{Scheme: "https", Host: "www.example.com:8080"}), - }, - }, - "undefined": struct{}{}, // keys without with no corresponding flag name should be ignored. - } - node := &yaml.Node{} - err := node.Encode(input) - require.NoError(t, err) - err = PopulateFlagsFromYAMLMap(input, node) - require.NoError(t, err) - - assert.Equal(t, false, *flagBool) - assert.Equal(t, 10, *flagOneTwoInt) - assert.Equal(t, []string{"hi", "hello", "meow", "woof"}, *flagOneTwoStringSlice) - assert.Equal(t, float64(7), *flagOneTwoTwoAndAHalfFloat) - assert.Equal(t, []testStruct{{Field: 4, Meadow: "Great"}}, flagOneTwoThreeStructSlice) - assert.Equal(t, "", *flagABString) - assert.Equal(t, []testStruct{{Field: 7, Meadow: "Chimney"}, {Field: 9}}, flagABStructSlice) - assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com:8080"}, *flagABURL) -} - -func TestBadPopulateFlagsFromYAML(t *testing.T) { - _ = replaceFlagsForTesting(t) - - flags := replaceFlagsForTesting(t) - flags.Var(&unsupportedFlagValue{}, "unsupported", "") - input := map[string]any{ - "unsupported": 0, - } - node := &yaml.Node{} - err := node.Encode(input) - require.NoError(t, err) - err = PopulateFlagsFromYAMLMap(input, node) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - flags.Bool("bool", false, "") - input = map[string]any{ - "bool": 0, - } - node = &yaml.Node{} - err = node.Encode(input) - require.NoError(t, err) - err = PopulateFlagsFromYAMLMap(input, node) - require.Error(t, err) -} - -func TestSetValueForFlagName(t *testing.T) { - flags := replaceFlagsForTesting(t) - flagBool := flags.Bool("bool", false, "") - err := SetValueForFlagName("bool", true, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, true, *flagBool) - - flags = replaceFlagsForTesting(t) - flagBool = flags.Bool("bool", false, "") - err = SetValueForFlagName("bool", true, map[string]struct{}{"bool": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, false, *flagBool) - - flags = replaceFlagsForTesting(t) - err = SetValueForFlagName("bool", true, map[string]struct{}{}, true, false) - require.NoError(t, err) - - flags = replaceFlagsForTesting(t) - flagInt := flags.Int("int", 2, "") - err = SetValueForFlagName("int", 1, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, 1, *flagInt) - - flags = replaceFlagsForTesting(t) - flagInt = flags.Int("int", 2, "") - err = SetValueForFlagName("int", 1, map[string]struct{}{"int": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, 2, *flagInt) - - flags = replaceFlagsForTesting(t) - flagInt64 := flags.Int64("int64", 2, "") - err = SetValueForFlagName("int64", 1, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, int64(1), *flagInt64) - - flags = replaceFlagsForTesting(t) - flagInt64 = flags.Int64("int64", 2, "") - err = SetValueForFlagName("int64", 1, map[string]struct{}{"int64": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, int64(2), *flagInt64) - - flags = replaceFlagsForTesting(t) - flagUint := flags.Uint("uint", 2, "") - err = SetValueForFlagName("uint", 1, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, uint(1), *flagUint) - - flags = replaceFlagsForTesting(t) - flagUint = flags.Uint("uint", 2, "") - err = SetValueForFlagName("uint", 1, map[string]struct{}{"uint": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, uint(2), *flagUint) - - flags = replaceFlagsForTesting(t) - flagUint64 := flags.Uint64("uint64", 2, "") - err = SetValueForFlagName("uint64", 1, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, uint64(1), *flagUint64) - - flags = replaceFlagsForTesting(t) - flagUint64 = flags.Uint64("uint64", 2, "") - err = SetValueForFlagName("uint64", 1, map[string]struct{}{"uint64": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, uint64(2), *flagUint64) - - flags = replaceFlagsForTesting(t) - flagFloat64 := flags.Float64("float64", 2, "") - err = SetValueForFlagName("float64", 1, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, float64(1), *flagFloat64) - - flags = replaceFlagsForTesting(t) - flagFloat64 = flags.Float64("float64", 2, "") - err = SetValueForFlagName("float64", 1, map[string]struct{}{"float64": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, float64(2), *flagFloat64) - - flags = replaceFlagsForTesting(t) - flagString := flags.String("string", "2", "") - err = SetValueForFlagName("string", "1", map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, "1", *flagString) - - flags = replaceFlagsForTesting(t) - flagString = flags.String("string", "2", "") - err = SetValueForFlagName("string", "1", map[string]struct{}{"string": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, "2", *flagString) - - flags = replaceFlagsForTesting(t) - flagString = flags.String("string", "2", "") - Alias[string]("string_alias", "string") - err = SetValueForFlagName("string_alias", "1", map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, "1", *flagString) - - flags = replaceFlagsForTesting(t) - flagString = flags.String("string", "2", "") - Alias[string]("string_alias", "string") - err = SetValueForFlagName("string_alias", "1", map[string]struct{}{"string": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, "2", *flagString) - - flags = replaceFlagsForTesting(t) - flagURL := URLFromString("url", "https://www.example.com", "") - u, err := url.Parse("https://www.example.com:8080") - require.NoError(t, err) - err = SetValueForFlagName("url", *u, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com:8080"}, *flagURL) - - flags = replaceFlagsForTesting(t) - flagURL = URLFromString("url", "https://www.example.com", "") - u, err = url.Parse("https://www.example.com:8080") - require.NoError(t, err) - err = SetValueForFlagName("url", *u, map[string]struct{}{"url": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com"}, *flagURL) - - flags = replaceFlagsForTesting(t) - string_slice := make([]string, 2) - string_slice[0] = "1" - string_slice[1] = "2" - SliceVar(&string_slice, "string_slice", "") - err = SetValueForFlagName("string_slice", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) - - flags = replaceFlagsForTesting(t) - flagStringSlice := Slice("string_slice", []string{"1", "2"}, "") - err = SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - err = SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - err = SetValueForFlagName("string_slice", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - string_slice = make([]string, 2) - string_slice[0] = "1" - string_slice[1] = "2" - SliceVar(&string_slice, "string_slice", "") - Alias[[]string]("string_slice_alias", "string_slice") - err = SetValueForFlagName("string_slice_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - err = SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - err = SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - err = SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - string_slice = make([]string, 2) - string_slice[0] = "1" - string_slice[1] = "2" - SliceVar(&string_slice, "string_slice", "") - Alias[[]string]("string_slice_alias", "string_slice") - Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = SetValueForFlagName("string_slice_alias_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"3"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - err = SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": struct{}{}}, false, true) - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, true, true) - require.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) - - flags = replaceFlagsForTesting(t) - flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": struct{}{}}, true, true) - require.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}, {Field: 3}}, flagStructSlice) - - flags = replaceFlagsForTesting(t) - flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{}, false, true) - require.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 3}}, flagStructSlice) - - flags = replaceFlagsForTesting(t) - flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", []testStruct{{Field: 3}}, map[string]struct{}{"struct_slice": struct{}{}}, false, true) - require.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, flagStructSlice) -} - -func TestBadSetValueForFlagName(t *testing.T) { - flags := replaceFlagsForTesting(t) - _ = flags.Bool("bool", false, "") - err := SetValueForFlagName("bool", 0, map[string]struct{}{}, true, true) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - err = SetValueForFlagName("bool", false, map[string]struct{}{}, true, true) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - _ = Slice("string_slice", []string{"1", "2"}, "") - err = SetValueForFlagName("string_slice", "3", map[string]struct{}{}, true, true) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - _ = Slice("string_slice", []string{"1", "2"}, "") - err = SetValueForFlagName("string_slice", "3", map[string]struct{}{}, false, true) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - flagStructSlice := []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, true, true) - require.Error(t, err) - - flags = replaceFlagsForTesting(t) - flagStructSlice = []testStruct{{Field: 1}, {Field: 2}} - SliceVar(&flagStructSlice, "struct_slice", "") - err = SetValueForFlagName("struct_slice", testStruct{Field: 3}, map[string]struct{}{}, false, true) - require.Error(t, err) -} - -func TestDereferencedValueFromFlagName(t *testing.T) { - flags := replaceFlagsForTesting(t) - _ = flags.Bool("bool", false, "") - v, err := GetDereferencedValue[bool]("bool") - require.NoError(t, err) - assert.Equal(t, false, v) - - flags = replaceFlagsForTesting(t) - _ = flags.Bool("bool", true, "") - v, err = GetDereferencedValue[bool]("bool") - require.NoError(t, err) - assert.Equal(t, true, v) - - flags = replaceFlagsForTesting(t) - _ = Slice("string_slice", []string{"1", "2"}, "") - stringSlice, err := GetDereferencedValue[[]string]("string_slice") - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, stringSlice) - - _ = Alias[[]string]("string_slice_alias", "string_slice") - stringSliceAlias, err := GetDereferencedValue[[]string]("string_slice_alias") - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, stringSliceAlias) - - _ = Alias[[]string]("string_slice_alias_alias", "string_slice_alias") - stringSliceAliasAlias, err := GetDereferencedValue[[]string]("string_slice_alias_alias") - require.NoError(t, err) - assert.Equal(t, []string{"1", "2"}, stringSliceAliasAlias) - - flags = replaceFlagsForTesting(t) - SliceVar(&[]testStruct{{Field: 1}, {Field: 2}}, "struct_slice", "") - structSlice, err := GetDereferencedValue[[]testStruct]("struct_slice") - require.NoError(t, err) - assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, structSlice) -} - -func TestBadDereferencedValueFromFlagName(t *testing.T) { - _ = replaceFlagsForTesting(t) - _, err := GetDereferencedValue[any]("unknown_flag") - require.Error(t, err) -} - -func TestFlagAlias(t *testing.T) { - flags := replaceFlagsForTesting(t) - s := flags.String("string", "test", "") - as := Alias[string]("string_alias", "string") - aas := Alias[string]("string_alias_alias", "string_alias") - assert.Equal(t, *s, "test") - assert.Equal(t, s, as) - assert.Equal(t, as, aas) - flags.Lookup("string").Value.Set("moo") - assert.Equal(t, *s, "moo") - flags.Lookup("string_alias").Value.Set("woof") - assert.Equal(t, *s, "woof") - flags.Lookup("string_alias_alias").Value.Set("meow") - assert.Equal(t, *s, "meow") - - asf := flags.Lookup("string_alias").Value.(*FlagAlias) - assert.Equal(t, "meow", asf.String()) - assert.Equal(t, "string", asf.AliasedName()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.AliasedType()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.YAMLTypeAlias()) - - aasf := flags.Lookup("string_alias").Value.(*FlagAlias) - assert.Equal(t, "meow", aasf.String()) - assert.Equal(t, "string", aasf.AliasedName()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.AliasedType()) - assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.YAMLTypeAlias()) - - flags = replaceFlagsForTesting(t) - - flagString := flags.String("string", "test", "") - Alias[string]("string_alias", "string") - Alias[string]("string_alias2", "string") - Alias[string]("string_alias3", "string") - yamlData := ` -string: "woof" -string_alias2: "moo" -string_alias3: "oink" -string_alias: "meow" -` - err := PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, "meow", *flagString) - - flags = replaceFlagsForTesting(t) - - flagStringSlice := Slice("string_slice", []string{"test"}, "") - Alias[[]string]("string_slice_alias", "string_slice") - Alias[[]string]("string_slice_alias2", "string_slice") - Alias[[]string]("string_slice_alias3", "string_slice") - yamlData = ` -string_slice: - - "woof" -string_slice_alias2: - - "moo" -string_slice_alias3: - - "oink" - - "ribbit" -string_slice_alias: - - "meow" -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, []string{"test", "woof", "moo", "oink", "ribbit", "meow"}, *flagStringSlice) - - flags = replaceFlagsForTesting(t) - - flagString = flags.String("string", "test", "") - Alias[string]("string_alias", "string") - yamlData = ` -string_alias: "meow" -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, "meow", *flagString) - - flags = replaceFlagsForTesting(t) - - flagString = flags.String("string", "test", "") - Alias[string]("string_alias", "string") - flags.Set("string", "moo") - yamlData = ` -string_alias: "meow" -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, "moo", *flagString) - - flags = replaceFlagsForTesting(t) - - flagString = flags.String("string", "test", "") - Alias[string]("string_alias", "string") - flags.Set("string_alias", "moo") - yamlData = ` -string: "meow" -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, "moo", *flagString) - - flags = replaceFlagsForTesting(t) - - flagString = flags.String("string", "test", "") - Alias[string]("string_alias", "string") - flags.Set("string_alias", "moo") - yamlData = ` -string_alias: "meow" -` - err = PopulateFlagsFromData([]byte(yamlData)) - require.NoError(t, err) - assert.Equal(t, "moo", *flagString) -} diff --git a/server/util/flagutil/types/BUILD b/server/util/flagutil/types/BUILD new file mode 100644 index 00000000000..21bb0dc0b39 --- /dev/null +++ b/server/util/flagutil/types/BUILD @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "types", + srcs = ["types.go"], + importpath = "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types", + visibility = ["//visibility:public"], + deps = [ + "//server/util/alert", + "//server/util/flagutil/common", + "//server/util/flagutil/yaml", + "//server/util/log", + "//server/util/status", + "@in_gopkg_yaml_v3//:yaml_v3", + ], +) + +go_test( + name = "types_test", + srcs = ["types_test.go"], + embed = [":types"], + deps = [ + "//server/util/flagutil/common", + "//server/util/flagutil/yaml", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_protobuf//types/known/timestamppb", + ], +) diff --git a/server/util/flagutil/types/types.go b/server/util/flagutil/types/types.go new file mode 100644 index 00000000000..773d92ad30d --- /dev/null +++ b/server/util/flagutil/types/types.go @@ -0,0 +1,214 @@ +package types + +import ( + "encoding/json" + "flag" + "fmt" + "net/url" + "reflect" + "strings" + + "github.com/buildbuddy-io/buildbuddy/server/util/alert" + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + "github.com/buildbuddy-io/buildbuddy/server/util/log" + "github.com/buildbuddy-io/buildbuddy/server/util/status" + "gopkg.in/yaml.v3" + + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" +) + +type SliceFlag[T any] []T + +func NewSliceFlag[T any](slice *[]T) *SliceFlag[T] { + return (*SliceFlag[T])(slice) +} + +func Slice[T any](name string, defaultValue []T, usage string) *[]T { + slice := make([]T, len(defaultValue)) + copy(slice, defaultValue) + common.DefaultFlagSet.Var(NewSliceFlag(&slice), name, usage) + return &slice +} + +func SliceVar[T any](slice *[]T, name, usage string) { + common.DefaultFlagSet.Var(NewSliceFlag(slice), name, usage) +} + +func (f *SliceFlag[T]) String() string { + switch v := any((*[]T)(f)).(type) { + case *[]string: + return strings.Join(*v, ",") + default: + b, err := json.Marshal(f) + if err != nil { + alert.UnexpectedEvent("config_cannot_marshal_struct", "err: %s", err) + return "[]" + } + return string(b) + } +} + +func (f *SliceFlag[T]) Set(values string) error { + if v, ok := any((*[]T)(f)).(*[]string); ok { + for _, val := range strings.Split(values, ",") { + *v = append(*v, val) + } + return nil + } + v := (*[]T)(f) + var a any + if err := json.Unmarshal([]byte(values), &a); err != nil { + return err + } + if _, ok := a.([]any); ok { + var dst []T + if err := json.Unmarshal([]byte(values), &dst); err != nil { + return err + } + *v = append(*v, dst...) + return nil + } + if _, ok := a.(map[string]any); ok { + var dst T + if err := json.Unmarshal([]byte(values), &dst); err != nil { + return err + } + *v = append(*v, dst) + return nil + } + return fmt.Errorf("Default Set for SliceFlag can only accept JSON objects or arrays, but type was %T", a) +} + +func (f *SliceFlag[T]) AppendSlice(slice any) error { + s, ok := slice.([]T) + if !ok { + return status.FailedPreconditionErrorf("Cannot append value %v of type %T to flag of type %T.", slice, slice, ([]T)(nil)) + } + v := (*[]T)(f) + *v = append(*v, s...) + return nil +} + +func (f *SliceFlag[T]) AliasedType() reflect.Type { + return reflect.TypeOf((*[]T)(nil)) +} + +func (f *SliceFlag[T]) YAMLTypeAlias() reflect.Type { + return f.AliasedType() +} + +type URLFlag url.URL + +func URL(name string, value url.URL, usage string) *url.URL { + u := &value + common.DefaultFlagSet.Var((*URLFlag)(u), name, usage) + return u +} + +func URLVar(value *url.URL, name string, usage string) { + common.DefaultFlagSet.Var((*URLFlag)(value), name, usage) +} + +func URLFromString(name, value, usage string) *url.URL { + u, err := url.Parse(value) + if err != nil { + log.Fatalf("Error parsing default URL value '%s' for flag: %v", value, err) + return nil + } + return URL(name, *u, usage) +} + +func (f *URLFlag) Set(value string) error { + u, err := url.Parse(value) + if err != nil { + return err + } + *(*url.URL)(f) = *u + return nil +} + +func (f *URLFlag) String() string { + return (*url.URL)(f).String() +} + +func (f *URLFlag) UnmarshalYAML(value *yaml.Node) error { + u, err := url.Parse(value.Value) + if err != nil { + return &yaml.TypeError{Errors: []string{err.Error()}} + } + *(*url.URL)(f) = *u + return nil +} + +func (f *URLFlag) MarshalYAML() (any, error) { + return f.String(), nil +} + +func (f *URLFlag) AliasedType() reflect.Type { + return reflect.TypeOf((*url.URL)(nil)) +} + +func (f *URLFlag) YAMLTypeAlias() reflect.Type { + return reflect.TypeOf((*URLFlag)(nil)) +} + +func (f *URLFlag) YAMLTypeString() string { + return "URL" +} + +type FlagAlias struct { + name string +} + +func Alias[T any](newName, name string) *T { + f := &FlagAlias{name: name} + var flg *flag.Flag + for aliaser, ok := common.IsNameAliasing(f), true; ok; aliaser, ok = flg.Value.(common.IsNameAliasing) { + if flg = common.DefaultFlagSet.Lookup(aliaser.AliasedName()); flg == nil { + log.Fatalf("Error aliasing flag %s as %s: flag %s does not exist.", name, newName, aliaser.AliasedName()) + } + } + addr := reflect.ValueOf(flg.Value) + if t, err := common.GetTypeForFlag(flg); err == nil { + if !addr.CanConvert(t) { + log.Fatalf("Error aliasing flag %s as %s: Flag %s of type %T could not be converted to %s.", name, newName, flg.Name, flg.Value, t) + } + addr = addr.Convert(t) + } + value, ok := addr.Interface().(*T) + if !ok { + log.Fatalf("Error aliasing flag %s as %s: Failed to assert flag %s of type %T as type %T.", name, newName, flg.Name, flg.Value, (*T)(nil)) + } + common.DefaultFlagSet.Var(f, newName, "Alias for "+name) + return value +} + +func (f *FlagAlias) Set(value string) error { + return common.DefaultFlagSet.Set(f.name, value) +} + +func (f *FlagAlias) String() string { + return common.DefaultFlagSet.Lookup(f.name).Value.String() +} + +func (f *FlagAlias) AliasedName() string { + return f.name +} + +func (f *FlagAlias) AliasedType() reflect.Type { + flg := common.DefaultFlagSet.Lookup(f.name) + t, err := common.GetTypeForFlag(flg) + if err != nil { + return reflect.TypeOf(flg.Value) + } + return t +} + +func (f *FlagAlias) YAMLTypeAlias() reflect.Type { + flg := common.DefaultFlagSet.Lookup(f.name) + t, err := flagyaml.GetYAMLTypeForFlag(flg) + if err != nil { + return reflect.TypeOf(flg.Value) + } + return t +} diff --git a/server/util/flagutil/types/types_test.go b/server/util/flagutil/types/types_test.go new file mode 100644 index 00000000000..5adc630c7e1 --- /dev/null +++ b/server/util/flagutil/types/types_test.go @@ -0,0 +1,375 @@ +package types + +import ( + "flag" + "reflect" + "testing" + + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" +) + +type testStruct struct { + Field int `json:"field"` + Meadow string `json:"meadow"` +} + +func replaceFlagsForTesting(t *testing.T) *flag.FlagSet { + flags := flag.NewFlagSet("test", flag.ContinueOnError) + common.DefaultFlagSet = flags + + t.Cleanup(func() { + common.DefaultFlagSet = flag.CommandLine + }) + + return flags +} + +func TestStringSliceFlag(t *testing.T) { + var err error + + flags := replaceFlagsForTesting(t) + + foo := Slice("foo", []string{}, "A list of foos") + assert.Equal(t, []string{}, *foo) + assert.Equal(t, []string{}, *(*[]string)(flags.Lookup("foo").Value.(*SliceFlag[string]))) + err = flags.Set("foo", "foo0,foo1") + assert.NoError(t, err) + err = flags.Set("foo", "foo2") + assert.NoError(t, err) + err = flags.Set("foo", "foo3,foo4,foo5") + assert.NoError(t, err) + assert.Equal(t, []string{"foo0", "foo1", "foo2", "foo3", "foo4", "foo5"}, *foo) + assert.Equal(t, []string{"foo0", "foo1", "foo2", "foo3", "foo4", "foo5"}, *(*[]string)(flags.Lookup("foo").Value.(*SliceFlag[string]))) + + bar := Slice("bar", []string{"bar0", "bar1"}, "A list of bars") + assert.Equal(t, []string{"bar0", "bar1"}, *bar) + assert.Equal(t, []string{"bar0", "bar1"}, *(*[]string)(flags.Lookup("bar").Value.(*SliceFlag[string]))) + err = flags.Set("bar", "bar2") + assert.NoError(t, err) + err = flags.Set("bar", "bar3,bar4,bar5") + assert.NoError(t, err) + assert.Equal(t, []string{"bar0", "bar1", "bar2", "bar3", "bar4", "bar5"}, *bar) + assert.Equal(t, []string{"bar0", "bar1", "bar2", "bar3", "bar4", "bar5"}, *(*[]string)(flags.Lookup("bar").Value.(*SliceFlag[string]))) + + baz := Slice("baz", []string{}, "A list of bazs") + err = flags.Set("baz", flags.Lookup("bar").Value.String()) + assert.NoError(t, err) + assert.Equal(t, *bar, *baz) + + testSlice := []string{"yes", "si", "hai"} + testFlag := NewSliceFlag(&testSlice) + testFlag.AppendSlice(*(*[]string)(testFlag)) + assert.Equal(t, []string{"yes", "si", "hai", "yes", "si", "hai"}, testSlice) +} + +func TestStructSliceFlag(t *testing.T) { + var err error + + flags := replaceFlagsForTesting(t) + + fooFlag := Slice("foo", []testStruct{}, "A list of foos") + assert.Equal(t, []testStruct{}, *fooFlag) + assert.Equal(t, []testStruct{}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) + err = flags.Set("foo", `[{"field":3,"meadow":"watership down"}]`) + assert.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}}, *fooFlag) + assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) + err = flags.Set("foo", `{"field":5,"meadow":"runnymede"}`) + assert.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}, {Field: 5, Meadow: "runnymede"}}, *fooFlag) + assert.Equal(t, []testStruct{{Field: 3, Meadow: "watership down"}, {Field: 5, Meadow: "runnymede"}}, *(*[]testStruct)(flags.Lookup("foo").Value.(*SliceFlag[testStruct]))) + + barFlag := []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}} + SliceVar(&barFlag, "bar", "A list of bars") + assert.Equal(t, []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}}, barFlag) + assert.Equal(t, []testStruct{{Field: 11, Meadow: "arcadia"}, {Field: 13, Meadow: "kingcombe"}}, *(*[]testStruct)(flags.Lookup("bar").Value.(*SliceFlag[testStruct]))) + + fooxFlag := Slice("foox", []testStruct{}, "A list of fooxes") + assert.Equal(t, []testStruct{}, *fooxFlag) + assert.Equal(t, []testStruct{}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) + err = flags.Set("foox", `[{"field":13,"meadow":"cors y llyn"},{},{"field":15}]`) + assert.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}}, *fooxFlag) + assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) + err = flags.Set("foox", `[{"field":17,"meadow":"red hill"},{},{"field":19}]`) + assert.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}, {Field: 17, Meadow: "red hill"}, {}, {Field: 19}}, *fooxFlag) + assert.Equal(t, []testStruct{{Field: 13, Meadow: "cors y llyn"}, {}, {Field: 15}, {Field: 17, Meadow: "red hill"}, {}, {Field: 19}}, *(*[]testStruct)(flags.Lookup("foox").Value.(*SliceFlag[testStruct]))) + + bazFlag := []testStruct{} + SliceVar(&bazFlag, "baz", "A list of bazs") + err = flags.Set("baz", flags.Lookup("bar").Value.String()) + assert.NoError(t, err) + assert.Equal(t, barFlag, bazFlag) + + testSlice := []testStruct{{}, {Field: 1}, {Meadow: "Paradise"}} + testFlag := NewSliceFlag(&testSlice) + testFlag.AppendSlice(*(*[]testStruct)(testFlag)) + assert.Equal(t, []testStruct{{}, {Field: 1}, {Meadow: "Paradise"}, {}, {Field: 1}, {Meadow: "Paradise"}}, testSlice) +} + +func TestProtoSliceFlag(t *testing.T) { + var err error + + flags := replaceFlagsForTesting(t) + + fooFlag := Slice("foo", []*timestamppb.Timestamp{}, "A list of foos") + assert.Equal(t, []*timestamppb.Timestamp{}, *fooFlag) + assert.Equal(t, []*timestamppb.Timestamp{}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) + err = flags.Set("foo", `[{"seconds":3,"nanos":5}]`) + assert.NoError(t, err) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}}, *fooFlag) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) + err = flags.Set("foo", `{"seconds":5,"nanos":9}`) + assert.NoError(t, err) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}, {Seconds: 5, Nanos: 9}}, *fooFlag) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 3, Nanos: 5}, {Seconds: 5, Nanos: 9}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foo").Value.(*SliceFlag[*timestamppb.Timestamp]))) + + barFlag := []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}} + SliceVar(&barFlag, "bar", "A list of bars") + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}}, barFlag) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 11, Nanos: 100}, {Seconds: 13, Nanos: 256}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("bar").Value.(*SliceFlag[*timestamppb.Timestamp]))) + + fooxFlag := Slice("foox", []*timestamppb.Timestamp{}, "A list of fooxes") + assert.Equal(t, []*timestamppb.Timestamp{}, *fooxFlag) + assert.Equal(t, []*timestamppb.Timestamp{}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) + err = flags.Set("foox", `[{"seconds":13,"nanos":64},{},{"seconds":15}]`) + assert.NoError(t, err) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}}, *fooxFlag) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) + err = flags.Set("foox", `[{"seconds":17,"nanos":9001},{},{"seconds":19}]`) + assert.NoError(t, err) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}, {Seconds: 17, Nanos: 9001}, {}, {Seconds: 19}}, *fooxFlag) + assert.Equal(t, []*timestamppb.Timestamp{{Seconds: 13, Nanos: 64}, {}, {Seconds: 15}, {Seconds: 17, Nanos: 9001}, {}, {Seconds: 19}}, *(*[]*timestamppb.Timestamp)(flags.Lookup("foox").Value.(*SliceFlag[*timestamppb.Timestamp]))) + + bazFlag := []*timestamppb.Timestamp{} + SliceVar(&bazFlag, "baz", "A list of bazs") + err = flags.Set("baz", flags.Lookup("bar").Value.String()) + assert.NoError(t, err) + assert.Equal(t, barFlag, bazFlag) + + testSlice := []*timestamppb.Timestamp{{}, {Seconds: 1}, {Nanos: 99}} + testFlag := NewSliceFlag(&testSlice) + testFlag.AppendSlice(*(*[]*timestamppb.Timestamp)(testFlag)) + assert.Equal(t, []*timestamppb.Timestamp{{}, {Seconds: 1}, {Nanos: 99}, {}, {Seconds: 1}, {Nanos: 99}}, testSlice) + +} + +func TestFlagAlias(t *testing.T) { + flags := replaceFlagsForTesting(t) + s := flags.String("string", "test", "") + as := Alias[string]("string_alias", "string") + aas := Alias[string]("string_alias_alias", "string_alias") + assert.Equal(t, *s, "test") + assert.Equal(t, s, as) + assert.Equal(t, as, aas) + flags.Lookup("string").Value.Set("moo") + assert.Equal(t, *s, "moo") + flags.Lookup("string_alias").Value.Set("woof") + assert.Equal(t, *s, "woof") + flags.Lookup("string_alias_alias").Value.Set("meow") + assert.Equal(t, *s, "meow") + + asf := flags.Lookup("string_alias").Value.(*FlagAlias) + assert.Equal(t, "meow", asf.String()) + assert.Equal(t, "string", asf.AliasedName()) + assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.AliasedType()) + assert.Equal(t, reflect.TypeOf((*string)(nil)), asf.YAMLTypeAlias()) + + aasf := flags.Lookup("string_alias").Value.(*FlagAlias) + assert.Equal(t, "meow", aasf.String()) + assert.Equal(t, "string", aasf.AliasedName()) + assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.AliasedType()) + assert.Equal(t, reflect.TypeOf((*string)(nil)), aasf.YAMLTypeAlias()) + + flags = replaceFlagsForTesting(t) + + flagString := flags.String("string", "test", "") + Alias[string]("string_alias", "string") + Alias[string]("string_alias2", "string") + Alias[string]("string_alias3", "string") + yamlData := ` +string: "woof" +string_alias2: "moo" +string_alias3: "oink" +string_alias: "meow" +` + err := flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, "meow", *flagString) + + flags = replaceFlagsForTesting(t) + + flagStringSlice := Slice("string_slice", []string{"test"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + Alias[[]string]("string_slice_alias2", "string_slice") + Alias[[]string]("string_slice_alias3", "string_slice") + yamlData = ` +string_slice: + - "woof" +string_slice_alias2: + - "moo" +string_slice_alias3: + - "oink" + - "ribbit" +string_slice_alias: + - "meow" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, []string{"test", "woof", "moo", "oink", "ribbit", "meow"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + + flagString = flags.String("string", "test", "") + Alias[string]("string_alias", "string") + yamlData = ` +string_alias: "meow" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, "meow", *flagString) + + flags = replaceFlagsForTesting(t) + + flagString = flags.String("string", "test", "") + Alias[string]("string_alias", "string") + flags.Set("string", "moo") + yamlData = ` +string_alias: "meow" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, "moo", *flagString) + + flags = replaceFlagsForTesting(t) + + flagString = flags.String("string", "test", "") + Alias[string]("string_alias", "string") + flags.Set("string_alias", "moo") + yamlData = ` +string: "meow" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, "moo", *flagString) + + flags = replaceFlagsForTesting(t) + + flagString = flags.String("string", "test", "") + Alias[string]("string_alias", "string") + flags.Set("string_alias", "moo") + yamlData = ` +string_alias: "meow" +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, "moo", *flagString) + + flags = replaceFlagsForTesting(t) + flagString = flags.String("string", "2", "") + Alias[string]("string_alias", "string") + err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, "1", *flagString) + + flags = replaceFlagsForTesting(t) + flagString = flags.String("string", "2", "") + Alias[string]("string_alias", "string") + err = common.SetValueForFlagName("string_alias", "1", map[string]struct{}{"string": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, "2", *flagString) + + flags = replaceFlagsForTesting(t) + string_slice := make([]string, 2) + string_slice[0] = "1" + string_slice[1] = "2" + SliceVar(&string_slice, "string_slice", "") + Alias[[]string]("string_slice_alias", "string_slice") + err = common.SetValueForFlagName("string_slice_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + err = common.SetValueForFlagName("string_slice_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + string_slice = make([]string, 2) + string_slice[0] = "1" + string_slice[1] = "2" + SliceVar(&string_slice, "string_slice", "") + Alias[[]string]("string_slice_alias", "string_slice") + Alias[[]string]("string_slice_alias_alias", "string_slice_alias") + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, map[string]struct{}{}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2"}, string_slice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + Alias[[]string]("string_slice_alias_alias", "string_slice_alias") + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, true, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2", "3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + Alias[[]string]("string_slice_alias_alias", "string_slice_alias") + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"3"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + flagStringSlice = Slice("string_slice", []string{"1", "2"}, "") + Alias[[]string]("string_slice_alias", "string_slice") + Alias[[]string]("string_slice_alias_alias", "string_slice_alias") + err = common.SetValueForFlagName("string_slice_alias_alias", []string{"3"}, map[string]struct{}{"string_slice": {}}, false, true) + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, *flagStringSlice) + + flags = replaceFlagsForTesting(t) + _ = Slice("string_slice", []string{"1", "2"}, "") + stringSlice, err := common.GetDereferencedValue[[]string]("string_slice") + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, stringSlice) + + _ = Alias[[]string]("string_slice_alias", "string_slice") + stringSliceAlias, err := common.GetDereferencedValue[[]string]("string_slice_alias") + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, stringSliceAlias) + + _ = Alias[[]string]("string_slice_alias_alias", "string_slice_alias") + stringSliceAliasAlias, err := common.GetDereferencedValue[[]string]("string_slice_alias_alias") + require.NoError(t, err) + assert.Equal(t, []string{"1", "2"}, stringSliceAliasAlias) + + flags = replaceFlagsForTesting(t) + SliceVar(&[]testStruct{{Field: 1}, {Field: 2}}, "struct_slice", "") + structSlice, err := common.GetDereferencedValue[[]testStruct]("struct_slice") + require.NoError(t, err) + assert.Equal(t, []testStruct{{Field: 1}, {Field: 2}}, structSlice) +} diff --git a/server/util/flagutil/yaml/BUILD b/server/util/flagutil/yaml/BUILD new file mode 100644 index 00000000000..a653401f01d --- /dev/null +++ b/server/util/flagutil/yaml/BUILD @@ -0,0 +1,29 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "yaml", + srcs = ["yaml.go"], + importpath = "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml", + visibility = ["//visibility:public"], + deps = [ + "//server/util/alert", + "//server/util/flagutil/common", + "//server/util/log", + "//server/util/status", + "@in_gopkg_yaml_v3//:yaml_v3", + ], +) + +go_test( + name = "yaml_test", + srcs = ["yaml_test.go"], + deps = [ + ":yaml", + "//server/util/flagutil/common", + "//server/util/flagutil/types", + "@com_github_google_go_cmp//cmp", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@in_gopkg_yaml_v3//:yaml_v3", + ], +) diff --git a/server/util/flagutil/yaml/yaml.go b/server/util/flagutil/yaml/yaml.go new file mode 100644 index 00000000000..255e09d288f --- /dev/null +++ b/server/util/flagutil/yaml/yaml.go @@ -0,0 +1,506 @@ +package yaml + +import ( + "bytes" + "flag" + "fmt" + "os" + "reflect" + "strings" + + "github.com/buildbuddy-io/buildbuddy/server/util/alert" + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + "github.com/buildbuddy-io/buildbuddy/server/util/log" + "github.com/buildbuddy-io/buildbuddy/server/util/status" + "gopkg.in/yaml.v3" +) + +var ( + // Flag names to ignore when generating a YAML map or populating flags (e. g., + // the flag specifying the path to the config file) + ignoreSet = make(map[string]struct{}) + + nilableKinds = map[reflect.Kind]struct{}{ + reflect.Chan: {}, + reflect.Func: {}, + reflect.Interface: {}, + reflect.Map: {}, + reflect.Ptr: {}, + reflect.Slice: {}, + } + + AppendTypeToLineComment = &appendTypeToLineComment{} +) + +// IgnoreFlagForYAML ignores the flag with this name when generating YAML and when +// populating flags from YAML input. +func IgnoreFlagForYAML(name string) { + ignoreSet[name] = struct{}{} +} + +// IgnoreFilter is a filter that checks flags against IgnoreSet. +func IgnoreFilter(flg *flag.Flag) bool { + keys := strings.Split(flg.Name, ".") + for i := range keys { + if _, ok := ignoreSet[strings.Join(keys[:i+1], ".")]; ok { + return false + } + } + return true +} + +type YAMLTypeAliasable interface { + YAMLTypeAlias() reflect.Type +} + +type YAMLTypeStringable interface { + YAMLTypeString() string +} + +type DocumentedMarshaler interface { + DocumentNode(n *yaml.Node, opts ...common.DocumentNodeOption) error +} + +// GetYAMLTypeForFlag returns the type alias to use in YAML contexts for the flag. +func GetYAMLTypeForFlag(flg *flag.Flag) (reflect.Type, error) { + if v, ok := flg.Value.(YAMLTypeAliasable); ok { + return v.YAMLTypeAlias(), nil + } else if t, err := common.GetTypeForFlag(flg); err == nil { + return t, nil + } + return nil, status.UnimplementedErrorf("Unsupported flag type at %s: %T", flg.Name, flg.Value) +} + +type HeadComment string + +func (h *HeadComment) Transform(in any, n *yaml.Node) { n.HeadComment = string(*h) } +func (h *HeadComment) Passthrough() bool { return false } + +// NewHeadComment returns a HeadComment for the specified string. +func NewHeadComment(s string) *HeadComment { return (*HeadComment)(&s) } + +type LineComment string + +func (l *LineComment) Transform(in any, n *yaml.Node) { n.LineComment = string(*l) } +func (l *LineComment) Passthrough() bool { return false } + +// NewLineComment returns a LineComment for the specified string. +func NewLineComment(s string) *LineComment { return (*LineComment)(&s) } + +type FootComment string + +func (f *FootComment) Transform(in any, n *yaml.Node) { n.FootComment = string(*f) } +func (f *FootComment) Passthrough() bool { return false } + +// NewFootComment returns a FootComment for the specified string. +func NewFootComment(s string) *FootComment { return (*FootComment)(&s) } + +type appendTypeToLineComment struct{} + +func (f *appendTypeToLineComment) Transform(in any, n *yaml.Node) { + typeString := fmt.Sprintf("%T", in) + if v, ok := in.(YAMLTypeStringable); ok { + typeString = v.YAMLTypeString() + } + if n.LineComment != "" { + n.LineComment += " " + } + n.LineComment += "type: " + typeString +} + +func (f *appendTypeToLineComment) Passthrough() bool { return true } + +func filterPassthrough(opts []common.DocumentNodeOption) []common.DocumentNodeOption { + ptOpts := []common.DocumentNodeOption{} + for _, opt := range opts { + if opt.Passthrough() { + ptOpts = append(ptOpts, opt) + } + } + return ptOpts +} + +// DocumentedNode returns a yaml.Node representing the input value with +// documentation in the comments. +func DocumentedNode(in any, opts ...common.DocumentNodeOption) (*yaml.Node, error) { + n := &yaml.Node{} + if err := n.Encode(in); err != nil { + return nil, err + } + if err := DocumentNode(in, n, opts...); err != nil { + return nil, err + } + return n, nil +} + +// DocumentNode fills the comments of a yaml.Node with documentation. +func DocumentNode(in any, n *yaml.Node, opts ...common.DocumentNodeOption) error { + switch m := in.(type) { + case DocumentedMarshaler: + return m.DocumentNode(n, opts...) + case yaml.Marshaler: + // pass + default: + v := reflect.ValueOf(in) + t := v.Type() + switch t.Kind() { + case reflect.Ptr: + // document based on the value pointed to + if !v.IsNil() { + return DocumentNode(v.Elem().Interface(), n, opts...) + } else { + return DocumentNode(reflect.New(reflect.TypeOf(t).Elem()).Elem().Interface(), n, opts...) + } + case reflect.Struct: + // yaml.Node stores mappings in Content as [key1, value1, key2, value2...] + contentIndex := make(map[string]int, len(n.Content)/2) + for i := 0; i < len(n.Content)/2; i++ { + contentIndex[n.Content[2*i].Value] = 2*i + 1 + } + for i := 0; i < t.NumField(); i++ { + ft := t.FieldByIndex([]int{i}) + name := strings.Split(ft.Tag.Get("yaml"), ",")[0] + if name == "" { + name = strings.ToLower(ft.Name) + } + idx, ok := contentIndex[name] + if !ok { + // field is not encoded by yaml + continue + } + if err := DocumentNode( + v.FieldByIndex([]int{i}).Interface(), + n.Content[idx], + append( + []common.DocumentNodeOption{NewLineComment(ft.Tag.Get("usage"))}, + filterPassthrough(opts)..., + )..., + ); err != nil { + return err + } + } + case reflect.Slice: + // yaml.Node stores sequences in Content as [element1, element2...] + for i := range n.Content { + var err error + if err = DocumentNode(v.Index(i).Interface(), n.Content[i], filterPassthrough(opts)...); err != nil { + return err + } + } + if len(n.Content) == 0 { + exampleNode, err := DocumentedNode(reflect.MakeSlice(t, 1, 1).Interface(), filterPassthrough(opts)...) + if err != nil { + return err + } + if exampleNode.Content[0].Kind != yaml.ScalarNode { + example, err := yaml.Marshal(exampleNode) + if err != nil { + return err + } + n.FootComment = fmt.Sprintf("e.g.,\n%s", string(example)) + } + } + case reflect.Map: + // yaml.Node stores mappings in Content as [key1, value1, key2, value2...] + for i := 0; i < len(n.Content)/2; i++ { + k := reflect.ValueOf(n.Content[2*i].Value) + if err := DocumentNode( + v.MapIndex(k).Interface(), + n.Content[2*i+1], + filterPassthrough(opts)..., + ); err != nil { + return err + } + } + } + } + for _, opt := range opts { + opt.Transform(in, n) + } + return nil +} + +// GenerateDocumentedYAMLNodeFromFlag produces a documented yaml.Node which +// represents the value contained in the flag. +func GenerateDocumentedYAMLNodeFromFlag(flg *flag.Flag) (*yaml.Node, error) { + t, err := GetYAMLTypeForFlag(flg) + if err != nil { + return nil, status.InternalErrorf("Error encountered generating default YAML from flags: %s", err) + } + v, err := common.GetDereferencedValue[any](flg.Name) + if err != nil { + return nil, status.InternalErrorf("Error encountered generating default YAML from flags: %s", err) + } + value := reflect.New(reflect.TypeOf(v)) + value.Elem().Set(reflect.ValueOf(v)) + if !value.CanConvert(t) { + return nil, status.FailedPreconditionErrorf("Cannot convert value %v of type %s into type %v for flag %s.", value.Interface(), value.Type(), t, flg.Name) + } + return DocumentedNode(value.Convert(t).Interface(), NewLineComment(flg.Usage), AppendTypeToLineComment) +} + +// SplitDocumentedYAMLFromFlags produces marshaled YAML representing the flags, +// partitioned into two groups: structured (flags containing dots), and +// unstructured (flags not containing dots). +func SplitDocumentedYAMLFromFlags() ([]byte, error) { + b := bytes.NewBuffer([]byte{}) + + if _, err := b.Write([]byte("# Unstructured settings\n\n")); err != nil { + return nil, err + } + um, err := GenerateYAMLMapWithValuesFromFlags( + GenerateDocumentedYAMLNodeFromFlag, + func(flg *flag.Flag) bool { return !strings.Contains(flg.Name, ".") }, + IgnoreFilter, + ) + if err != nil { + return nil, err + } + ub, err := yaml.Marshal(um) + if err != nil { + return nil, err + } + if _, err := b.Write(ub); err != nil { + return nil, err + } + + if _, err := b.Write([]byte("\n# Structured settings\n\n")); err != nil { + return nil, err + } + sm, err := GenerateYAMLMapWithValuesFromFlags( + GenerateDocumentedYAMLNodeFromFlag, + func(flg *flag.Flag) bool { return strings.Contains(flg.Name, ".") }, + IgnoreFilter, + ) + if err != nil { + return nil, err + } + sb, err := yaml.Marshal(sm) + if err != nil { + return nil, err + } + if _, err := b.Write(sb); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// GenerateYAMLMapWithValuesFromFlags generates a YAML map structure +// representing the flags, with values generated from the flags as per the +// generateValue function that has been passed in, and filtering out any flags +// for which any of the passed filter functions return false. Any nil generated +// values are not added to the map, and any empty maps are recursively removed +// such that the final map returned contains no empty maps at any point in its +// structure. +func GenerateYAMLMapWithValuesFromFlags[T any](generateValue func(*flag.Flag) (T, error), filters ...func(*flag.Flag) bool) (map[string]any, error) { + yamlMap := make(map[string]any) + var errors []error + common.DefaultFlagSet.VisitAll(func(flg *flag.Flag) { + for _, f := range filters { + if !f(flg) { + return + } + } + keys := strings.Split(flg.Name, ".") + m := yamlMap + for i, k := range keys[:len(keys)-1] { + v, ok := m[k] + if !ok { + v := make(map[string]any) + m[k], m = v, v + continue + } + m, ok = v.(map[string]any) + if !ok { + errors = append(errors, status.FailedPreconditionErrorf("When trying to create YAML map hierarchy for %s, encountered non-map value %s of type %T at %s", flg.Name, v, v, strings.Join(keys[:i+1], "."))) + return + } + } + k := keys[len(keys)-1] + if v, ok := m[k]; ok { + errors = append(errors, status.FailedPreconditionErrorf("When generating value for %s for YAML map, encountered pre-existing value %s of type %T.", flg.Name, v, v)) + return + } + v, err := generateValue(flg) + if err != nil { + errors = append(errors, err) + return + } + value := reflect.ValueOf(v) + if _, ok := nilableKinds[value.Kind()]; ok && value.IsNil() { + return + } + m[k] = v + }) + if errors != nil { + return nil, status.InternalErrorf("Errors encountered when generating YAML map from flags: %v", errors) + } + + return RemoveEmptyMapsFromYAMLMap(yamlMap), nil +} + +// RemoveEmptyMapsFromYAMLMap recursively removes all empty maps, such that the +// returned map contains no empty maps at any point in its structure. The +// original map is returned unless it is empty after removal, in which case nil +// is returned. +func RemoveEmptyMapsFromYAMLMap(m map[string]any) map[string]any { + for k, v := range m { + mv, ok := v.(map[string]any) + if !ok { + continue + } + if m[k] = RemoveEmptyMapsFromYAMLMap(mv); m[k] == nil { + delete(m, k) + } + } + if len(m) == 0 { + return nil + } + return m +} + +// RetypeAndFilterYAMLMap un-marshals yaml from the input yamlMap and then +// re-marshals it into the types specified by the type map, replacing the +// original value in the input map. Filters out any values not specified by the +// flags. +func RetypeAndFilterYAMLMap(yamlMap map[string]any, typeMap map[string]any, prefix []string) error { + for k := range yamlMap { + label := append(prefix, k) + if _, ok := typeMap[k]; !ok { + // No flag corresponds to this, warn and delete. + log.Warningf("No flags correspond to YAML input at '%s'.", strings.Join(label, ".")) + delete(yamlMap, k) + continue + } + switch t := typeMap[k].(type) { + case reflect.Type: + // this is a value, populate it from the YAML + yamlData, err := yaml.Marshal(yamlMap[k]) + if err != nil { + return status.InternalErrorf("Encountered error marshaling %v to YAML at %s: %s", yamlMap[k], strings.Join(label, "."), err) + } + v := reflect.New(t.Elem()).Elem() + err = yaml.Unmarshal(yamlData, v.Addr().Interface()) + if err != nil { + return status.InternalErrorf("Encountered error marshaling %s to YAML for type %v at %s: %s", string(yamlData), v.Type(), strings.Join(label, "."), err) + } + if v.Type() != t.Elem() { + return status.InternalErrorf("Failed to unmarshal YAML to the specified type at %s: wanted %v, got %T", strings.Join(label, "."), t.Elem(), v.Type()) + } + yamlMap[k] = v.Interface() + case map[string]any: + yamlSubmap, ok := yamlMap[k].(map[string]any) + if !ok { + // this is a value, not a map, and there is no corresponding type + alert.UnexpectedEvent("Input YAML contained non-map value %v of type %T at label %s", yamlMap[k], yamlMap[k], strings.Join(label, ".")) + delete(yamlMap, k) + } + err := RetypeAndFilterYAMLMap(yamlSubmap, t, label) + if err != nil { + return err + } + default: + return status.InvalidArgumentErrorf("typeMap contained invalid type %T at %s.", typeMap[k], strings.Join(label, ".")) + } + } + return nil +} + +// PopulateFlagsFromData takes some YAML input and unmarshals it, then uses the +// umnarshaled data to populate the unset flags with names corresponding to the +// keys. +func PopulateFlagsFromData(data []byte) error { + // expand environment variables + expandedData := []byte(os.ExpandEnv(string(data))) + + yamlMap := make(map[string]any) + if err := yaml.Unmarshal([]byte(expandedData), yamlMap); err != nil { + return status.InternalErrorf("Error parsing config file: %s", err) + } + node := &yaml.Node{} + if err := yaml.Unmarshal([]byte(expandedData), node); err != nil { + return status.InternalErrorf("Error parsing config file: %s", err) + } + if len(node.Content) > 0 { + node = node.Content[0] + } + typeMap, err := GenerateYAMLMapWithValuesFromFlags(GetYAMLTypeForFlag, IgnoreFilter) + if err != nil { + return err + } + if err := RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}); err != nil { + return status.InternalErrorf("Error encountered retyping YAML map: %s", err) + } + + return PopulateFlagsFromYAMLMap(yamlMap, node) +} + +// PopulateFlagsFromData takes the path to some YAML file, reads it, and +// unmarshals it, then uses the umnarshaled data to populate the unset flags +// with names corresponding to the keys. +func PopulateFlagsFromFile(configFile string) error { + log.Infof("Reading buildbuddy config from '%s'", configFile) + + _, err := os.Stat(configFile) + + // If the file does not exist then skip it. + if os.IsNotExist(err) { + log.Warningf("No config file found at %s.", configFile) + return nil + } + + fileBytes, err := os.ReadFile(configFile) + if err != nil { + return fmt.Errorf("Error reading config file: %s", err) + } + + return PopulateFlagsFromData(fileBytes) +} + +// PopulateFlagsFromYAMLMap takes a map populated by YAML from some YAML input +// and a yaml.Node populated by YAML from the same input and iterates over it, +// finding flags with names corresponding to the keys and setting the flag to +// the YAML value if the flag was not set on the command line. The yaml.Node +// preserves order when setting the flag values, which is important for aliases. +// If Node is nil, the order values will be set in is random, as per go's +// implementation of map traversal. +func PopulateFlagsFromYAMLMap(m map[string]any, node *yaml.Node) error { + setFlags := make(map[string]struct{}) + common.DefaultFlagSet.Visit(func(flg *flag.Flag) { + setFlags[flg.Name] = struct{}{} + }) + + return populateFlagsFromYAML(m, []string{}, node, setFlags) +} + +func populateFlagsFromYAML(a any, prefix []string, node *yaml.Node, setFlags map[string]struct{}) error { + if m, ok := a.(map[string]any); ok { + i := 0 + for k, v := range m { + var n *yaml.Node + if node != nil { + // Ensure that we populate flags in the order they are specified in the + // YAML data if the node structure data was provided. + for ok := false; node != nil && !ok; i++ { + k = node.Content[2*i].Value + n = node.Content[2*i+1] + v, ok = m[k] + } + } + p := append(prefix, k) + if _, ok := ignoreSet[strings.Join(p, ".")]; ok { + return nil + } + if err := populateFlagsFromYAML(v, p, n, setFlags); err != nil { + return err + } + } + return nil + } + name := strings.Join(prefix, ".") + if _, ok := ignoreSet[name]; ok { + return nil + } + return common.SetValueForFlagName(name, a, setFlags, true, false) +} diff --git a/server/util/flagutil/yaml/yaml_test.go b/server/util/flagutil/yaml/yaml_test.go new file mode 100644 index 00000000000..f416dbc5270 --- /dev/null +++ b/server/util/flagutil/yaml/yaml_test.go @@ -0,0 +1,355 @@ +package yaml_test + +import ( + "flag" + "net/url" + "reflect" + "testing" + + "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" +) + +type unsupportedFlagValue struct{} + +func (f *unsupportedFlagValue) Set(string) error { return nil } +func (f *unsupportedFlagValue) String() string { return "" } + +type testStruct struct { + Field int `json:"field"` + Meadow string `json:"meadow"` +} + +func replaceFlagsForTesting(t *testing.T) *flag.FlagSet { + flags := flag.NewFlagSet("test", flag.ContinueOnError) + common.DefaultFlagSet = flags + + t.Cleanup(func() { + common.DefaultFlagSet = flag.CommandLine + }) + + return flags +} + +func TestGenerateYAMLTypeMapFromFlags(t *testing.T) { + flags := replaceFlagsForTesting(t) + + flags.Bool("bool", true, "") + flags.Int("one.two.int", 10, "") + flagtypes.Slice("one.two.string_slice", []string{"hi", "hello"}, "") + flags.Float64("one.two.two_and_a_half.float64", 5.2, "") + flagtypes.Slice("one.two.three.struct_slice", []testStruct{{Field: 4, Meadow: "Great"}}, "") + flags.String("a.b.string", "xxx", "") + flagtypes.URLFromString("a.b.url", "https://www.example.com", "") + actual, err := flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + require.NoError(t, err) + expected := map[string]any{ + "bool": reflect.TypeOf((*bool)(nil)), + "one": map[string]any{ + "two": map[string]any{ + "int": reflect.TypeOf((*int)(nil)), + "string_slice": reflect.TypeOf((*[]string)(nil)), + "two_and_a_half": map[string]any{ + "float64": reflect.TypeOf((*float64)(nil)), + }, + "three": map[string]any{ + "struct_slice": reflect.TypeOf((*[]testStruct)(nil)), + }, + }, + }, + "a": map[string]any{ + "b": map[string]any{ + "string": reflect.TypeOf((*string)(nil)), + "url": reflect.TypeOf((*flagtypes.URLFlag)(nil)), + }, + }, + } + if diff := cmp.Diff(expected, actual, cmp.Comparer(func(x, y reflect.Type) bool { return x == y })); diff != "" { + t.Error(diff) + } +} + +func TestBadGenerateYAMLTypeMapFromFlags(t *testing.T) { + flags := replaceFlagsForTesting(t) + + flags.Int("one.two.int", 10, "") + flags.Int("one.two", 10, "") + _, err := flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + + flags.Int("one.two", 10, "") + flags.Int("one.two.int", 10, "") + _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + + flags.Var(&unsupportedFlagValue{}, "unsupported", "") + _, err = flagyaml.GenerateYAMLMapWithValuesFromFlags(flagyaml.GetYAMLTypeForFlag, flagyaml.IgnoreFilter) + require.Error(t, err) + +} + +func TestRetypeAndFilterYAMLMap(t *testing.T) { + typeMap := map[string]any{ + "bool": reflect.TypeOf((*bool)(nil)), + "one": map[string]any{ + "two": map[string]any{ + "int": reflect.TypeOf((*int)(nil)), + "string_slice": reflect.TypeOf((*[]string)(nil)), + "two_and_a_half": map[string]any{ + "float64": reflect.TypeOf((*float64)(nil)), + }, + "three": map[string]any{ + "struct_slice": reflect.TypeOf((*[]testStruct)(nil)), + }, + }, + }, + "a": map[string]any{ + "b": map[string]any{ + "string": reflect.TypeOf((*string)(nil)), + "url": reflect.TypeOf((*flagtypes.URLFlag)(nil)), + }, + }, + "foo": map[string]any{ + "bar": reflect.TypeOf((*int64)(nil)), + }, + } + yamlData := ` +bool: true +one: + two: + int: 1 + string_slice: + - "string1" + - "string2" + two_and_a_half: + float64: 9.4 + three: + struct_slice: + - field: 9 + meadow: "Eternal" + - field: 5 +a: + b: + url: "http://www.example.com" +foo: 7 +first: + second: + unknown: 9009 + no: "definitely not" +` + yamlMap := make(map[string]any) + err := yaml.Unmarshal([]byte(yamlData), yamlMap) + require.NoError(t, err) + err = flagyaml.RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) + require.NoError(t, err) + expected := map[string]any{ + "bool": true, + "one": map[string]any{ + "two": map[string]any{ + "int": int(1), + "string_slice": []string{"string1", "string2"}, + "two_and_a_half": map[string]any{ + "float64": float64(9.4), + }, + "three": map[string]any{ + "struct_slice": []testStruct{{Field: 9, Meadow: "Eternal"}, {Field: 5}}, + }, + }, + }, + "a": map[string]any{ + "b": map[string]any{ + "url": flagtypes.URLFlag(url.URL{Scheme: "http", Host: "www.example.com"}), + }, + }, + } + if diff := cmp.Diff(expected, yamlMap); diff != "" { + t.Error(diff) + } +} + +func TestBadRetypeAndFilterYAMLMap(t *testing.T) { + typeMap := map[string]any{ + "bool": reflect.TypeOf((*bool)(nil)), + } + yamlData := ` +bool: 7 +` + yamlMap := make(map[string]any) + err := yaml.Unmarshal([]byte(yamlData), yamlMap) + require.NoError(t, err) + err = flagyaml.RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) + require.Error(t, err) + + typeMap = map[string]any{ + "bool": false, + } + yamlData = ` +bool: true +` + yamlMap = make(map[string]any) + err = yaml.Unmarshal([]byte(yamlData), yamlMap) + require.NoError(t, err) + err = flagyaml.RetypeAndFilterYAMLMap(yamlMap, typeMap, []string{}) + require.Error(t, err) +} + +func TestPopulateFlagsFromData(t *testing.T) { + flags := replaceFlagsForTesting(t) + + flagBool := flags.Bool("bool", true, "") + flagOneTwoInt := flags.Int("one.two.int", 10, "") + flagOneTwoStringSlice := flagtypes.Slice("one.two.string_slice", []string{"hi", "hello"}, "") + flagOneTwoTwoAndAHalfFloat := flags.Float64("one.two.two_and_a_half.float64", 5.2, "") + flagOneTwoThreeStructSlice := []testStruct{{Field: 4, Meadow: "Great"}} + flagtypes.SliceVar(&flagOneTwoThreeStructSlice, "one.two.three.struct_slice", "") + flagABString := flags.String("a.b.string", "xxx", "") + flagABStructSlice := []testStruct{{Field: 7, Meadow: "Chimney"}} + flagtypes.SliceVar(&flagABStructSlice, "a.b.struct_slice", "") + flagABURL := flagtypes.URLFromString("a.b.url", "https://www.example.com", "") + yamlData := ` +bool: true +one: + two: + int: 1 + string_slice: + - "string1" + - "string2" + two_and_a_half: + float64: 9.4 + three: + struct_slice: + - field: 9 + meadow: "Eternal" + - field: 5 +a: + b: + url: "http://www.example.com:8080" +foo: 7 +first: + second: + unknown: 9009 + no: "definitely not" +` + err := flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.NoError(t, err) + assert.Equal(t, true, *flagBool) + assert.Equal(t, int(1), *flagOneTwoInt) + assert.Equal(t, []string{"hi", "hello", "string1", "string2"}, *flagOneTwoStringSlice) + assert.Equal(t, float64(9.4), *flagOneTwoTwoAndAHalfFloat) + assert.Equal(t, []testStruct{{Field: 4, Meadow: "Great"}, {Field: 9, Meadow: "Eternal"}, {Field: 5}}, flagOneTwoThreeStructSlice) + assert.Equal(t, "xxx", *flagABString) + assert.Equal(t, []testStruct{{Field: 7, Meadow: "Chimney"}}, flagABStructSlice) + assert.Equal(t, url.URL{Scheme: "http", Host: "www.example.com:8080"}, *flagABURL) +} + +func TestBadPopulateFlagsFromData(t *testing.T) { + _ = replaceFlagsForTesting(t) + + yamlData := ` + bool: true +` + err := flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.Error(t, err) + + flags := replaceFlagsForTesting(t) + + flags.Var(&unsupportedFlagValue{}, "bad", "") + err = flagyaml.PopulateFlagsFromData([]byte{}) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + + _ = flags.Bool("bool", false, "") + yamlData = ` +bool: 7 +` + err = flagyaml.PopulateFlagsFromData([]byte(yamlData)) + require.Error(t, err) +} + +func TestPopulateFlagsFromYAML(t *testing.T) { + flags := replaceFlagsForTesting(t) + + flagBool := flags.Bool("bool", true, "") + flagOneTwoInt := flags.Int("one.two.int", 10, "") + flagOneTwoStringSlice := flagtypes.Slice("one.two.string_slice", []string{"hi", "hello"}, "") + flagOneTwoTwoAndAHalfFloat := flags.Float64("one.two.two_and_a_half.float64", 5.2, "") + flagOneTwoThreeStructSlice := []testStruct{{Field: 4, Meadow: "Great"}} + flagtypes.SliceVar(&flagOneTwoThreeStructSlice, "one.two.three.struct_slice", "") + flagABString := flags.String("a.b.string", "xxx", "") + flagABStructSlice := []testStruct{{Field: 7, Meadow: "Chimney"}} + flagtypes.SliceVar(&flagABStructSlice, "a.b.struct_slice", "") + flagABURL := flagtypes.URLFromString("a.b.url", "https://www.example.com", "") + input := map[string]any{ + "bool": false, + "one": map[string]any{ + "two": map[string]any{ + "string_slice": []string{"meow", "woof"}, + "two_and_a_half": map[string]any{ + "float64": float64(7), + }, + "three": map[string]any{ + "struct_slice": ([]testStruct)(nil), + }, + }, + }, + "a": map[string]any{ + "b": map[string]any{ + "string": "", + "struct_slice": []testStruct{{Field: 9}}, + "url": flagtypes.URLFlag(url.URL{Scheme: "https", Host: "www.example.com:8080"}), + }, + }, + "undefined": struct{}{}, // keys without with no corresponding flag name should be ignored. + } + node := &yaml.Node{} + err := node.Encode(input) + require.NoError(t, err) + err = flagyaml.PopulateFlagsFromYAMLMap(input, node) + require.NoError(t, err) + + assert.Equal(t, false, *flagBool) + assert.Equal(t, 10, *flagOneTwoInt) + assert.Equal(t, []string{"hi", "hello", "meow", "woof"}, *flagOneTwoStringSlice) + assert.Equal(t, float64(7), *flagOneTwoTwoAndAHalfFloat) + assert.Equal(t, []testStruct{{Field: 4, Meadow: "Great"}}, flagOneTwoThreeStructSlice) + assert.Equal(t, "", *flagABString) + assert.Equal(t, []testStruct{{Field: 7, Meadow: "Chimney"}, {Field: 9}}, flagABStructSlice) + assert.Equal(t, url.URL{Scheme: "https", Host: "www.example.com:8080"}, *flagABURL) +} + +func TestBadPopulateFlagsFromYAML(t *testing.T) { + _ = replaceFlagsForTesting(t) + + flags := replaceFlagsForTesting(t) + flags.Var(&unsupportedFlagValue{}, "unsupported", "") + input := map[string]any{ + "unsupported": 0, + } + node := &yaml.Node{} + err := node.Encode(input) + require.NoError(t, err) + err = flagyaml.PopulateFlagsFromYAMLMap(input, node) + require.Error(t, err) + + flags = replaceFlagsForTesting(t) + flags.Bool("bool", false, "") + input = map[string]any{ + "bool": 0, + } + node = &yaml.Node{} + err = node.Encode(input) + require.NoError(t, err) + err = flagyaml.PopulateFlagsFromYAMLMap(input, node) + require.Error(t, err) +} diff --git a/server/util/testing/flags/BUILD b/server/util/testing/flags/BUILD index adcb63cb878..430b26db34b 100644 --- a/server/util/testing/flags/BUILD +++ b/server/util/testing/flags/BUILD @@ -8,6 +8,8 @@ go_library( visibility = ["//visibility:public"], deps = [ "//server/util/flagutil", + "//server/util/flagutil/common", + "//server/util/flagutil/yaml", "@com_github_stretchr_testify//require", ], ) diff --git a/server/util/testing/flags/flags.go b/server/util/testing/flags/flags.go index ee07ab0622a..d63103d7553 100644 --- a/server/util/testing/flags/flags.go +++ b/server/util/testing/flags/flags.go @@ -7,6 +7,9 @@ import ( "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/stretchr/testify/require" + + flagutil_common "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/common" + flagyaml "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/yaml" ) var populateFlagsOnce sync.Once @@ -14,8 +17,8 @@ var populateFlagsOnce sync.Once func PopulateFlagsFromData(t testing.TB, testConfigData []byte) { populateFlagsOnce.Do(func() { // add placeholder type for type adding by testing - flagutil.AddTestFlagTypeForTesting(flag.Lookup("test.benchtime").Value, &struct{}{}) - err := flagutil.PopulateFlagsFromData(testConfigData) + flagutil_common.AddTestFlagTypeForTesting(flag.Lookup("test.benchtime").Value, &struct{}{}) + err := flagyaml.PopulateFlagsFromData(testConfigData) require.NoError(t, err) }) } @@ -25,11 +28,11 @@ func PopulateFlagsFromData(t testing.TB, testConfigData []byte) { func Set(t testing.TB, name string, value any) { origValue, err := flagutil.GetDereferencedValue[any](name) require.NoError(t, err) - err = flagutil.SetValueForFlagName(name, value, nil, false, true) + err = flagutil_common.SetValueForFlagName(name, value, nil, false, true) require.NoError(t, err) t.Cleanup(func() { - err = flagutil.SetValueForFlagName(name, origValue, nil, false, true) + err = flagutil_common.SetValueForFlagName(name, origValue, nil, false, true) require.NoError(t, err) }) } diff --git a/server/util/tracing/BUILD b/server/util/tracing/BUILD index a6be9232f47..d0b68bb0d67 100644 --- a/server/util/tracing/BUILD +++ b/server/util/tracing/BUILD @@ -8,7 +8,7 @@ go_library( deps = [ "//proto:trace_go_proto", "//server/environment", - "//server/util/flagutil", + "//server/util/flagutil/types", "//server/util/log", "//server/util/status", "@io_opentelemetry_go_contrib_detectors_gcp//:gcp", diff --git a/server/util/tracing/tracing.go b/server/util/tracing/tracing.go index 877ae4b17d0..f8db1c70fec 100644 --- a/server/util/tracing/tracing.go +++ b/server/util/tracing/tracing.go @@ -15,7 +15,6 @@ import ( "time" "github.com/buildbuddy-io/buildbuddy/server/environment" - "github.com/buildbuddy-io/buildbuddy/server/util/flagutil" "github.com/buildbuddy-io/buildbuddy/server/util/log" "github.com/buildbuddy-io/buildbuddy/server/util/status" "go.opentelemetry.io/contrib/detectors/gcp" @@ -29,6 +28,7 @@ import ( "google.golang.org/grpc/metadata" tpb "github.com/buildbuddy-io/buildbuddy/proto/trace" + flagtypes "github.com/buildbuddy-io/buildbuddy/server/util/flagutil/types" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.4.0" ) @@ -39,7 +39,7 @@ var ( traceJaegerCollector = flag.String("app.trace_jaeger_collector", "", "Address of the Jager collector endpoint where traces will be sent.") traceServiceName = flag.String("app.trace_service_name", "", "Name of the service to associate with traces.") traceFraction = flag.Float64("app.trace_fraction", 0, "Fraction of requests to sample for tracing.") - traceFractionOverrides = flagutil.Slice("app.trace_fraction_overrides", []string{}, "Tracing fraction override based on name in format name=fraction.") + traceFractionOverrides = flagtypes.Slice("app.trace_fraction_overrides", []string{}, "Tracing fraction override based on name in format name=fraction.") ignoreForcedTracingHeader = flag.Bool("app.ignore_forced_tracing_header", false, "If set, we will not honor the forced tracing header.") // bound overrides are parsed from the traceFractionOverrides flag.