Skip to content

Commit

Permalink
Improve chat example test
Browse files Browse the repository at this point in the history
  • Loading branch information
nhooyr committed Feb 27, 2020
1 parent 7329b27 commit 91b7f61
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 98 deletions.
6 changes: 6 additions & 0 deletions chat-example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin

The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
`index.html` and then `index.js`.

There are two automated tests for the server included in `chat_test.go`. The first is a simple one
client echo test. It publishes a single message and ensures it's received.

The second is a complex concurrency test where 10 clients send 128 unique messages
of max 128 bytes concurrently. The test ensures all messages are seen by every client.
67 changes: 47 additions & 20 deletions chat-example/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,57 @@ package main
import (
"context"
"errors"
"io"
"io/ioutil"
"log"
"net/http"
"sync"
"time"

"golang.org/x/time/rate"

"nhooyr.io/websocket"
)

// chatServer enables broadcasting to a set of subscribers.
type chatServer struct {
registerOnce sync.Once
m http.ServeMux

subscribersMu sync.RWMutex
// subscriberMessageBuffer controls the max number
// of messages that can be queued for a subscriber
// before it is kicked.
//
// Defaults to 16.
subscriberMessageBuffer int

// publishLimiter controls the rate limit applied to the publish endpoint.
//
// Defaults to one publish every 100ms with a burst of 8.
publishLimiter *rate.Limiter

// logf controls where logs are sent.
// Defaults to log.Printf.
logf func(f string, v ...interface{})

// serveMux routes the various endpoints to the appropriate handler.
serveMux http.ServeMux

subscribersMu sync.Mutex
subscribers map[*subscriber]struct{}
}

// newChatServer constructs a chatServer with the defaults.
func newChatServer() *chatServer {
cs := &chatServer{
subscriberMessageBuffer: 16,
logf: log.Printf,
subscribers: make(map[*subscriber]struct{}),
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
}
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
cs.serveMux.HandleFunc("/publish", cs.publishHandler)

return cs
}

// subscriber represents a subscriber.
// Messages are sent on the msgs channel and if the client
// cannot keep up with the messages, closeSlow is called.
Expand All @@ -31,20 +63,15 @@ type subscriber struct {
}

func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cs.registerOnce.Do(func() {
cs.m.Handle("/", http.FileServer(http.Dir(".")))
cs.m.HandleFunc("/subscribe", cs.subscribeHandler)
cs.m.HandleFunc("/publish", cs.publishHandler)
})
cs.m.ServeHTTP(w, r)
cs.serveMux.ServeHTTP(w, r)
}

// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Print(err)
cs.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
Expand All @@ -58,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err != nil {
log.Print(err)
cs.logf("%v", err)
return
}
}

Expand All @@ -69,7 +97,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
body := io.LimitReader(r.Body, 8192)
body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := ioutil.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
Expand All @@ -93,7 +121,7 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)

s := &subscriber{
msgs: make(chan []byte, 16),
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
},
Expand All @@ -118,8 +146,10 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
// It never blocks and so messages to slow subscribers
// are dropped.
func (cs *chatServer) publish(msg []byte) {
cs.subscribersMu.RLock()
defer cs.subscribersMu.RUnlock()
cs.subscribersMu.Lock()
defer cs.subscribersMu.Unlock()

cs.publishLimiter.Wait(context.Background())

for s := range cs.subscribers {
select {
Expand All @@ -133,9 +163,6 @@ func (cs *chatServer) publish(msg []byte) {
// addSubscriber registers a subscriber.
func (cs *chatServer) addSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
if cs.subscribers == nil {
cs.subscribers = make(map[*subscriber]struct{})
}
cs.subscribers[s] = struct{}{}
cs.subscribersMu.Unlock()
}
Expand Down
Loading

0 comments on commit 91b7f61

Please sign in to comment.