From f4edb52b87bf801f8052c92346bbed4845dd85c3 Mon Sep 17 00:00:00 2001 From: Taku Amano Date: Mon, 13 Apr 2026 20:56:53 +0900 Subject: [PATCH] fix: ensure close handler is attached for Blob/ReadableStream cacheable responses --- src/listener.ts | 32 ++++++-- test/listener.test.ts | 165 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 160 insertions(+), 37 deletions(-) diff --git a/src/listener.ts b/src/listener.ts index 28aa3c2..e91cd47 100644 --- a/src/listener.ts +++ b/src/listener.ts @@ -110,6 +110,16 @@ const makeCloseHandler = } } +const isImmediateCacheableResponse = (res: Response): boolean => { + if (!(cacheKey in res)) { + return false + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const body = ((res as any)[cacheKey] as InternalCache)[1] + return body === null || typeof body === 'string' || body instanceof Uint8Array +} + const handleRequestError = (): Response => new Response(null, { status: 400, @@ -358,6 +368,16 @@ export const getRequestListener = ( // eslint-disable-next-line @typescript-eslint/no-explicit-any let res, req: any let needsBodyCleanup = false + let closeHandlerAttached = false + + const ensureCloseHandler = () => { + if (!req || closeHandlerAttached) { + return + } + + closeHandlerAttached = true + outgoing.on('close', makeCloseHandler(req, incoming, outgoing, needsBodyCleanup)) + } try { // `fetchCallback()` requests a Request object, but global.Request is expensive to generate, @@ -396,7 +416,7 @@ export const getRequestListener = ( res = fetchCallback(req, { incoming, outgoing } as HttpBindings) as | Response | Promise - if (cacheKey in res) { + if (!isPromise(res) && isImmediateCacheableResponse(res)) { // Synchronous cacheable response — no close listener needed. // No I/O events can fire between fetchCallback returning and responseViaCache // completing, so abort detection is not needed here. @@ -410,18 +430,14 @@ export const getRequestListener = ( } }) } - return responseViaCache(res as Response, outgoing) + return responseViaCache(res, outgoing) } - // Async response — create and register close listener only now, avoiding - // closure allocation on the synchronous hot path. - outgoing.on('close', makeCloseHandler(req, incoming, outgoing, needsBodyCleanup)) + ensureCloseHandler() } catch (e: unknown) { if (!res) { if (options.errorHandler) { // Async error handler — register close listener so client disconnect aborts the signal. - if (req) { - outgoing.on('close', makeCloseHandler(req, incoming, outgoing, needsBodyCleanup)) - } + ensureCloseHandler() res = await options.errorHandler(req ? e : toRequestError(e)) if (!res) { return diff --git a/test/listener.test.ts b/test/listener.test.ts index b27dc17..2b608a2 100644 --- a/test/listener.test.ts +++ b/test/listener.test.ts @@ -4,6 +4,54 @@ import { getRequestListener } from '../src/listener' import { GlobalRequest, Request as LightweightRequest, RequestError } from '../src/request' import { GlobalResponse, Response as LightweightResponse } from '../src/response' +const withTimeout = async (promise: Promise, message: string): Promise => { + let timeoutId: ReturnType | undefined + + try { + return await Promise.race([ + promise, + new Promise((_, reject) => { + timeoutId = setTimeout(() => { + reject(new Error(message)) + }, 1_000) + }), + ]) + } finally { + if (timeoutId) { + clearTimeout(timeoutId) + } + } +} + +const runRequestAndCollectOutgoingEvents = async ( + fetchCallback: Parameters[0] +): Promise<{ + closeListenerCount: number + response: request.Response +}> => { + let closeListenerCount = 0 + const requestListener = getRequestListener(fetchCallback) + const server = createServer(async (req, res) => { + const originalOn = res.on.bind(res) + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ;(res as any).on = ((event: string, listener: (...args: any[]) => void) => { + if (event === 'close') { + closeListenerCount++ + } + return originalOn(event, listener) + }) as typeof res.on + + await requestListener(req, res) + }) + + try { + const response = await request(server).get('/') + return { closeListenerCount, response } + } finally { + server.close() + } +} + describe('Invalid request', () => { describe('default error handler', () => { const requestListener = getRequestListener(vi.fn()) @@ -303,24 +351,34 @@ describe('Abort request', () => { }) describe('Abort request - error path', () => { - it('should abort request signal when client disconnects while async error handler is running after sync throw', async () => { + const runAbortDuringErrorHandlerCase = async (mode: 'sync' | 'async') => { let capturedReq: Request | undefined - let resolveAborted: () => void + let resolveAborted!: () => void const abortedPromise = new Promise((r) => { resolveAborted = r }) - const fetchCallback = (req: Request) => { - capturedReq = req - req.signal.addEventListener('abort', () => resolveAborted()) - throw new Error('sync error') - } - - let resolveErrorHandlerStarted: () => void + let resolveErrorHandlerStarted!: () => void const errorHandlerStarted = new Promise((r) => { resolveErrorHandlerStarted = r }) + const onRequest = (req: Request) => { + capturedReq = req + req.signal.addEventListener('abort', resolveAborted) + } + + const fetchCallback = + mode === 'sync' + ? (req: Request) => { + onRequest(req) + throw new Error('sync error') + } + : async (req: Request) => { + onRequest(req) + throw new Error('async error') + } + const errorHandler = async () => { resolveErrorHandlerStarted() await new Promise(() => {}) // never resolves — client will disconnect first @@ -333,48 +391,97 @@ describe('Abort request - error path', () => { const req = request(server) .get('/') .end(() => {}) - await errorHandlerStarted + await withTimeout(errorHandlerStarted, 'error handler did not start') req.abort() - await abortedPromise + await withTimeout(abortedPromise, 'request abort did not propagate') expect(capturedReq?.signal.aborted).toBe(true) } finally { server.close() } + } + + it.each(['sync', 'async'] as const)( + 'should abort request signal when client disconnects while async error handler is running after %s', + async (mode) => { + await runAbortDuringErrorHandlerCase(mode) + } + ) +}) + +describe('Abort request - cacheable response path', () => { + it.each([ + ['string', () => new Response('fast path')], + ['Uint8Array', () => new Response(new TextEncoder().encode('fast path'))], + ['null', () => new Response(null, { status: 204 })], + ] as const)( + 'should avoid attaching a close listener for sync immediate cacheable %s responses', + async (_type, createResponse) => { + const { closeListenerCount, response } = await runRequestAndCollectOutgoingEvents(() => + createResponse() + ) + + expect(closeListenerCount).toBe(0) + + if (response.status === 204) { + expect(response.text).toBe('') + } else { + expect(response.text).toBe('fast path') + } + } + ) + + it('should attach a close listener and send the body for sync Blob responses', async () => { + const { closeListenerCount, response } = await runRequestAndCollectOutgoingEvents( + () => + new Response(new Blob(['blob-body']), { + headers: { + 'content-type': 'text/plain; charset=UTF-8', + }, + }) + ) + + expect(closeListenerCount).toBe(1) + expect(response.text).toBe('blob-body') }) - it('should abort request signal when client disconnects while async error handler is running after async throw', async () => { - let capturedReq: Request | undefined - let resolveAborted: () => void + it('should abort request signal when client disconnects during sync cacheable ReadableStream response', async () => { + let resolveAborted!: () => void const abortedPromise = new Promise((r) => { resolveAborted = r }) - const fetchCallback = async (req: Request) => { - capturedReq = req - req.signal.addEventListener('abort', () => resolveAborted()) - throw new Error('async error') - } - - let resolveErrorHandlerStarted: () => void - const errorHandlerStarted = new Promise((r) => { - resolveErrorHandlerStarted = r + let capturedReq: Request | undefined + let resolveStreamConstructed!: () => void + const streamConstructed = new Promise((r) => { + resolveStreamConstructed = r }) - const errorHandler = async () => { - resolveErrorHandlerStarted() - await new Promise(() => {}) // never resolves — client will disconnect first + const fetchCallback = (req: Request) => { + capturedReq = req + req.signal.addEventListener('abort', resolveAborted) + + const body = new ReadableStream({ + start() { + resolveStreamConstructed() + }, + async pull() { + await new Promise(() => {}) // never resolves — client will disconnect first + }, + }) + + return new Response(body) } - const requestListener = getRequestListener(fetchCallback, { errorHandler }) + const requestListener = getRequestListener(fetchCallback) const server = createServer(requestListener) try { const req = request(server) .get('/') .end(() => {}) - await errorHandlerStarted + await withTimeout(streamConstructed, 'stream body was not constructed') req.abort() - await abortedPromise + await withTimeout(abortedPromise, 'request abort did not propagate for cacheable stream') expect(capturedReq?.signal.aborted).toBe(true) } finally { server.close()