diff --git a/lib/async_hooks.js b/lib/async_hooks.js index 0943534790550c..cb6fc183004dee 100644 --- a/lib/async_hooks.js +++ b/lib/async_hooks.js @@ -2,6 +2,7 @@ const { NumberIsSafeInteger, + PromiseResolve, ReflectApply, Symbol, } = primordials; @@ -211,19 +212,73 @@ class AsyncResource { } const storageList = []; +const seenLayer = []; +let trackerCount = 0; +let depth = 0; + +function refreshStorageHooks() { + if (storageList.length === 0) { + storageHookWithTracking.disable(); + storageHook.disable(); + } else if (trackerCount > 0) { + storageHookWithTracking.enable(); + storageHook.disable(); + } else { + storageHookWithTracking.disable(); + storageHook.enable(); + } +} + +function patchPromiseBarrier(currentResource) { + PromiseResolve({ + then(resolve) { + const resource = executionAsyncResource(); + propagateToStorageLists(resource, currentResource); + resolve(); + } + }); +} + +function propagateToStorageLists(resource, currentResource) { + for (let i = 0; i < storageList.length; ++i) { + storageList[i]._propagate(resource, currentResource); + } +} + const storageHook = createHook({ init(asyncId, type, triggerAsyncId, resource) { const currentResource = executionAsyncResource(); // Value of currentResource is always a non null object - for (let i = 0; i < storageList.length; ++i) { - storageList[i]._propagate(resource, currentResource); + propagateToStorageLists(resource, currentResource); + } +}); + +const storageHookWithTracking = createHook({ + init(asyncId, type, triggerAsyncId, resource) { + const currentResource = executionAsyncResource(); + // Value of currentResource is always a non null object + propagateToStorageLists(resource, currentResource); + + if (type === 'PROMISE' && !seenLayer[depth]) { + seenLayer[depth] = true; + patchPromiseBarrier(currentResource); } + }, + + before(asyncId) { + depth++; + seenLayer[depth] = false; + }, + + after(asyncId) { + depth--; } }); class AsyncLocalStorage { - constructor() { + constructor({ trackAsyncAwait = false } = {}) { this.kResourceStore = Symbol('kResourceStore'); + this.trackAsyncAwait = trackAsyncAwait; this.enabled = false; } @@ -232,9 +287,10 @@ class AsyncLocalStorage { this.enabled = false; // If this.enabled, the instance must be in storageList storageList.splice(storageList.indexOf(this), 1); - if (storageList.length === 0) { - storageHook.disable(); + if (this.trackAsyncAwait) { + trackerCount--; } + refreshStorageHooks(); } } @@ -250,7 +306,10 @@ class AsyncLocalStorage { if (!this.enabled) { this.enabled = true; storageList.push(this); - storageHook.enable(); + if (this.trackAsyncAwait) { + trackerCount++; + } + refreshStorageHooks(); } const resource = executionAsyncResource(); resource[this.kResourceStore] = store; diff --git a/test/parallel/test-async-local-storage-async-await.js b/test/parallel/test-async-local-storage-async-await.js new file mode 100644 index 00000000000000..f0ec016fa0792c --- /dev/null +++ b/test/parallel/test-async-local-storage-async-await.js @@ -0,0 +1,51 @@ +'use strict'; +const common = require('../common'); +const assert = require('assert'); +const { AsyncLocalStorage } = require('async_hooks'); + +const store = new AsyncLocalStorage({ trackAsyncAwait: true }); +let checked = 0; + +function thenable(expected, count) { + return { + then: common.mustCall((cb) => { + assert.strictEqual(expected, store.getStore()); + checked++; + cb(); + }, count) + }; +} + +function main(n) { + const firstData = Symbol('first-data'); + const secondData = Symbol('second-data'); + + const first = thenable(firstData, 1); + const second = thenable(secondData, 1); + const third = thenable(firstData, 2); + + return store.run(firstData, common.mustCall(async () => { + assert.strictEqual(firstData, store.getStore()); + await first; + + await store.run(secondData, common.mustCall(async () => { + assert.strictEqual(secondData, store.getStore()); + await second; + assert.strictEqual(secondData, store.getStore()); + })); + + await Promise.all([ third, third ]); + assert.strictEqual(firstData, store.getStore()); + })); +} + +const outerData = Symbol('outer-data'); + +Promise.all([ + store.run(outerData, () => Promise.resolve(thenable(outerData))), + Promise.resolve(3).then(common.mustCall(main)), + main(1), + main(2) +]).then(common.mustCall(() => { + assert.strictEqual(checked, 13); +}));