Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,52 @@ describe('SSEClientTransport', () => {

await transport.start();

// Store original fetch
const originalFetch = global.fetch;
try {
global.fetch = vi.fn().mockResolvedValue({ ok: true });

const message: JSONRPCMessage = {
jsonrpc: '2.0',
id: '1',
method: 'test',
params: {}
};

await transport.send(message);

const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
expect(calledHeaders.get('content-type')).toBe('application/json');

customHeaders['X-Custom-Header'] = 'updated-value';

await transport.send(message);

const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers;
expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value');
} finally {
global.fetch = originalFetch;
}
});

it('passes custom headers to fetch requests (Headers class)', async () => {
const customHeaders = new Headers({
Authorization: 'Bearer test-token',
'X-Custom-Header': 'custom-value'
});

transport = new SSEClientTransport(resourceBaseUrl, {
requestInit: {
headers: customHeaders
}
});

await transport.start();

const originalFetch = global.fetch;
try {
// Mock fetch for the message sending test
global.fetch = vi.fn().mockResolvedValue({
ok: true
});
global.fetch = vi.fn().mockResolvedValue({ ok: true });

const message: JSONRPCMessage = {
jsonrpc: '2.0',
Expand All @@ -326,20 +364,45 @@ describe('SSEClientTransport', () => {

await transport.send(message);

// Verify fetch was called with correct headers
expect(global.fetch).toHaveBeenCalledWith(
expect.any(URL),
expect.objectContaining({
headers: expect.any(Headers)
})
);
const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
expect(calledHeaders.get('content-type')).toBe('application/json');

customHeaders.set('X-Custom-Header', 'updated-value');

await transport.send(message);

const updatedHeaders = (global.fetch as Mock).mock.calls[1][1].headers;
expect(updatedHeaders.get('X-Custom-Header')).toBe('updated-value');
} finally {
global.fetch = originalFetch;
}
});

it('passes custom headers to fetch requests (array of tuples)', async () => {
transport = new SSEClientTransport(resourceBaseUrl, {
requestInit: {
headers: [
['Authorization', 'Bearer test-token'],
['X-Custom-Header', 'custom-value']
]
}
});

await transport.start();

const originalFetch = global.fetch;
try {
global.fetch = vi.fn().mockResolvedValue({ ok: true });

await transport.send({ jsonrpc: '2.0', id: '1', method: 'test', params: {} });

const calledHeaders = (global.fetch as Mock).mock.calls[0][1].headers;
expect(calledHeaders.get('Authorization')).toBe(customHeaders.Authorization);
expect(calledHeaders.get('X-Custom-Header')).toBe(customHeaders['X-Custom-Header']);
expect(calledHeaders.get('Authorization')).toBe('Bearer test-token');
expect(calledHeaders.get('X-Custom-Header')).toBe('custom-value');
expect(calledHeaders.get('content-type')).toBe('application/json');
} finally {
// Restore original fetch
global.fetch = originalFetch;
}
});
Expand Down
11 changes: 8 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EventSource, type ErrorEvent, type EventSourceInit } from 'eventsource';
import { Transport, FetchLike, createFetchWithInit } from '../shared/transport.js';
import { Transport, FetchLike, createFetchWithInit, normalizeHeaders } from '../shared/transport.js';
import { JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';

Expand Down Expand Up @@ -114,7 +114,7 @@ export class SSEClientTransport implements Transport {
}

private async _commonHeaders(): Promise<Headers> {
const headers: HeadersInit = {};
const headers: HeadersInit & Record<string, string> = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
Expand All @@ -125,7 +125,12 @@ export class SSEClientTransport implements Transport {
headers['mcp-protocol-version'] = this._protocolVersion;
}

return new Headers({ ...headers, ...this._requestInit?.headers });
const extraHeaders = normalizeHeaders(this._requestInit?.headers);

return new Headers({
...headers,
...extraHeaders
});
}

private _startOrAuth(): Promise<void> {
Expand Down
28 changes: 28 additions & 0 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ describe('StreamableHTTPClientTransport', () => {
it('should always send specified custom headers', async () => {
const requestInit = {
headers: {
Authorization: 'Bearer test-token',
'X-Custom-Header': 'CustomValue'
}
};
Expand All @@ -497,6 +498,7 @@ describe('StreamableHTTPClientTransport', () => {
await transport.start();

await transport['_startOrAuthSse']({});
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');

requestInit.headers['X-Custom-Header'] = 'SecondCustomValue';
Expand All @@ -510,6 +512,7 @@ describe('StreamableHTTPClientTransport', () => {
it('should always send specified custom headers (Headers class)', async () => {
const requestInit = {
headers: new Headers({
Authorization: 'Bearer test-token',
'X-Custom-Header': 'CustomValue'
})
};
Expand All @@ -527,6 +530,7 @@ describe('StreamableHTTPClientTransport', () => {
await transport.start();

await transport['_startOrAuthSse']({});
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');

(requestInit.headers as Headers).set('X-Custom-Header', 'SecondCustomValue');
Expand All @@ -537,6 +541,30 @@ describe('StreamableHTTPClientTransport', () => {
expect(global.fetch).toHaveBeenCalledTimes(2);
});

it('should always send specified custom headers (array of tuples)', async () => {
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
requestInit: {
headers: [
['Authorization', 'Bearer test-token'],
['X-Custom-Header', 'CustomValue']
]
}
});

let actualReqInit: RequestInit = {};

(global.fetch as Mock).mockImplementation(async (_url, reqInit) => {
actualReqInit = reqInit;
return new Response(null, { status: 200, headers: { 'content-type': 'text/event-stream' } });
});

await transport.start();

await transport['_startOrAuthSse']({});
expect((actualReqInit.headers as Headers).get('authorization')).toBe('Bearer test-token');
expect((actualReqInit.headers as Headers).get('x-custom-header')).toBe('CustomValue');
});

it('should have exponential backoff with configurable maxRetries', () => {
// This test verifies the maxRetries and backoff calculation directly

Expand Down
Loading