Skip to content

Commit

Permalink
server: first pass at Shutdown and Close
Browse files Browse the repository at this point in the history
closes: #22
  • Loading branch information
mattatcha committed Apr 15, 2017
1 parent 1051a0d commit b3d709d
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 0 deletions.
154 changes: 154 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
package ssh

import (
"context"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"

gossh "golang.org/x/crypto/ssh"
)

// ErrServerClosed is returned by the Server's Serve, ListenAndServe,
// and ListenAndServeTLS methods after a call to Shutdown or Close.
var ErrServerClosed = errors.New("http: Server closed")

// Server defines parameters for running an SSH server. The zero value for
// Server is a valid configuration. When both PasswordHandler and
// PublicKeyHandler are nil, no client authentication is performed.
Expand All @@ -22,6 +30,12 @@ type Server struct {
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil

channelHandlers map[string]channelHandler

mu sync.Mutex
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
listeners map[net.Listener]struct{}
activeConn map[*gossh.ServerConn]struct{}
doneChan chan struct{}
}

// internal for now
Expand Down Expand Up @@ -79,6 +93,67 @@ func (srv *Server) Handle(fn Handler) {
srv.Handler = fn
}

// Close immediately closes all active listeners and all active
// connections.
//
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
srv.mu.Lock()
defer srv.mu.Unlock()
srv.closeDoneChanLocked()
err := srv.closeListenersLocked()
for c := range srv.activeConn {
c.Close()
delete(srv.activeConn, c)
}
return err
}

// shutdownPollInterval is how often we poll for quiescence
// during Server.Shutdown. This is lower during tests, to
// speed up tests.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
var shutdownPollInterval = 500 * time.Millisecond

// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, and then waiting indefinitely for connections to close.
// If the provided context expires before the shutdown is complete,
// then the context's error is returned.
func (srv *Server) Shutdown(ctx context.Context) error {
atomic.AddInt32(&srv.inShutdown, 1)
defer atomic.AddInt32(&srv.inShutdown, -1)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
srv.closeDoneChanLocked()
srv.mu.Unlock()
ticker := time.NewTicker(shutdownPollInterval)
defer ticker.Stop()
for {

srv.mu.Lock()
activeConns := len(srv.activeConn)
srv.mu.Unlock()
if activeConns == 0 {
return lnerr
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}

}

func (s *Server) shuttingDown() bool {
return atomic.LoadInt32(&s.inShutdown) != 0
}

// Serve accepts incoming connections on the Listener l, creating a new
// connection goroutine for each. The connection goroutines read requests and then
// calls srv.Handler to handle sessions.
Expand All @@ -93,9 +168,17 @@ func (srv *Server) Serve(l net.Listener) error {
srv.Handler = DefaultHandler
}
var tempDelay time.Duration

srv.trackListener(l, true)
defer srv.trackListener(l, false)
for {
conn, e := l.Accept()
if e != nil {
select {
case <-srv.getDoneChan():
return ErrServerClosed
default:
}
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
Expand All @@ -122,6 +205,10 @@ func (srv *Server) handleConn(conn net.Conn) {
// TODO: trigger event callback
return
}

srv.trackConn(sshConn, true)
defer srv.trackConn(sshConn, false)

ctx.SetValue(ContextKeyConn, sshConn)
ctx.applyConnMetadata(sshConn)
go gossh.DiscardRequests(reqs)
Expand Down Expand Up @@ -163,3 +250,70 @@ func (srv *Server) AddHostKey(key Signer) {
func (srv *Server) SetOption(option Option) error {
return option(srv)
}

func (srv *Server) getDoneChan() <-chan struct{} {
srv.mu.Lock()
defer srv.mu.Unlock()
return srv.getDoneChanLocked()
}

func (srv *Server) getDoneChanLocked() chan struct{} {
if srv.doneChan == nil {
srv.doneChan = make(chan struct{})
}
return srv.doneChan
}

func (srv *Server) closeDoneChanLocked() {
ch := srv.getDoneChanLocked()
select {
case <-ch:
// Already closed. Don't close again.
default:
// Safe to close here. We're the only closer, guarded
// by srv.mu.
close(ch)
}
}

func (srv *Server) closeListenersLocked() error {
var err error
for ln := range srv.listeners {
if cerr := ln.Close(); cerr != nil && err == nil {
err = cerr
}
delete(srv.listeners, ln)
}
return err
}

func (srv *Server) trackListener(ln net.Listener, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.listeners == nil {
srv.listeners = make(map[net.Listener]struct{})
}
if add {
// If the *Server is being reused after a previous
// Close or Shutdown, reset its doneChan:
if len(srv.listeners) == 0 && len(srv.activeConn) == 0 {
srv.doneChan = nil
}
srv.listeners[ln] = struct{}{}
} else {
delete(srv.listeners, ln)
}
}

func (srv *Server) trackConn(c *gossh.ServerConn, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.activeConn == nil {
srv.activeConn = make(map[*gossh.ServerConn]struct{})
}
if add {
srv.activeConn[c] = struct{}{}
} else {
delete(srv.activeConn, c)
}
}
102 changes: 102 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package ssh

import (
"bytes"
"context"
"io"
"testing"
"time"
)

func TestServerShutdown(t *testing.T) {
l := newLocalListener()
testBytes := []byte("Hello world\n")
s := &Server{
Handler: func(s Session) {
s.Write(testBytes)
time.Sleep(50 * time.Millisecond)
},
}
go func() {
err := s.Serve(l)
if err != nil && err != ErrServerClosed {
t.Fatal(err)
}
}()
sessDone := make(chan struct{})
sess, cleanup := newClientSession(t, l.Addr().String(), nil)
go func() {
defer cleanup()
defer close(sessDone)
var stdout bytes.Buffer
sess.Stdout = &stdout
if err := sess.Run(""); err != nil {
t.Fatal(err)
}
if !bytes.Equal(stdout.Bytes(), testBytes) {
t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes())
}
}()

srvDone := make(chan struct{})
go func() {
defer close(srvDone)
err := s.Shutdown(context.Background())
if err != nil {
t.Fatal(err)
}
}()

timeout := time.After(time.Second * 2)
select {
case <-timeout:
t.Fatal("timeout")
return
case <-srvDone:
// TODO: add timeout for sessDone
<-sessDone
return
}
}

func TestServerClose(t *testing.T) {
l := newLocalListener()
s := &Server{
Handler: func(s Session) {
time.Sleep(5 * time.Second)
},
}
go func() {
err := s.Serve(l)
if err != nil && err != ErrServerClosed {
t.Fatal(err)
}
}()

doneCh := make(chan struct{})
sess, cleanup := newClientSession(t, l.Addr().String(), nil)
go func() {
defer cleanup()
defer close(doneCh)
if err := sess.Run(""); err != nil && err != io.EOF {
t.Fatal(err)
}
}()

go func() {
err := s.Close()
if err != nil {
t.Fatal(err)
}
}()

timeout := time.After(time.Millisecond * 100)
select {
case <-timeout:
t.Error("timeout")
return
case <-s.getDoneChan():
<-doneCh
return
}
}

0 comments on commit b3d709d

Please sign in to comment.