Skip to content

Commit

Permalink
vault: support allowing tokens to expire without refresh (#19691)
Browse files Browse the repository at this point in the history
Some users with batch workloads or short-lived prestart tasks want to derive a
Vaul token, use it, and then allow it to expire without requiring a constant
refresh. Add the `vault.allow_token_expiration` field, which works only with the
Workload Identity workflow and not the legacy workflow.

When set to true, this disables the client's renewal loop in the
`vault_hook`. When Vault revokes the token lease, the token will no longer be
valid. The client will also now automatically detect if the Vault auth
configuration does not allow renewals and will disable the renewal loop
automatically.

Note this should only be used when a secret is requested from Vault once at the
start of a task or in a short-lived prestart task. Long-running tasks should
never set `allow_token_expiration=true` if they obtain Vault secrets via
`template` blocks, as the Vault token will expire and the template runner will
continue to make failing requests to Vault until the `vault_retry` attempts are
exhausted.

Fixes: #8690
  • Loading branch information
tgross committed Jan 10, 2024
1 parent 5267eec commit 0935f44
Show file tree
Hide file tree
Showing 14 changed files with 284 additions and 121 deletions.
7 changes: 7 additions & 0 deletions .changelog/19691.txt
@@ -0,0 +1,7 @@
```release-note:improvement
vault: Add `allow_token_expiration` field to allow Vault tokens to expire without renewal for short-lived tasks
```

```release-note:improvement
vault: Nomad clients will no longer attempt to renew Vault tokens that cannot be renewed
```
20 changes: 12 additions & 8 deletions api/tasks.go
Expand Up @@ -937,14 +937,15 @@ func (tmpl *Template) Canonicalize() {
}

type Vault struct {
Policies []string `hcl:"policies,optional"`
Role string `hcl:"role,optional"`
Namespace *string `mapstructure:"namespace" hcl:"namespace,optional"`
Cluster string `hcl:"cluster,optional"`
Env *bool `hcl:"env,optional"`
DisableFile *bool `mapstructure:"disable_file" hcl:"disable_file,optional"`
ChangeMode *string `mapstructure:"change_mode" hcl:"change_mode,optional"`
ChangeSignal *string `mapstructure:"change_signal" hcl:"change_signal,optional"`
Policies []string `hcl:"policies,optional"`
Role string `hcl:"role,optional"`
Namespace *string `mapstructure:"namespace" hcl:"namespace,optional"`
Cluster string `hcl:"cluster,optional"`
Env *bool `hcl:"env,optional"`
DisableFile *bool `mapstructure:"disable_file" hcl:"disable_file,optional"`
ChangeMode *string `mapstructure:"change_mode" hcl:"change_mode,optional"`
ChangeSignal *string `mapstructure:"change_signal" hcl:"change_signal,optional"`
AllowTokenExpiration *bool `mapstructure:"allow_token_expiration" hcl:"allow_token_expiration,optional"`
}

func (v *Vault) Canonicalize() {
Expand All @@ -966,6 +967,9 @@ func (v *Vault) Canonicalize() {
if v.ChangeSignal == nil {
v.ChangeSignal = pointerOf("SIGHUP")
}
if v.AllowTokenExpiration == nil {
v.AllowTokenExpiration = pointerOf(false)
}
}

// NewTask creates and initializes a new Task.
Expand Down
13 changes: 7 additions & 6 deletions api/tasks_test.go
Expand Up @@ -459,12 +459,13 @@ func TestTask_Canonicalize_Vault(t *testing.T) {
name: "empty",
input: &Vault{},
expected: &Vault{
Env: pointerOf(true),
DisableFile: pointerOf(false),
Namespace: pointerOf(""),
Cluster: "default",
ChangeMode: pointerOf("restart"),
ChangeSignal: pointerOf("SIGHUP"),
Env: pointerOf(true),
DisableFile: pointerOf(false),
Namespace: pointerOf(""),
Cluster: "default",
ChangeMode: pointerOf("restart"),
ChangeSignal: pointerOf("SIGHUP"),
AllowTokenExpiration: pointerOf(false),
},
},
}
Expand Down
47 changes: 33 additions & 14 deletions client/allocrunner/taskrunner/vault_hook.go
Expand Up @@ -123,26 +123,30 @@ type vaultHook struct {
// deriveTokenFunc is the function used to derive Vault tokens.
deriveTokenFunc deriveTokenFunc

// allowTokenExpiration determines if a renew loop should be run
allowTokenExpiration bool

// future is used to wait on retrieving a Vault token
future *tokenFuture
}

func newVaultHook(config *vaultHookConfig) *vaultHook {
ctx, cancel := context.WithCancel(context.Background())
h := &vaultHook{
vaultBlock: config.vaultBlock,
vaultConfigsFunc: config.vaultConfigsFunc,
clientFunc: config.clientFunc,
eventEmitter: config.events,
lifecycle: config.lifecycle,
updater: config.updater,
alloc: config.alloc,
task: config.task,
firstRun: true,
ctx: ctx,
cancel: cancel,
future: newTokenFuture(),
widmgr: config.widmgr,
vaultBlock: config.vaultBlock,
vaultConfigsFunc: config.vaultConfigsFunc,
clientFunc: config.clientFunc,
eventEmitter: config.events,
lifecycle: config.lifecycle,
updater: config.updater,
alloc: config.alloc,
task: config.task,
firstRun: true,
ctx: ctx,
cancel: cancel,
future: newTokenFuture(),
widmgr: config.widmgr,
allowTokenExpiration: config.vaultBlock.AllowTokenExpiration,
}
h.logger = config.logger.Named(h.Name())

Expand Down Expand Up @@ -237,6 +241,9 @@ func (h *vaultHook) Shutdown() {
func (h *vaultHook) run(token string) {
// Helper for stopping token renewal
stopRenewal := func() {
if h.allowTokenExpiration {
return
}
if err := h.client.StopRenewToken(h.future.Get()); err != nil {
h.logger.Warn("failed to stop token renewal", "error", err)
}
Expand Down Expand Up @@ -280,6 +287,12 @@ OUTER:
}
}

if h.allowTokenExpiration {
h.future.Set(token)
h.logger.Debug("Vault token will not renew")
return
}

// Start the renewal process.
//
// This is the initial renew of the token which we derived from the
Expand Down Expand Up @@ -430,7 +443,7 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
}

// Derive Vault token with signed identity.
token, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
token, renewable, err := h.client.DeriveTokenWithJWT(h.ctx, vaultclient.JWTLoginRequest{
JWT: signed.JWT,
Role: role,
Namespace: h.vaultBlock.Namespace,
Expand All @@ -442,6 +455,12 @@ func (h *vaultHook) deriveVaultTokenJWT() (string, error) {
)
}

// If the token cannot be renewed, it doesn't matter if the user set
// allow_token_expiration or not, so override the requested behavior
if !renewable {
h.allowTokenExpiration = true
}

return token, nil
}

Expand Down
100 changes: 77 additions & 23 deletions client/allocrunner/taskrunner/vault_hook_test.go
Expand Up @@ -112,11 +112,13 @@ func TestTaskRunner_VaultHook(t *testing.T) {
ci.Parallel(t)

testCases := []struct {
name string
task *structs.Task
configs map[string]*sconfig.VaultConfig
expectRole string
expectLegacy bool
name string
task *structs.Task
configs map[string]*sconfig.VaultConfig
configNonrenewable bool
expectRole string
expectLegacy bool
expectNoRenew bool
}{
{
name: "legacy flow",
Expand Down Expand Up @@ -205,14 +207,40 @@ func TestTaskRunner_VaultHook(t *testing.T) {
},
},
},
{
name: "job requests no renewal",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
AllowTokenExpiration: true,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
expectNoRenew: true,
},
{
name: "tokens are not renewable",
task: &structs.Task{
Vault: &structs.Vault{
Cluster: structs.VaultDefaultCluster,
},
Identities: []*structs.WorkloadIdentity{
{Name: "vault_default"},
},
},
configNonrenewable: true,
expectNoRenew: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
alloc := mock.MinAlloc()
alloc.Job.TaskGroups[0].Tasks[0] = tc.task

hook := setupTestVaultHook(t, &vaultHookConfig{
hookConfig := &vaultHookConfig{
task: tc.task,
alloc: alloc,
vaultConfigsFunc: func(hclog.Logger) map[string]*sconfig.VaultConfig {
Expand All @@ -223,7 +251,17 @@ func TestTaskRunner_VaultHook(t *testing.T) {
"default": sconfig.DefaultVaultConfig(),
}
},
})
}

if tc.configNonrenewable {
hookConfig.clientFunc = func(cluster string) (vaultclient.VaultClient, error) {
client := &vaultclient.MockVaultClient{}
client.SetRenewable(false)
return client, nil
}
}

hook := setupTestVaultHook(t, hookConfig)

// Ensure Prestart() returns within a reasonable time.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
Expand Down Expand Up @@ -293,8 +331,12 @@ func TestTaskRunner_VaultHook(t *testing.T) {
}

// Token must be set for renewal.
must.MapLen(t, 1, client.RenewTokens())
must.NotNil(t, client.RenewTokens()[updater.currentToken])
if tc.expectNoRenew {
must.MapEmpty(t, client.RenewTokens())
} else {
must.MapLen(t, 1, client.RenewTokens())
must.NotNil(t, client.RenewTokens()[updater.currentToken])
}

// PrestartDone must be false so we can recover tokens.
// firstRun is used to prevent multiple executions.
Expand All @@ -307,6 +349,14 @@ func TestTaskRunner_VaultHook(t *testing.T) {
must.Wait(t, wait.InitialSuccess(
wait.ErrorFunc(func() error {
tokens := client.StoppedTokens()

if tc.expectNoRenew {
if len(tokens) != 0 {
return fmt.Errorf("expected no stopped tokens when renewal is disabled, got %d", len(tokens))
}
return nil
}

if len(tokens) != 1 {
return fmt.Errorf("expected stopped tokens to be %d, got %d", 1, len(tokens))
}
Expand Down Expand Up @@ -424,11 +474,12 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
t.Cleanup(cancel)

// Set unrecoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) {
// Cancel the context to simulate the task being killed.
cancel()
return "", structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
})
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
// Cancel the context to simulate the task being killed.
cancel()
return "", false, structs.NewRecoverableError(errors.New("unrecoverable test error"), false)
})

err := hook.Prestart(ctx, req, &resp)
must.NoError(t, err)
Expand Down Expand Up @@ -472,16 +523,18 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
t.Cleanup(cancel)

// Set recoverable error.
mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) {
return "", structs.NewRecoverableError(errors.New("recoverable test error"), true)
})
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "", false, structs.NewRecoverableError(errors.New("recoverable test error"), true)
})

go func() {
// Wait a bit for the first error then fix token renewal.
time.Sleep(time.Second)
mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) {
return "secret", nil
})
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
})

}()
err := hook.Prestart(ctx, req, &resp)
Expand Down Expand Up @@ -516,9 +569,10 @@ func TestTaskRunner_VaultHook_deriveError(t *testing.T) {
t.Cleanup(cancel)

// Derive predictable token and fail renew request.
mockVaultClient.SetDeriveTokenWithJWTFn(func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, error) {
return "secret", nil
})
mockVaultClient.SetDeriveTokenWithJWTFn(
func(_ context.Context, _ vaultclient.JWTLoginRequest) (string, bool, error) {
return "secret", true, nil
})
mockVaultClient.SetRenewTokenError("secret", errors.New("test error"))

