Skip to content

Commit

Permalink
Merge pull request #11018 from nestjs/revert-10809-fix/durable-payloa…
Browse files Browse the repository at this point in the history
…d-regression

Revert "fix(core,microservices): inject the context when the tree is not durable"
  • Loading branch information
kamilmysliwiec committed Feb 2, 2023
2 parents da708c7 + 9141d24 commit 4ad3cbc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
35 changes: 14 additions & 21 deletions packages/core/middleware/middleware-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { ExecutionContextHost } from '../helpers/execution-context-host';
import { STATIC_CONTEXT } from '../injector/constants';
import { NestContainer } from '../injector/container';
import { Injector } from '../injector/injector';
import { ContextId, InstanceWrapper } from '../injector/instance-wrapper';
import { InstanceWrapper } from '../injector/instance-wrapper';
import { InstanceToken, Module } from '../injector/module';
import { GraphInspector } from '../inspector/graph-inspector';
import {
Expand Down Expand Up @@ -250,9 +250,6 @@ export class MiddlewareModule<
const proxy = await this.createProxy(instance);
return this.registerHandler(applicationRef, routeInfo, proxy);
}

const isTreeDurable = wrapper.isDependencyTreeDurable();

await this.registerHandler(
applicationRef,
routeInfo,
Expand All @@ -262,7 +259,19 @@ export class MiddlewareModule<
next: () => void,
) => {
try {
const contextId = this.getContextId(req, isTreeDurable);
const contextId = ContextIdFactory.getByRequest(req);
if (!req[REQUEST_CONTEXT_ID]) {
Object.defineProperty(req, REQUEST_CONTEXT_ID, {
value: contextId,
enumerable: false,
writable: false,
configurable: false,
});
this.container.registerRequestProvider(
contextId.getParent ? contextId.payload : req,
contextId,
);
}
const contextInstance = await this.injector.loadPerContext(
instance,
moduleRef,
Expand Down Expand Up @@ -360,20 +369,4 @@ export class MiddlewareModule<

router(path, middlewareFunction);
}

private getContextId(request: unknown, isTreeDurable: boolean): ContextId {
const contextId = ContextIdFactory.getByRequest(request);
if (!request[REQUEST_CONTEXT_ID]) {
Object.defineProperty(request, REQUEST_CONTEXT_ID, {
value: contextId,
enumerable: false,
writable: false,
configurable: false,
});

const requestProviderValue = isTreeDurable ? contextId.payload : request;
this.container.registerRequestProvider(requestProviderValue, contextId);
}
return contextId;
}
}
13 changes: 5 additions & 8 deletions packages/core/router/router-explorer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,13 @@ export class RouterExplorer {
) {
const { instance } = instanceWrapper;
const collection = moduleRef.controllers;

const isTreeDurable = instanceWrapper.isDependencyTreeDurable();

return async <TRequest extends Record<any, any>, TResponse>(
req: TRequest,
res: TResponse,
next: () => void,
) => {
try {
const contextId = this.getContextId(req, isTreeDurable);
const contextId = this.getContextId(req);
const contextInstance = await this.injector.loadPerContext(
instance,
moduleRef,
Expand Down Expand Up @@ -400,7 +397,6 @@ export class RouterExplorer {

private getContextId<T extends Record<any, unknown> = any>(
request: T,
isTreeDurable: boolean,
): ContextId {
const contextId = ContextIdFactory.getByRequest(request);
if (!request[REQUEST_CONTEXT_ID as any]) {
Expand All @@ -410,9 +406,10 @@ export class RouterExplorer {
writable: false,
configurable: false,
});

const requestProviderValue = isTreeDurable ? contextId.payload : request;
this.container.registerRequestProvider(requestProviderValue, contextId);
this.container.registerRequestProvider(
contextId.getParent ? contextId.payload : request,
contextId,
);
}
return contextId;
}
Expand Down
22 changes: 11 additions & 11 deletions packages/microservices/listeners-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,13 @@ export class ListenersController {
const collection = moduleRef.controllers;
const { instance } = wrapper;

const isTreeDurable = wrapper.isDependencyTreeDurable();

const requestScopedHandler: MessageHandler = async (...args: unknown[]) => {
try {
let contextId: ContextId;

let [dataOrContextHost] = args;
if (dataOrContextHost instanceof RequestContextHost) {
contextId = this.getContextId(dataOrContextHost, isTreeDurable);
contextId = this.getContextId(dataOrContextHost);
args.shift();
} else {
const [data, reqCtx] = args;
Expand All @@ -222,7 +220,11 @@ export class ListenersController {
data,
reqCtx as BaseRpcContext,
);
contextId = this.getContextId(request, isTreeDurable);
contextId = this.getContextId(request);
this.container.registerRequestProvider(
contextId.getParent ? contextId.payload : request,
contextId,
);
dataOrContextHost = request;
}
const contextInstance = await this.injector.loadPerContext(
Expand Down Expand Up @@ -268,10 +270,7 @@ export class ListenersController {
return requestScopedHandler;
}

private getContextId<T extends RequestContext = any>(
request: T,
isTreeDurable: boolean,
): ContextId {
private getContextId<T extends RequestContext = any>(request: T): ContextId {
const contextId = ContextIdFactory.getByRequest(request);
if (!request[REQUEST_CONTEXT_ID as any]) {
Object.defineProperty(request, REQUEST_CONTEXT_ID, {
Expand All @@ -280,9 +279,10 @@ export class ListenersController {
writable: false,
configurable: false,
});

const requestProviderValue = isTreeDurable ? contextId.payload : request;
this.container.registerRequestProvider(requestProviderValue, contextId);
this.container.registerRequestProvider(
contextId.getParent ? contextId.payload : request,
contextId,
);
}
return contextId;
}
Expand Down

0 comments on commit 4ad3cbc

Please sign in to comment.