Skip to content

Commit

Permalink
Rewrite mergeRequestModifiers to operate on pointers rather than copi…
Browse files Browse the repository at this point in the history
…es. (#224)
  • Loading branch information
averche committed Jul 31, 2023
1 parent bbcc6be commit e6c4201
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 62 deletions.
26 changes: 10 additions & 16 deletions client_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path string,
body any,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
var buf bytes.Buffer

Expand All @@ -152,7 +152,7 @@ func sendStructuredRequestParseResponse[ResponseT any](
path,
&buf,
parameters,
requestModifiersPerRequest,
requestModifiers,
)
}

Expand All @@ -164,7 +164,7 @@ func sendRequestParseResponse[ResponseT any](
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*Response[ResponseT], error) {
// apply the client-level request timeout, if set
if client.configuration.RequestTimeout > 0 {
Expand All @@ -174,13 +174,10 @@ func sendRequestParseResponse[ResponseT any](
}

// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down Expand Up @@ -212,16 +209,13 @@ func sendRequestReturnRawResponse(
path string,
body io.Reader,
parameters url.Values,
requestModifiersPerRequest requestModifiers,
requestModifiers requestModifiers,
) (*http.Response, error) {
// clone the client-level request modifiers to prevent race conditions
requestModifiersClient := client.cloneClientRequestModifiers()
modifiers := client.cloneClientRequestModifiers()

// merge the client-level & request-level modifiers, preferring the later
modifiers := mergeRequestModifiers(
requestModifiersClient,
requestModifiersPerRequest,
)
// merge in the request-level request modifiers
mergeRequestModifiers(&modifiers, &requestModifiers)

req, err := client.newRequest(ctx, method, path, body, parameters, modifiers.headers)
if err != nil {
Expand Down
44 changes: 20 additions & 24 deletions request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,47 +245,43 @@ func (m *requestModifiers) additionalQueryParametersOrDefault() url.Values {
return m.additionalQueryParameters
}

// mergeRequestModifiers merges the two objects, preferring the per-request modifiers
func mergeRequestModifiers(perClient, perRequest requestModifiers) requestModifiers {
merged := perClient

if perRequest.headers.userAgent != "" {
merged.headers.userAgent = perRequest.headers.userAgent
// mergeRequestModifiers merges the values from *rhs into *lhs.
func mergeRequestModifiers(lhs, rhs *requestModifiers) {
if rhs.headers.userAgent != "" {
lhs.headers.userAgent = rhs.headers.userAgent
}

if perRequest.headers.token != "" {
merged.headers.token = perRequest.headers.token
if rhs.headers.token != "" {
lhs.headers.token = rhs.headers.token
}

if perRequest.headers.namespace != "" {
merged.headers.namespace = perRequest.headers.namespace
if rhs.headers.namespace != "" {
lhs.headers.namespace = rhs.headers.namespace
}

if len(perRequest.headers.mfaCredentials) != 0 {
merged.headers.mfaCredentials = perRequest.headers.mfaCredentials
if len(rhs.headers.mfaCredentials) != 0 {
lhs.headers.mfaCredentials = rhs.headers.mfaCredentials
}

if perRequest.headers.responseWrappingTTL != 0 {
merged.headers.responseWrappingTTL = perRequest.headers.responseWrappingTTL
if rhs.headers.responseWrappingTTL != 0 {
lhs.headers.responseWrappingTTL = rhs.headers.responseWrappingTTL
}

if perRequest.headers.replicationForwardingMode != ReplicationForwardNone {
merged.headers.replicationForwardingMode = perRequest.headers.replicationForwardingMode
if rhs.headers.replicationForwardingMode != ReplicationForwardNone {
lhs.headers.replicationForwardingMode = rhs.headers.replicationForwardingMode
}

if len(perRequest.headers.customHeaders) != 0 {
merged.headers.customHeaders = perRequest.headers.customHeaders
if len(rhs.headers.customHeaders) != 0 {
lhs.headers.customHeaders = rhs.headers.customHeaders
}

if len(perRequest.requestCallbacks) != 0 {
merged.requestCallbacks = perRequest.requestCallbacks
if len(rhs.requestCallbacks) != 0 {
lhs.requestCallbacks = rhs.requestCallbacks
}

if len(perRequest.responseCallbacks) != 0 {
merged.responseCallbacks = perRequest.responseCallbacks
if len(rhs.responseCallbacks) != 0 {
lhs.responseCallbacks = rhs.responseCallbacks
}

return merged
}

func validateToken(token string) error {
Expand Down
45 changes: 23 additions & 22 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,43 +88,44 @@ func Test_validateCustomHeaders(t *testing.T) {
}
}

func Test_mergeRequestModifiers(t *testing.T) {
func Test_mergeRequestModifiers_overwrite(t *testing.T) {
cases := map[string]struct {
name string
a requestModifiers
b requestModifiers
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"empty": {
a: requestModifiers{},
b: requestModifiers{},
lhs: requestModifiers{},
rhs: requestModifiers{},
expected: requestModifiers{},
},
"token-a": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-a"}},
"token-in-lhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
},
"token-b": {
a: requestModifiers{},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-rhs": {
lhs: requestModifiers{},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-a-b": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{token: "token-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-b"}},
"token-in-both": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-rhs"}},
},
"token-namespace": {
a: requestModifiers{headers: requestHeaders{token: "token-a"}},
b: requestModifiers{headers: requestHeaders{namespace: "namespace-b"}},
expected: requestModifiers{headers: requestHeaders{token: "token-a", namespace: "namespace-b"}},
"token-lhs-and-namespace-rhs": {
lhs: requestModifiers{headers: requestHeaders{token: "token-lhs"}},
rhs: requestModifiers{headers: requestHeaders{namespace: "namespace-rhs"}},
expected: requestModifiers{headers: requestHeaders{token: "token-lhs", namespace: "namespace-rhs"}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
require.Equal(t, tc.expected, mergeRequestModifiers(tc.a, tc.b))
mergeRequestModifiers(&tc.lhs, &tc.rhs)
require.Equal(t, tc.expected, tc.lhs)
})
}
}

0 comments on commit e6c4201

Please sign in to comment.