diff --git a/lib/internal/streams/iter/pull.js b/lib/internal/streams/iter/pull.js index 3ff88b251d182a..62781047a6b358 100644 --- a/lib/internal/streams/iter/pull.js +++ b/lib/internal/streams/iter/pull.js @@ -13,6 +13,9 @@ const { ArrayPrototypePush, ArrayPrototypeSlice, PromisePrototypeThen, + PromiseWithResolvers, + SafePromisePrototypeFinally, + SafePromiseRace, SymbolAsyncIterator, SymbolIterator, TypedArrayPrototypeGetByteLength, @@ -607,6 +610,77 @@ async function* applyValidatedStatefulAsyncTransform(source, transform, options) options.signal?.throwIfAborted(); } +/** + * Read one item from an async iterator, rejecting early if the signal aborts. + * @param {AsyncIterator} iterator - The iterator to read from. + * @param {AbortSignal|undefined} signal - Optional abort signal. + * @returns {Promise>|IteratorResult} + */ +function abortableNext(iterator, signal) { + if (signal === undefined) { + return iterator.next(); + } + + signal.throwIfAborted(); + + const next = iterator.next(); + const { promise, reject } = PromiseWithResolvers(); + const onAbort = () => reject(signal.reason); + signal.addEventListener('abort', onAbort, { __proto__: null, once: true }); + if (signal.aborted) { + onAbort(); + } + + return SafePromisePrototypeFinally(SafePromiseRace([next, promise]), () => { + signal.removeEventListener('abort', onAbort); + }); +} + +/** + * Wrap an async source so each pending read is abort-aware. + * @param {AsyncIterable} source - The source to read from. + * @param {AbortSignal|undefined} signal - Optional abort signal. + * @returns {AsyncIterable} + */ +function yieldAbortable(source, signal) { + if (signal === undefined) { + return source; + } + + return { + __proto__: null, + async *[SymbolAsyncIterator]() { + const iterator = source[SymbolAsyncIterator](); + let completed = false; + let aborted = false; + + try { + while (true) { + const { done, value } = await abortableNext(iterator, signal); + if (done) { + completed = true; + return; + } + signal.throwIfAborted(); + yield value; + } + } catch (error) { + aborted = signal.aborted; + throw error; + } finally { + if (!completed && typeof iterator.return === 'function') { + const result = iterator.return(); + if (aborted) { + PromisePrototypeThen(result, undefined, () => {}); + } else { + await result; + } + } + } + }, + }; +} + /** * Create an async pipeline from source through transforms. * @yields {Uint8Array[]} @@ -615,17 +689,14 @@ async function* createAsyncPipeline(source, transforms, signal) { // Check for abort signal?.throwIfAborted(); - const normalized = source; - // Fast path: no transforms, just yield normalized source directly if (transforms.length === 0) { - for await (const batch of normalized) { - signal?.throwIfAborted(); - yield batch; - } + yield* yieldAbortable(source, signal); return; } + const normalized = yieldAbortable(source, signal); + // Create internal controller for transform cancellation. // Note: if signal was already aborted, we threw above - no need to check here. const controller = new AbortController(); diff --git a/test/parallel/test-stream-iter-pull-async.js b/test/parallel/test-stream-iter-pull-async.js index 157cc5e265ea34..c75a4d305503b9 100644 --- a/test/parallel/test-stream-iter-pull-async.js +++ b/test/parallel/test-stream-iter-pull-async.js @@ -156,6 +156,44 @@ async function testPullSignalAbortMidIteration() { await assert.rejects(() => iter.next(), { name: 'AbortError' }); } +async function testPullSignalAbortWhileSourceNextPending() { + const source = { + [Symbol.asyncIterator]() { + return { + async next() { + await new Promise(() => {}); + }, + }; + }, + }; + const ac = new AbortController(); + const iter = pull(source, { signal: ac.signal })[Symbol.asyncIterator](); + const next = iter.next(); + ac.abort(); + await assert.rejects(next, { name: 'AbortError' }); +} + +async function testPullSignalAbortWithTransformWhileSourceNextPending() { + const source = { + [Symbol.asyncIterator]() { + return { + async next() { + await new Promise(() => {}); + }, + }; + }, + }; + const ac = new AbortController(); + const iter = pull( + source, + (chunks) => chunks, + { signal: ac.signal }, + )[Symbol.asyncIterator](); + const next = iter.next(); + ac.abort(); + await assert.rejects(next, { name: 'AbortError' }); +} + // Pull consumer break (return()) cleans up transform signal async function testPullConsumerBreakCleanup() { let signalAborted = false; @@ -351,6 +389,8 @@ async function testTransformOptionsNotShared() { testPullSourceError(), testTapCallbackError(), testPullSignalAbortMidIteration(), + testPullSignalAbortWhileSourceNextPending(), + testPullSignalAbortWithTransformWhileSourceNextPending(), testPullConsumerBreakCleanup(), testPullTransformReturnsPromise(), testPullTransformYieldsStrings(),