diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 3bdeb65e2..8d78fb95a 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -308,14 +308,52 @@ describe('SSEClientTransport', () => { await transport.start(); - // Store original fetch const originalFetch = global.fetch; + try { + global.fetch = vi.fn().mockResolvedValue({ ok: true }); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + await transport.send(message); + + const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers; + expect(calledHeaders.get('Authorization')).toBe('Bearer test-token'); + expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value'); + expect(calledHeaders.get('content-type')).toBe('application/json'); + + customHeaders['X-Custom-Header'] = 'updated-value'; + + await transport.send(message); + + const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers; + expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value'); + } finally { + global.fetch = originalFetch; + } + }); + + it('passes custom headers to fetch requests (Headers class)', async () => { + const customHeaders = new Headers({ + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value' + }); + + transport = new SSEClientTransport(resourceBaseUrl, { + requestInit: { + headers: customHeaders + } + }); + + await transport.start(); + const originalFetch = global.fetch; try { - // Mock fetch for the message sending test - global.fetch = vi.fn().mockResolvedValue({ - ok: true - }); + global.fetch = vi.fn().mockResolvedValue({ ok: true }); const message: JSONRPCMessage = { jsonrpc: '2.0', @@ -326,20 +364,45 @@ describe('SSEClientTransport', () => { await transport.send(message); - // Verify fetch was called with correct headers - expect(global.fetch).toHaveBeenCalledWith( - expect.any(URL), - expect.objectContaining({ - headers: expect.any(Headers) - }) - ); + const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers; + expect(calledHeaders.get('Authorization')).toBe('Bearer test-token'); + expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value'); + expect(calledHeaders.get('content-type')).toBe('application/json'); + + customHeaders.set('X-Custom-Header', 'updated-value'); + + await transport.send(message); + + const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers; + expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value'); + } finally { + global.fetch = originalFetch; + } + }); + + it('passes custom headers to fetch requests (array of tuples)', async () => { + transport = new SSEClientTransport(resourceBaseUrl, { + requestInit: { + headers: [ + ['Authorization', 'Bearer test-token'], + ['X-Custom-Header', 'custom-value'] + ] + } + }); + + await transport.start(); + + const originalFetch = global.fetch; + try { + global.fetch = vi.fn().mockResolvedValue({ ok: true }); + + await transport.send({ jsonrpc: '2.0', id: '1', method: 'test', params: {} }); const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers; - expect(calledHeaders.get('Authorization')).toBe(customHeaders.Authorization); - expect(calledHeaders.get('X-Custom-Header')).toBe(customHeaders['X-Custom-Header']); + expect(calledHeaders.get('Authorization')).toBe('Bearer test-token'); + expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value'); expect(calledHeaders.get('content-type')).toBe('application/json'); } finally { - // Restore original fetch global.fetch = originalFetch; } }); diff --git a/src/client/sse.ts b/src/client/sse.ts index 94eafe1b1..2b0661958 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,5 +1,5 @@ import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource'; -import { Transport, FetchLike, createFetchWithInit } from '../shared/transport.js'; +import { Transport, FetchLike, createFetchWithInit, normalizeHeaders } from '../shared/transport.js'; import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js'; import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js'; @@ -114,7 +114,7 @@ export class SSEClientTransport implements Transport { } private async _commonHeaders(): Promise { - const headers: HeadersInit = {}; + const headers: HeadersInit & Record = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { @@ -125,7 +125,12 @@ export class SSEClientTransport implements Transport { headers['mcp-protocol-version'] = this._protocolVersion; } - return new Headers({ ...headers, ...this._requestInit?.headers }); + const extraHeaders = normalizeHeaders(this._requestInit?.headers); + + return new Headers({ + ...headers, + ...extraHeaders + }); } private _startOrAuth(): Promise { diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 596ad2310..db836d127 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -480,6 +480,7 @@ describe('StreamableHTTPClientTransport', () => { it('should always send specified custom headers', async () => { const requestInit = { headers: { + Authorization: 'Bearer test-token', 'X-Custom-Header': 'CustomValue' } }; @@ -497,6 +498,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.start(); await transport['_startOrAuthSse']({}); + expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token'); expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); requestInit.headers['X-Custom-Header'] = 'SecondCustomValue'; @@ -510,6 +512,7 @@ describe('StreamableHTTPClientTransport', () => { it('should always send specified custom headers (Headers class)', async () => { const requestInit = { headers: new Headers({ + Authorization: 'Bearer test-token', 'X-Custom-Header': 'CustomValue' }) }; @@ -527,6 +530,7 @@ describe('StreamableHTTPClientTransport', () => { await transport.start(); await transport['_startOrAuthSse']({}); + expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token'); expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); (requestInit.headers as Headers).set('X-Custom-Header', 'SecondCustomValue'); @@ -537,6 +541,30 @@ describe('StreamableHTTPClientTransport', () => { expect(global.fetch).toHaveBeenCalledTimes(2); }); + it('should always send specified custom headers (array of tuples)', async () => { + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + requestInit: { + headers: [ + ['Authorization', 'Bearer test-token'], + ['X-Custom-Header', 'CustomValue'] + ] + } + }); + + let actualReqInit: RequestInit = {}; + + (global.fetch as Mock).mockImplementation(async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); + }); + + await transport.start(); + + await transport['_startOrAuthSse']({}); + expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token'); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); + }); + it('should have exponential backoff with configurable maxRetries', () => { // This test verifies the maxRetries and backoff calculation directly