Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,6 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl
// Create a channel for this specific request
responseChan := make(chan *JSONRPCResponse, 1)

// Add timeout context for request processing if not already set
if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -601,8 +594,7 @@ func (c *StreamableHTTP) IsOAuthEnabled() bool {
func (c *StreamableHTTP) listenForever(ctx context.Context) {
c.logger.Infof("listening to server forever")
for {
// Add timeout for individual connection attempts
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
connectCtx, cancel := context.WithCancel(ctx)
err := c.createGETConnectionToServer(connectCtx)
cancel()

Expand Down Expand Up @@ -771,7 +763,7 @@ func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSO
ctx, cancel := c.contextAwareOfClientClose(ctx)
defer cancel()

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json")
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream")
if err != nil {
c.logger.Errorf("failed to send response to server: %v", err)
return
Expand Down
69 changes: 68 additions & 1 deletion client/transport/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,23 @@ func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bo
return
}

// Handle client JSON-RPC responses (e.g., ping replies)
if request["jsonrpc"] == "2.0" && request["id"] != nil && request["method"] == nil {
if _, hasResult := request["result"]; hasResult {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusAccepted)
if err := json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": "response received",
}); err != nil {
http.Error(w, "Failed to encode response acknowledgment", http.StatusInternalServerError)
return
}
return
}
}

method := request["method"]
if method == "initialize" {
// Generate a new session ID
Expand Down Expand Up @@ -627,7 +644,31 @@ func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bo
fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData)
flusher.Flush()
sendNotification()
return
}

// Keep the connection open, send periodic pings
pingTicker := time.NewTicker(3 * time.Second)
defer pingTicker.Stop()

for {
select {
case <-disconnectCh:
// Force disconnect
return
case <-r.Context().Done():
// Client disconnected
return
case <-pingTicker.C:
// Send ping message according to MCP specification
pingMessage := map[string]any{
"jsonrpc": "2.0",
"id": fmt.Sprintf("ping-%d", time.Now().UnixNano()),
"method": "ping",
}
pingData, _ := json.Marshal(pingMessage)
fmt.Fprintf(w, "event: message\ndata: %s\n\n", pingData)
flusher.Flush()
}
}
} else {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
Expand Down Expand Up @@ -667,6 +708,23 @@ func TestContinuousListening(t *testing.T) {
notificationReceived <- struct{}{}
})

// Setup ping handler
pingReceived := make(chan struct{}, 10)

// Setup request handler for ping requests
trans.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
if request.Method == "ping" {
pingReceived <- struct{}{}
// Return proper ping response according to MCP specification
return &JSONRPCResponse{
JSONRPC: "2.0",
ID: request.ID,
Result: json.RawMessage("{}"),
}, nil
}
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
})

// Start the transport - this will launch listenForever in a goroutine
if err := trans.Start(context.Background()); err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -714,6 +772,15 @@ func TestContinuousListening(t *testing.T) {
return
}
}

// Wait for at least one ping to be received (should happen within 3 seconds)
select {
case <-pingReceived:
t.Log("Received ping message successfully")
time.Sleep(10 * time.Millisecond) // Allow time for response
case <-time.After(5 * time.Second):
t.Errorf("Expected to receive ping message within 5 seconds, but didn't")
}
}

func TestContinuousListeningMethodNotAllowed(t *testing.T) {
Expand Down