diff --git a/httphelpers/handlers_sse.go b/httphelpers/handlers_sse.go index 9256219..d33a943 100644 --- a/httphelpers/handlers_sse.go +++ b/httphelpers/handlers_sse.go @@ -106,6 +106,23 @@ func SSEHandler(initialEvent *SSEEvent) (http.Handler, SSEStreamControl) { return handler, &sseStreamControlImpl{streamControl} } +// SSEHandlerWithEnvironmentID creates an HTTP handler that streams Server-Sent Events data. +// +// The behavior is exactly the same as SSEHandler except environmentID will be returned in +// the response header X-Ld-Envid. +func SSEHandlerWithEnvironmentID(initialEvent *SSEEvent, environmentID string) (http.Handler, SSEStreamControl) { + var initialData []byte + if initialEvent != nil { + initialData = initialEvent.Bytes() + } + handler, streamControl := ChunkedStreamingHandler( + initialData, + "text/event-stream; charset=utf-8", + ChunkedStreamingHandlerOptionEnvironmentID(environmentID), + ) + return handler, &sseStreamControlImpl{streamControl} +} + func (s *sseStreamControlImpl) Enqueue(event SSEEvent) { s.streamControl.Enqueue(event.Bytes()) } diff --git a/httphelpers/handlers_sse_test.go b/httphelpers/handlers_sse_test.go index 18bda27..a25a553 100644 --- a/httphelpers/handlers_sse_test.go +++ b/httphelpers/handlers_sse_test.go @@ -50,3 +50,29 @@ data: data3 `, string(data)) }) } + +func TestSSEHandlerWithEnvironmentID(t *testing.T) { + initialEvent := SSEEvent{"id1", "event1", "data1", 0} + handler, stream := SSEHandlerWithEnvironmentID(&initialEvent, "env-id") + defer stream.Close() + + WithServer(handler, func(server *httptest.Server) { + resp1, err := http.DefaultClient.Get(server.URL) + require.NoError(t, err) + defer resp1.Body.Close() + + assert.Equal(t, 200, resp1.StatusCode) + assert.Equal(t, "text/event-stream; charset=utf-8", resp1.Header.Get("Content-Type")) + assert.Equal(t, "env-id", resp1.Header.Get("X-Ld-Envid")) + + stream.EndAll() + + data, err := io.ReadAll(resp1.Body) + assert.NoError(t, err) + assert.Equal(t, `id: id1 +event: event1 +data: data1 + +`, string(data)) + }) +} diff --git a/httphelpers/handlers_streaming.go b/httphelpers/handlers_streaming.go index 10acd8e..71d7a33 100644 --- a/httphelpers/handlers_streaming.go +++ b/httphelpers/handlers_streaming.go @@ -25,6 +25,25 @@ type StreamControl interface { Close() error } +// ChunkedStreamingHandlerOption is a common interface for optional configuration parameters that +// can be used in creating a ChunkedStreamingHandler. +type ChunkedStreamingHandlerOption interface { + apply(h *chunkedStreamingHandlerImpl) +} + +type environmentIDChunkedStreamingHandlerOption string + +func (o environmentIDChunkedStreamingHandlerOption) apply(h *chunkedStreamingHandlerImpl) { + h.environmentID = string(o) +} + +// ChunkedStreamingHandlerOptionEnvironmentID returns an option that sets the environment ID +// for a ChunkedStreamingHandler when the handler is created. The environment ID will be +// returned in the response header X-Ld-Envid. +func ChunkedStreamingHandlerOptionEnvironmentID(environmentID string) ChunkedStreamingHandlerOption { + return environmentIDChunkedStreamingHandlerOption(environmentID) +} + // ChunkedStreamingHandler creates an HTTP handler that streams arbitrary data using chunked encoding. // // The initialData parameter, if not nil, specifies a starting chunk that should always be sent to any @@ -52,21 +71,29 @@ type StreamControl interface { // } // } // }() -func ChunkedStreamingHandler(initialChunk []byte, contentType string) (http.Handler, StreamControl) { +func ChunkedStreamingHandler( + initialChunk []byte, + contentType string, + options ...ChunkedStreamingHandlerOption, +) (http.Handler, StreamControl) { sh := &chunkedStreamingHandlerImpl{ initialChunk: initialChunk, contentType: contentType, } + for _, o := range options { + o.apply(sh) + } return sh, sh } type chunkedStreamingHandlerImpl struct { - initialChunk []byte - contentType string - queued [][]byte - channels []chan []byte - closed bool - lock sync.Mutex + initialChunk []byte + contentType string + queued [][]byte + channels []chan []byte + closed bool + lock sync.Mutex + environmentID string } func (s *chunkedStreamingHandlerImpl) Enqueue(data []byte) { @@ -173,6 +200,9 @@ func (s *chunkedStreamingHandlerImpl) ServeHTTP(w http.ResponseWriter, r *http.R h := w.Header() h.Set("Content-Type", s.contentType) h.Set("Cache-Control", "no-cache, no-store, must-revalidate") + if len(s.environmentID) > 0 { + h.Set("X-Ld-Envid", s.environmentID) + } if s.initialChunk != nil { _, _ = w.Write(s.initialChunk) diff --git a/httphelpers/handlers_streaming_test.go b/httphelpers/handlers_streaming_test.go index bd8243c..59f6f68 100644 --- a/httphelpers/handlers_streaming_test.go +++ b/httphelpers/handlers_streaming_test.go @@ -123,3 +123,29 @@ func TestChunkedStreamingHandlerClose(t *testing.T) { assert.Equal(t, 500, resp2.StatusCode) }) } + +func TestChunkedStreamingHandlerWithEnvironmentID(t *testing.T) { + initialData := []byte("hello") + handler, stream := ChunkedStreamingHandler( + initialData, + "text/plain", + ChunkedStreamingHandlerOptionEnvironmentID("env-id"), + ) + defer stream.Close() + + WithServer(handler, func(server *httptest.Server) { + resp1, err := http.DefaultClient.Get(server.URL) + require.NoError(t, err) + defer resp1.Body.Close() + + assert.Equal(t, 200, resp1.StatusCode) + assert.Equal(t, "text/plain", resp1.Header.Get("Content-Type")) + assert.Equal(t, "env-id", resp1.Header.Get("X-Ld-Envid")) + + stream.EndAll() + + data, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + assert.Equal(t, "hello", string(data)) + }) +}