Skip to content

Commit

Permalink
refactor: leverage Go 1.20's http.ResponseController
Browse files Browse the repository at this point in the history
  • Loading branch information
dunglas committed Aug 3, 2023
1 parent 37326a0 commit 29c40fe
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 89 deletions.
4 changes: 3 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
"MERCURE_PUBLISHER_JWT_KEY": "!ChangeThisMercureHubJWTSecretKey!",
"MERCURE_SUBSCRIBER_JWT_KEY": "!ChangeThisMercureHubJWTSecretKey!",
"MERCURE_EXTRA_DIRECTIVES": "anonymous",
"GLOBAL_OPTIONS": "debug"
"GLOBAL_OPTIONS": "debug",
"SERVER_NAME": "localhost, host.docker.internal",
"EXTRA_DIRECTIVES": "tls internal"
},
"args": ["run", "--config", "../../Caddyfile.dev"]
}
Expand Down
2 changes: 1 addition & 1 deletion hub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func TestStop(t *testing.T) {
defer wg.Done()
req := httptest.NewRequest(http.MethodGet, defaultHubURL+"?topic=http://example.com/foo", nil)

w := httptest.NewRecorder()
w := newSubscribeRecorder()
hub.SubscribeHandler(w, req)

r := w.Result()
Expand Down
179 changes: 102 additions & 77 deletions subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,98 @@ package mercure

import (
"encoding/json"
"fmt"
"io"
"errors"
"net/http"
"time"

"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

type responseController struct {
http.ResponseController
rw http.ResponseWriter
end time.Time
hub *Hub
subscriber *Subscriber
}

func (rc *responseController) setDispatchWriteDeadline() bool {
if rc.hub.dispatchTimeout == 0 {
return true
}

deadline := time.Now().Add(rc.hub.dispatchTimeout)
if deadline.After(rc.end) {
return true
}

if err := rc.SetWriteDeadline(deadline); err != nil {
if c := rc.hub.logger.Check(zap.ErrorLevel, "Unable to set dispatch write deadline"); c != nil {
c.Write(zap.Object("subscriber", rc.subscriber), zap.Error(err))
}

return false
}

return true
}

func (rc *responseController) setDefaultWriteDeadline() bool {
if err := rc.SetWriteDeadline(rc.end); err != nil {
if errors.Is(err, http.ErrNotSupported) {
panic(err)
}

if c := rc.hub.logger.Check(zap.InfoLevel, "Error while setting default write deadline"); c != nil {
c.Write(zap.Object("subscriber", rc.subscriber), zap.Error(err))
}

return false
}

return true
}

func (rc *responseController) flush() bool {
if err := rc.Flush(); err != nil {
if errors.Is(err, http.ErrNotSupported) {
panic(err)
}

if c := rc.hub.logger.Check(zap.InfoLevel, "Unable to flush"); c != nil {
c.Write(zap.Object("subscriber", rc.subscriber), zap.Error(err))
}

return false
}

return true
}

func newResponseController(w http.ResponseWriter, h *Hub, s *Subscriber) *responseController {
var end time.Time
if h.writeTimeout != 0 {
end = time.Now().Add(h.writeTimeout)
}

if s.Claims != nil && s.Claims.ExpiresAt != nil && (end == time.Time{} || s.Claims.ExpiresAt.Time.Before(end)) {
end = s.Claims.ExpiresAt.Time
}

return &responseController{*http.NewResponseController(w), w, end, h, s} // nolint:bodyclose
}

// SubscribeHandler creates a keep alive connection and sends the events to the subscribers.
//
//nolint:funlen,gocognit
func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) {
assertFlusher(w)

s := h.registerSubscriber(w, r)
s, rc := h.registerSubscriber(w, r)
if s == nil {
return
}
defer h.shutdown(s)

var expireTimer *time.Timer
var expireTimerC <-chan time.Time
if s.Claims != nil && s.Claims.ExpiresAt != nil {
expireTimer = time.NewTimer(time.Until(s.Claims.ExpiresAt.Time))
defer expireTimer.Stop()
expireTimerC = expireTimer.C
}
rc.setDefaultWriteDeadline()

var heartbeatTimer *time.Timer
var heartbeatTimerC <-chan time.Time
Expand All @@ -39,42 +103,22 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) {
heartbeatTimerC = heartbeatTimer.C
}

var writeTimer *time.Timer
var writeTimerC <-chan time.Time
if h.writeTimeout != 0 {
writeTimer = time.NewTimer(h.writeTimeout - h.dispatchTimeout)
defer writeTimer.Stop()
writeTimerC = writeTimer.C
}

for {
select {
case <-r.Context().Done():
if c := h.logger.Check(zap.DebugLevel, "connection closed by the client"); c != nil {
c.Write(zap.Object("subscriber", s))
}

return
case <-expireTimerC:
if c := h.logger.Check(zap.DebugLevel, "JWT expired: close the connection"); c != nil {
c.Write(zap.Object("subscriber", s))
}

return
case <-writeTimerC:
if c := h.logger.Check(zap.DebugLevel, "write timeout: close the connection"); c != nil {
if c := h.logger.Check(zap.DebugLevel, "Connection closed by the client"); c != nil {
c.Write(zap.Object("subscriber", s))
}

return
case <-heartbeatTimerC:
// Send a SSE comment as a heartbeat, to prevent issues with some proxies and old browsers
if !h.write(w, s, ":\n") {
if !h.write(rc, ":\n") {
return
}
heartbeatTimer.Reset(h.heartbeat)
case update, ok := <-s.Receive():
if !ok || !h.write(w, s, newSerializedUpdate(update).event) {
if !ok || !h.write(rc, newSerializedUpdate(update).event) {
return
}
if heartbeatTimer != nil {
Expand All @@ -91,7 +135,7 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) {
}

// registerSubscriber initializes the connection.
func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) *Subscriber {
func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) (*Subscriber, *responseController) {
s := NewSubscriber(retrieveLastEventID(r, h.opt, h.logger), h.logger)
s.Debug = h.debug
s.RemoteAddr = r.RemoteAddr
Expand All @@ -111,15 +155,15 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) *Subscr
c.Write(zap.Object("subscriber", s), zap.Error(err))
}

return nil
return nil, nil
}
}

topics := r.URL.Query()["topic"]
if len(topics) == 0 {
http.Error(w, "Missing \"topic\" parameter.", http.StatusBadRequest)

return nil
return nil, nil
}
s.SetTopics(topics, privateTopics)

Expand All @@ -131,10 +175,10 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) *Subscr
c.Write(zap.Object("subscriber", s), zap.Error(err))
}

