Skip to content

Commit

Permalink
Change how request modifiers behave for slices and maps. (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
averche committed Aug 2, 2023
1 parent e6c4201 commit 27cd603
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 22 deletions.
20 changes: 16 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,25 @@ for engine := range resp.Data {
### Modifying Requests

You can modify the requests in one of two ways, either at the client level or by
decorating individual requests:
decorating individual requests. In case both client-level and request-specific
modifiers are present, the following rules will apply:

- For scalar values (such as `vault.WithToken` example below), the
request-specific decorators will take precedence over the client-level
settings.
- For slices (e.g. `vault.WithResponseCallbacks`), the request-specific
decorators will be appended to the client-level settings for the given
request.
- For maps (e.g. `vault.WithCustomHeaders`), the request-specific decorators
will be merged into the client-level settings using `maps.Copy` semantics
(appended, overwriting the existing keys) for the given request.

```go
// all subsequent requests will use the given token & namespace
_ = client.SetToken("my-token")
_ = client.SetNamespace("my-namespace")

// per-request decorators take precedence over the client-level settings
// for scalar settings, request-specific decorators take precedence
resp, err := client.Secrets.KvV2Read(
ctx,
"my-secret",
Expand Down Expand Up @@ -382,9 +393,10 @@ client.SetResponseCallbacks(func(req *http.Request, resp *http.Response) {
})
```

Alternatively, `vault.WithRequestCallbacks(..)` /
Additionally, `vault.WithRequestCallbacks(..)` /
`vault.WithResponseCallbacks(..)` can be used to inject callbacks for individual
requests:
requests. These request-level callbacks will be appended to the list of the
respective client-level callbacks for the given request.

```go
resp, err := client.Secrets.KvV2Read(
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/hashicorp/go-rootcerts v1.0.2
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2
github.com/stretchr/testify v1.8.0
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b
golang.org/x/sys v0.4.0
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b h1:r+vk0EmXNmekl0S0BascoeeoHk/L7wmaW2QF90K+kYI=
golang.org/x/exp v0.0.0-20230801115018-d63ba01acd4b/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af h1:Yx9k8YCG3dvF87UAn2tu2HQLf2dt/eR1bXxpLMWeH+Y=
Expand Down
36 changes: 23 additions & 13 deletions request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"unicode"

"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
)

type (
Expand Down Expand Up @@ -245,7 +246,14 @@ func (m *requestModifiers) additionalQueryParametersOrDefault() url.Values {
return m.additionalQueryParameters
}

// mergeRequestModifiers merges the values from *rhs into *lhs.
// mergeRequestModifiers merges the values in *rhs into *lhs. The merging is
// done according the following rules:
//
// - for scalars : the rhs values, if present, will overwrite the lhs values
// - for slices : the rhs values will be appended to the lhs values
// - for maps
// -- new keys : the rhs values will be appended to the lhs values
// -- existing keys : the rhs values will overwrite the corresponding lhs values
func mergeRequestModifiers(lhs, rhs *requestModifiers) {
if rhs.headers.userAgent != "" {
lhs.headers.userAgent = rhs.headers.userAgent
Expand All @@ -259,9 +267,10 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) {
lhs.headers.namespace = rhs.headers.namespace
}

if len(rhs.headers.mfaCredentials) != 0 {
lhs.headers.mfaCredentials = rhs.headers.mfaCredentials
}
lhs.headers.mfaCredentials = append(
lhs.headers.mfaCredentials,
rhs.headers.mfaCredentials...,
)

if rhs.headers.responseWrappingTTL != 0 {
lhs.headers.responseWrappingTTL = rhs.headers.responseWrappingTTL
Expand All @@ -271,17 +280,18 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) {
lhs.headers.replicationForwardingMode = rhs.headers.replicationForwardingMode
}

if len(rhs.headers.customHeaders) != 0 {
lhs.headers.customHeaders = rhs.headers.customHeaders
}
// in case of key collisions, the rhs keys will take precedence
maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders)

if len(rhs.requestCallbacks) != 0 {
lhs.requestCallbacks = rhs.requestCallbacks
}
lhs.requestCallbacks = append(
lhs.requestCallbacks,
rhs.requestCallbacks...,
)

if len(rhs.responseCallbacks) != 0 {
lhs.responseCallbacks = rhs.responseCallbacks
}
lhs.responseCallbacks = append(
lhs.responseCallbacks,
rhs.responseCallbacks...,
)
}

func validateToken(token string) error {
Expand Down
48 changes: 48 additions & 0 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package vault

import (
"fmt"
"net/http"
"testing"

Expand Down Expand Up @@ -129,3 +130,50 @@ func Test_mergeRequestModifiers_overwrite(t *testing.T) {
})
}
}

func Test_mergeRequestModifiers_append(t *testing.T) {
requestCallback1 := func(r *http.Request) {
t.Logf("callback 1: %v", *r)
}
requestCallback2 := func(r *http.Request) {
t.Logf("callback 2: %v", *r)
}
responseCallback := func(req *http.Request, resp *http.Response) {
t.Logf("callback 3: %v, %v", *req, *resp)
}

cases := map[string]struct {
name string
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"custom-headers": {
lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil}}},
rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"b": nil, "c": nil}}},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil, "c": nil}}},
},
"request-callbacks": {
lhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1}},
rhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback2}},
expected: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1, requestCallback2}},
},
"response-callbacks": {
lhs: requestModifiers{responseCallbacks: []ResponseCallback{}},
rhs: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
expected: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mergeRequestModifiers(&tc.lhs, &tc.rhs)

// testify doesn't currently work with func types; stringify instead
require.Equal(t, fmt.Sprintf("%v", tc.expected.requestCallbacks), fmt.Sprintf("%v", tc.lhs.requestCallbacks))
require.Equal(t, fmt.Sprintf("%v", tc.expected.responseCallbacks), fmt.Sprintf("%v", tc.lhs.responseCallbacks))

require.Equal(t, tc.expected.headers.customHeaders, tc.lhs.headers.customHeaders)
})
}
}
11 changes: 6 additions & 5 deletions request_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func WithResponseWrapping(ttl time.Duration) RequestOption {
}

// WithCustomHeaders sets custom headers for the next request; these headers
// take precedence over the client-level custom headers. The internal prefix
// 'X-Vault-' is not permitted for the header keys.
// will be appended to any custom headers set at the client level. The internal
// prefix 'X-Vault-' is not permitted for the header keys.
func WithCustomHeaders(headers http.Header) RequestOption {
return func(m *requestModifiers) error {
if err := validateCustomHeaders(headers); err != nil {
Expand All @@ -87,7 +87,8 @@ func WithCustomHeaders(headers http.Header) RequestOption {
}

// WithRequestCallbacks sets callbacks which will be invoked before the next
// request; these take precedence over the client-level request callbacks.
// request; these callbacks will be appended to the list of the callbacks set
// at the client-level.
func WithRequestCallbacks(callbacks ...RequestCallback) RequestOption {
return func(m *requestModifiers) error {
m.requestCallbacks = callbacks
Expand All @@ -96,8 +97,8 @@ func WithRequestCallbacks(callbacks ...RequestCallback) RequestOption {
}

// WithResponseCallbacks sets callbacks which will be invoked after a
// successful response within the next request; these take precedence over the
// client-level response callbacks.
// successful response within the next request; these callbacks will be
// appended to the list of the callbacks set at the client-level.
func WithResponseCallbacks(callbacks ...ResponseCallback) RequestOption {
return func(m *requestModifiers) error {
m.responseCallbacks = callbacks
Expand Down

0 comments on commit 27cd603

Please sign in to comment.