Skip to content

Commit

Permalink
Merge pull request #6 from k1LoW/tls
Browse files Browse the repository at this point in the history
Add NewTLSServer and *Router.TLSServer
  • Loading branch information
k1LoW committed Feb 1, 2023
2 parents 50f3679 + 658a926 commit 3aa640b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
28 changes: 25 additions & 3 deletions httpstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ func NewServer(t *testing.T) *Router {
return rt
}

// NewTLSServer returns a new router including TLS *httptest.Server.
func NewTLSServer(t *testing.T) *Router {
t.Helper()
rt := &Router{t: t}
_ = rt.TLSServer()
return rt
}

// Client returns *http.Client which requests *httptest.Server.
func (rt *Router) Client() *http.Client {
if rt.server == nil {
Expand All @@ -101,7 +109,19 @@ func (rt *Router) Server() *httptest.Server {
rt.server = httptest.NewServer(rt)
}
client := rt.server.Client()
client.Transport = newTransport(rt.server.URL)
tp := client.Transport.(*http.Transport)
client.Transport = newTransport(rt.server.URL, tp)
return rt.server
}

// TLSServer returns TLS *httptest.Server with *Router set.
func (rt *Router) TLSServer() *httptest.Server {
if rt.server == nil {
rt.server = httptest.NewTLSServer(rt)
}
client := rt.server.Client()
tp := client.Transport.(*http.Transport)
client.Transport = newTransport(rt.server.URL, tp)
return rt.server
}

Expand Down Expand Up @@ -274,17 +294,19 @@ func pathMatchFunc(path string) matchFunc {

type transport struct {
URL *url.URL
tp *http.Transport
}

func newTransport(rawURL string) http.RoundTripper {
func newTransport(rawURL string, tp *http.Transport) http.RoundTripper {
u, _ := url.Parse(rawURL)
return &transport{
URL: u,
tp: tp,
}
}

func (t *transport) transport() http.RoundTripper {
return http.DefaultTransport
return t.tp
}

func (t *transport) CancelRequest(r *http.Request) {
Expand Down
43 changes: 43 additions & 0 deletions httpstub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,46 @@ func TestRequests(t *testing.T) {
t.Errorf("got %v\nwant %v", len(m.Requests()), 1)
}
}

func TestTLSServer(t *testing.T) {
r := NewRouter(t)
r.Method(http.MethodGet).Path("/api/v1/users/1").Header("Content-Type", "application/json").ResponseString(http.StatusOK, `{"name":"alice"}`)
ts := r.TLSServer()
t.Cleanup(func() {
ts.Close()
})
tc := ts.Client()
res, err := tc.Get("https://example.com/api/v1/users/1")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
res.Body.Close()
})
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}

{
got := res.StatusCode
want := http.StatusOK
if got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := res.Header.Get("Content-Type")
want := "application/json"
if got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
{
got := string(body)
want := `{"name":"alice"}`
if got != want {
t.Errorf("got %v\nwant %v", got, want)
}
}
}

0 comments on commit 3aa640b

Please sign in to comment.