Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Commit

Permalink
add test for RegisterProtocol
Browse files Browse the repository at this point in the history
Signed-off-by: 楚贤 <chuxian.mjj@antfin.com>
  • Loading branch information
jim3ma committed Mar 16, 2020
1 parent 92de56e commit 523e98d
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 2 deletions.
35 changes: 35 additions & 0 deletions pkg/httputils/http_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package httputils

import (
"crypto/tls"
"encoding/json"
"fmt"
"math/rand"
"net"
"net/http"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -257,3 +259,36 @@ type testJSONReq struct {
type testJSONRes struct {
Sum int
}

type testTransport struct {
}

func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Body: http.NoBody,
Status: http.StatusText(http.StatusOK),
StatusCode: http.StatusOK,
ContentLength: -1,
}, nil
}

func (s *HTTPUtilTestSuite) TestRegisterProtocol(c *check.C) {
protocol := "test"
RegisterProtocol(protocol, &testTransport{})
resp, err := HTTPWithHeaders(http.MethodGet,
protocol+"://test/test",
map[string]string{
"test": "test",
},
time.Second,
&tls.Config{},
)
c.Assert(err, check.IsNil)
defer resp.Body.Close()

c.Assert(resp, check.NotNil)
c.Assert(resp.ContentLength, check.Equals, int64(-1))
}
22 changes: 20 additions & 2 deletions supernode/httpclient/origin_http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,30 @@ type OriginHTTPClient interface {

// OriginClient is an implementation of the interface of OriginHTTPClient.
type OriginClient struct {
clientMap *sync.Map
clientMap *sync.Map
defaultHTTPClient *http.Client
}

// NewOriginClient returns a new OriginClient.
func NewOriginClient() OriginHTTPClient {
defaultTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 3 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httputils.RegisterProtocolOnTransport(defaultTransport)
return &OriginClient{
clientMap: &sync.Map{},
defaultHTTPClient: &http.Client{
Transport: defaultTransport,
},
}
}

Expand Down Expand Up @@ -195,7 +212,8 @@ func (client *OriginClient) HTTPWithHeaders(method, url string, headers map[stri

httpClientObject, existed := client.clientMap.Load(req.Host)
if !existed {
httpClientObject = http.DefaultClient
// use client.defaultHTTPClient to support custom protocols
httpClientObject = client.defaultHTTPClient
}

httpClient, ok := httpClientObject.(*http.Client)
Expand Down
100 changes: 100 additions & 0 deletions supernode/httpclient/origin_http_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright The Dragonfly Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package httpclient

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/go-check/check"

"github.com/dragonflyoss/Dragonfly/pkg/httputils"
)

func init() {
check.Suite(&OriginHTTPClientTestSuite{})
}

func Test(t *testing.T) {
check.TestingT(t)
}

type OriginHTTPClientTestSuite struct {
client *OriginClient
}

func (s *OriginHTTPClientTestSuite) SetUpSuite(c *check.C) {
s.client = NewOriginClient().(*OriginClient)
}

func (s *OriginHTTPClientTestSuite) TearDownSuite(c *check.C) {
}

func (s *OriginHTTPClientTestSuite) TestHTTPWithHeaders(c *check.C) {
testString := "test bytes"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(testString))
if r.Method != "GET" {
c.Errorf("Expected 'GET' request, got '%s'", r.Method)
}
}))
defer ts.Close()

httptest.NewRecorder()
resp, err := s.client.HTTPWithHeaders(http.MethodGet, ts.URL, map[string]string{}, time.Second)
c.Check(err, check.IsNil)
defer resp.Body.Close()

testBytes, err := ioutil.ReadAll(resp.Body)
c.Check(err, check.IsNil)
c.Check(string(testBytes), check.Equals, testString)
}

type testTransport struct {
}

func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Body: http.NoBody,
Status: http.StatusText(http.StatusOK),
StatusCode: http.StatusOK,
ContentLength: -1,
}, nil
}

func (s *OriginHTTPClientTestSuite) TestRegisterTLSConfig(c *check.C) {
protocol := "test"
httputils.RegisterProtocol(protocol, &testTransport{})
s.client.RegisterTLSConfig(protocol+"://test/test", true, nil)
httpClientInterface, ok := s.client.clientMap.Load("test")
c.Check(ok, check.Equals, true)
httpClient, ok := httpClientInterface.(*http.Client)
c.Check(ok, check.Equals, true)

resp, err := httpClient.Get(protocol + "://test/test")
c.Assert(err, check.IsNil)
defer resp.Body.Close()
c.Assert(resp, check.NotNil)
c.Assert(resp.ContentLength, check.Equals, int64(-1))
}

0 comments on commit 523e98d

Please sign in to comment.