From 523e98d08315368c66e5979e296ffd55991b52f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E8=B4=A4?= Date: Fri, 6 Mar 2020 16:53:30 +0800 Subject: [PATCH] add test for RegisterProtocol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 楚贤 --- pkg/httputils/http_util_test.go | 35 ++++++ supernode/httpclient/origin_http_client.go | 22 +++- .../httpclient/origin_http_client_test.go | 100 ++++++++++++++++++ 3 files changed, 155 insertions(+), 2 deletions(-) create mode 100644 supernode/httpclient/origin_http_client_test.go diff --git a/pkg/httputils/http_util_test.go b/pkg/httputils/http_util_test.go index e31a9677b..6f7213964 100644 --- a/pkg/httputils/http_util_test.go +++ b/pkg/httputils/http_util_test.go @@ -17,10 +17,12 @@ package httputils import ( + "crypto/tls" "encoding/json" "fmt" "math/rand" "net" + "net/http" "sync" "testing" "time" @@ -257,3 +259,36 @@ type testJSONReq struct { type testJSONRes struct { Sum int } + +type testTransport struct { +} + +func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: http.NoBody, + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + ContentLength: -1, + }, nil +} + +func (s *HTTPUtilTestSuite) TestRegisterProtocol(c *check.C) { + protocol := "test" + RegisterProtocol(protocol, &testTransport{}) + resp, err := HTTPWithHeaders(http.MethodGet, + protocol+"://test/test", + map[string]string{ + "test": "test", + }, + time.Second, + &tls.Config{}, + ) + c.Assert(err, check.IsNil) + defer resp.Body.Close() + + c.Assert(resp, check.NotNil) + c.Assert(resp.ContentLength, check.Equals, int64(-1)) +} \ No newline at end of file diff --git a/supernode/httpclient/origin_http_client.go b/supernode/httpclient/origin_http_client.go index 8d458b0b9..2b3fc1600 100644 --- a/supernode/httpclient/origin_http_client.go +++ b/supernode/httpclient/origin_http_client.go @@ -49,13 +49,30 @@ type OriginHTTPClient interface { // OriginClient is an implementation of the interface of OriginHTTPClient. type OriginClient struct { - clientMap *sync.Map + clientMap *sync.Map + defaultHTTPClient *http.Client } // NewOriginClient returns a new OriginClient. func NewOriginClient() OriginHTTPClient { + defaultTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 3 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + httputils.RegisterProtocolOnTransport(defaultTransport) return &OriginClient{ clientMap: &sync.Map{}, + defaultHTTPClient: &http.Client{ + Transport: defaultTransport, + }, } } @@ -195,7 +212,8 @@ func (client *OriginClient) HTTPWithHeaders(method, url string, headers map[stri httpClientObject, existed := client.clientMap.Load(req.Host) if !existed { - httpClientObject = http.DefaultClient + // use client.defaultHTTPClient to support custom protocols + httpClientObject = client.defaultHTTPClient } httpClient, ok := httpClientObject.(*http.Client) diff --git a/supernode/httpclient/origin_http_client_test.go b/supernode/httpclient/origin_http_client_test.go new file mode 100644 index 000000000..971f04466 --- /dev/null +++ b/supernode/httpclient/origin_http_client_test.go @@ -0,0 +1,100 @@ +/* + * Copyright The Dragonfly Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package httpclient + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-check/check" + + "github.com/dragonflyoss/Dragonfly/pkg/httputils" +) + +func init() { + check.Suite(&OriginHTTPClientTestSuite{}) +} + +func Test(t *testing.T) { + check.TestingT(t) +} + +type OriginHTTPClientTestSuite struct { + client *OriginClient +} + +func (s *OriginHTTPClientTestSuite) SetUpSuite(c *check.C) { + s.client = NewOriginClient().(*OriginClient) +} + +func (s *OriginHTTPClientTestSuite) TearDownSuite(c *check.C) { +} + +func (s *OriginHTTPClientTestSuite) TestHTTPWithHeaders(c *check.C) { + testString := "test bytes" + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(testString)) + if r.Method != "GET" { + c.Errorf("Expected 'GET' request, got '%s'", r.Method) + } + })) + defer ts.Close() + + httptest.NewRecorder() + resp, err := s.client.HTTPWithHeaders(http.MethodGet, ts.URL, map[string]string{}, time.Second) + c.Check(err, check.IsNil) + defer resp.Body.Close() + + testBytes, err := ioutil.ReadAll(resp.Body) + c.Check(err, check.IsNil) + c.Check(string(testBytes), check.Equals, testString) +} + +type testTransport struct { +} + +func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: http.NoBody, + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + ContentLength: -1, + }, nil +} + +func (s *OriginHTTPClientTestSuite) TestRegisterTLSConfig(c *check.C) { + protocol := "test" + httputils.RegisterProtocol(protocol, &testTransport{}) + s.client.RegisterTLSConfig(protocol+"://test/test", true, nil) + httpClientInterface, ok := s.client.clientMap.Load("test") + c.Check(ok, check.Equals, true) + httpClient, ok := httpClientInterface.(*http.Client) + c.Check(ok, check.Equals, true) + + resp, err := httpClient.Get(protocol + "://test/test") + c.Assert(err, check.IsNil) + defer resp.Body.Close() + c.Assert(resp, check.NotNil) + c.Assert(resp.ContentLength, check.Equals, int64(-1)) +}