diff --git a/cli.go b/cli.go index af50dab..985a9cb 100644 --- a/cli.go +++ b/cli.go @@ -7,8 +7,10 @@ import ( "log" "os" "os/signal" + "strings" "sync" "syscall" + "time" "github.com/hashicorp/consul-template/logging" "github.com/hashicorp/consul-template/watch" @@ -23,10 +25,10 @@ const ( ExitCodeOK int = 0 ExitCodeError = 10 + iota + ExitCodeInterrupt ExitCodeParseFlagsError - ExitCodeLoggingError ExitCodeRunnerError - ExitCodeInterrupt + ExitCodeConfigError ) /// ------------------------- /// @@ -61,23 +63,24 @@ func (cli *CLI) Run(args []string) int { return cli.handleError(err, ExitCodeParseFlagsError) } - // Setup the logging - if err := logging.Setup(&logging.Config{ - Name: Name, - Level: config.LogLevel, - Syslog: config.Syslog.Enabled, - SyslogFacility: config.Syslog.Facility, - Writer: cli.errStream, - }); err != nil { - return cli.handleError(err, ExitCodeLoggingError) + // Save original config (defaults + parsed flags) for handling reloads + baseConfig := config.Copy() + + // Setup the config and logging + config, err = cli.setup(config) + if err != nil { + return cli.handleError(err, ExitCodeConfigError) } + // Print version information for debugging + log.Printf("[INFO] %s", formattedVersion()) + // If the version was requested, return an "error" containing the version // information. This might sound weird, but most *nix applications actually // print their version on stderr anyway. if version { log.Printf("[DEBUG] (cli) version flag was given, exiting now") - fmt.Fprintf(cli.errStream, "%s v%s\n", Name, Version) + fmt.Fprintf(cli.errStream, "%s\n", formattedVersion()) return ExitCodeOK } @@ -112,6 +115,13 @@ func (cli *CLI) Run(args []string) int { case syscall.SIGHUP: fmt.Fprintf(cli.errStream, "Received HUP, reloading configuration...\n") runner.Stop() + + // Load the new configuration from disk + config, err = cli.setup(baseConfig) + if err != nil { + return cli.handleError(err, ExitCodeConfigError) + } + runner, err = NewRunner(config, once) if err != nil { return cli.handleError(err, ExitCodeRunnerError) @@ -143,95 +153,162 @@ func (cli *CLI) stop() { // much easier and cleaner. func (cli *CLI) parseFlags(args []string) (*Config, bool, bool, error) { var once, version bool - var config = DefaultConfig() + config := DefaultConfig() // Parse the flags and options flags := flag.NewFlagSet(Name, flag.ContinueOnError) flags.SetOutput(cli.errStream) - flags.Usage = func() { - fmt.Fprintf(cli.errStream, usage, Name) - } - flags.StringVar(&config.Consul, "consul", config.Consul, "") - flags.StringVar(&config.Token, "token", config.Token, "") - flags.Var((*authVar)(config.Auth), "auth", "") - flags.BoolVar(&config.SSL.Enabled, "ssl", config.SSL.Enabled, "") - flags.BoolVar(&config.SSL.Verify, "ssl-verify", config.SSL.Verify, "") - flags.DurationVar(&config.MaxStale, "max-stale", config.MaxStale, "") - flags.BoolVar(&config.Syslog.Enabled, "syslog", config.Syslog.Enabled, "") - flags.StringVar(&config.Syslog.Facility, "syslog-facility", config.Syslog.Facility, "") - flags.Var((*prefixVar)(&config.Prefixes), "prefix", "") - flags.Var((*watch.WaitVar)(config.Wait), "wait", "") - flags.DurationVar(&config.Retry, "retry", config.Retry, "") - flags.StringVar(&config.Path, "config", config.Path, "") - flags.StringVar(&config.LogLevel, "log-level", config.LogLevel, "") + flags.Usage = func() { fmt.Fprintf(cli.errStream, usage, Name) } + + flags.Var((funcVar)(func(s string) error { + config.Consul = s + config.set("consul") + return nil + }), "consul", "") + + flags.Var((funcVar)(func(s string) error { + config.Token = s + config.set("token") + return nil + }), "token", "") + + flags.Var((funcVar)(func(s string) error { + config.Auth.Enabled = true + config.set("auth.enabled") + if strings.Contains(s, ":") { + split := strings.SplitN(s, ":", 2) + config.Auth.Username = split[0] + config.set("auth.username") + config.Auth.Password = split[1] + config.set("auth.password") + } else { + config.Auth.Username = s + config.set("auth.username") + } + return nil + }), "auth", "") + + flags.Var((funcBoolVar)(func(b bool) error { + config.SSL.Enabled = b + config.set("ssl") + config.set("ssl.enabled") + return nil + }), "ssl", "") + + flags.Var((funcBoolVar)(func(b bool) error { + config.SSL.Verify = b + config.set("ssl") + config.set("ssl.verify") + return nil + }), "ssl-verify", "") + + flags.Var((funcVar)(func(s string) error { + config.SSL.Cert = s + config.set("ssl") + config.set("ssl.cert") + return nil + }), "ssl-cert", "") + + flags.Var((funcVar)(func(s string) error { + config.SSL.Key = s + config.set("ssl") + config.set("ssl.key") + return nil + }), "ssl-key", "") + + flags.Var((funcVar)(func(s string) error { + config.SSL.CaCert = s + config.set("ssl") + config.set("ssl.ca_cert") + return nil + }), "ssl-ca-cert", "") + + flags.Var((funcDurationVar)(func(d time.Duration) error { + config.MaxStale = d + config.set("max_stale") + return nil + }), "max-stale", "") + + flags.Var((funcVar)(func(s string) error { + p, err := ParsePrefix(s) + if err != nil { + return err + } + if config.Prefixes == nil { + config.Prefixes = make([]*Prefix, 0, 1) + } + config.Prefixes = append(config.Prefixes, p) + return nil + }), "prefix", "") + + flags.Var((funcBoolVar)(func(b bool) error { + config.Syslog.Enabled = b + config.set("syslog") + config.set("syslog.enabled") + return nil + }), "syslog", "") + + flags.Var((funcVar)(func(s string) error { + config.Syslog.Facility = s + config.set("syslog.facility") + return nil + }), "syslog-facility", "") + + flags.Var((funcVar)(func(s string) error { + w, err := watch.ParseWait(s) + if err != nil { + return err + } + config.Wait.Min = w.Min + config.Wait.Max = w.Max + config.set("wait") + return nil + }), "wait", "") + + flags.Var((funcDurationVar)(func(d time.Duration) error { + config.Retry = d + config.set("retry") + return nil + }), "retry", "") + + flags.Var((funcVar)(func(s string) error { + config.Path = s + config.set("path") + return nil + }), "config", "") + + flags.Var((funcVar)(func(s string) error { + config.PidFile = s + config.set("pid_file") + return nil + }), "pid-file", "") + + flags.Var((funcVar)(func(s string) error { + config.StatusDir = s + config.set("status_dir") + return nil + }), "status-dir", "") + + flags.Var((funcVar)(func(s string) error { + config.LogLevel = s + config.set("log_level") + return nil + }), "log-level", "") + flags.BoolVar(&once, "once", false, "") + flags.BoolVar(&version, "v", false, "") flags.BoolVar(&version, "version", false, "") - // Advanced options - flags.StringVar(&config.StatusDir, "status-dir", config.StatusDir, "") - - // Deprecated options - var deprecatedAddr string - flags.StringVar(&deprecatedAddr, "addr", "", "") - var deprecatedDest string - flags.StringVar(&deprecatedDest, "dst-prefix", "", "") - var deprecatedSrc string - flags.StringVar(&deprecatedSrc, "src", "", "") - var deprecatedLock string - flags.StringVar(&deprecatedLock, "lock", "", "") - var deprecatedStatus string - flags.StringVar(&deprecatedStatus, "status", "", "") - var deprecatedService string - flags.StringVar(&deprecatedService, "service", "", "") - // If there was a parser error, stop if err := flags.Parse(args); err != nil { return nil, false, false, err } - // Handle deprecations - if deprecatedAddr != "" { - log.Printf("[WARN] -addr is deprecated - please use -consul=<...> instead") - config.Consul = deprecatedAddr - } - if deprecatedDest != "" { - log.Printf("[WARN] -dst-prefix is deprecated - please use -prefix= instead") - - // If there are no prefixes, we cannot reasonably continue - if len(config.Prefixes) < 1 { - return nil, false, false, fmt.Errorf("must specify at least one prefix") - } - - config.Prefixes[0].Destination = deprecatedDest - } - if deprecatedSrc != "" { - log.Printf("[WARN] -src is deprecated - please use -prefix= instead") - - // If there are no prefixes, we cannot reasonably continue - if len(config.Prefixes) < 1 { - return nil, false, false, fmt.Errorf("must specify at least one prefix") - } - - // This is pretty jank, but build the thing into a string so we can convert - // it back into a prefix. Good times. Good times. - prefix := config.Prefixes[0] - raw := fmt.Sprintf("%s@%s:%s", prefix.Source.Prefix, deprecatedSrc, prefix.Destination) - newPrefix, err := ParsePrefix(raw) - if err != nil { - return nil, false, false, fmt.Errorf("error parsing source datacenter: %s", err) - } - - config.Prefixes[0] = newPrefix - } - if deprecatedLock != "" { - log.Printf("[WARN] -lock is deprecated - please use consul lock instead") - } - if deprecatedStatus != "" { - log.Printf("[WARN] -status is deprecated - please use -status-dir= instead") - config.StatusDir = deprecatedStatus - } - if deprecatedService != "" { - log.Printf("[WARN] -service is deprecated - please use consul lock instead") + // Error if extra arguments are present + args = flags.Args() + if len(args) > 0 { + return nil, false, false, fmt.Errorf("cli: extra argument(s): %q", + args) } return config, once, version, nil @@ -244,6 +321,33 @@ func (cli *CLI) handleError(err error, status int) int { return status } +// setup initializes the CLI. +func (cli *CLI) setup(config *Config) (*Config, error) { + if config.Path != "" { + newConfig, err := ConfigFromPath(config.Path) + if err != nil { + return nil, err + } + + // Merge ensuring that the CLI options still take precedence + newConfig.Merge(config) + config = newConfig + } + + // Setup the logging + if err := logging.Setup(&logging.Config{ + Name: Name, + Level: config.LogLevel, + Syslog: config.Syslog.Enabled, + SyslogFacility: config.Syslog.Facility, + Writer: cli.errStream, + }); err != nil { + return nil, err + } + + return config, nil +} + const usage = ` Usage: %s [options] @@ -254,12 +358,18 @@ Options: -auth= Set the basic authentication username (and password) -consul=
Sets the address of the Consul instance + -token= Sets the Consul API token -max-stale= Set the maximum staleness and allow stale queries to Consul which will distribute work among all servers instead of just the leader + -ssl Use SSL when connecting to Consul -ssl-verify Verify certificates when connecting via SSL - -token= Sets the Consul API token + -ssl-cert SSL client certificate to send to server + -ssl-key SSL/TLS private key for use in client authentication + key exchange + -ssl-ca-cert Validate server certificate against this CA + certificate file list -syslog Send the output to syslog instead of standard error and standard out. The syslog facility defaults to @@ -273,17 +383,20 @@ Options: the destination datacenters - if the destination is omitted, it is assumed to be the same as the source -wait= Sets the 'minumum(:maximum)' amount of time to wait - for stability before replicating + before replicating -retry= The amount of time to wait if Consul returns an error when communicating with the API -config= Sets the path to a configuration file on disk + + -pid-file= Path on disk to write the PID of the process -log-level= Set the logging level - valid values are "debug", "info", "warn" (default), and "err" -once Do not run the process as a daemon - -version Print the version of this daemon + + -v, -version Print the version of this daemon Advanced Options: diff --git a/cli_test.go b/cli_test.go index 668a3be..35b86bf 100644 --- a/cli_test.go +++ b/cli_test.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "fmt" "io/ioutil" "reflect" @@ -10,149 +9,9 @@ import ( "time" "github.com/hashicorp/consul-template/watch" + "github.com/hashicorp/go-gatedio" ) -// Deprecated CLI options -// TODO: Remove in the next release - -func TestParseFlags_addr(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-addr", "1.2.3.4", - }) - if err != nil { - t.Fatal(err) - } - - expected := "1.2.3.4" - if config.Consul != expected { - t.Errorf("expected %q to be %q", config.Consul, expected) - } -} - -func TestParseFlags_dstPrefix(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-prefix", "global", - "-dst-prefix", "backup", - }) - if err != nil { - t.Fatal(err) - } - - if len(config.Prefixes) < 1 { - t.Errorf("no prefixes") - } - - prefix := config.Prefixes[0].Source - if prefix.Prefix != "global" { - t.Errorf("expected %q to be %q", prefix.Prefix, "global") - } -} - -func TestParseFlags_dstPrefixNoPrefix(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - _, _, _, err := cli.parseFlags([]string{ - "-dst-prefix", "backup", - }) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } - - expected := "must specify at least one prefix" - if err.Error() != expected { - t.Errorf("expected %q to be %q", err.Error(), expected) - } -} - -func TestParseFlags_src(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-prefix", "global", - "-src", "nyc2", - }) - if err != nil { - t.Fatal(err) - } - - if len(config.Prefixes) < 1 { - t.Errorf("no prefixes") - } - - prefix := config.Prefixes[0].Source - if prefix.Prefix != "global" { - t.Errorf("expected %q to be %q", prefix.Prefix, "global") - } -} - -func TestParseFlags_srcNoPrefix(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - _, _, _, err := cli.parseFlags([]string{ - "-src", "nyc2", - }) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } - - expected := "must specify at least one prefix" - if err.Error() != expected { - t.Errorf("expected %q to be %q", err.Error(), expected) - } -} - -func TestParseFlags_srcBadPrefix(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - _, _, _, err := cli.parseFlags([]string{ - "-prefix", "global", - "-src", "n((*y@#c@!2", - }) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } - - expected := "invalid key prefix dependency format" - if !strings.Contains(err.Error(), expected) { - t.Errorf("expected %q to be %q", err.Error(), expected) - } -} - -func TestParseFlags_lock(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - _, _, _, err := cli.parseFlags([]string{ - "-lock", "service/locks/consul-replicate", - }) - if err != nil { - t.Fatal(err) - } -} - -func TestParseFlags_status(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-status", "service/statuses/consul-replicate", - }) - if err != nil { - t.Fatal(err) - } - - expected := "service/statuses/consul-replicate" - if config.StatusDir != expected { - t.Errorf("expected %q to be %q", config.StatusDir, expected) - } -} - -func TestParseFlags_service(t *testing.T) { - cli := NewCLI(ioutil.Discard, ioutil.Discard) - _, _, _, err := cli.parseFlags([]string{ - "-service", "replicator", - }) - if err != nil { - t.Fatal(err) - } -} - -// End deprecated CLI options - func TestParseFlags_consul(t *testing.T) { cli := NewCLI(ioutil.Discard, ioutil.Discard) config, _, _, err := cli.parseFlags([]string{ @@ -166,6 +25,9 @@ func TestParseFlags_consul(t *testing.T) { if config.Consul != expected { t.Errorf("expected %q to be %q", config.Consul, expected) } + if !config.WasSet("consul") { + t.Errorf("expected consul to be set") + } } func TestParseFlags_token(t *testing.T) { @@ -181,6 +43,9 @@ func TestParseFlags_token(t *testing.T) { if config.Token != expected { t.Errorf("expected %q to be %q", config.Token, expected) } + if !config.WasSet("token") { + t.Errorf("expected token to be set") + } } func TestParseFlags_authUsername(t *testing.T) { @@ -195,11 +60,17 @@ func TestParseFlags_authUsername(t *testing.T) { if config.Auth.Enabled != true { t.Errorf("expected auth to be enabled") } + if !config.WasSet("auth.enabled") { + t.Errorf("expected auth.enabled to be set") + } expected := "test" if config.Auth.Username != expected { t.Errorf("expected %v to be %v", config.Auth.Username, expected) } + if !config.WasSet("auth.username") { + t.Errorf("expected auth.username to be set") + } } func TestParseFlags_authUsernamePassword(t *testing.T) { @@ -214,14 +85,23 @@ func TestParseFlags_authUsernamePassword(t *testing.T) { if config.Auth.Enabled != true { t.Errorf("expected auth to be enabled") } + if !config.WasSet("auth.enabled") { + t.Errorf("expected auth.enabled to be set") + } expected := "test" if config.Auth.Username != expected { t.Errorf("expected %v to be %v", config.Auth.Username, expected) } + if !config.WasSet("auth.username") { + t.Errorf("expected auth.username to be set") + } if config.Auth.Password != expected { t.Errorf("expected %v to be %v", config.Auth.Password, expected) } + if !config.WasSet("auth.password") { + t.Errorf("expected auth.password to be set") + } } func TestParseFlags_SSL(t *testing.T) { @@ -237,6 +117,12 @@ func TestParseFlags_SSL(t *testing.T) { if config.SSL.Enabled != expected { t.Errorf("expected %v to be %v", config.SSL.Enabled, expected) } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.enabled") { + t.Errorf("expected ssl.enabled to be set") + } } func TestParseFlags_noSSL(t *testing.T) { @@ -252,6 +138,12 @@ func TestParseFlags_noSSL(t *testing.T) { if config.SSL.Enabled != expected { t.Errorf("expected %v to be %v", config.SSL.Enabled, expected) } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.enabled") { + t.Errorf("expected ssl.enabled to be set") + } } func TestParseFlags_SSLVerify(t *testing.T) { @@ -267,6 +159,12 @@ func TestParseFlags_SSLVerify(t *testing.T) { if config.SSL.Verify != expected { t.Errorf("expected %v to be %v", config.SSL.Verify, expected) } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.verify") { + t.Errorf("expected ssl.verify to be set") + } } func TestParseFlags_noSSLVerify(t *testing.T) { @@ -282,6 +180,75 @@ func TestParseFlags_noSSLVerify(t *testing.T) { if config.SSL.Verify != expected { t.Errorf("expected %v to be %v", config.SSL.Verify, expected) } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.verify") { + t.Errorf("expected ssl.verify to be set") + } +} + +func TestParseFlags_SSLCert(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-ssl-cert", "/path/to/c1.pem", + }) + if err != nil { + t.Fatal(err) + } + + expected := "/path/to/c1.pem" + if config.SSL.Cert != expected { + t.Errorf("expected %v to be %v", config.SSL.Cert, expected) + } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.cert") { + t.Errorf("expected ssl.cert to be set") + } +} + +func TestParseFlags_SSLKey(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-ssl-key", "/path/to/client-key.pem", + }) + if err != nil { + t.Fatal(err) + } + + expected := "/path/to/client-key.pem" + if config.SSL.Key != expected { + t.Errorf("expected %v to be %v", config.SSL.Key, expected) + } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.key") { + t.Errorf("expected ssl.key to be set") + } +} + +func TestParseFlags_SSLCaCert(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-ssl-ca-cert", "/path/to/c2.pem", + }) + if err != nil { + t.Fatal(err) + } + + expected := "/path/to/c2.pem" + if config.SSL.CaCert != expected { + t.Errorf("expected %v to be %v", config.SSL.CaCert, expected) + } + if !config.WasSet("ssl") { + t.Errorf("expected ssl to be set") + } + if !config.WasSet("ssl.ca_cert") { + t.Errorf("expected ssl.ca_cert to be set") + } } func TestParseFlags_maxStale(t *testing.T) { @@ -297,6 +264,9 @@ func TestParseFlags_maxStale(t *testing.T) { if config.MaxStale != expected { t.Errorf("expected %q to be %q", config.MaxStale, expected) } + if !config.WasSet("max_stale") { + t.Errorf("expected max_stale to be set") + } } func TestParseFlags_prefixes(t *testing.T) { @@ -313,9 +283,6 @@ func TestParseFlags_prefixes(t *testing.T) { } prefix := config.Prefixes[0] - if prefix.SourceRaw != "global@nyc1" { - t.Errorf("expected %q to be %q", prefix.SourceRaw, "global@nyc1") - } if prefix.Destination != "backup" { t.Errorf("expected %q to be %q", prefix.Destination, "backup") } @@ -334,6 +301,12 @@ func TestParseFlags_syslog(t *testing.T) { if config.Syslog.Enabled != expected { t.Errorf("expected %v to be %v", config.Syslog.Enabled, expected) } + if !config.WasSet("syslog") { + t.Errorf("expected syslog to be set") + } + if !config.WasSet("syslog.enabled") { + t.Errorf("expected syslog.enabled to be set") + } } func TestParseFlags_syslogFacility(t *testing.T) { @@ -349,6 +322,9 @@ func TestParseFlags_syslogFacility(t *testing.T) { if config.Syslog.Facility != expected { t.Errorf("expected %v to be %v", config.Syslog.Facility, expected) } + if !config.WasSet("syslog.facility") { + t.Errorf("expected syslog.facility to be set") + } } func TestParseFlags_wait(t *testing.T) { @@ -367,6 +343,9 @@ func TestParseFlags_wait(t *testing.T) { if !reflect.DeepEqual(config.Wait, expected) { t.Errorf("expected %v to be %v", config.Wait, expected) } + if !config.WasSet("wait") { + t.Errorf("expected wait to be set") + } } func TestParseFlags_waitError(t *testing.T) { @@ -397,6 +376,9 @@ func TestParseFlags_config(t *testing.T) { if config.Path != expected { t.Errorf("expected %v to be %v", config.Path, expected) } + if !config.WasSet("path") { + t.Errorf("expected path to be set") + } } func TestParseFlags_retry(t *testing.T) { @@ -412,6 +394,63 @@ func TestParseFlags_retry(t *testing.T) { if config.Retry != expected { t.Errorf("expected %v to be %v", config.Retry, expected) } + if !config.WasSet("retry") { + t.Errorf("expected retry to be set") + } +} + +func TestParseFlags_logLevel(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-log-level", "debug", + }) + if err != nil { + t.Fatal(err) + } + + expected := "debug" + if config.LogLevel != expected { + t.Errorf("expected %v to be %v", config.LogLevel, expected) + } + if !config.WasSet("log_level") { + t.Errorf("expected log_level to be set") + } +} + +func TestParseFlags_pidFile(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-pid-file", "/path/to/pid", + }) + if err != nil { + t.Fatal(err) + } + + expected := "/path/to/pid" + if config.PidFile != expected { + t.Errorf("expected %v to be %v", config.PidFile, expected) + } + if !config.WasSet("pid_file") { + t.Errorf("expected pid_file to be set") + } +} + +func TestParseFlags_statusDir(t *testing.T) { + cli := NewCLI(ioutil.Discard, ioutil.Discard) + config, _, _, err := cli.parseFlags([]string{ + "-status-dir", "consul/status/dir", + }) + if err != nil { + t.Fatal(err) + } + + expected := "consul/status/dir" + if config.StatusDir != expected { + t.Errorf("expected %v to be %v", config.StatusDir, expected) + } + if !config.WasSet("status_dir") { + t.Errorf("expected status_dir to be set") + } } func TestParseFlags_once(t *testing.T) { @@ -442,40 +481,35 @@ func TestParseFlags_version(t *testing.T) { } } -func TestParseFlags_logLevel(t *testing.T) { +func TestParseFlags_v(t *testing.T) { cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-log-level", "debug", + _, _, version, err := cli.parseFlags([]string{ + "-v", }) if err != nil { t.Fatal(err) } - expected := "debug" - if config.LogLevel != expected { - t.Errorf("expected %v to be %v", config.LogLevel, expected) + if version != true { + t.Errorf("expected version to be true") } } -func TestParseFlags_statusDir(t *testing.T) { +func TestParseFlags_errors(t *testing.T) { cli := NewCLI(ioutil.Discard, ioutil.Discard) - config, _, _, err := cli.parseFlags([]string{ - "-status-dir", "custom-status-dir", + _, _, _, err := cli.parseFlags([]string{ + "-totally", "-not", "-valid", }) - if err != nil { - t.Fatal(err) - } - expected := "custom-status-dir" - if config.StatusDir != expected { - t.Errorf("expected %v to be %v", config.StatusDir, expected) + if err == nil { + t.Fatal("expected error, but nothing was returned") } } -func TestParseFlags_errors(t *testing.T) { +func TestParseFlags_badArgs(t *testing.T) { cli := NewCLI(ioutil.Discard, ioutil.Discard) _, _, _, err := cli.parseFlags([]string{ - "-totally", "-not", "-valid", + "foo", "bar", }) if err == nil { @@ -484,7 +518,7 @@ func TestParseFlags_errors(t *testing.T) { } func TestRun_printsErrors(t *testing.T) { - outStream, errStream := new(bytes.Buffer), new(bytes.Buffer) + outStream, errStream := gatedio.NewByteBuffer(), gatedio.NewByteBuffer() cli := NewCLI(outStream, errStream) args := strings.Split("consul-replicate -bacon delicious", " ") @@ -500,7 +534,7 @@ func TestRun_printsErrors(t *testing.T) { } func TestRun_versionFlag(t *testing.T) { - outStream, errStream := new(bytes.Buffer), new(bytes.Buffer) + outStream, errStream := gatedio.NewByteBuffer(), gatedio.NewByteBuffer() cli := NewCLI(outStream, errStream) args := strings.Split("consul-replicate -version", " ") @@ -516,7 +550,7 @@ func TestRun_versionFlag(t *testing.T) { } func TestRun_parseError(t *testing.T) { - outStream, errStream := new(bytes.Buffer), new(bytes.Buffer) + outStream, errStream := gatedio.NewByteBuffer(), gatedio.NewByteBuffer() cli := NewCLI(outStream, errStream) args := strings.Split("consul-replicate -bacon delicious", " ") @@ -530,28 +564,3 @@ func TestRun_parseError(t *testing.T) { t.Fatalf("expected %q to contain %q", errStream.String(), expected) } } - -func TestRun_onceFlag(t *testing.T) { - t.Skip("pending a rewrite of the runner") - - outStream, errStream := new(bytes.Buffer), new(bytes.Buffer) - cli := NewCLI(outStream, errStream) - - command := fmt.Sprintf("consul-replicate -consul demo.consul.io -prefix global@nyc1 -once") - args := strings.Split(command, " ") - - ch := make(chan int, 1) - go func() { - ch <- cli.Run(args) - }() - - select { - case status := <-ch: - if status != ExitCodeOK { - t.Errorf("expected %d to eq %d", status, ExitCodeOK) - t.Errorf("stderr: %s", errStream.String()) - } - case <-time.After(2 * time.Second): - t.Errorf("expected exit, did not exit after 2 seconds") - } -} diff --git a/config.go b/config.go index d2996cc..a74d4a6 100644 --- a/config.go +++ b/config.go @@ -4,6 +4,7 @@ import ( "fmt" "io/ioutil" "os" + "path/filepath" "regexp" "strings" "time" @@ -33,51 +34,182 @@ type Config struct { Prefixes []*Prefix `mapstructure:"prefix"` // Auth is the HTTP basic authentication for communicating with Consul. - Auth *Auth `mapstructure:"-"` - AuthRaw []*Auth `mapstructure:"auth"` + Auth *AuthConfig `mapstructure:"auth"` + + // PidFile is the path on disk where a PID file should be written containing + // this processes PID. + PidFile string `mapstructure:"pid_file"` // SSL indicates we should use a secure connection while talking to // Consul. This requires Consul to be configured to serve HTTPS. - SSL *SSL `mapstructure:"-"` - SSLRaw []*SSL `mapstructure:"ssl"` + SSL *SSLConfig `mapstructure:"ssl"` // Syslog is the configuration for syslog. - Syslog *Syslog `mapstructure:"-"` - SyslogRaw []*Syslog `mapstructure:"syslog"` + Syslog *SyslogConfig `mapstructure:"syslog"` // MaxStale is the maximum amount of time for staleness from Consul as given // by LastContact. If supplied, Consul Replicate will query all servers // instead of just the leader. - MaxStale time.Duration `mapstructure:"-"` - MaxStaleRaw string `mapstructure:"max_stale"` + MaxStale time.Duration `mapstructure:"max_stale"` // Retry is the duration of time to wait between Consul failures. - Retry time.Duration `mapstructure:"-"` - RetryRaw string `mapstructure:"retry"` + Retry time.Duration `mapstructure:"retry"` // Wait is the quiescence timers. - Wait *watch.Wait `mapstructure:"-"` - WaitRaw string `mapstructure:"wait"` + Wait *watch.Wait `mapstructure:"wait"` // LogLevel is the level with which to log for this config. LogLevel string `mapstructure:"log_level"` // StatusDir is the path in the KV store that is used to store the // replication statuses (default: "service/consul-replicate/statuses"). - StatusDir string `mapstructure:"status_path"` + StatusDir string `mapstructure:"status_dir"` + + // setKeys is the list of config keys that were set by the user. + setKeys map[string]struct{} +} + +// Copy returns a deep copy of the current configuration. This is useful because +// the nested data structures may be shared. +func (c *Config) Copy() *Config { + config := new(Config) + config.Path = c.Path + config.Consul = c.Consul + config.Token = c.Token + + if c.Auth != nil { + config.Auth = &AuthConfig{ + Enabled: c.Auth.Enabled, + Username: c.Auth.Username, + Password: c.Auth.Password, + } + } + + config.PidFile = c.PidFile + + if c.SSL != nil { + config.SSL = &SSLConfig{ + Enabled: c.SSL.Enabled, + Verify: c.SSL.Verify, + Cert: c.SSL.Cert, + Key: c.SSL.Key, + CaCert: c.SSL.CaCert, + } + } + + if c.Syslog != nil { + config.Syslog = &SyslogConfig{ + Enabled: c.Syslog.Enabled, + Facility: c.Syslog.Facility, + } + } + + config.MaxStale = c.MaxStale + + config.Prefixes = make([]*Prefix, len(c.Prefixes)) + for i, p := range c.Prefixes { + config.Prefixes[i] = &Prefix{ + Source: p.Source, + SourceRaw: p.SourceRaw, + Destination: p.Destination, + } + } + + config.Retry = c.Retry + + if c.Wait != nil { + config.Wait = &watch.Wait{ + Min: c.Wait.Min, + Max: c.Wait.Max, + } + } + + config.LogLevel = c.LogLevel + config.StatusDir = c.StatusDir + + config.setKeys = c.setKeys + + return config } // Merge merges the values in config into this config object. Values in the // config object overwrite the values in c. func (c *Config) Merge(config *Config) { - if config.Consul != "" { + if config.WasSet("path") { + c.Path = config.Path + } + + if config.WasSet("consul") { c.Consul = config.Consul } - if config.Token != "" { + if config.WasSet("token") { c.Token = config.Token } + if config.WasSet("auth") { + if c.Auth == nil { + c.Auth = &AuthConfig{} + } + if config.WasSet("auth.username") { + c.Auth.Username = config.Auth.Username + c.Auth.Enabled = true + } + if config.WasSet("auth.password") { + c.Auth.Password = config.Auth.Password + c.Auth.Enabled = true + } + if config.WasSet("auth.enabled") { + c.Auth.Enabled = config.Auth.Enabled + } + } + + if config.WasSet("pid_file") { + c.PidFile = config.PidFile + } + + if config.WasSet("ssl") { + if c.SSL == nil { + c.SSL = &SSLConfig{} + } + if config.WasSet("ssl.verify") { + c.SSL.Verify = config.SSL.Verify + c.SSL.Enabled = true + } + if config.WasSet("ssl.cert") { + c.SSL.Cert = config.SSL.Cert + c.SSL.Enabled = true + } + if config.WasSet("ssl.key") { + c.SSL.Key = config.SSL.Key + c.SSL.Enabled = true + } + if config.WasSet("ssl.ca_cert") { + c.SSL.CaCert = config.SSL.CaCert + c.SSL.Enabled = true + } + if config.WasSet("ssl.enabled") { + c.SSL.Enabled = config.SSL.Enabled + } + } + + if config.WasSet("syslog") { + if c.Syslog == nil { + c.Syslog = &SyslogConfig{} + } + if config.WasSet("syslog.facility") { + c.Syslog.Facility = config.Syslog.Facility + c.Syslog.Enabled = true + } + if config.WasSet("syslog.enabled") { + c.Syslog.Enabled = config.Syslog.Enabled + } + } + + if config.WasSet("max_stale") { + c.MaxStale = config.MaxStale + } + if config.Prefixes != nil { if c.Prefixes == nil { c.Prefixes = make([]*Prefix, 0) @@ -92,56 +224,53 @@ func (c *Config) Merge(config *Config) { } } - if config.Auth != nil { - c.Auth = &Auth{ - Enabled: config.Auth.Enabled, - Username: config.Auth.Username, - Password: config.Auth.Password, - } + if config.WasSet("retry") { + c.Retry = config.Retry } - if config.SSL != nil { - c.SSL = &SSL{ - Enabled: config.SSL.Enabled, - Verify: config.SSL.Verify, + if config.WasSet("wait") { + c.Wait = &watch.Wait{ + Min: config.Wait.Min, + Max: config.Wait.Max, } } - if config.Syslog != nil { - c.Syslog = &Syslog{ - Enabled: config.Syslog.Enabled, - Facility: config.Syslog.Facility, - } + if config.WasSet("log_level") { + c.LogLevel = config.LogLevel } - if config.MaxStale != 0 { - c.MaxStale = config.MaxStale - c.MaxStaleRaw = config.MaxStaleRaw + if config.WasSet("status_dir") { + c.StatusDir = config.StatusDir } - if config.Retry != 0 { - c.Retry = config.Retry - c.RetryRaw = config.RetryRaw + if c.setKeys == nil { + c.setKeys = make(map[string]struct{}) } - if config.Wait != nil { - c.Wait = &watch.Wait{ - Min: config.Wait.Min, - Max: config.Wait.Max, + for k := range config.setKeys { + if _, ok := c.setKeys[k]; !ok { + c.setKeys[k] = struct{}{} } - c.WaitRaw = config.WaitRaw } +} - if config.LogLevel != "" { - c.LogLevel = config.LogLevel +// WasSet determines if the given key was set in the config (as opposed to just +// having the default value). +func (c *Config) WasSet(key string) bool { + if _, ok := c.setKeys[key]; ok { + return true } + return false +} - if config.StatusDir != "" { - c.StatusDir = config.StatusDir +// set is a helper function for marking a key as set. +func (c *Config) set(key string) { + if _, ok := c.setKeys[key]; !ok { + c.setKeys[key] = struct{}{} } } -// ParseConfig reads the configuration file at the given path and returns a new +// g reads the configuration file at the given path and returns a new // Config struct with the data populated. func ParseConfig(path string) (*Config, error) { var errs *multierror.Error @@ -149,31 +278,41 @@ func ParseConfig(path string) (*Config, error) { // Read the contents of the file contents, err := ioutil.ReadFile(path) if err != nil { - errs = multierror.Append(errs, err) - return nil, errs.ErrorOrNil() + return nil, fmt.Errorf("error reading config at %q: %s", path, err) } // Parse the file (could be HCL or JSON) - var parsed interface{} - if err := hcl.Decode(&parsed, string(contents)); err != nil { - errs = multierror.Append(errs, err) - return nil, errs.ErrorOrNil() + var shadow interface{} + if err := hcl.Decode(&shadow, string(contents)); err != nil { + return nil, fmt.Errorf("error decoding config at %q: %s", path, err) } + // Convert to a map and flatten the keys we want to flatten + parsed, ok := shadow.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("error converting config at %q", path) + } + flattenKeys(parsed, []string{"auth", "ssl", "syslog"}) + // Create a new, empty config - config := &Config{} + config := new(Config) // Use mapstructure to populate the basic config fields + metadata := new(mapstructure.Metadata) decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.ComposeDecodeHookFunc( + watch.StringToWaitDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + mapstructure.StringToTimeDurationHookFunc(), + ), ErrorUnused: true, - Metadata: nil, + Metadata: metadata, Result: config, }) if err != nil { errs = multierror.Append(errs, err) return nil, errs.ErrorOrNil() } - if err := decoder.Decode(parsed); err != nil { errs = multierror.Append(errs, err) return nil, errs.ErrorOrNil() @@ -197,53 +336,20 @@ func ParseConfig(path string) (*Config, error) { } } - // Parse the MaxStale component - if raw := config.MaxStaleRaw; raw != "" { - stale, err := time.ParseDuration(raw) - - if err == nil { - config.MaxStale = stale - } else { - errs = multierror.Append(errs, fmt.Errorf("max_stale invalid: %v", err)) - } - } - - // Extract the last Auth block - if len(config.AuthRaw) > 0 { - config.Auth = config.AuthRaw[len(config.AuthRaw)-1] - } - - // Extract the last SSL block - if len(config.SSLRaw) > 0 { - config.SSL = config.SSLRaw[len(config.SSLRaw)-1] - } - - // Extract the last Syslog block - if len(config.SyslogRaw) > 0 { - config.Syslog = config.SyslogRaw[len(config.SyslogRaw)-1] + // Update the list of set keys + if config.setKeys == nil { + config.setKeys = make(map[string]struct{}) } - - // Parse the Retry component - if raw := config.RetryRaw; raw != "" { - retry, err := time.ParseDuration(raw) - - if err == nil { - config.Retry = retry - } else { - errs = multierror.Append(errs, fmt.Errorf("retry invalid: %v", err)) + for _, key := range metadata.Keys { + if _, ok := config.setKeys[key]; !ok { + config.setKeys[key] = struct{}{} } } + config.setKeys["path"] = struct{}{} - // Parse the Wait component - if raw := config.WaitRaw; raw != "" { - wait, err := watch.ParseWait(raw) - - if err == nil { - config.Wait = wait - } else { - errs = multierror.Append(errs, fmt.Errorf("wait invalid: %v", err)) - } - } + d := DefaultConfig() + d.Merge(config) + config = d return config, errs.ErrorOrNil() } @@ -256,43 +362,106 @@ func DefaultConfig() *Config { } return &Config{ - Auth: &Auth{ + Auth: &AuthConfig{ Enabled: false, }, - SSL: &SSL{ + SSL: &SSLConfig{ Enabled: false, Verify: true, }, - Syslog: &Syslog{ + Syslog: &SyslogConfig{ Enabled: false, Facility: "LOCAL0", }, - Prefixes: []*Prefix{}, - Retry: 5 * time.Second, + LogLevel: logLevel, + Prefixes: []*Prefix{}, + Retry: 5 * time.Second, + StatusDir: "service/consul-replicate/statuses", Wait: &watch.Wait{ Min: 150 * time.Millisecond, Max: 400 * time.Millisecond, }, - LogLevel: logLevel, - StatusDir: "service/consul-replicate/statuses", + setKeys: make(map[string]struct{}), } } -// Auth is the HTTP basic authentication data. -type Auth struct { +// ConfigFromPath iterates and merges all configuration files in a given +// directory, returning the resulting config. +func ConfigFromPath(path string) (*Config, error) { + // Ensure the given filepath exists + if _, err := os.Stat(path); os.IsNotExist(err) { + return nil, fmt.Errorf("config: missing file/folder: %s", path) + } + + // Check if a file was given or a path to a directory + stat, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("config: error stating file: %s", err) + } + + // Recursively parse directories, single load files + if stat.Mode().IsDir() { + // Ensure the given filepath has at least one config file + _, err := ioutil.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("config: error listing directory: %s", err) + } + + // Create a blank config to merge off of + config := DefaultConfig() + + // Potential bug: Walk does not follow symlinks! + err = filepath.Walk(path, func(path string, info os.FileInfo, err error) error { + // If WalkFunc had an error, just return it + if err != nil { + return err + } + + // Do nothing for directories + if info.IsDir() { + return nil + } + + // Parse and merge the config + newConfig, err := ParseConfig(path) + if err != nil { + return err + } + config.Merge(newConfig) + + return nil + }) + + if err != nil { + return nil, fmt.Errorf("config: walk error: %s", err) + } + + return config, nil + } else if stat.Mode().IsRegular() { + return ParseConfig(path) + } + + return nil, fmt.Errorf("config: unknown filetype: %q", stat.Mode().String()) +} + +// AuthConfig is the HTTP basic authentication data. +type AuthConfig struct { Enabled bool `mapstructure:"enabled"` Username string `mapstructure:"username"` Password string `mapstructure:"password"` } -// SSL is the configuration for SSL. -type SSL struct { - Enabled bool `mapstructure:"enabled"` - Verify bool `mapstructure:"verify"` +// SSLConfig is the configuration for SSL. +type SSLConfig struct { + Enabled bool `mapstructure:"enabled"` + Verify bool `mapstructure:"verify"` + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` + CaCert string `mapstructure:"ca_cert"` } -// Syslog is the configuration for syslog. -type Syslog struct { +// SyslogConfig is the configuration for syslog. +type SyslogConfig struct { Enabled bool `mapstructure:"enabled"` Facility string `mapstructure:"facility"` } @@ -343,3 +512,40 @@ func ParsePrefix(s string) (*Prefix, error) { Destination: destination, }, nil } + +// flattenKeys is a function that takes a map[string]interface{} and recursively +// flattens any keys that are a []map[string]interface{} where the key is in the +// given list of keys. +func flattenKeys(m map[string]interface{}, keys []string) { + keyMap := make(map[string]struct{}) + for _, key := range keys { + keyMap[key] = struct{}{} + } + + var flatten func(map[string]interface{}) + flatten = func(m map[string]interface{}) { + for k, v := range m { + if _, ok := keyMap[k]; !ok { + continue + } + + switch typed := v.(type) { + case []map[string]interface{}: + if len(typed) > 0 { + last := typed[len(typed)-1] + flatten(last) + m[k] = last + } else { + m[k] = nil + } + case map[string]interface{}: + flatten(typed) + m[k] = typed + default: + m[k] = v + } + } + } + + flatten(m) +} diff --git a/config_test.go b/config_test.go index 2203090..8b2113b 100644 --- a/config_test.go +++ b/config_test.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "os" "path" "reflect" @@ -13,165 +14,203 @@ import ( "github.com/hashicorp/consul-template/watch" ) -// Test that an empty config does nothing -func TestMerge_emptyConfig(t *testing.T) { - consul := "consul.io:8500" - config := &Config{Consul: consul} - config.Merge(&Config{}) +func testConfig(contents string, t *testing.T) *Config { + f, err := ioutil.TempFile(os.TempDir(), "") + if err != nil { + t.Fatal(err) + } + + _, err = f.Write([]byte(contents)) + if err != nil { + t.Fatal(err) + } - if config.Consul != consul { - t.Fatalf("expected %q to equal %q", config.Consul, consul) + config, err := ParseConfig(f.Name()) + if err != nil { + t.Fatal(err) } + return config } -// Test that simple values are merged -func TestMerge_simpleConfig(t *testing.T) { - config, newConsul := &Config{Consul: "consul.io:8500"}, "packer.io:7300" - config.Merge(&Config{Consul: newConsul}) +func TestMerge_emptyConfig(t *testing.T) { + config := DefaultConfig() + config.Merge(&Config{}) - if config.Consul != newConsul { - t.Fatalf("expected %q to equal %q", config.Consul, newConsul) + expected := DefaultConfig() + if !reflect.DeepEqual(config, expected) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config, expected) } } -// Test that the flags for HTTPS are properly merged -func TestMerge_HttpsOptions(t *testing.T) { - config := &Config{ - SSL: &SSL{ - Enabled: false, - Verify: false, - }, - } - otherConfig := &Config{ - SSL: &SSL{ - Enabled: true, - Verify: true, - }, +func TestMerge_topLevel(t *testing.T) { + config1 := testConfig(` + consul = "consul-1" + token = "token-1" + max_stale = "1s" + retry = "1s" + wait = "1s" + pid_file = "/pid-1" + status_dir = "service/consul/foo" + log_level = "log_level-1" + `, t) + config2 := testConfig(` + consul = "consul-2" + token = "token-2" + max_stale = "2s" + retry = "2s" + wait = "2s" + pid_file = "/pid-2" + status_dir = "service/consul/bar" + log_level = "log_level-2" + `, t) + config1.Merge(config2) + + if !reflect.DeepEqual(config1, config2) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config1, config2) } - config.Merge(otherConfig) +} - if config.SSL.Enabled != true { - t.Errorf("expected enabled to be true") +func TestMerge_auth(t *testing.T) { + config := testConfig(` + auth { + enabled = true + username = "1" + password = "1" + } + `, t) + config.Merge(testConfig(` + auth { + password = "2" + } + `, t)) + + expected := &AuthConfig{ + Enabled: true, + Username: "1", + Password: "2", + } + + if !reflect.DeepEqual(config.Auth, expected) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config.Auth, expected) } +} - if config.SSL.Verify != true { - t.Errorf("expected SSL verify to be true") +func TestMerge_SSL(t *testing.T) { + config := testConfig(` + ssl { + enabled = true + verify = true + cert = "1.pem" + ca_cert = "ca-1.pem" + } + `, t) + config.Merge(testConfig(` + ssl { + enabled = false + } + `, t)) + + expected := &SSLConfig{ + Enabled: false, + Verify: true, + Cert: "1.pem", + CaCert: "ca-1.pem", + } + + if !reflect.DeepEqual(config.SSL, expected) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config.SSL, expected) } +} - config = &Config{ - SSL: &SSL{ - Enabled: true, - Verify: true, - }, - } - otherConfig = &Config{ - SSL: &SSL{ - Enabled: false, - Verify: false, - }, - } - config.Merge(otherConfig) +func TestMerge_syslog(t *testing.T) { + config := testConfig(` + syslog { + enabled = true + facility = "1" + } + `, t) + config.Merge(testConfig(` + syslog { + facility = "2" + } + `, t)) - if config.SSL.Enabled != false { - t.Errorf("expected enabled to be false") + expected := &SyslogConfig{ + Enabled: true, + Facility: "2", } - if config.SSL.Verify != false { - t.Errorf("expected SSL verify to be false") + if !reflect.DeepEqual(config.Syslog, expected) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config.Syslog, expected) } } func TestMerge_Prefixes(t *testing.T) { - global := &Prefix{SourceRaw: "global/config"} - redis := &Prefix{SourceRaw: "redis/config"} - - config := &Config{Prefixes: []*Prefix{global}} - otherConfig := &Config{Prefixes: []*Prefix{redis}} - config.Merge(otherConfig) + config1 := testConfig(` + prefix { + source = "foo" + destination = "bar" + } + `, t) + config2 := testConfig(` + prefix { + source = "foo-2" + destination = "bar-2" + } + `, t) + config1.Merge(config2) - expected := []*Prefix{global, redis} - if !reflect.DeepEqual(config.Prefixes, expected) { - t.Errorf("expected %#v to be %#v", config.Prefixes, expected) + if len(config1.Prefixes) != 2 { + t.Fatalf("bad prefixes %d", len(config1.Prefixes)) } -} -func TestMerge_AuthOptions(t *testing.T) { - config := &Config{ - Auth: &Auth{Username: "user", Password: "pass"}, + if config1.Prefixes[0].Source == nil { + t.Errorf("bad source: %#v", config1.Prefixes[0].Source) } - otherConfig := &Config{ - Auth: &Auth{Username: "newUser", Password: ""}, + if config1.Prefixes[0].SourceRaw != "foo" { + t.Errorf("bad source_raw: %s", config1.Prefixes[0].SourceRaw) } - config.Merge(otherConfig) - - if config.Auth.Username != "newUser" { - t.Errorf("expected %q to be %q", config.Auth.Username, "newUser") + if config1.Prefixes[0].Destination != "bar" { + t.Errorf("bad destination: %s", config1.Prefixes[0].Destination) } -} -func TestMerge_SyslogOptions(t *testing.T) { - config := &Config{ - Syslog: &Syslog{Enabled: false, Facility: "LOCAL0"}, - } - otherConfig := &Config{ - Syslog: &Syslog{Enabled: true, Facility: "LOCAL1"}, + if config1.Prefixes[1].Source == nil { + t.Errorf("bad source: %#v", config1.Prefixes[1].Source) } - config.Merge(otherConfig) - - if config.Syslog.Enabled != true { - t.Errorf("expected %t to be %t", config.Syslog.Enabled, true) + if config1.Prefixes[1].SourceRaw != "foo-2" { + t.Errorf("bad source_raw: %s", config1.Prefixes[1].SourceRaw) } - - if config.Syslog.Facility != "LOCAL1" { - t.Errorf("expected %q to be %q", config.Syslog.Facility, "LOCAL1") + if config1.Prefixes[1].Destination != "bar-2" { + t.Errorf("bad destination: %s", config1.Prefixes[1].Destination) } } -// Test that file read errors are propagated up -func TestParseConfig_readFileError(t *testing.T) { - _, err := ParseConfig(path.Join(os.TempDir(), "config.json")) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } +func TestMerge_wait(t *testing.T) { + config1 := testConfig(` + wait = "1s:1s" + `, t) + config2 := testConfig(` + wait = "2s:2s" + `, t) + config1.Merge(config2) - expectedErr := "no such file or directory" - if !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q to contain %q", err.Error(), expectedErr) + if !reflect.DeepEqual(config1.Wait, config2.Wait) { + t.Errorf("expected \n\n%#v\n\n to be \n\n%#v\n\n", config1.Wait, config2.Wait) } } -// Test that parser errors are propagated up -func TestParseConfig_parseFileError(t *testing.T) { - configFile := test.CreateTempfile([]byte(` - invalid - `), t) - defer test.DeleteTempfile(configFile, t) - - _, err := ParseConfig(configFile.Name()) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } -} - -// Test that mapstructure errors are propagated up -func TestParseConfig_mapstructureError(t *testing.T) { - configFile := test.CreateTempfile([]byte(` - consul = true - `), t) - defer test.DeleteTempfile(configFile, t) - - _, err := ParseConfig(configFile.Name()) +func TestParseConfig_readFileError(t *testing.T) { + _, err := ParseConfig(path.Join(os.TempDir(), "config.json")) if err == nil { t.Fatal("expected error, but nothing was returned") } - expectedErr := "nconvertible type 'bool'" - if !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q to contain %q", err.Error(), expectedErr) + expected := "no such file or directory" + if !strings.Contains(err.Error(), expected) { + t.Fatalf("expected %q to include %q", err.Error(), expected) } } -// Test that the config is parsed correctly func TestParseConfig_correctValues(t *testing.T) { configFile := test.CreateTempfile([]byte(` consul = "nyc1.demo.consul.io" @@ -179,9 +218,10 @@ func TestParseConfig_correctValues(t *testing.T) { token = "abcd1234" wait = "5s:10s" retry = "10s" + pid_file = "/var/run/ct" log_level = "warn" - status_path = "global/statuses/replicators" + status_dir = "global/statuses/replicators" auth { enabled = true @@ -207,52 +247,33 @@ func TestParseConfig_correctValues(t *testing.T) { } expected := &Config{ - Path: configFile.Name(), - Consul: "nyc1.demo.consul.io", - Token: "abcd1234", - MaxStale: time.Second * 5, - MaxStaleRaw: "5s", - Auth: &Auth{ + Path: configFile.Name(), + PidFile: "/var/run/ct", + Consul: "nyc1.demo.consul.io", + Token: "abcd1234", + MaxStale: time.Second * 5, + Auth: &AuthConfig{ Enabled: true, Username: "test", Password: "test", }, - AuthRaw: []*Auth{ - &Auth{ - Enabled: true, - Username: "test", - Password: "test", - }, - }, - SSL: &SSL{ + Prefixes: []*Prefix{}, + SSL: &SSLConfig{ Enabled: true, Verify: false, }, - SSLRaw: []*SSL{ - &SSL{ - Enabled: true, - Verify: false, - }, - }, - Syslog: &Syslog{ + Syslog: &SyslogConfig{ Enabled: true, Facility: "LOCAL5", }, - SyslogRaw: []*Syslog{ - &Syslog{ - Enabled: true, - Facility: "LOCAL5", - }, - }, Wait: &watch.Wait{ Min: time.Second * 5, Max: time.Second * 10, }, - WaitRaw: "5s:10s", Retry: 10 * time.Second, - RetryRaw: "10s", LogLevel: "warn", StatusDir: "global/statuses/replicators", + setKeys: config.setKeys, } if !reflect.DeepEqual(config, expected) { @@ -279,40 +300,6 @@ func TestParseConfig_parseStoreKeyPrefixError(t *testing.T) { } } -func TestParseConfig_parseRetryError(t *testing.T) { - configFile := test.CreateTempfile([]byte(` - retry = "bacon pants" - `), t) - defer test.DeleteTempfile(configFile, t) - - _, err := ParseConfig(configFile.Name()) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } - - expectedErr := "retry invalid" - if !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q to contain %q", err.Error(), expectedErr) - } -} - -func TestParseConfig_parseWaitError(t *testing.T) { - configFile := test.CreateTempfile([]byte(` - wait = "not_valid:duration" - `), t) - defer test.DeleteTempfile(configFile, t) - - _, err := ParseConfig(configFile.Name()) - if err == nil { - t.Fatal("expected error, but nothing was returned") - } - - expectedErr := "wait invalid" - if !strings.Contains(err.Error(), expectedErr) { - t.Fatalf("expected error %q to contain %q", err.Error(), expectedErr) - } -} - func TestParsePrefix_emptyStringArgs(t *testing.T) { _, err := ParsePrefix("") if err == nil { @@ -356,9 +343,6 @@ func TestParsePrefix_source(t *testing.T) { t.Fatal(err) } - if prefix.SourceRaw != source { - t.Errorf("expected %q to be %q", prefix.SourceRaw, source) - } if prefix.Source.Prefix != source { t.Errorf("expected %q to be %q", prefix.Source.Prefix, source) } @@ -372,9 +356,6 @@ func TestParsePrefix_sourceSlash(t *testing.T) { } expected := "global" - if prefix.SourceRaw != expected { - t.Errorf("expected %q to be %q", prefix.SourceRaw, expected) - } if prefix.Source.Prefix != expected { t.Errorf("expected %q to be %q", prefix.Source.Prefix, expected) } @@ -387,9 +368,6 @@ func TestParsePrefix_destination(t *testing.T) { t.Fatal(err) } - if prefix.SourceRaw != "global@nyc4" { - t.Errorf("expected %q to be %q", prefix.SourceRaw, "global@nyc4") - } if prefix.Destination != destination { t.Errorf("expected %q to be %q", prefix.Destination, destination) } diff --git a/flags.go b/flags.go index 19f4ff2..d7967d7 100644 --- a/flags.go +++ b/flags.go @@ -1,63 +1,42 @@ package main import ( - "fmt" - "strings" + "strconv" + "time" ) -// prefixVar implements the Flag.Value interface and allows the user -// to specify multiple -prefix keys in the CLI where each option is parsed -// as a dependency. -type prefixVar []*Prefix +// funcVar is a type of flag that accepts a function that is the string given +// by the user. +type funcVar func(s string) error -func (pv *prefixVar) Set(value string) error { - prefix, err := ParsePrefix(value) - if err != nil { - return err - } +func (f funcVar) Set(s string) error { return f(s) } +func (f funcVar) String() string { return "" } +func (f funcVar) IsBoolFlag() bool { return false } - if *pv == nil { - *pv = make([]*Prefix, 0, 1) - } - *pv = append(*pv, prefix) +// funcBoolVar is a type of flag that accepts a function, converts the user's +// value to a bool, and then calls the given function. +type funcBoolVar func(b bool) error - return nil -} - -func (pv *prefixVar) String() string { - list := make([]string, 0, len(*pv)) - for _, prefix := range *pv { - list = append(list, fmt.Sprintf("%s:%s", prefix.SourceRaw, prefix.Destination)) +func (f funcBoolVar) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err } - return strings.Join(list, ", ") + return f(v) } +func (f funcBoolVar) String() string { return "" } +func (f funcBoolVar) IsBoolFlag() bool { return true } -/// ------------------------- /// - -// authVar implements the Flag.Value interface and allows the user to specify -// authentication in the username[:password] form. -type authVar Auth +// funcDurationVar is a type of flag that accepts a function, converts the +// user's value to a duration, and then calls the given function. +type funcDurationVar func(d time.Duration) error -// Set sets the value for this authentication. -func (a *authVar) Set(value string) error { - a.Enabled = true - - if strings.Contains(value, ":") { - split := strings.SplitN(value, ":", 2) - a.Username = split[0] - a.Password = split[1] - } else { - a.Username = value - } - - return nil -} - -// String returns the string representation of this authentication. -func (a *authVar) String() string { - if a.Password == "" { - return a.Username +func (f funcDurationVar) Set(s string) error { + v, err := time.ParseDuration(s) + if err != nil { + return err } - - return fmt.Sprintf("%s:%s", a.Username, a.Password) + return f(v) } +func (f funcDurationVar) String() string { return "" } +func (f funcDurationVar) IsBoolFlag() bool { return false } diff --git a/main.go b/main.go index 5aa2e79..be3c758 100644 --- a/main.go +++ b/main.go @@ -1,16 +1,35 @@ package main // import "github.com/hashicorp/consul-replicate" import ( + "bytes" + "fmt" "os" ) -// Name is the exported name of this application. -const Name = "consul-replicate" +// The git commit that was compiled. This will be filled in by the compiler. +var GitCommit string -// Version is the current version of this application. -const Version = "0.2.0.dev" +const Name = "consul-replicate" +const Version = "0.2.0" +const VersionPrerelease = "dev" func main() { cli := NewCLI(os.Stdout, os.Stderr) os.Exit(cli.Run(os.Args)) } + +// formattedVersion returns a formatted version string which includes the git +// commit and development information. +func formattedVersion() string { + var versionString bytes.Buffer + fmt.Fprintf(&versionString, "%s v%s", Name, Version) + + if VersionPrerelease != "" { + fmt.Fprintf(&versionString, "-%s", VersionPrerelease) + + if GitCommit != "" { + fmt.Fprintf(&versionString, " (%s)", GitCommit) + } + } + return versionString.String() +} diff --git a/runner.go b/runner.go index 204b12f..92d2b02 100644 --- a/runner.go +++ b/runner.go @@ -96,6 +96,12 @@ func NewRunner(config *Config, once bool) (*Runner, error) { func (r *Runner) Start() { log.Printf("[INFO] (runner) starting") + // Create the pid before doing anything. + if err := r.storePid(); err != nil { + r.ErrCh <- err + return + } + // Add the dependencies to the watcher for _, prefix := range r.config.Prefixes { r.watcher.Add(prefix.Source) @@ -162,6 +168,10 @@ func (r *Runner) Start() { func (r *Runner) Stop() { log.Printf("[INFO] (runner) stopping") r.watcher.Stop() + if err := r.deletePid(); err != nil { + log.Printf("[WARN] (runner) could not remove pid at %q: %s", + r.config.PidFile, err) + } close(r.DoneCh) } @@ -403,6 +413,53 @@ func (r *Runner) statusPath(prefix *Prefix) string { return filepath.Join(r.config.StatusDir, enc) } +// storePid is used to write out a PID file to disk. +func (r *Runner) storePid() error { + path := r.config.PidFile + if path == "" { + return nil + } + + log.Printf("[INFO] creating pid file at %q", path) + + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666) + if err != nil { + return fmt.Errorf("runner: could not open pid file: %s", err) + } + defer f.Close() + + pid := os.Getpid() + _, err = f.WriteString(fmt.Sprintf("%d", pid)) + if err != nil { + return fmt.Errorf("runner: could not write to pid file: %s", err) + } + return nil +} + +// deletePid is used to remove the PID on exit. +func (r *Runner) deletePid() error { + path := r.config.PidFile + if path == "" { + return nil + } + + log.Printf("[DEBUG] removing pid file at %q", path) + + stat, err := os.Stat(path) + if err != nil { + return fmt.Errorf("runner: could not remove pid file: %s", err) + } + if stat.IsDir() { + return fmt.Errorf("runner: specified pid file path is directory") + } + + err = os.Remove(path) + if err != nil { + return fmt.Errorf("runner: could not remove pid file: %s", err) + } + return nil +} + // newAPIClient creates a new API client from the given config and func newAPIClient(config *Config) (*api.Client, error) { log.Printf("[INFO] (runner) creating consul/api client") diff --git a/runner_test.go b/runner_test.go index c6b780e..eea6648 100644 --- a/runner_test.go +++ b/runner_test.go @@ -3,7 +3,6 @@ package main import ( "io/ioutil" "os" - "path/filepath" "reflect" "strings" "testing" @@ -65,19 +64,6 @@ func TestNewRunner_initialize(t *testing.T) { } } -func TestConfigDefaultOverrides(t *testing.T) { - expected := "test/statuses" - - config := &Config{ - StatusDir: expected, - } - - r, _ := NewRunner(config, true) - if r.config.StatusDir != expected { - t.Errorf("expected StatusDir %q to be %q", r.config.StatusDir, expected) - } -} - func TestBuildConfig_singleFile(t *testing.T) { configFile := test.CreateTempfile([]byte(` consul = "127.0.0.1" @@ -136,23 +122,3 @@ func TestBuildConfig_EmptyDirectory(t *testing.T) { t.Fatalf("expected %q to contain %q", err.Error(), expected) } } - -func TestBuildConfig_BadConfigs(t *testing.T) { - configFile := test.CreateTempfile([]byte(` - totally not a vaild config - `), t) - defer test.DeleteTempfile(configFile, t) - - configDir := filepath.Dir(configFile.Name()) - - config := new(Config) - err := buildConfig(config, configDir) - if err == nil { - t.Fatalf("expected error, but nothing was returned") - } - - expected := "1 error(s) occurred" - if !strings.Contains(err.Error(), expected) { - t.Fatalf("expected %q to contain %q", err.Error(), expected) - } -}