Skip to content

Commit

Permalink
CLOUDP-110989: Adds OnResponseProcessed (#298)
Browse files Browse the repository at this point in the history
* CLOUDP-110989: Adds OnAfterRequestCompleted

* Fix interface

* Rename OnAfterRequestCompleted
  • Loading branch information
fmenezes committed May 31, 2022
1 parent 6256d9f commit 551edbf
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 28 deletions.
8 changes: 4 additions & 4 deletions auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ func (c *Config) Do(ctx context.Context, req *http.Request, v interface{}) (*atl
defer resp.Body.Close()

r := &atlas.Response{Response: resp}
if err2 := atlas.CheckResponse(resp); err2 != nil {
return r, err2
}

body := resp.Body

if c.withRaw {
Expand All @@ -159,6 +155,10 @@ func (c *Config) Do(ctx context.Context, req *http.Request, v interface{}) (*atl
body = io.NopCloser(raw)
}

if err2 := r.CheckResponse(body); err2 != nil {
return r, err2
}

if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, body)
Expand Down
38 changes: 29 additions & 9 deletions mongodbatlas/mongodbatlas.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ type Completer interface {
OnRequestCompleted(RequestCompletionCallback)
}

// ResponseProcessor interface for clients with callback.
type ResponseProcessor interface {
OnResponseProcessed(ResponseProcessedCallback)
}

// RequestDoer minimum interface for any service of the client.
type RequestDoer interface {
Doer
Expand Down Expand Up @@ -147,12 +152,16 @@ type Client struct {
CloudProviderSnapshotExportJobs CloudProviderSnapshotExportJobsService
FederatedSettings FederatedSettingsService

onRequestCompleted RequestCompletionCallback
onRequestCompleted RequestCompletionCallback
onResponseProcessed ResponseProcessedCallback
}

// RequestCompletionCallback defines the type of the request callback function.
type RequestCompletionCallback func(*http.Request, *http.Response)

// ResponseProcessedCallback defines the type of the after request completion callback function.
type ResponseProcessedCallback func(*Response)

type service struct {
Client RequestDoer
}
Expand Down Expand Up @@ -434,6 +443,11 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) {
c.onRequestCompleted = rc
}

// OnResponseProcessed sets the DO API request completion callback after it has been processed.
func (c *Client) OnResponseProcessed(rc ResponseProcessedCallback) {
c.onResponseProcessed = rc
}

// Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value
// pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface,
// the raw response will be written to v, without attempting to decode it.
Expand All @@ -458,10 +472,11 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res

response := &Response{Response: resp}

err = CheckResponse(resp)
if err != nil {
return response, err
}
defer func() {
if c.onResponseProcessed != nil {
c.onResponseProcessed(response)
}
}()

body := resp.Body

Expand All @@ -476,6 +491,11 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
body = io.NopCloser(raw)
}

err = response.CheckResponse(body)
if err != nil {
return response, err
}

if v != nil {
if w, ok := v.(io.Writer); ok {
_, err = io.Copy(w, body)
Expand Down Expand Up @@ -528,13 +548,13 @@ func (r *ErrorResponse) Is(target error) bool {
// CheckResponse checks the API response for errors, and returns them if present. A response is considered an
// error if it has a status code outside the 200 range. API error responses are expected to have either no response
// body, or a JSON response body that maps to ErrorResponse. Any other response body will be silently ignored.
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; c >= 200 && c <= 299 {
func (resp *Response) CheckResponse(body io.ReadCloser) error {
if c := resp.StatusCode; c >= 200 && c <= 299 {
return nil
}

errorResponse := &ErrorResponse{Response: r}
data, err := io.ReadAll(r.Body)
errorResponse := &ErrorResponse{Response: resp.Response}
data, err := io.ReadAll(body)
if err == nil && len(data) > 0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
Expand Down
72 changes: 57 additions & 15 deletions mongodbatlas/mongodbatlas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -450,22 +450,24 @@ func TestDo_noContent(t *testing.T) {
}

func TestCheckResponse(t *testing.T) {
res := &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(
strings.NewReader(
`{"error":409, "errorCode": "GROUP_ALREADY_EXISTS", "reason":"Conflict", "detail":"A group with name \"Test\" already exists"}`,
res := &Response{
Response: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(
strings.NewReader(
`{"error":409, "errorCode": "GROUP_ALREADY_EXISTS", "reason":"Conflict", "detail":"A group with name \"Test\" already exists"}`,
),
),
),
},
}
var target *ErrorResponse
if !errors.As(CheckResponse(res), &target) {
if !errors.As(res.CheckResponse(res.Body), &target) {
t.Fatalf("Expected error response.")
}

expected := &ErrorResponse{
Response: res,
Response: res.Response,
HTTPCode: 409,
ErrorCode: "GROUP_ALREADY_EXISTS",
Reason: "Conflict",
Expand All @@ -479,18 +481,20 @@ func TestCheckResponse(t *testing.T) {
// ensure that we properly handle API errors that do not contain a response
// body.
func TestCheckResponse_noBody(t *testing.T) {
res := &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader("")),
res := &Response{
Response: &http.Response{
Request: &http.Request{},
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader("")),
},
}
var target *ErrorResponse
if !errors.As(CheckResponse(res), &target) {
if !errors.As(res.CheckResponse(res.Body), &target) {
t.Errorf("Expected error response.")
}

expected := &ErrorResponse{
Response: res,
Response: res.Response,
}
if !errors.Is(target, expected) {
t.Errorf("Got = %#v, expected %#v", target, expected)
Expand Down Expand Up @@ -545,6 +549,44 @@ func TestDo_completion_callback(t *testing.T) {
}
}

func TestDo_response_processed_callback(t *testing.T) {
client, mux, teardown := setup()
defer teardown()

type foo struct {
A string
}

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if m := http.MethodGet; m != r.Method {
t.Errorf("Request method = %v, expected %v", r.Method, m)
}
fmt.Fprint(w, `{"A":"a"}`)
})

client.withRaw = true
req, _ := client.NewRequest(ctx, http.MethodGet, ".", nil)
body := new(foo)
var completedReq *http.Request
var completedResp string
client.OnResponseProcessed(func(resp *Response) {
completedReq = req
completedResp = string(resp.Raw)
})
_, err := client.Do(context.Background(), req, body)
client.withRaw = false
if err != nil {
t.Fatalf("Do(): %v", err)
}
if !reflect.DeepEqual(req, completedReq) {
t.Errorf("Completed request = %v, expected %v", completedReq, req)
}
const expected = `{"A":"a"}`
if !strings.Contains(completedResp, expected) {
t.Errorf("expected response to contain %v, Response = %v", expected, completedResp)
}
}

func TestCustomUserAgent(t *testing.T) {
ua := fmt.Sprintf("testing/%s", Version)
c, err := New(nil, SetUserAgent(ua))
Expand Down

0 comments on commit 551edbf

Please sign in to comment.