diff --git a/client.go b/client.go index b9857d7..21eb1e2 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package fetch import ( "context" + "errors" "fmt" "io" "net/http" @@ -30,6 +31,10 @@ func New(opts ...ClientOpts) Client { // WithProtocol returns function for setting http.Client. func WithHTTPClient(cli *http.Client) ClientOpts { return func(c *client) { + if cli == nil { + c.client = &http.Client{} + return + } c.client = cli } } @@ -44,11 +49,14 @@ func (c *client) Get(u *url.URL, opts ...RequestOpts) (Response, error) { func (c *client) GetWithContext(ctx context.Context, u *url.URL, opts ...RequestOpts) (Response, error) { req, err := request(ctx, http.MethodGet, u, nil, opts...) if err != nil { - return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String())) + if errors.Is(err, ErrInvalidURL) { + return nil, errs.Wrap(ErrInvalidURL, errs.WithCause(err), errs.WithContext("url", urlText(u))) + } + return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", urlText(u))) } resp, err := c.fetch(req) if err != nil { - return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String())) + return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", urlText(u))) } return resp, nil } @@ -63,11 +71,14 @@ func (c *client) Post(u *url.URL, payload io.Reader, opts ...RequestOpts) (Respo func (c *client) PostWithContext(ctx context.Context, u *url.URL, payload io.Reader, opts ...RequestOpts) (Response, error) { req, err := request(ctx, http.MethodPost, u, payload, opts...) if err != nil { - return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String())) + if errors.Is(err, ErrInvalidURL) { + return nil, errs.Wrap(ErrInvalidURL, errs.WithCause(err), errs.WithContext("url", urlText(u))) + } + return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", urlText(u))) } resp, err := c.fetch(req) if err != nil { - return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", u.String())) + return nil, errs.Wrap(ErrInvalidRequest, errs.WithCause(err), errs.WithContext("url", urlText(u))) } return resp, nil } @@ -100,6 +111,9 @@ func WithRequestHeaderSet(name, value string) RequestOpts { } func request(ctx context.Context, method string, u *url.URL, payload io.Reader, opts ...RequestOpts) (*http.Request, error) { + if u == nil { + return nil, errs.Wrap(ErrInvalidURL) + } req, err := http.NewRequestWithContext(ctx, method, u.String(), payload) if err != nil { return nil, errs.Wrap(err) @@ -114,6 +128,9 @@ func (c *client) fetch(request *http.Request) (Response, error) { if c == nil { c = New().(*client) } + if c.client == nil { + c.client = &http.Client{} + } r, err := c.client.Do(request) if err != nil { return nil, errs.Wrap(err) @@ -121,7 +138,7 @@ func (c *client) fetch(request *http.Request) (Response, error) { resp := &response{r} if resp.StatusCode == 0 || resp.StatusCode >= http.StatusBadRequest { err := ErrHTTPStatus - if cerr := resp.Close(); cerr != nil && !errs.Is(err, os.ErrClosed) { + if cerr := resp.Close(); cerr != nil && !errs.Is(cerr, os.ErrClosed) { err = errs.Join(cerr, err) } return nil, errs.Wrap(fmt.Errorf("%w: status %d", err, resp.StatusCode), errs.WithContext("status", resp.StatusCode)) @@ -129,6 +146,13 @@ func (c *client) fetch(request *http.Request) (Response, error) { return resp, nil } +func urlText(u *url.URL) string { + if u == nil { + return "" + } + return u.String() +} + /* Copyright 2021-2025 Spiegel * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/dependency.png b/dependency.png index a7f564c..2f44745 100644 Binary files a/dependency.png and b/dependency.png differ diff --git a/fetch_test.go b/fetch_test.go index 2198128..6bacdb9 100644 --- a/fetch_test.go +++ b/fetch_test.go @@ -4,7 +4,9 @@ import ( "context" "errors" "fmt" + "io" "net/http" + "net/http/httptest" "testing" "github.com/goark/fetch" @@ -44,6 +46,43 @@ func TestGet(t *testing.T) { } } +func TestGetWithNilURL(t *testing.T) { + resp, err := fetch.New().GetWithContext(context.Background(), nil) + if err == nil { + t.Fatal("error is nil, want ErrInvalidURL") + } + if !errors.Is(err, fetch.ErrInvalidURL) { + t.Fatalf("GetWithContext(nil) is %v, want ErrInvalidURL", err) + } + if resp != nil { + t.Fatal("response is not nil, want nil") + } +} + +func TestWithHTTPClientNilFallback(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "ok") + })) + defer ts.Close() + + u, err := fetch.URL(ts.URL) + if err != nil { + t.Fatalf("fetch.URL() error = %v", err) + } + + resp, err := fetch.New(fetch.WithHTTPClient(nil)).GetWithContext(context.Background(), u) + if err != nil { + t.Fatalf("GetWithContext() error = %v", err) + } + if resp == nil { + t.Fatal("response is nil") + } + if cerr := resp.Close(); cerr != nil { + t.Fatalf("resp.Close() error = %v", cerr) + } +} + /* Copyright 2023-2025 Spiegel * * Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/go.sum b/go.sum index 4e63d82..19eb1da 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,2 @@ -github.com/goark/errs v1.3.3 h1:vqzm/1aqvh4Ha8JDqIqIU5ZoBtjaRezFcMC55B5ljAM= -github.com/goark/errs v1.3.3/go.mod h1:4xM7rorwYQlqh9kUhfKpC5P7VAJW2KfvuQpYnTaU0ek= github.com/goark/errs v1.3.4 h1:/+/xwF3UwXGxGGLurzBTaMMoryTBeaPfheJ1aW9cglA= github.com/goark/errs v1.3.4/go.mod h1:4xM7rorwYQlqh9kUhfKpC5P7VAJW2KfvuQpYnTaU0ek= diff --git a/response.go b/response.go index d3a52a2..03a6ddf 100644 --- a/response.go +++ b/response.go @@ -64,7 +64,7 @@ func (resp *response) DumpBodyAndClose() (b []byte, err error) { if resp == nil { return } - if cerr := resp.Body().Close(); cerr != nil && !errs.Is(err, os.ErrClosed) { + if cerr := resp.Body().Close(); cerr != nil && !errs.Is(cerr, os.ErrClosed) { err = errs.Join(cerr, err) } }()