diff --git a/.changeset/fresh-buses-auth.md b/.changeset/fresh-buses-auth.md new file mode 100644 index 0000000000..6727dd8faa --- /dev/null +++ b/.changeset/fresh-buses-auth.md @@ -0,0 +1,5 @@ +--- +'@modelcontextprotocol/client': patch +--- + +Let auth provider headers override requestInit headers in client transports. diff --git a/packages/client/src/client/sse.ts b/packages/client/src/client/sse.ts index bf554aba29..51d96e777d 100644 --- a/packages/client/src/client/sse.ts +++ b/packages/client/src/client/sse.ts @@ -120,8 +120,8 @@ export class SSEClientTransport implements Transport { const extraHeaders = normalizeHeaders(this._requestInit?.headers); return new Headers({ - ...headers, - ...extraHeaders + ...extraHeaders, + ...headers }); } diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index 3b8ddafe5a..aca6126159 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -226,8 +226,8 @@ export class StreamableHTTPClientTransport implements Transport { const extraHeaders = normalizeHeaders(this._requestInit?.headers); return new Headers({ - ...headers, - ...extraHeaders + ...extraHeaders, + ...headers }); } diff --git a/packages/client/test/client/sse.test.ts b/packages/client/test/client/sse.test.ts index 6948d9a4e0..384416ca97 100644 --- a/packages/client/test/client/sse.test.ts +++ b/packages/client/test/client/sse.test.ts @@ -344,6 +344,44 @@ describe('SSEClientTransport', () => { } }); + it('lets auth provider headers override custom Authorization headers', async () => { + const authProvider: AuthProvider = { + token: async () => 'fresh-token' + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider, + requestInit: { + headers: { + Authorization: 'Bearer stale-token', + 'X-Custom-Header': 'custom-value' + } + } + }); + + await transport.start(); + + const originalFetch = globalThis.fetch; + try { + globalThis.fetch = vi.fn().mockResolvedValue({ ok: true }); + + const message: JSONRPCMessage = { + jsonrpc: '2.0', + id: '1', + method: 'test', + params: {} + }; + + await transport.send(message); + + const calledHeaders = (globalThis.fetch as Mock).mock.calls[0]![1].headers; + expect(calledHeaders.get('Authorization')).toBe('Bearer fresh-token'); + expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value'); + } finally { + globalThis.fetch = originalFetch; + } + }); + it('passes custom headers to fetch requests (Headers class)', async () => { const customHeaders = new Headers({ Authorization: 'Bearer test-token', diff --git a/packages/client/test/client/streamableHttp.test.ts b/packages/client/test/client/streamableHttp.test.ts index 0edf8b75ac..08ed5678be 100644 --- a/packages/client/test/client/streamableHttp.test.ts +++ b/packages/client/test/client/streamableHttp.test.ts @@ -2,7 +2,7 @@ import type { JSONRPCMessage, JSONRPCRequest } from '@modelcontextprotocol/core' import { OAuthError, OAuthErrorCode, SdkErrorCode, SdkHttpError } from '@modelcontextprotocol/core'; import type { Mock, Mocked } from 'vitest'; -import type { OAuthClientProvider } from '../../src/client/auth.js'; +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; import { UnauthorizedError } from '../../src/client/auth.js'; import type { ReconnectionScheduler, StartSSEOptions, StreamableHTTPReconnectionOptions } from '../../src/client/streamableHttp.js'; import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; @@ -573,6 +573,36 @@ describe('StreamableHTTPClientTransport', () => { expect(globalThis.fetch).toHaveBeenCalledTimes(2); }); + it('should let auth provider headers override custom Authorization headers', async () => { + const authProvider: AuthProvider = { + token: async () => 'fresh-token' + }; + + transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), { + authProvider, + requestInit: { + headers: { + Authorization: 'Bearer stale-token', + 'X-Custom-Header': 'CustomValue' + } + } + }); + + let actualReqInit: RequestInit = {}; + + (globalThis.fetch as Mock).mockImplementation(async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } }); + }); + + await transport.start(); + + await transport['_startOrAuthSse']({}); + + expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer fresh-token'); + expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue'); + }); + it('should always send specified custom headers (Headers class)', async () => { const requestInit = { headers: new Headers({