Skip to content

Commit

Permalink
Expand API
Browse files Browse the repository at this point in the history
- Closes #1 (Ping API)
- Closes #75 (Read/Write convienence methods)
- Closes #83 (SetReadLimit)
  • Loading branch information
nhooyr committed May 30, 2019
1 parent 027e6af commit 36d5ce8
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 51 deletions.
2 changes: 0 additions & 2 deletions example_echo_test.go
Expand Up @@ -94,7 +94,6 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
// echo reads from the websocket connection and then writes
// the received message back to it.
// The entire function has 10s to complete.
// The received message is limited to 32768 bytes.
func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
Expand All @@ -108,7 +107,6 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
if err != nil {
return err
}
r = io.LimitReader(r, 32768)

w, err := c.Writer(ctx, typ)
if err != nil {
Expand Down
18 changes: 0 additions & 18 deletions export_test.go

This file was deleted.

99 changes: 92 additions & 7 deletions websocket.go
Expand Up @@ -5,9 +5,13 @@ import (
"context"
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"

"golang.org/x/xerrors"
Expand All @@ -25,6 +29,8 @@ type Conn struct {
closer io.Closer
client bool

msgReadLimit int64

closeOnce sync.Once
closeErr error
closed chan struct{}
Expand All @@ -41,14 +47,16 @@ type Conn struct {
setWriteTimeout chan context.Context
setConnContext chan context.Context
getConnContext chan context.Context

pingListener map[string]chan<- struct{}
}

// Context returns a context derived from parent that will be cancelled
// when the connection is closed.
// when the connection is closed or broken.
// If the parent context is cancelled, the connection will be closed.
//
// This is an experimental API that may be remove in the future.
// Please let me know how you feel about it.
// This is an experimental API that may be removed in the future.
// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79
func (c *Conn) Context(parent context.Context) context.Context {
select {
case <-c.closed:
Expand Down Expand Up @@ -105,6 +113,8 @@ func (c *Conn) Subprotocol() string {
func (c *Conn) init() {
c.closed = make(chan struct{})

c.msgReadLimit = 32768

c.writeDataLock = make(chan struct{}, 1)
c.writeFrameLock = make(chan struct{}, 1)

Expand All @@ -118,6 +128,8 @@ func (c *Conn) init() {
c.setConnContext = make(chan context.Context)
c.getConnContext = make(chan context.Context)

c.pingListener = make(map[string]chan struct{})

runtime.SetFinalizer(c, func(c *Conn) {
c.close(xerrors.New("connection garbage collected"))
})
Expand Down Expand Up @@ -242,6 +254,10 @@ func (c *Conn) handleControl(h header) {
case opPing:
c.writePong(b)
case opPong:
listener, ok := c.pingListener[string(b)]
if ok {
close(listener)
}
case opClose:
ce, err := parseClosePayload(b)
if err != nil {
Expand Down Expand Up @@ -321,7 +337,7 @@ func (c *Conn) writePong(p []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

err := c.writeCompleteMessage(ctx, opPong, p)
err := c.writeMessage(ctx, opPong, p)
return err
}

Expand Down Expand Up @@ -369,7 +385,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

err := c.writeCompleteMessage(ctx, opClose, p)
err := c.writeMessage(ctx, opClose, p)

c.close(cerr)

Expand Down Expand Up @@ -399,7 +415,7 @@ func (c *Conn) releaseLock(lock chan struct{}) {
<-lock
}

func (c *Conn) writeCompleteMessage(ctx context.Context, opcode opcode, p []byte) error {
func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error {
if !opcode.controlOp() {
err := c.acquireLock(ctx, c.writeDataLock)
if err != nil {
Expand Down Expand Up @@ -445,6 +461,30 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
}, nil
}

// Read is a convenience method to read a single message from the connection.
//
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}

b, err := ioutil.ReadAll(r)
if err != nil {
return typ, b, err
}

return typ, b, nil
}

// Write is a convenience method to write a message to the connection.
//
// See the Writer method if you want to stream a message.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
return c.writeMessage(ctx, opcode(typ), p)
}

// messageWriter enables writing to a WebSocket connection.
type messageWriter struct {
ctx context.Context
Expand Down Expand Up @@ -519,7 +559,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
if err != nil {
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
}
return typ, r, nil
return typ, io.LimitReader(r, c.msgReadLimit), nil
}

func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
Expand Down Expand Up @@ -640,3 +680,48 @@ func (r *messageReader) read(p []byte) (int, error) {

return n, nil
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
func (c *Conn) SetReadLimit(n int64) {
atomic.StoreInt64(&c.msgReadLimit, n)
}

func init() {
rand.Seed(time.Now().UnixNano())
}

// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
//
// This API is experimental and subject to change.
// Please provide feedback in https://github.com/nhooyr/websocket/issues/1.
func (c *Conn) Ping(ctx context.Context) error {
err := c.ping(ctx)
if err != nil {
return xerrors.Errorf("failed to ping: %w", err)
}
return nil
}

func (c *Conn) ping(ctx context.Context) error {
id := rand.Uint64()
p := strconv.FormatUint(id, 10)

pong := make(chan struct{})
c.pingListener[p] = pong

err := c.writeMessage(ctx, opPing, []byte(p))
if err != nil {
return err
}

select {
case <-ctx.Done():
return ctx.Err()
case <-pong:
return nil
}
}
2 changes: 2 additions & 0 deletions websocket_test.go
Expand Up @@ -489,6 +489,8 @@ func TestAutobahnServer(t *testing.T) {
func echoLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "")

c.SetReadLimit(1 << 30)

ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()

Expand Down
4 changes: 0 additions & 4 deletions wsjson/wsjson.go
Expand Up @@ -12,8 +12,6 @@ import (
)

// Read reads a json message from c into v.
// For security reasons, it will not read messages
// larger than 32768 bytes.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := read(ctx, c, v)
if err != nil {
Expand All @@ -33,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
}

r = io.LimitReader(r, 32768)

d := json.NewDecoder(r)
err = d.Decode(v)
if err != nil {
Expand Down
21 changes: 1 addition & 20 deletions wspb/wspb.go
Expand Up @@ -3,7 +3,6 @@ package wspb

import (
"context"
"io"
"io/ioutil"

"github.com/golang/protobuf/proto"
Expand All @@ -13,8 +12,6 @@ import (
)

// Read reads a protobuf message from c into v.
// For security reasons, it will not read messages
// larger than 32768 bytes.
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := read(ctx, c, v)
if err != nil {
Expand All @@ -34,8 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
}

r = io.LimitReader(r, 32768)

b, err := ioutil.ReadAll(r)
if err != nil {
return xerrors.Errorf("failed to read message: %w", err)
Expand Down Expand Up @@ -64,19 +59,5 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
return xerrors.Errorf("failed to marshal protobuf: %w", err)
}

w, err := c.Writer(ctx, websocket.MessageBinary)
if err != nil {
return err
}

_, err = w.Write(b)
if err != nil {
return err
}

err = w.Close()
if err != nil {
return err
}
return nil
return c.Write(ctx, websocket.MessageBinary, b)
}

0 comments on commit 36d5ce8

Please sign in to comment.