diff --git a/context.go b/context.go index baa4b0f9c9..e61bb73705 100644 --- a/context.go +++ b/context.go @@ -1137,10 +1137,9 @@ func (c *Context) SSEvent(name string, message any) { // indicates "Is client disconnected in middle of stream" func (c *Context) Stream(step func(w io.Writer) bool) bool { w := c.Writer - clientGone := w.CloseNotify() for { select { - case <-clientGone: + case <-c.Request.Context().Done(): return true default: keepOpen := step(w) diff --git a/context_test.go b/context_test.go index 8bbf270086..36522fb514 100644 --- a/context_test.go +++ b/context_test.go @@ -2517,10 +2517,6 @@ func (r *TestResponseRecorder) CloseNotify() <-chan bool { return r.closeChannel } -func (r *TestResponseRecorder) closeClient() { - r.closeChannel <- true -} - func CreateTestResponseRecorder() *TestResponseRecorder { return &TestResponseRecorder{ httptest.NewRecorder(), @@ -2531,6 +2527,7 @@ func CreateTestResponseRecorder() *TestResponseRecorder { func TestContextStream(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w) + c.Request, _ = http.NewRequest(http.MethodGet, "", nil) stopStream := true c.Stream(func(w io.Writer) bool { @@ -2550,10 +2547,12 @@ func TestContextStream(t *testing.T) { func TestContextStreamWithClientGone(t *testing.T) { w := CreateTestResponseRecorder() c, _ := CreateTestContext(w) + done, cancel := context.WithCancel(context.Background()) + c.Request, _ = http.NewRequestWithContext(done, http.MethodGet, "", nil) c.Stream(func(writer io.Writer) bool { defer func() { - w.closeClient() + cancel() }() _, err := writer.Write([]byte("test"))