Skip to content

Commit

Permalink
undo append operations
Browse files Browse the repository at this point in the history
  • Loading branch information
averche committed Jul 30, 2023
1 parent dbd0e73 commit ad8d007
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 67 deletions.
33 changes: 13 additions & 20 deletions request_modifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,7 @@ func (m *requestModifiers) additionalQueryParametersOrDefault() url.Values {
return m.additionalQueryParameters
}

// mergeRequestModifiers merges the values in *rhs into *lhs. The merging is
// done according the following rules:
//
// - for scalaras : the rhs values, if present, will overwrite the lhs values
// - for slices : the rhs values will be appended to the lhs values
// - for maps : the rhs values will be copied into the lhs using maps.Copy
// mergeRequestModifiers merges the values from *rhs into *lhs.
func mergeRequestModifiers(lhs, rhs *requestModifiers) {
if rhs.headers.userAgent != "" {
lhs.headers.userAgent = rhs.headers.userAgent
Expand All @@ -265,10 +260,9 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) {
lhs.headers.namespace = rhs.headers.namespace
}

lhs.headers.mfaCredentials = append(
lhs.headers.mfaCredentials,
rhs.headers.mfaCredentials...,
)
if len(rhs.headers.mfaCredentials) != 0 {
lhs.headers.mfaCredentials = rhs.headers.mfaCredentials
}

if rhs.headers.responseWrappingTTL != 0 {
lhs.headers.responseWrappingTTL = rhs.headers.responseWrappingTTL
Expand All @@ -278,18 +272,17 @@ func mergeRequestModifiers(lhs, rhs *requestModifiers) {
lhs.headers.replicationForwardingMode = rhs.headers.replicationForwardingMode
}

// in case of key collisions, the rhs keys will take precedence
maps.Copy(lhs.headers.customHeaders, rhs.headers.customHeaders)
if len(rhs.headers.customHeaders) != 0 {
lhs.headers.customHeaders = rhs.headers.customHeaders
}

lhs.requestCallbacks = append(
lhs.requestCallbacks,
rhs.requestCallbacks...,
)
if len(rhs.requestCallbacks) != 0 {
lhs.requestCallbacks = rhs.requestCallbacks
}

lhs.responseCallbacks = append(
lhs.responseCallbacks,
rhs.responseCallbacks...,
)
if len(rhs.responseCallbacks) != 0 {
lhs.responseCallbacks = rhs.responseCallbacks
}
}

func validateToken(token string) error {
Expand Down
47 changes: 0 additions & 47 deletions request_modifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,50 +130,3 @@ func Test_mergeRequestModifiers_overwrite(t *testing.T) {
})
}
}

func Test_mergeRequestModifiers_append(t *testing.T) {
requestCallback1 := func(r *http.Request) {
t.Logf("callback 1: %v", *r)
}
requestCallback2 := func(r *http.Request) {
t.Logf("callback 2: %v", *r)
}
responseCallback := func(req *http.Request, resp *http.Response) {
t.Logf("callback 3: %v, %v", *req, *resp)
}

cases := map[string]struct {
name string
lhs requestModifiers
rhs requestModifiers
expected requestModifiers
}{
"custom-headers": {
lhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil}}},
rhs: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"b": nil, "c": nil}}},
expected: requestModifiers{headers: requestHeaders{customHeaders: http.Header{"a": []string{"hi"}, "b": nil, "c": nil}}},
},
"request-callbacks": {
lhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1}},
rhs: requestModifiers{requestCallbacks: []RequestCallback{requestCallback2}},
expected: requestModifiers{requestCallbacks: []RequestCallback{requestCallback1, requestCallback2}},
},
"response-callbacks": {
lhs: requestModifiers{responseCallbacks: []ResponseCallback{}},
rhs: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
expected: requestModifiers{responseCallbacks: []ResponseCallback{responseCallback}},
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mergeRequestModifiers(&tc.lhs, &tc.rhs)

// testify doesn't currently work with func types; stringify instead
require.Equal(t, fmt.Sprintf("%v", tc.expected.requestCallbacks), fmt.Sprintf("%v", tc.lhs.requestCallbacks))
require.Equal(t, fmt.Sprintf("%v", tc.expected.responseCallbacks), fmt.Sprintf("%v", tc.lhs.responseCallbacks))

require.Equal(t, tc.expected.headers.customHeaders, tc.lhs.headers.customHeaders)
})
}
}

0 comments on commit ad8d007

Please sign in to comment.