Skip to content

Commit

Permalink
Merge pull request #130 from austinvazquez/fix-server-shutdown
Browse files Browse the repository at this point in the history
Fix server shutdown logic
  • Loading branch information
dmcgowan committed Mar 8, 2023
2 parents 39515bd + 19445fd commit 7e006e7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
46 changes: 32 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,18 @@ func (s *Server) Serve(ctx context.Context, l net.Listener) error {

approved, handshake, err := handshaker.Handshake(ctx, conn)
if err != nil {
logrus.WithError(err).Errorf("ttrpc: refusing connection after handshake")
logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
conn.Close()
continue
}

sc, err := s.newConn(approved, handshake)
if err != nil {
logrus.WithError(err).Error("ttrpc: create connection failed")
conn.Close()
continue
}

sc := s.newConn(approved, handshake)
go sc.run(ctx)
}
}
Expand All @@ -145,15 +151,20 @@ func (s *Server) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
if s.closeIdleConns() {
return lnerr
s.closeIdleConns()

if s.countConnection() == 0 {
break
}

select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}

return lnerr
}

// Close the server without waiting for active connections.
Expand Down Expand Up @@ -205,11 +216,18 @@ func (s *Server) closeListeners() error {
return err
}

func (s *Server) addConnection(c *serverConn) {
func (s *Server) addConnection(c *serverConn) error {
s.mu.Lock()
defer s.mu.Unlock()

select {
case <-s.done:
return ErrServerClosed
default:
}

s.connections[c] = struct{}{}
return nil
}

func (s *Server) delConnection(c *serverConn) {
Expand All @@ -226,20 +244,17 @@ func (s *Server) countConnection() int {
return len(s.connections)
}

func (s *Server) closeIdleConns() bool {
func (s *Server) closeIdleConns() {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true

for c := range s.connections {
st, ok := c.getState()
if !ok || st != connStateIdle {
quiescent = false
if st, ok := c.getState(); !ok || st == connStateActive {
continue
}
c.close()
delete(s.connections, c)
}
return quiescent
}

type connState int
Expand All @@ -263,16 +278,19 @@ func (cs connState) String() string {
}
}

func (s *Server) newConn(conn net.Conn, handshake interface{}) *serverConn {
func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
c := &serverConn{
server: s,
conn: conn,
handshake: handshake,
shutdown: make(chan struct{}),
}
c.setState(connStateIdle)
s.addConnection(c)
return c
if err := s.addConnection(c); err != nil {
c.close()
return nil, err
}
return c, nil
}

type serverConn struct {
Expand Down
36 changes: 16 additions & 20 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,18 @@ func TestServerListenerClosed(t *testing.T) {
func TestServerShutdown(t *testing.T) {
const ncalls = 5
var (
ctx = context.Background()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{})
handlersStarted = make(chan struct{})
handlersStartedCloseOnce sync.Once
proceed = make(chan struct{})
serveErrs = make(chan error, 1)
callwg sync.WaitGroup
callErrs = make(chan error, ncalls)
shutdownErrs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
_, cleanup2 = newTestClient(t, addr) // secondary connection
ctx = context.Background()
server = mustServer(t)(NewServer())
addr, listener = newTestListener(t)
shutdownStarted = make(chan struct{})
shutdownFinished = make(chan struct{})
handlersStarted sync.WaitGroup
proceed = make(chan struct{})
serveErrs = make(chan error, 1)
callErrs = make(chan error, ncalls)
shutdownErrs = make(chan error, 1)
client, cleanup = newTestClient(t, addr)
_, cleanup2 = newTestClient(t, addr) // secondary connection
)
defer cleanup()
defer cleanup2()
Expand All @@ -227,7 +225,7 @@ func TestServerShutdown(t *testing.T) {
return nil, err
}

handlersStartedCloseOnce.Do(func() { close(handlersStarted) })
handlersStarted.Done()
<-proceed
return &internal.TestPayload{Foo: "waited"}, nil
},
Expand All @@ -238,20 +236,18 @@ func TestServerShutdown(t *testing.T) {
}()

// send a series of requests that will get blocked
for i := 0; i < 5; i++ {
callwg.Add(1)
for i := 0; i < ncalls; i++ {
handlersStarted.Add(1)
go func(i int) {
callwg.Done()
tp := internal.TestPayload{Foo: "half" + fmt.Sprint(i)}
callErrs <- client.Call(ctx, serviceName, "Test", &tp, &tp)
}(i)
}

<-handlersStarted
handlersStarted.Wait()
go func() {
close(shutdownStarted)
shutdownErrs <- server.Shutdown(ctx)
// server.Close()
close(shutdownFinished)
}()

Expand Down

0 comments on commit 7e006e7

Please sign in to comment.