Skip to content

Commit

Permalink
close.go: Rewrite how the library handles closing
Browse files Browse the repository at this point in the history
Far simpler now. Sorry this took a while.

Closes #427
Closes #429
Closes #434
Closes #436
Closes #437
  • Loading branch information
nhooyr committed Apr 5, 2024
1 parent 0b3912f commit db18a31
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 136 deletions.
154 changes: 103 additions & 51 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,82 +97,106 @@ func CloseStatus(err error) StatusCode {
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
defer c.wg.Wait()
return c.closeHandshake(code, reason)
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")

if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()

err = c.closeHandshake(code, reason)

err2 := c.close()
if err == nil && err2 != nil {
err = err2
}

err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}

return err
}

// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer c.wg.Wait()
defer errd.Wrap(&err, "failed to close WebSocket")

if c.isClosed() {
if !c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()

c.close(nil)
c.closeMu.Lock()
defer c.closeMu.Unlock()
return c.closeErr
}

func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")

writeErr := c.writeClose(code, reason)
closeHandshakeErr := c.waitCloseHandshake()
err = c.close()

if writeErr != nil {
return writeErr
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}

if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
return closeHandshakeErr
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}

err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}

func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return net.ErrClosed
}

ce := CloseError{
Code: code,
Reason: reason,
}

var p []byte
var marshalErr error
var err error
if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes()
}

writeErr := c.writeControl(context.Background(), opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
p, err = ce.bytes()
if err != nil {
return err
}
}

// We do this after in case there was an error writing the close frame.
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

if marshalErr != nil {
return marshalErr
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return writeErr
return nil
}

func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

Expand Down Expand Up @@ -208,6 +232,36 @@ func (c *Conn) waitCloseHandshake() error {
}
}

func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()

select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}

c.closeReadMu.Lock()
ctx := c.closeReadCtx
c.closeReadMu.Unlock()
if ctx != nil {
select {
case <-ctx.Done():
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}

select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}

return nil
}

func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Expand Down Expand Up @@ -278,16 +332,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
return buf, nil
}

func (c *Conn) setCloseErr(err error) {
func (c *Conn) casClosing() bool {
c.closeMu.Lock()
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}

func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil && err != nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
defer c.closeMu.Unlock()
if !c.closing {
c.closing = true
return true
}
return false
}

func (c *Conn) isClosed() bool {
Expand Down
74 changes: 29 additions & 45 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -53,8 +52,9 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer

readTimeout chan context.Context
writeTimeout chan context.Context
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}

// Read state.
readMu *mu
Expand All @@ -70,11 +70,12 @@ type Conn struct {
writeHeaderBuf [8]byte
writeHeader header

wg sync.WaitGroup
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
closeReadMu sync.Mutex
closeReadCtx context.Context

closed chan struct{}
closeMu sync.Mutex
closing bool

pingCounter int32
activePingsMu sync.Mutex
Expand Down Expand Up @@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,

readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),

closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
Expand All @@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn {
}

runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
c.close()
})

c.wg.Add(1)
go func() {
defer c.wg.Done()
c.timeoutLoop()
}()
go c.timeoutLoop()

return c
}
Expand All @@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}

func (c *Conn) close(err error) {
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()

if c.isClosed() {
return
}
if err == nil {
err = c.rwc.Close()
return net.ErrClosed
}
c.setCloseErrLocked(err)

close(c.closed)
runtime.SetFinalizer(c, nil)
close(c.closed)

// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()

c.wg.Add(1)
go func() {
defer c.wg.Done()
c.msgWriter.close()
c.msgReader.close()
}()
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}

func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)

readCtx := context.Background()
writeCtx := context.Background()

Expand All @@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout:

case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
}()
c.close()
return
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
c.close()
return
}
}
Expand Down Expand Up @@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
case <-c.closed:
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
Expand Down Expand Up @@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error {
case <-m.c.closed:
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
Expand Down
3 changes: 3 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ func TestConn(t *testing.T) {

func TestWasm(t *testing.T) {
t.Parallel()
if os.Getenv("CI") == "" {
t.Skip()
}

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, &websocket.AcceptOptions{
Expand Down
Loading

0 comments on commit db18a31

Please sign in to comment.