diff --git a/httpstub.go b/httpstub.go index bd54d92..4109dc1 100644 --- a/httpstub.go +++ b/httpstub.go @@ -86,6 +86,14 @@ func NewServer(t *testing.T) *Router { return rt } +// NewTLSServer returns a new router including TLS *httptest.Server. +func NewTLSServer(t *testing.T) *Router { + t.Helper() + rt := &Router{t: t} + _ = rt.TLSServer() + return rt +} + // Client returns *http.Client which requests *httptest.Server. func (rt *Router) Client() *http.Client { if rt.server == nil { @@ -101,7 +109,19 @@ func (rt *Router) Server() *httptest.Server { rt.server = httptest.NewServer(rt) } client := rt.server.Client() - client.Transport = newTransport(rt.server.URL) + tp := client.Transport.(*http.Transport) + client.Transport = newTransport(rt.server.URL, tp) + return rt.server +} + +// TLSServer returns TLS *httptest.Server with *Router set. +func (rt *Router) TLSServer() *httptest.Server { + if rt.server == nil { + rt.server = httptest.NewTLSServer(rt) + } + client := rt.server.Client() + tp := client.Transport.(*http.Transport) + client.Transport = newTransport(rt.server.URL, tp) return rt.server } @@ -274,17 +294,19 @@ func pathMatchFunc(path string) matchFunc { type transport struct { URL *url.URL + tp *http.Transport } -func newTransport(rawURL string) http.RoundTripper { +func newTransport(rawURL string, tp *http.Transport) http.RoundTripper { u, _ := url.Parse(rawURL) return &transport{ URL: u, + tp: tp, } } func (t *transport) transport() http.RoundTripper { - return http.DefaultTransport + return t.tp } func (t *transport) CancelRequest(r *http.Request) { diff --git a/httpstub_test.go b/httpstub_test.go index d62129e..1ed82d9 100644 --- a/httpstub_test.go +++ b/httpstub_test.go @@ -291,3 +291,46 @@ func TestRequests(t *testing.T) { t.Errorf("got %v\nwant %v", len(m.Requests()), 1) } } + +func TestTLSServer(t *testing.T) { + r := NewRouter(t) + r.Method(http.MethodGet).Path("/api/v1/users/1").Header("Content-Type", "application/json").ResponseString(http.StatusOK, `{"name":"alice"}`) + ts := r.TLSServer() + t.Cleanup(func() { + ts.Close() + }) + tc := ts.Client() + res, err := tc.Get("https://example.com/api/v1/users/1") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + res.Body.Close() + }) + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + { + got := res.StatusCode + want := http.StatusOK + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + { + got := res.Header.Get("Content-Type") + want := "application/json" + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } + { + got := string(body) + want := `{"name":"alice"}` + if got != want { + t.Errorf("got %v\nwant %v", got, want) + } + } +}