Skip to content

Commit

Permalink
Fix writeEvents race condition.
Browse files Browse the repository at this point in the history
This required removing the compress middleware from the /events route.
  • Loading branch information
deluan committed Apr 2, 2023
1 parent 83ae2ba commit 1c7fb74
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 195 deletions.
6 changes: 3 additions & 3 deletions cmd/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion reflex.conf
@@ -1 +1 @@
-s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -tags netgo .
-s -r "(\.go$$|\.cpp$$|\.h$$|navidrome.toml|resources|token_received.html)" -R "(^ui|^data|^db/migration)" -- go run -race -tags netgo .
60 changes: 28 additions & 32 deletions server/events/sse.go
Expand Up @@ -3,7 +3,6 @@ package events

import (
"context"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -93,38 +92,35 @@ func (b *broker) prepareMessage(ctx context.Context, event Event) message {
return msg
}

var errWriteTimeOut = errors.New("write timeout")

// writeEvent writes a message to the given io.Writer, formatted as a Server-Sent Event.
// If the writer is an http.Flusher, it flushes the data immediately instead of buffering it.
// The function waits for the message to be written or times out after the specified timeout.
func writeEvent(w io.Writer, event message, timeout time.Duration) error {
// Create a context with a timeout based on the event's sender context.
ctx, cancel := context.WithTimeout(event.senderCtx, timeout)
defer cancel()

// Create a channel to signal the completion of writing.
errC := make(chan error, 1)

// Start a goroutine to write the event and optionally flush the writer.
go func() {
_, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)

// If the writer is an http.Flusher, flush the data immediately.
if flusher, ok := w.(http.Flusher); ok && flusher != nil {
flusher.Flush()
}

// Signal that writing is complete.
errC <- err
}()
func writeEvent(ctx context.Context, w io.Writer, event message, timeout time.Duration) error {
if err := setWriteTimeout(w, timeout); err != nil {
log.Debug(ctx, "Error setting write timeout", err)
}

// Wait for either the write completion or the context to time out.
select {
case err := <-errC:
_, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)
if err != nil {
return err
case <-ctx.Done():
return errWriteTimeOut
}

// If the writer is an http.Flusher, flush the data immediately.
if flusher, ok := w.(http.Flusher); ok && flusher != nil {
flusher.Flush()
}
return nil
}

func setWriteTimeout(rw io.Writer, timeout time.Duration) error {
for {
switch t := rw.(type) {
case interface{ SetWriteDeadline(time.Time) error }:
return t.SetWriteDeadline(time.Now().Add(timeout))
case interface{ Unwrap() http.ResponseWriter }:
rw = t.Unwrap()
default:
return fmt.Errorf("%T - %w", rw, http.ErrNotSupported)
}
}
}

Expand Down Expand Up @@ -160,9 +156,9 @@ func (b *broker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
log.Trace(ctx, "Sending event to client", "event", *event, "client", c.String())
if err := writeEvent(w, *event, writeTimeOut); errors.Is(err, errWriteTimeOut) {
log.Debug(ctx, "Timeout sending event to client", "event", *event, "client", c.String())
return
err := writeEvent(ctx, w, *event, writeTimeOut)
if err != nil {
log.Debug(ctx, "Error sending event to client", "event", *event, "client", c.String(), err)
}
}
}
Expand Down
127 changes: 0 additions & 127 deletions server/events/sse_test.go
@@ -1,12 +1,7 @@
package events

import (
"bytes"
"context"
"fmt"
"io"
"sync/atomic"
"time"

"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
Expand Down Expand Up @@ -63,126 +58,4 @@ var _ = Describe("Broker", func() {
})
})
})

Describe("writeEvent", func() {
var (
timeout time.Duration
buffer *bytes.Buffer
event message
senderCtx context.Context
cancel context.CancelFunc
)

BeforeEach(func() {
buffer = &bytes.Buffer{}
senderCtx, cancel = context.WithCancel(context.Background())
DeferCleanup(cancel)
})

Context("with an HTTP flusher", func() {
var flusher *fakeFlusher

BeforeEach(func() {
flusher = &fakeFlusher{Writer: buffer}
event = message{
senderCtx: senderCtx,
id: 1,
event: "test",
data: "testdata",
}
})

Context("when the write completes before the timeout", func() {
BeforeEach(func() {
timeout = 1 * time.Second
})
It("should successfully write the event", func() {
err := writeEvent(flusher, event, timeout)
Expect(err).NotTo(HaveOccurred())
Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)))
Expect(flusher.flushed.Load()).To(BeTrue())
})
})

Context("when the write does not complete before the timeout", func() {
BeforeEach(func() {
timeout = 1 * time.Millisecond
flusher.delay = 2 * time.Second
})

It("should return an errWriteTimeOut error", func() {
err := writeEvent(flusher, event, timeout)
Expect(err).To(MatchError(errWriteTimeOut))
Expect(flusher.flushed.Load()).To(BeFalse())
})
})

Context("without an HTTP flusher", func() {
var writer *fakeWriter

BeforeEach(func() {
writer = &fakeWriter{Writer: buffer}
event = message{
senderCtx: senderCtx,
id: 1,
event: "test",
data: "testdata",
}
})

Context("when the write completes before the timeout", func() {
BeforeEach(func() {
timeout = 1 * time.Second
})

It("should successfully write the event", func() {
err := writeEvent(writer, event, timeout)
Expect(err).NotTo(HaveOccurred())
Eventually(writer.done.Load).Should(BeTrue())
Expect(buffer.String()).To(Equal(fmt.Sprintf("id: %d\nevent: %s\ndata: %s\n\n", event.id, event.event, event.data)))
})
})

Context("when the write does not complete before the timeout", func() {
BeforeEach(func() {
timeout = 1 * time.Millisecond
writer.delay = 2 * time.Second
})

It("should return an errWriteTimeOut error", func() {
err := writeEvent(writer, event, timeout)
Expect(err).To(MatchError(errWriteTimeOut))
Expect(writer.done.Load()).To(BeFalse())
})
})
})
})
})
})

type fakeWriter struct {
io.Writer
delay time.Duration
done atomic.Bool
}

func (f *fakeWriter) Write(p []byte) (n int, err error) {
time.Sleep(f.delay)
f.done.Store(true)
return f.Writer.Write(p)
}

type fakeFlusher struct {
io.Writer
delay time.Duration
flushed atomic.Bool
}

func (f *fakeFlusher) Write(p []byte) (n int, err error) {
time.Sleep(f.delay)
return f.Writer.Write(p)
}

func (f *fakeFlusher) Flush() {
f.flushed.Store(true)
}
14 changes: 4 additions & 10 deletions server/nativeapi/native_api.go
Expand Up @@ -10,18 +10,16 @@ import (
"github.com/navidrome/navidrome/core"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/server"
"github.com/navidrome/navidrome/server/events"
)

type Router struct {
http.Handler
ds model.DataStore
broker events.Broker
share core.Share
ds model.DataStore
share core.Share
}

func New(ds model.DataStore, broker events.Broker, share core.Share) *Router {
r := &Router{ds: ds, broker: broker, share: share}
func New(ds model.DataStore, share core.Share) *Router {
r := &Router{ds: ds, share: share}
r.Handler = r.routes()
return r
}
Expand Down Expand Up @@ -55,10 +53,6 @@ func (n *Router) routes() http.Handler {
r.Get("/keepalive/*", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"response":"ok", "id":"keepalive"}`))
})

if conf.Server.DevActivityPanel {
r.Handle("/events", n.broker)
}
})

return r
Expand Down

0 comments on commit 1c7fb74

Please sign in to comment.