Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reauth #17219

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion api/client.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
package api

import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"path"
"regexp"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -1176,7 +1180,7 @@ START:
}

if checkRetry == nil {
checkRetry = DefaultRetryPolicy
checkRetry = TokenMismatchedRetryPolicy
}

client := &retryablehttp.Client{
Expand Down Expand Up @@ -1525,6 +1529,35 @@ func ForwardAlways() RequestCallback {
}
}

var MismatchedTokenErrorRe = regexp.MustCompile(`token mac for token_version:1 hmac:.*is incorrect: err.*`)

func TokenMismatchedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
retry, err := DefaultRetryPolicy(ctx, resp, err)
if err != nil {
return retry, err
}
if retry && resp != nil && resp.StatusCode == 500 {
// We have an error. Let's look into the body to find out
// whether uptream server has changed
bodyBuf := &bytes.Buffer{}
if _, err := io.Copy(bodyBuf, resp.Body); err != nil {
// don't propagate other errors
return true, nil
}

resp.Body.Close()
resp.Body = ioutil.NopCloser(bodyBuf)

if MismatchedTokenErrorRe.MatchString(bodyBuf.String()) {
return false, nil
}
}
if retry {
return true, nil
}
return false, nil
}

// DefaultRetryPolicy is the default retry policy used by new Client objects.
// It is the same as retryablehttp.DefaultRetryPolicy except that it also retries
// 412 requests, which are returned by Vault when a X-Vault-Index header isn't
Expand Down
9 changes: 8 additions & 1 deletion command/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,9 @@ func (c *AgentCommand) Run(args []string) int {
c.UI.Output("==> Vault agent started! Log data will stream in below:\n")
}

var reauthCh = make(chan struct{})
m := sync.Mutex{}
cond := sync.NewCond(&m)
var leaseCache *cache.LeaseCache
var previousToken string
// Parse agent listener configurations
Expand Down Expand Up @@ -680,7 +683,7 @@ func (c *AgentCommand) Run(args []string) int {
proxyVaultToken := !config.Cache.ForceAutoAuthToken

// Create the request handler
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, proxyVaultToken)
cacheHandler := cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, proxyVaultToken, reauthCh, cond)

var listeners []net.Listener

Expand Down Expand Up @@ -791,6 +794,7 @@ func (c *AgentCommand) Run(args []string) int {
if method != nil {
enableTokenCh := len(config.Templates) > 0
ah := auth.NewAuthHandler(&auth.AuthHandlerConfig{
ReauthCh: reauthCh,
Logger: c.logger.Named("auth.handler"),
Client: c.client,
WrapTTL: config.AutoAuth.Method.WrapTTL,
Expand All @@ -804,6 +808,7 @@ func (c *AgentCommand) Run(args []string) int {
Logger: c.logger.Named("sink.server"),
Client: client,
ExitAfterAuth: exitAfterAuth,
Cond: cond,
})

ts := template.NewServer(&template.ServerConfig{
Expand Down Expand Up @@ -833,11 +838,13 @@ func (c *AgentCommand) Run(args []string) int {
// Start goroutine to drain from ah.OutputCh from this point onward
// to prevent ah.Run from being blocked.
go func() {
c.logger.Info("start goroutine to drain from ah.OutputCh")
for {
select {
case <-ctx.Done():
return
case <-ah.OutputCh:
c.logger.Info("drain from ah.OutputCh")
}
}
}()
Expand Down
24 changes: 20 additions & 4 deletions command/agent/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ type AuthConfig struct {
Config map[string]interface{}
}

type OutputInfo struct {
IsReauth bool
Data string
}

// AuthHandler is responsible for keeping a token alive and renewed and passing
// new tokens to the sink server
type AuthHandler struct {
OutputCh chan string
reauthCh chan struct{}
isReauth bool
OutputCh chan OutputInfo
TemplateTokenCh chan string
token string
logger hclog.Logger
Expand All @@ -60,6 +67,7 @@ type AuthHandler struct {
}

type AuthHandlerConfig struct {
ReauthCh chan struct{}
Logger hclog.Logger
Client *api.Client
WrapTTL time.Duration
Expand All @@ -73,7 +81,7 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
ah := &AuthHandler{
// This is buffered so that if we try to output after the sink server
// has been shut down, during agent shutdown, we won't block
OutputCh: make(chan string, 1),
OutputCh: make(chan OutputInfo, 1),
TemplateTokenCh: make(chan string, 1),
token: conf.Token,
logger: conf.Logger,
Expand All @@ -83,6 +91,7 @@ func NewAuthHandler(conf *AuthHandlerConfig) *AuthHandler {
maxBackoff: conf.MaxBackoff,
enableReauthOnNewCredentials: conf.EnableReauthOnNewCredentials,
enableTemplateTokenCh: conf.EnableTemplateTokenCh,
reauthCh: conf.ReauthCh,
}

return ah
Expand Down Expand Up @@ -252,7 +261,8 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
continue
}
ah.logger.Info("authentication successful, sending wrapped token to sinks and pausing")
ah.OutputCh <- string(wrappedResp)
ah.OutputCh <- OutputInfo{IsReauth: ah.isReauth, Data: string(wrappedResp)}
ah.isReauth = false
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- string(wrappedResp)
}
Expand Down Expand Up @@ -284,7 +294,8 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
continue
}
ah.logger.Info("authentication successful, sending token to sinks")
ah.OutputCh <- secret.Auth.ClientToken
ah.OutputCh <- OutputInfo{IsReauth: ah.isReauth, Data: secret.Auth.ClientToken}
ah.isReauth = false
if ah.enableTemplateTokenCh {
ah.TemplateTokenCh <- secret.Auth.ClientToken
}
Expand Down Expand Up @@ -335,6 +346,11 @@ func (ah *AuthHandler) Run(ctx context.Context, am AuthMethod) error {
case <-credCh:
ah.logger.Info("auth method found new credentials, re-authenticating")
break LifetimeWatcherLoop

case <-ah.reauthCh:
ah.logger.Info("upstream is switched, re-authenticating")
ah.isReauth = true
break LifetimeWatcherLoop
}
}
}
Expand Down
18 changes: 15 additions & 3 deletions command/agent/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ func setupClusterAndAgentCommon(ctx context.Context, t *testing.T, coreConfig *v
mux := http.NewServeMux()
mux.Handle("/agent/v1/cache-clear", leaseCache.HandleCacheClear(ctx))

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, nil, true))
var reauthCh = make(chan struct{})
m := sync.Mutex{}
cond := sync.NewCond(&m)

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, nil, true, reauthCh, cond))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
Expand Down Expand Up @@ -248,7 +252,11 @@ func TestCache_AutoAuthTokenStripping(t *testing.T) {
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, mock.NewSink("testid"), true))
var reauthCh = make(chan struct{})
m := sync.Mutex{}
cond := sync.NewCond(&m)

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, mock.NewSink("testid"), true, reauthCh, cond))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
Expand Down Expand Up @@ -337,7 +345,11 @@ func TestCache_AutoAuthClientTokenProxyStripping(t *testing.T) {
mux := http.NewServeMux()
// mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, mock.NewSink(realToken), false))
var reauthCh = make(chan struct{})
m := sync.Mutex{}
cond := sync.NewCond(&m)

mux.Handle("/", Handler(ctx, cacheLogger, leaseCache, mock.NewSink(realToken), false, reauthCh, cond))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
Expand Down
21 changes: 19 additions & 2 deletions command/agent/cache/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"io/ioutil"
"net/http"
"sync"
"time"

"github.com/armon/go-metrics"
Expand All @@ -20,7 +21,7 @@ import (
"github.com/hashicorp/vault/sdk/logical"
)

func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink, proxyVaultToken bool) http.Handler {
func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSink sink.Sink, proxyVaultToken bool, reauthCh chan struct{}, cond *sync.Cond) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Info("received request", "method", r.Method, "path", r.URL.Path)

