Skip to content

Commit d94d846

Browse files
committed
chore: fixing race
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 7e07b1d commit d94d846

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

intercept_anthropic_messages_streaming.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ newStream:
414414
prompt = nil
415415
}
416416

417-
if events.hasInitiated() {
417+
if events.isStreaming() {
418418
// Check if the stream encountered any errors.
419419
if streamErr := stream.Err(); streamErr != nil {
420420
if isUnrecoverableError(streamErr) {

intercept_openai_chat_streaming.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
172172
})
173173
}
174174

175-
if events.hasInitiated() {
175+
if events.isStreaming() {
176176
// Check if the stream encountered any errors.
177177
if streamErr := stream.Err(); streamErr != nil {
178178
if isUnrecoverableError(streamErr) {

streaming.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ type eventStream struct {
2828

2929
pingPayload []byte
3030

31-
initiated atomic.Bool
32-
31+
initiated atomic.Bool
3332
initiateOnce sync.Once
33+
3434
closeOnce sync.Once
3535
shutdownOnce sync.Once
3636
eventsCh chan event
@@ -79,14 +79,22 @@ func (s *eventStream) start(w http.ResponseWriter, r *http.Request) {
7979
return
8080
case ev, open = <-s.eventsCh: // Once closed, the buffered channel will drain all buffered values before showing as closed.
8181
if !open {
82+
s.logger.Debug(ctx, "events channel closed")
8283
return
8384
}
8485

8586
// Initiate the stream once the first event is received.
8687
s.initiateOnce.Do(func() {
8788
s.initiated.Store(true)
89+
s.logger.Debug(ctx, "stream initiated")
8890

8991
// Send headers for Server-Sent Event stream.
92+
//
93+
// We only send these once an event is processed because an error can occur in the upstream
94+
// request prior to the stream starting, in which case the SSE headers are inappropriate to
95+
// send to the client.
96+
//
97+
// See use of isStreaming().
9098
w.Header().Set("Content-Type", "text/event-stream")
9199
w.Header().Set("Cache-Control", "no-cache")
92100
w.Header().Set("Connection", "keep-alive")
@@ -185,8 +193,10 @@ func (s *eventStream) Shutdown(shutdownCtx context.Context) error {
185193
return err
186194
}
187195

188-
func (s *eventStream) hasInitiated() bool {
189-
return s.initiated.Load()
196+
// isStreaming checks if the stream has been initiated, or
197+
// when events are buffered which - when processed - will initiate the stream.
198+
func (s *eventStream) isStreaming() bool {
199+
return s.initiated.Load() || len(s.eventsCh) > 0
190200
}
191201

192202
// isConnError checks if an error is related to client disconnection or context cancellation.

0 commit comments

Comments
 (0)