Skip to content

Commit

Permalink
agent: add -exit-after-auth flag (#7920) (#7983)
Browse files Browse the repository at this point in the history
* agent: add -exit-after-auth flag

* use short timeout for tests to prevent long test runs on error

* revert sdk/go.mod
  • Loading branch information
calvn authored and briankassouf committed Dec 9, 2019
1 parent 9d7f637 commit 7f04144
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
25 changes: 21 additions & 4 deletions command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ type AgentCommand struct {

startedCh chan (struct{}) // for tests

flagConfigs []string
flagLogLevel string
flagConfigs []string
flagLogLevel string
flagExitAfterAuth bool

flagTestVerifyOnly bool
flagCombineLogs bool
Expand Down Expand Up @@ -115,6 +116,15 @@ func (c *AgentCommand) Flags() *FlagSets {
"\"trace\", \"debug\", \"info\", \"warn\", and \"err\".",
})

f.BoolVar(&BoolVar{
Name: "exit-after-auth",
Target: &c.flagExitAfterAuth,
Default: false,
Usage: "If set to true, the agent will exit with code 0 after a single " +
"successful auth, where success means that a token was retrieved and " +
"all sinks successfully wrote it",
})

// Internal-only flags to follow.
//
// Why hello there little source code reader! Welcome to the Vault source
Expand Down Expand Up @@ -223,6 +233,13 @@ func (c *AgentCommand) Run(args []string) int {
config.Vault = new(agentConfig.Vault)
}

exitAfterAuth := config.ExitAfterAuth
f.Visit(func(fl *flag.Flag) {
if fl.Name == "exit-after-auth" {
exitAfterAuth = c.flagExitAfterAuth
}
})

c.setStringFlag(f, config.Vault.Address, &StringVar{
Name: flagNameAddress,
Target: &c.flagAddress,
Expand Down Expand Up @@ -524,7 +541,7 @@ func (c *AgentCommand) Run(args []string) int {
ss := sink.NewSinkServer(&sink.SinkServerConfig{
Logger: c.logger.Named("sink.server"),
Client: client,
ExitAfterAuth: config.ExitAfterAuth,
ExitAfterAuth: exitAfterAuth,
})
ssDoneCh = ss.DoneCh

Expand All @@ -534,7 +551,7 @@ func (c *AgentCommand) Run(args []string) int {
LogWriter: c.logWriter,
VaultConf: config.Vault,
Namespace: namespace,
ExitAfterAuth: config.ExitAfterAuth,
ExitAfterAuth: exitAfterAuth,
})
tsDoneCh = ts.DoneCh

Expand Down
51 changes: 40 additions & 11 deletions command/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ cache {
*/

func TestAgent_ExitAfterAuth(t *testing.T) {
t.Run("via_config", func(t *testing.T) {
testAgentExitAfterAuth(t, false)
})

t.Run("via_flag", func(t *testing.T) {
testAgentExitAfterAuth(t, true)
})
}

func testAgentExitAfterAuth(t *testing.T, viaFlag bool) {
logger := logging.NewVaultLogger(hclog.Trace)
coreConfig := &vault.CoreConfig{
Logger: logger,
Expand Down Expand Up @@ -313,8 +323,13 @@ func TestAgent_ExitAfterAuth(t *testing.T) {
logger.Trace("wrote test jwt", "path", in)
}

exitAfterAuthTemplText := "exit_after_auth = true"
if viaFlag {
exitAfterAuthTemplText = ""
}

config := `
exit_after_auth = true
%s
auto_auth {
method {
Expand All @@ -340,23 +355,37 @@ auto_auth {
}
`

config = fmt.Sprintf(config, in, sink1, sink2)
config = fmt.Sprintf(config, exitAfterAuthTemplText, in, sink1, sink2)
if err := ioutil.WriteFile(conf, []byte(config), 0600); err != nil {
t.Fatal(err)
} else {
logger.Trace("wrote test config", "path", conf)
}

// If this hangs forever until the test times out, exit-after-auth isn't
// working
ui, cmd := testAgentCommand(t, logger)
cmd.client = client
doneCh := make(chan struct{})
go func() {
ui, cmd := testAgentCommand(t, logger)
cmd.client = client

code := cmd.Run([]string{"-config", conf})
if code != 0 {
t.Errorf("expected %d to be %d", code, 0)
t.Logf("output from agent:\n%s", ui.OutputWriter.String())
t.Logf("error from agent:\n%s", ui.ErrorWriter.String())
args := []string{"-config", conf}
if viaFlag {
args = append(args, "-exit-after-auth")
}

code := cmd.Run(args)
if code != 0 {
t.Errorf("expected %d to be %d", code, 0)
t.Logf("output from agent:\n%s", ui.OutputWriter.String())
t.Logf("error from agent:\n%s", ui.ErrorWriter.String())
}
close(doneCh)
}()

select {
case <-doneCh:
break
case <-time.After(1 * time.Minute):
t.Fatal("timeout reached while waiting for agent to exit")
}

sink1Bytes, err := ioutil.ReadFile(sink1)
Expand Down

0 comments on commit 7f04144

Please sign in to comment.