go func() {
Expand Down
18 changes: 9 additions & 9 deletions client/vaultclient/vaultclient.go
Expand Up @@ -62,8 +62,8 @@ type VaultClient interface {
DeriveToken(*structs.Allocation, []string) (map[string]string, error)

// DeriveTokenWithJWT returns a Vault ACL token using the JWT login
// endpoint.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, error)
// endpoint, along with whether or not the token is renewable.
DeriveTokenWithJWT(context.Context, JWTLoginRequest) (string, bool, error)

// GetConsulACL fetches the Consul ACL token required for the task
GetConsulACL(string, string) (*vaultapi.Secret, error)
Expand Down Expand Up @@ -293,12 +293,12 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string)
}

// DeriveTokenWithJWT returns a Vault ACL token using the JWT login endpoint.
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, error) {
func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginRequest) (string, bool, error) {
if !c.config.IsEnabled() {
return "", fmt.Errorf("vault client not enabled")
return "", false, fmt.Errorf("vault client not enabled")
}
if !c.isRunning() {
return "", fmt.Errorf("vault client is not running")
return "", false, fmt.Errorf("vault client is not running")
}

c.lock.Lock()
Expand All @@ -319,20 +319,20 @@ func (c *vaultClient) DeriveTokenWithJWT(ctx context.Context, req JWTLoginReques
},
)
if err != nil {
return "", fmt.Errorf("failed to login with JWT: %v", err)
return "", false, fmt.Errorf("failed to login with JWT: %v", err)
}
if s == nil {
return "", errors.New("JWT login returned an empty secret")
return "", false, errors.New("JWT login returned an empty secret")
}
if s.Auth == nil {
return "", errors.New("JWT login did not return a token")
return "", false, errors.New("JWT login did not return a token")
}

for _, w := range s.Warnings {
c.logger.Warn("JWT login warning", "warning", w)
}

return s.Auth.ClientToken, nil
return s.Auth.ClientToken, s.Auth.Renewable, nil
}

// GetConsulACL creates a vault API client and reads from vault a consul ACL
Expand Down
5 changes: 3 additions & 2 deletions client/vaultclient/vaultclient_test.go
Expand Up @@ -217,12 +217,13 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {

// Derive Vault token using signed JWT.
jwtStr := signedWIDs[0].JWT
token, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, renewable, err := c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Namespace: "default",
})
must.NoError(t, err)
must.NotEq(t, "", token)
must.True(t, renewable)

// Verify token has expected properties.
v.Client.SetToken(token)
Expand Down Expand Up @@ -257,7 +258,7 @@ func TestVaultClient_DeriveTokenWithJWT(t *testing.T) {
must.Eq(t, []any{"deny"}, (s.Data[pathDenied]).([]any))

// Derive Vault token with non-existing role.
token, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
token, _, err = c.DeriveTokenWithJWT(context.Background(), JWTLoginRequest{
JWT: jwtStr,
Role: "test",
Namespace: "default",
Expand Down

0 comments on commit 0935f44

Please sign in to comment.