Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: improve hertz/pkg/routeut unit test coverage #992

Merged
merged 8 commits into from Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/app/server/binding/internal/decoder/decoder.go
Expand Up @@ -103,7 +103,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder
}, needValidate, nil
}

func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) {
for field.Type.Kind() == reflect.Ptr {
field.Type = field.Type.Elem()
}
Expand Down
175 changes: 175 additions & 0 deletions pkg/route/engine_test.go
Expand Up @@ -55,6 +55,7 @@ import (

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server/binding"
"github.com/cloudwego/hertz/pkg/app/server/registry"
"github.com/cloudwego/hertz/pkg/common/config"
errs "github.com/cloudwego/hertz/pkg/common/errors"
"github.com/cloudwego/hertz/pkg/common/test/assert"
Expand All @@ -63,6 +64,7 @@ import (
"github.com/cloudwego/hertz/pkg/network/standard"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/protocol/consts"
"github.com/cloudwego/hertz/pkg/protocol/suite"
"github.com/cloudwego/hertz/pkg/route/param"
)

Expand Down Expand Up @@ -854,3 +856,176 @@ func TestCustomValidator(t *testing.T) {
})
performRequest(e, "GET", "/validate?a=2")
}

var errTestDeregsitry = fmt.Errorf("test deregsitry error")

type mockDeregsitryErr struct{}

var _ registry.Registry = &mockDeregsitryErr{}

func (e mockDeregsitryErr) Register(*registry.Info) error {
return nil
}

func (e mockDeregsitryErr) Deregister(*registry.Info) error {
return errTestDeregsitry
}

func TestEngineShutdown(t *testing.T) {
defaultTransporter = standard.NewTransporter
mockCtxCallback := func(ctx context.Context) {}
// Test case 1: serve not running error
engine := NewEngine(config.NewOptions(nil))
ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
defer cancel1()
err := engine.Shutdown(ctx1)
assert.DeepEqual(t, errStatusNotRunning, err)

// Test case 2: serve successfully running and shutdown
engine = NewEngine(config.NewOptions(nil))
engine.OnShutdown = []CtxCallback{mockCtxCallback}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)

ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
defer cancel2()
err = engine.Shutdown(ctx2)
assert.Nil(t, err)
assert.DeepEqual(t, statusClosed, atomic.LoadUint32(&engine.status))

// Test case 3: serve successfully running and shutdown with deregistry error
engine = NewEngine(config.NewOptions(nil))
engine.OnShutdown = []CtxCallback{mockCtxCallback}
engine.options.Registry = &mockDeregsitryErr{}
go func() {
engine.Run()
}()
// wait for engine to start
time.Sleep(100 * time.Millisecond)

ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second)
defer cancel3()
err = engine.Shutdown(ctx3)
assert.DeepEqual(t, errTestDeregsitry, err)
assert.DeepEqual(t, statusShutdown, atomic.LoadUint32(&engine.status))
}

type mockStreamer struct{}

type mockProtocolServer struct{}

func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error {
return nil
}

func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error {
return nil
}

type mockStreamConn struct {
network.StreamConn
version string
}

var _ network.StreamConn = &mockStreamConn{}

func (m *mockStreamConn) GetVersion() uint32 {
HzTTT marked this conversation as resolved.
Show resolved Hide resolved
return network.Version1
}

func TestEngineServeStream(t *testing.T) {
engine := &Engine{
options: &config.Options{
ALPN: true,
TLS: &tls.Config{},
},
protocolStreamServers: map[string]protocol.StreamServer{
suite.HTTP3: &mockStreamer{},
},
}

// Test ALPN path
conn := &mockStreamConn{version: suite.HTTP3}
err := engine.ServeStream(context.Background(), conn)
assert.Nil(t, err)

// Test default path
engine.options.ALPN = false
conn = &mockStreamConn{}
err = engine.ServeStream(context.Background(), conn)
assert.Nil(t, err)

// Test unsupported protocol
engine.protocolStreamServers = map[string]protocol.StreamServer{}
conn = &mockStreamConn{}
err = engine.ServeStream(context.Background(), conn)
assert.DeepEqual(t, errs.ErrNotSupportProtocol, err)
}

func TestEngineServe(t *testing.T) {
engine := NewEngine(config.NewOptions(nil))
engine.protocolServers[suite.HTTP1] = &mockProtocolServer{}
engine.protocolServers[suite.HTTP2] = &mockProtocolServer{}

// test H2C path
ctx := context.Background()
conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
engine.options.H2C = true
err := engine.Serve(ctx, conn)
assert.Nil(t, err)

// test ALPN path
ctx = context.Background()
conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
engine.options.H2C = false
engine.options.ALPN = true
engine.options.TLS = &tls.Config{}
err = engine.Serve(ctx, conn)
assert.Nil(t, err)

// test HTTP1 path
engine.options.ALPN = false
err = engine.Serve(ctx, conn)
assert.Nil(t, err)
}

func TestOndata(t *testing.T) {
ctx := context.Background()
engine := NewEngine(config.NewOptions(nil))

// test stream conn
streamConn := &mockStreamConn{version: suite.HTTP3}
engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{}
err := engine.onData(ctx, streamConn)
assert.Nil(t, err)

// test conn
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
engine.protocolServers[suite.HTTP1] = &mockProtocolServer{}
err = engine.onData(ctx, conn)
assert.Nil(t, err)
}

func TestAcquireHijackConn(t *testing.T) {
engine := &Engine{
NoHijackConnPool: false,
}
// test conn pool
conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
hijackConn := engine.acquireHijackConn(conn)
assert.NotNil(t, hijackConn)
assert.NotNil(t, hijackConn.Conn)
assert.DeepEqual(t, engine, hijackConn.e)
assert.DeepEqual(t, conn, hijackConn.Conn)

// test no conn pool
engine.NoHijackConnPool = true
hijackConn = engine.acquireHijackConn(conn)
assert.NotNil(t, hijackConn)
assert.NotNil(t, hijackConn.Conn)
assert.DeepEqual(t, engine, hijackConn.e)
assert.DeepEqual(t, conn, hijackConn.Conn)
}