Skip to content

Commit

Permalink
Merge 6278658 into 6be04fd
Browse files Browse the repository at this point in the history
  • Loading branch information
jfbus committed Nov 12, 2018
2 parents 6be04fd + 6278658 commit 1d6d948
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
26 changes: 11 additions & 15 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ type Server struct {
logger log.Logger

transports []Transport

// exit chan for graceful shutdown
exit chan chan error
}

type contextKey int
Expand All @@ -36,14 +33,17 @@ const (
func NewServer(t ...Transport) *Server {
return &Server{
transports: t,
exit: make(chan chan error),
logger: &nopLogger{},
middleware: nopMiddleware,
}
}

// Run starts the server.
func (s *Server) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
stop := make(chan error)
exit := make(chan error)
ctx = s.addLoggerToContext(ctx, nil)
for _, t := range s.transports {
m := s.addLoggerToContextMiddleware(s.middleware, t)
Expand All @@ -55,19 +55,18 @@ func (s *Server) Run(ctx context.Context) error {
go func(t Transport) {
if err := t.Start(ctx); err != nil {
_ = s.logger.Log("msg", fmt.Sprintf("Shutting down due to server error: %s", err))
_ = s.stop()
stop <- err
}
}(t)
}

go func() {
exit := <-s.exit
err := <-stop
for _, fn := range s.shutdown {
fn()
}
var err error
for _, t := range s.transports {
if eerr := t.Shutdown(ctx); eerr != nil {
if eerr := t.Shutdown(ctx); eerr != nil && err == nil {
err = eerr
}
}
Expand All @@ -81,12 +80,9 @@ func (s *Server) Run(ctx context.Context) error {
_ = s.logger.Log("msg", "received signal", "signal", sig)
case <-ctx.Done():
_ = s.logger.Log("msg", "canceled context")
case err := <-exit:
return err
}
return s.stop()
}

func (s *Server) stop() error {
ch := make(chan error)
s.exit <- ch
return <-ch
stop <- nil
return <-exit
}
48 changes: 48 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"reflect"
"testing"
"time"

"github.com/go-kit/kit/endpoint"
)

func TestServer(t *testing.T) {
Expand Down Expand Up @@ -121,3 +123,49 @@ func goodDecoder(_ context.Context, r *http.Request) (interface{}, error) {
func badDecoder(_ context.Context, _ *http.Request) (interface{}, error) {
return nil, errors.New("decoding error")
}

type failingTransport struct {
Transport
}

func (*failingTransport) RegisterEndpoints(m endpoint.Middleware) error { return nil }
func (*failingTransport) Start(ctx context.Context) error { return errors.New("unable to start") }
func (*failingTransport) Shutdown(ctx context.Context) error { return nil }

type workingTransport struct {
running chan struct{}
Transport
}

func (*workingTransport) RegisterEndpoints(m endpoint.Middleware) error { return nil }
func (t *workingTransport) Start(ctx context.Context) error {
defer func() { close(t.running) }()
<-ctx.Done()
return nil
}
func (*workingTransport) Shutdown(ctx context.Context) error { return nil }

func TestStartError(t *testing.T) {
ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second)
defer cancel()
exitError := make(chan error)
wt := &workingTransport{running: make(chan struct{})}
srv := NewServer(&failingTransport{}, wt)
go func() {
exitError <- srv.Run(ctx)
}()

select {
case <-ctx.Done():
t.Error("Server.Run has not stopped after 5sec")
case err := <-exitError:
if err == nil || err.Error() != "unable to start" {
t.Errorf("Server.Run returned an invalid error : %v", err)
}
}
select {
case <-ctx.Done():
t.Error("Alternate transport has not stopped after 5sec")
case <-wt.running:
}
}

0 comments on commit 1d6d948

Please sign in to comment.