Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how request modifiers behave for slices and maps. #225

Merged
merged 18 commits into from
Aug 2, 2023
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,17 @@ for engine := range resp.Data {
### Modifying Requests

You can modify the requests in one of two ways, either at the client level or by
decorating individual requests:
decorating individual requests. For scalar values (such as the strings in the
example below), request-level decorators take precedence over the client-level
settings. For maps and slices (e.g. `vault.WithCustomHeaders`), the
request-level modifiers are appended to the client-level ones.

```go
// all subsequent requests will use the given token & namespace
_ = client.SetToken("my-token")
_ = client.SetNamespace("my-namespace")

// per-request decorators take precedence over the client-level settings
// for scalar settings, request-specific decorators take precedence
resp, err := client.Secrets.KvV2Read(
ctx,
"my-secret",
Expand Down Expand Up @@ -382,9 +385,10 @@ client.SetResponseCallbacks(func(req *http.Request, resp *http.Response) {
})
```

Alternatively, `vault.WithRequestCallbacks(..)` /
Additionally, `vault.WithRequestCallbacks(..)` /
`vault.WithResponseCallbacks(..)` can be used to inject callbacks for individual
requests:
requests. These request-level callbacks will be appended to the list of the
respective client-level callbacks for the given request.

```go
resp, err := client.Secrets.KvV2Read(
Expand Down
26 changes: 10 additions & 16 deletions client_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path string,
body any,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
var buf bytes.Buffer

Expand All @@ -155,7 +155,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path,
&buf,
parameters,
requestModifiersPerRequest,
requestModifiers,
)
}

Expand All @@ -167,7 +167,7 @@ func sendRequestParseResponse[ResponseT any](
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
// apply the client-level request timeout, if set
if client.configuration.RequestTimeout > 0 {
Expand All @@ -177,13 +177,10 @@ func sendRequestParseResponse[ResponseT any](
}

// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down Expand Up @@ -215,16 +212,13 @@ func sendRequestReturnRawResponse(
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*http.Response, error) {
// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down
60 changes: 32 additions & 28 deletions request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"unicode"

"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
)

type (
Expand Down Expand Up @@ -245,47 +246,50 @@ func (m *requestModifiers) additionalQueryParametersOrDefault() url.Values {
return m.additionalQueryParameters
}

// mergeRequestModifiers merges the two objects, preferring the per-request modifiers
func mergeRequestModifiers(perClient, perRequest requestModifiers) requestModifiers {
merged := perClient

if perRequest.headers.userAgent != "" {
merged.headers.userAgent = perRequest.headers.userAgent
// mergeRequestModifiers merges the values in *rhs into *lhs. The merging is
// done according the following rules:
//
// - for scalaras : the rhs values, if present, will overwrite the lhs values
averche marked this conversation as resolved.
Show resolved Hide resolved
// - for slices : the rhs values will be appended to the lhs values
// - for maps : the rhs values will be copied into the lhs using maps.Copy
func mergeRequestModifiers(lhs, rhs *requestModifiers) {
if rhs.headers.userAgent != "" {
lhs.headers.userAgent = rhs.headers.userAgent
}

if perRequest.headers.token != "" {
merged.headers.token = perRequest.headers.token
if rhs.headers.token != "" {
lhs.headers.token = rhs.headers.token
}

if perRequest.headers.namespace != "" {
merged.headers.namespace = perRequest.headers.namespace
if rhs.headers.namespace != "" {
lhs.headers.namespace = rhs.headers.namespace
}

if len(perRequest.headers.mfaCredentials) != 0 {
merged.headers.mfaCredentials = perRequest.headers.mfaCredentials
}
lhs.headers.mfaCredentials = append(
lhs.headers.mfaCredentials,
rhs.headers.mfaCredentials...,
)

if perRequest.headers.responseWrappingTTL != 0 {
merged.headers.responseWrappingTTL = perRequest.headers.responseWrappingTTL
if rhs.headers.responseWrappingTTL != 0 {
lhs.headers.responseWrappingTTL = rhs.headers.responseWrappingTTL
}

if perRequest.headers.replicationForwardingMode != ReplicationForwardNone {
merged.headers.replicationForwardingMode = perRequest.headers.replicationForwardingMode
if rhs.headers.replicationForwardingMode != ReplicationForwardNone {
lhs.headers.replicationForwardingMode = rhs.headers.replicationForwardingMode
}

if len(perRequest.headers.customHeaders) != 0 {
merged.headers.customHeaders = perRequest.headers.customHeaders
}
// in case of key collisions, the rhs keys will take precedence
maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders)

if len(perRequest.requestCallbacks) != 0 {
merged.requestCallbacks = perRequest.requestCallbacks
}

if len(perRequest.responseCallbacks) != 0 {
merged.responseCallbacks = perRequest.responseCallbacks
}
lhs.requestCallbacks = append(
lhs.requestCallbacks,
rhs.requestCallbacks...,
)

return merged
lhs.responseCallbacks = append(
lhs.responseCallbacks,
rhs.responseCallbacks...,
)
}

func validateToken(token string) error {
Expand Down
93 changes: 71 additions & 22 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package vault

import (
"fmt"
"net/http"
"testing"

Expand Down Expand Up @@ -88,43 +89,91 @@ func Test_validateCustomHeaders(t *testing.T) {
}
}

func Test_mergeRequestModifiers(t *testing.T) {
func Test_mergeRequestModifiers_overwrite(t *testing.T) {
cases := map[string]struct {
name string
a requestModifiers
b requestModifiers
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"empty": {
a: requestModifiers{},
b: requestModifiers{},
lhs: requestModifiers{},
rhs: requestModifiers{},
expected: requestModifiers{},
},
"token-a": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-a"}},
"token-in-lhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
},
"token-b": {
a: requestModifiers{},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-rhs": {
lhs: requestModifiers{},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-a-b": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-both": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-namespace": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{namespace: "namespace-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-a", namespace: "namespace-b"}},
"token-lhs-and-namespace-rhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{namespace: "namespace-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs", namespace: "namespace-rhs"}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.expected, mergeRequestModifiers(tc.a, tc.b))
mergeRequestModifiers(&tc.lhs, &tc.rhs)
require.Equal(t, tc.expected, tc.lhs)
})
}
}

func Test_mergeRequestModifiers_append(t *testing.T) {
requestCallback1 := func(r *http.Request) {
t.Logf("callback 1: %v", *r)
}
requestCallback2 := func(r *http.Request) {
t.Logf("callback 2: %v", *r)
}
responseCallback := func(req *http.Request, resp *http.Response) {
t.Logf("callback 3: %v, %v", *req, *resp)
}

cases := map[string]struct {
name string
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"custom-headers": {
lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil}}},
rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"b": nil, "c": nil}}},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil, "c": nil}}},
},
"request-callbacks": {
lhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1}},
rhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback2}},
expected: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1, requestCallback2}},
},
"response-callbacks": {
lhs: requestModifiers{responseCallbacks: []ResponseCallback{}},
rhs: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
expected: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mergeRequestModifiers(&tc.lhs, &tc.rhs)

// testify doesn't currently work with func types; stringify instead
require.Equal(t, fmt.Sprintf("%v", tc.expected.requestCallbacks), fmt.Sprintf("%v", tc.lhs.requestCallbacks))
require.Equal(t, fmt.Sprintf("%v", tc.expected.responseCallbacks), fmt.Sprintf("%v", tc.lhs.responseCallbacks))

require.Equal(t, tc.expected.headers.customHeaders, tc.lhs.headers.customHeaders)
})
}
}
11 changes: 6 additions & 5 deletions request_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ func WithResponseWrapping(ttl time.Duration) RequestOption {
}

// WithCustomHeaders sets custom headers for the next request; these headers
// take precedence over the client-level custom headers. The internal prefix
// 'X-Vault-' is not permitted for the header keys.
// will be appended to any custom headers set at the client level. The internal
// prefix 'X-Vault-' is not permitted for the header keys.
func WithCustomHeaders(headers http.Header) RequestOption {
return func(m *requestModifiers) error {
if err := validateCustomHeaders(headers); err != nil {
Expand All @@ -87,7 +87,8 @@ func WithCustomHeaders(headers http.Header) RequestOption {
}

// WithRequestCallbacks sets callbacks which will be invoked before the next
// request; these take precedence over the client-level request callbacks.
// request; these callbacks will be appended to the list of the callbacks set
// at the client-level.
func WithRequestCallbacks(callbacks ...RequestCallback) RequestOption {
return func(m *requestModifiers) error {
m.requestCallbacks = callbacks
Expand All @@ -96,8 +97,8 @@ func WithRequestCallbacks(callbacks ...RequestCallback) RequestOption {
}

// WithResponseCallbacks sets callbacks which will be invoked after a
// successful response within the next request; these take precedence over the
// client-level response callbacks.
// successful response within the next request; these callbacks will be
// appended to the list of the callbacks set at the client-level.
func WithResponseCallbacks(callbacks ...ResponseCallback) RequestOption {
return func(m *requestModifiers) error {
m.responseCallbacks = callbacks
Expand Down