diff --git a/packages/core/interceptors/interceptors-consumer.ts b/packages/core/interceptors/interceptors-consumer.ts index bcf3a592184..22f6d096e80 100644 --- a/packages/core/interceptors/interceptors-consumer.ts +++ b/packages/core/interceptors/interceptors-consumer.ts @@ -24,19 +24,16 @@ export class InterceptorsConsumer { const context = this.createContext(args, instance, callback); context.setType(type); - const start$ = defer(() => this.transformDeferred(next)); - const nextFn = - (i = 0) => - async () => { - if (i >= interceptors.length) { - return start$; - } - const handler: CallHandler = { - handle: () => fromPromise(nextFn(i + 1)()).pipe(mergeAll()), - }; - return interceptors[i].intercept(context, handler); + const nextFn = async (i = 0) => { + if (i >= interceptors.length) { + return this.transformDeferred(next); + } + const handler: CallHandler = { + handle: () => fromPromise(nextFn(i + 1)).pipe(mergeAll()), }; - return nextFn()(); + return interceptors[i].intercept(context, handler); + }; + return defer(() => nextFn()).pipe(mergeAll()); } public createContext( diff --git a/packages/core/test/interceptors/interceptors-consumer.spec.ts b/packages/core/test/interceptors/interceptors-consumer.spec.ts index d2efbcaf41e..d16c4a6af7b 100644 --- a/packages/core/test/interceptors/interceptors-consumer.spec.ts +++ b/packages/core/test/interceptors/interceptors-consumer.spec.ts @@ -1,5 +1,7 @@ +import { CallHandler, ExecutionContext, NestInterceptor } from '@nestjs/common'; +import { AsyncLocalStorage } from 'async_hooks'; import { expect } from 'chai'; -import { lastValueFrom, of } from 'rxjs'; +import { lastValueFrom, Observable, of } from 'rxjs'; import * as sinon from 'sinon'; import { InterceptorsConsumer } from '../../interceptors/interceptors-consumer'; @@ -35,7 +37,7 @@ describe('InterceptorsConsumer', () => { beforeEach(() => { next = sinon.stub().returns(Promise.resolve('')); }); - it('should call every `intercept` method', async () => { + it('does not call `intercept` (lazy evaluation)', async () => { await consumer.intercept( interceptors, null, @@ -44,6 +46,19 @@ describe('InterceptorsConsumer', () => { next, ); + expect(interceptors[0].intercept.called).to.be.false; + expect(interceptors[1].intercept.called).to.be.false; + }); + it('should call every `intercept` method when subscribe', async () => { + const intercepted = await consumer.intercept( + interceptors, + null, + { constructor: null }, + null, + next, + ); + await transformToResult(intercepted); + expect(interceptors[0].intercept.calledOnce).to.be.true; expect(interceptors[1].intercept.calledOnce).to.be.true; }); @@ -58,15 +73,6 @@ describe('InterceptorsConsumer', () => { expect(next.called).to.be.false; }); it('should call `next` when subscribe', async () => { - async function transformToResult(resultOrDeferred: any) { - if ( - resultOrDeferred && - typeof resultOrDeferred.subscribe === 'function' - ) { - return lastValueFrom(resultOrDeferred); - } - return resultOrDeferred; - } const intercepted = await consumer.intercept( interceptors, null, @@ -78,6 +84,30 @@ describe('InterceptorsConsumer', () => { expect(next.called).to.be.true; }); }); + + describe('AsyncLocalStorage', () => { + it('Allows an interceptor to set values in AsyncLocalStorage that are accesible from the controller', async () => { + const storage = new AsyncLocalStorage>({}); + class StorageInterceptor implements NestInterceptor { + intercept( + _context: ExecutionContext, + next: CallHandler, + ): Observable | Promise> { + return storage.run({ value: 'hello' }, () => next.handle()); + } + } + const next = () => Promise.resolve(storage.getStore().value); + const intercepted = await consumer.intercept( + [new StorageInterceptor()], + null, + { constructor: null }, + null, + next, + ); + const result = await transformToResult(intercepted); + expect(result).to.equal('hello'); + }); + }); }); describe('createContext', () => { it('should return execution context object', () => { @@ -119,3 +149,13 @@ describe('InterceptorsConsumer', () => { }); }); }); + +async function transformToResult(resultOrDeferred: any) { + console.log('RESULT_OR_DEFERRED', resultOrDeferred); + if (resultOrDeferred && typeof resultOrDeferred.subscribe === 'function') { + const result = await lastValueFrom(resultOrDeferred); + console.log('RESULT_OR_DEFERRED', result); + return result; + } + return resultOrDeferred; +}