Expand Down Expand Up @@ -51,8 +52,24 @@ func Handler(ctx context.Context, logger hclog.Logger, proxier Proxier, inmemSin
Request: r,
RequestBody: reqBody,
}

resp, err := proxier.Send(ctx, req)
for resp !=nil && api.MismatchedTokenErrorRe.MatchString(string(resp.ResponseBody)) {
logger.Trace("token prior to reauth", "token", inmemSink.(sink.SinkReader).Token())
cond.L.Lock()
select {
case reauthCh <- struct{}{}:
logger.Trace("trigger reauthentication")
default:
}
logger.Trace("waiting for new valid token")
cond.Wait()
cond.L.Unlock()
logger.Trace("request woken up")
token = inmemSink.(sink.SinkReader).Token()
req.Token = token
resp, err = proxier.Send(ctx, req)
logger.Trace("token after reauth", "token", req.Token)
}
if err != nil {
// If this is a api.Response error, don't wrap the response.
if resp != nil && resp.Response.Error() != nil {
Expand Down
7 changes: 6 additions & 1 deletion command/agent/cache_end_to_end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"net/http"
"os"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -314,8 +315,12 @@ func TestCache_UsingAutoAuthToken(t *testing.T) {
mux := http.NewServeMux()
mux.Handle(consts.AgentPathCacheClear, leaseCache.HandleCacheClear(ctx))

var reauthCh = make(chan struct{})
m := sync.Mutex{}
cond := sync.NewCond(&m)

// Passing a non-nil inmemsink tells the agent to use the auto-auth token
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, true))
mux.Handle("/", cache.Handler(ctx, cacheLogger, leaseCache, inmemSink, true, reauthCh, cond))
server := &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
Expand Down
11 changes: 6 additions & 5 deletions command/agent/sink/file/sink_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

hclog "github.com/hashicorp/go-hclog"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/command/agent/auth"
"github.com/hashicorp/vault/command/agent/sink"
"github.com/hashicorp/vault/sdk/helper/logging"
)
Expand All @@ -31,15 +32,15 @@ func TestSinkServer(t *testing.T) {
})

uuidStr, _ := uuid.GenerateUUID()
in := make(chan string)
in := make(chan auth.OutputInfo)
sinks := []*sink.SinkConfig{fs1, fs2}
errCh := make(chan error)
go func() {
errCh <- ss.Run(ctx, in, sinks)
}()

// Seed a token
in <- uuidStr
in <- auth.OutputInfo{IsReauth: false, Data: uuidStr}

// Tell it to shut down and give it time to do so
timer := time.AfterFunc(3*time.Second, func() {
Expand Down Expand Up @@ -98,15 +99,15 @@ func TestSinkServerRetry(t *testing.T) {
Logger: log.Named("sink.server"),
})

in := make(chan string)
in := make(chan auth.OutputInfo)
sinks := []*sink.SinkConfig{{Sink: b1}, {Sink: b2}}
errCh := make(chan error)
go func() {
errCh <- ss.Run(ctx, in, sinks)
}()

// Seed a token
in <- "bad"
in <- auth.OutputInfo{IsReauth: false, Data: "bad"}

// During this time we should see it retry multiple times
time.Sleep(10 * time.Second)
Expand All @@ -117,7 +118,7 @@ func TestSinkServerRetry(t *testing.T) {
t.Fatal("bad try count")
}

in <- "good"
in <- auth.OutputInfo{IsReauth: false, Data: "good"}

time.Sleep(2 * time.Second)
if atomic.LoadUint32(&b1.tryCount) != 0 {
Expand Down
Loading