Skip to content

Commit

Permalink
Add Grace to gracefully close WebSocket connections
Browse files Browse the repository at this point in the history
Closes #199
  • Loading branch information
nhooyr committed Feb 26, 2020
1 parent fa720b9 commit b0c36b9
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 12 deletions.
20 changes: 18 additions & 2 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")

g := graceFromRequest(r)
if g != nil && g.isClosing() {
err := errors.New("server closing")
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return nil, err
}

if opts == nil {
opts = &AcceptOptions{}
}
Expand Down Expand Up @@ -120,7 +127,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))

return newConn(connConfig{
c := newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
Expand All @@ -129,7 +136,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con

br: brw.Reader,
bw: brw.Writer,
}), nil
})

if g != nil {
err = g.addConn(c)
if err != nil {
return nil, err
}
}

return c, nil
}

func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
Expand Down
5 changes: 5 additions & 0 deletions conn_notjs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Conn struct {
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
g *Grace

readTimeout chan context.Context
writeTimeout chan context.Context
Expand Down Expand Up @@ -138,6 +139,10 @@ func (c *Conn) close(err error) {
// closeErr.
c.rwc.Close()

if c.g != nil {
c.g.delConn(c)
}

go func() {
c.msgWriterState.close()

Expand Down
12 changes: 4 additions & 8 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"os"
"os/exec"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -272,11 +271,9 @@ func TestWasm(t *testing.T) {
t.Skip("skipping on CI")
}

var wg sync.WaitGroup
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wg.Add(1)
defer wg.Done()

var g websocket.Grace
defer g.Close()
s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
Expand All @@ -294,8 +291,7 @@ func TestWasm(t *testing.T) {
t.Errorf("echo server failed: %v", err)
return
}
}))
defer wg.Wait()
})))
defer s.Close()

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Expand Down
6 changes: 4 additions & 2 deletions example_echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ func Example_echo() {
}
defer l.Close()

var g websocket.Grace
defer g.Close()
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r)
if err != nil {
log.Printf("echo server: %v", err)
}
}),
})),
ReadTimeout: time.Second * 15,
WriteTimeout: time.Second * 15,
}
Expand Down
46 changes: 46 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"log"
"net/http"
"net/url"
"os"
"os/signal"
"time"

"nhooyr.io/websocket"
Expand Down Expand Up @@ -143,3 +145,47 @@ func Example_crossOrigin() {
err := http.ListenAndServe("localhost:8080", fn)
log.Fatal(err)
}

// This example demonstrates how to create a WebSocket server
// that gracefully exits when sent a signal.
//
// It starts a WebSocket server that keeps every connection open
// for 10 seconds.
// If you CTRL+C while a connection is open, it will wait at most 30s
// for all connections to terminate before shutting down.
func ExampleGrace() {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")

ctx := c.CloseRead(r.Context())
select {
case <-ctx.Done():
case <-time.After(time.Second * 10):
}

c.Close(websocket.StatusNormalClosure, "")
})

var g websocket.Grace
s := &http.Server{
Handler: g.Handler(fn),
ReadTimeout: time.Second * 15,
WriteTimeout: time.Second * 15,
}
go s.ListenAndServe()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
sig := <-sigs
log.Printf("recieved %v, shutting down", sig)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
s.Shutdown(ctx)
g.Shutdown(ctx)
}
123 changes: 123 additions & 0 deletions grace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package websocket

import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"
)

// Grace enables graceful shutdown of accepted WebSocket connections.
//
// Use Handler to wrap WebSocket handlers to record accepted connections
// and then use Close or Shutdown to gracefully close these connections.
//
// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
type Grace struct {
mu sync.Mutex
closing bool
conns map[*Conn]struct{}
}

// Handler returns a handler that wraps around h to record
// all WebSocket connections accepted.
//
// Use Close or Shutdown to gracefully close recorded connections.
func (g *Grace) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), gracefulContextKey{}, g)
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}

func (g *Grace) isClosing() bool {
g.mu.Lock()
defer g.mu.Unlock()
return g.closing
}

func graceFromRequest(r *http.Request) *Grace {
g, _ := r.Context().Value(gracefulContextKey{}).(*Grace)
return g
}

func (g *Grace) addConn(c *Conn) error {
g.mu.Lock()
defer g.mu.Unlock()
if g.closing {
c.Close(StatusGoingAway, "server shutting down")
return errors.New("server shutting down")
}
if g.conns == nil {
g.conns = make(map[*Conn]struct{})
}
g.conns[c] = struct{}{}
c.g = g
return nil
}

func (g *Grace) delConn(c *Conn) {
g.mu.Lock()
defer g.mu.Unlock()
delete(g.conns, c)
}

type gracefulContextKey struct{}

// Close prevents the acceptance of new connections with
// http.StatusServiceUnavailable and closes all accepted
// connections with StatusGoingAway.
func (g *Grace) Close() error {
g.mu.Lock()
g.closing = true
var wg sync.WaitGroup
for c := range g.conns {
wg.Add(1)
go func(c *Conn) {
defer wg.Done()
c.Close(StatusGoingAway, "server shutting down")
}(c)

delete(g.conns, c)
}
g.mu.Unlock()

wg.Wait()

return nil
}

// Shutdown prevents the acceptance of new connections and waits until
// all connections close. If the context is cancelled before that, it
// calls Close to close all connections immediately.
func (g *Grace) Shutdown(ctx context.Context) error {
defer g.Close()

g.mu.Lock()
g.closing = true
g.mu.Unlock()

// Same poll period used by net/http.
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
if g.zeroConns() {
return nil
}

select {
case <-t.C:
case <-ctx.Done():
return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err())
}
}
}

func (g *Grace) zeroConns() bool {
g.mu.Lock()
defer g.mu.Unlock()
return len(g.conns) == 0
}
2 changes: 2 additions & 0 deletions ws_js.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Conn struct {
readSignal chan struct{}
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent

g *Grace
}

func (c *Conn) close(err error, wasClean bool) {
Expand Down

0 comments on commit b0c36b9

Please sign in to comment.