return nil
return nil, nil
}

sendHeaders(w, s)
rc := h.sendHeaders(w, s)

if c := h.logger.Check(zap.InfoLevel, "New subscriber"); c != nil {
fields := []LogField{zap.Object("subscriber", s)}
Expand All @@ -146,11 +190,11 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request) *Subscr
}
h.metrics.SubscriberConnected(s)

return s
return s, rc
}

// sendHeaders sends correct HTTP headers to create a keep-alive connection.
func sendHeaders(w http.ResponseWriter, s *Subscriber) {
func (h *Hub) sendHeaders(w http.ResponseWriter, s *Subscriber) *responseController {
// Keep alive, useful only for HTTP 1 clients https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Keep-Alive
w.Header().Set("Connection", "keep-alive")

Expand All @@ -171,8 +215,12 @@ func sendHeaders(w http.ResponseWriter, s *Subscriber) {

// Write a comment in the body
// Go currently doesn't provide a better way to flush the headers
fmt.Fprint(w, ":\n")
w.(http.Flusher).Flush()
w.Write([]byte{':', '\n'})

rc := newResponseController(w, h, s)
rc.flush()

return rc
}

// retrieveLastEventID extracts the Last-Event-ID from the corresponding HTTP header with a fallback on the query parameter.
Expand Down Expand Up @@ -202,38 +250,21 @@ func retrieveLastEventID(r *http.Request, opt *opt, logger Logger) string {
}

// Write sends the given string to the client.
// It returns false if the dispatch timed out.
// The current write cannot be cancelled because of https://github.com/golang/go/issues/16100
func (h *Hub) write(w io.Writer, s zapcore.ObjectMarshaler, data string) bool {
if h.dispatchTimeout == 0 {
fmt.Fprint(w, data)
w.(http.Flusher).Flush()

return true
// It returns false if the subscriber has been disconnected (e.g. timeout).
func (h *Hub) write(rc *responseController, data string) bool {
if !rc.setDispatchWriteDeadline() {
return false
}

done := make(chan struct{})
go func() {
fmt.Fprint(w, data)
w.(http.Flusher).Flush()
close(done)
}()

timeout := time.NewTimer(h.dispatchTimeout)
defer timeout.Stop()
select {
case <-done:
return true
case <-timeout.C:
if c := h.logger.Check(zap.WarnLevel, "Dispatch timeout reached"); c != nil {
c.Write(zap.Object("subscriber", s))
if _, err := rc.rw.Write([]byte(data)); err != nil {
if c := h.logger.Check(zap.DebugLevel, "Error writing to client"); c != nil {
c.Write(zap.Object("subscriber", rc.subscriber), zap.Error(err))
}

// wait for the dispatch goroutine to finish
<-done

return false
}

return rc.flush() && rc.setDefaultWriteDeadline()
}

func (h *Hub) shutdown(s *Subscriber) {
Expand Down Expand Up @@ -267,9 +298,3 @@ func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) {
h.transport.Dispatch(u)
}
}

func assertFlusher(w http.ResponseWriter) {
if _, ok := w.(http.Flusher); !ok {
panic("http.ResponseWriter must be an instance of http.Flusher")
}
}
Loading

0 comments on commit 29c40fe

Please sign in to comment.