diff --git a/integration/websockets/e2e/gateway-validation-pipe.spec.ts b/integration/websockets/e2e/gateway-validation-pipe.spec.ts new file mode 100644 index 00000000000..0108e440c3a --- /dev/null +++ b/integration/websockets/e2e/gateway-validation-pipe.spec.ts @@ -0,0 +1,103 @@ +import { INestApplication } from '@nestjs/common'; +import { WsAdapter } from '@nestjs/platform-ws'; +import { Test } from '@nestjs/testing'; +import * as WebSocket from 'ws'; +import { ValidationPipeGateway } from '../src/validation-pipe.gateway'; +import { expect } from 'chai'; +import { ApplicationGateway } from '../src/app.gateway'; + +async function createNestApp(...gateways): Promise { + const testingModule = await Test.createTestingModule({ + providers: gateways, + }).compile(); + const app = testingModule.createNestApplication(); + app.useWebSocketAdapter(new WsAdapter(app) as any); + return app; +} + +const testBody = { ws: null, app: null }; + +async function prepareGatewayAndClientForResponseAction( + gateway: typeof ValidationPipeGateway | ApplicationGateway, + action: () => void, +) { + testBody.app = await createNestApp(gateway); + await testBody.app.listen(3000); + + testBody.ws = new WebSocket('ws://localhost:8080'); + await new Promise(resolve => testBody.ws.on('open', resolve)); + + testBody.ws.send( + JSON.stringify({ + event: 'push', + data: { + stringProp: 123, + }, + }), + ); + + action(); +} + +const UNCAUGHT_EXCEPTION = 'uncaughtException'; + +type WsExceptionWithWrappedValidationError = { + getError: () => { + response: { + message: string[]; + }; + }; +}; + +function prepareToHandleExpectedUncaughtException() { + const listeners = process.listeners(UNCAUGHT_EXCEPTION); + process.removeAllListeners(UNCAUGHT_EXCEPTION); + + process.on( + UNCAUGHT_EXCEPTION, + (err: WsExceptionWithWrappedValidationError) => { + expect(err.getError().response.message[0]).to.equal( + 'stringProp must be a string', + ); + reattachUncaughtExceptionListeners(listeners); + }, + ); +} + +function reattachUncaughtExceptionListeners( + listeners: NodeJS.UncaughtExceptionListener[], +) { + process.removeAllListeners(UNCAUGHT_EXCEPTION); + for (const listener of listeners) { + process.on(UNCAUGHT_EXCEPTION, listener); + } +} + +describe('WebSocketGateway with ValidationPipe', () => { + it(`should throw WsException`, async () => { + prepareToHandleExpectedUncaughtException(); + + await prepareGatewayAndClientForResponseAction( + ValidationPipeGateway, + () => { + testBody.ws.once('message', () => {}); + }, + ); + }); + + it('should return message normally', async () => { + await new Promise(resolve => + prepareGatewayAndClientForResponseAction(ApplicationGateway, async () => { + testBody.ws.once('message', msg => { + expect(JSON.parse(msg).data.stringProp).to.equal(123); + resolve(); + }); + }), + ); + }); + + afterEach(function (done) { + testBody.ws.close(); + testBody.app.close().then(() => done()); + }); +}); diff --git a/integration/websockets/src/validation-pipe.gateway.ts b/integration/websockets/src/validation-pipe.gateway.ts new file mode 100644 index 00000000000..df2f9aef2a9 --- /dev/null +++ b/integration/websockets/src/validation-pipe.gateway.ts @@ -0,0 +1,40 @@ +import { + ArgumentsHost, + Catch, + UseFilters, + UsePipes, + ValidationPipe, +} from '@nestjs/common'; +import { + BaseWsExceptionFilter, + MessageBody, + SubscribeMessage, + WebSocketGateway, +} from '@nestjs/websockets'; +import { IsString } from 'class-validator'; + +class TestModel { + @IsString() + stringProp: string; +} + +@Catch() +export class AllExceptionsFilter extends BaseWsExceptionFilter { + catch(exception: unknown, host: ArgumentsHost) { + throw exception; + } +} + +@WebSocketGateway(8080) +@UsePipes(new ValidationPipe()) +@UseFilters(new AllExceptionsFilter()) +export class ValidationPipeGateway { + @SubscribeMessage('push') + onPush(@MessageBody() data: TestModel) { + console.log('received msg'); + return { + event: 'push', + data, + }; + } +} diff --git a/packages/common/constants.ts b/packages/common/constants.ts index 71f27685d88..6794ea82b8b 100644 --- a/packages/common/constants.ts +++ b/packages/common/constants.ts @@ -45,3 +45,4 @@ export const INJECTABLE_WATERMARK = '__injectable__'; export const CONTROLLER_WATERMARK = '__controller__'; export const CATCH_WATERMARK = '__catch__'; export const ENTRY_PROVIDER_WATERMARK = '__entryProvider__'; +export const GATEWAY_METADATA = 'websockets:is_gateway'; diff --git a/packages/common/decorators/core/use-pipes.decorator.ts b/packages/common/decorators/core/use-pipes.decorator.ts index 03a7bc08de4..0e1a49a400d 100644 --- a/packages/common/decorators/core/use-pipes.decorator.ts +++ b/packages/common/decorators/core/use-pipes.decorator.ts @@ -3,6 +3,7 @@ import { PipeTransform } from '../../interfaces/index'; import { extendArrayMetadata } from '../../utils/extend-metadata.util'; import { isFunction } from '../../utils/shared.utils'; import { validateEach } from '../../utils/validate-each.util'; +import { isTargetAware } from '../../interfaces/features/target-aware-pipe.interface'; /** * Decorator that binds pipes to the scope of the controller or method, @@ -43,6 +44,12 @@ export function UsePipes( return descriptor; } validateEach(target, pipes, isPipeValid, '@UsePipes', 'pipe'); + + const pipesWithSetTarget = pipes.filter(pipe => isTargetAware(pipe)); + pipesWithSetTarget.forEach(pipeWithSetTarget => + pipeWithSetTarget['setTarget'](target), + ); + extendArrayMetadata(PIPES_METADATA, pipes, target); return target; }; diff --git a/packages/common/exceptions/index.ts b/packages/common/exceptions/index.ts index ab2a948b17b..a0d9174eb9b 100644 --- a/packages/common/exceptions/index.ts +++ b/packages/common/exceptions/index.ts @@ -20,3 +20,4 @@ export * from './gateway-timeout.exception'; export * from './im-a-teapot.exception'; export * from './precondition-failed.exception'; export * from './misdirected.exception'; +export * from './ws-exception'; diff --git a/packages/common/exceptions/ws-exception.ts b/packages/common/exceptions/ws-exception.ts new file mode 100644 index 00000000000..f63dd266778 --- /dev/null +++ b/packages/common/exceptions/ws-exception.ts @@ -0,0 +1,27 @@ +import { isObject, isString } from '../utils/shared.utils'; + +export class WsException extends Error { + constructor(private readonly error: string | object) { + super(); + this.initMessage(); + } + + public initMessage() { + if (isString(this.error)) { + this.message = this.error; + } else if ( + isObject(this.error) && + isString((this.error as Record).message) + ) { + this.message = (this.error as Record).message; + } else if (this.constructor) { + this.message = this.constructor.name + .match(/[A-Z][a-z]+|[0-9]+/g) + .join(' '); + } + } + + public getError(): string | object { + return this.error; + } +} diff --git a/packages/common/interfaces/features/target-aware-pipe.interface.ts b/packages/common/interfaces/features/target-aware-pipe.interface.ts new file mode 100644 index 00000000000..15f79cdb38f --- /dev/null +++ b/packages/common/interfaces/features/target-aware-pipe.interface.ts @@ -0,0 +1,12 @@ +/** + * Interface describing method to set the target of the pipe decorator + */ +export interface TargetAwarePipe { + isTargetAware: true; + + setTarget(target: unknown): void; +} + +export function isTargetAware(pipe: unknown): pipe is TargetAwarePipe { + return pipe['isTargetAware']; +} diff --git a/packages/common/pipes/validation.pipe.ts b/packages/common/pipes/validation.pipe.ts index 6ee05a4adf9..3b97002dce3 100644 --- a/packages/common/pipes/validation.pipe.ts +++ b/packages/common/pipes/validation.pipe.ts @@ -19,6 +19,10 @@ import { } from '../utils/http-error-by-code.util'; import { loadPackage } from '../utils/load-package.util'; import { isNil, isUndefined } from '../utils/shared.utils'; +import { GATEWAY_METADATA } from '../constants'; +import { WsException } from '../exceptions'; +import { HttpException } from '../exceptions'; +import { TargetAwarePipe } from '../interfaces/features/target-aware-pipe.interface'; /** * @publicApi @@ -44,7 +48,7 @@ let classTransformer: TransformerPackage = {} as any; * @publicApi */ @Injectable() -export class ValidationPipe implements PipeTransform { +export class ValidationPipe implements PipeTransform, TargetAwarePipe { protected isTransformEnabled: boolean; protected isDetailedOutputDisabled?: boolean; protected validatorOptions: ValidatorOptions; @@ -53,6 +57,9 @@ export class ValidationPipe implements PipeTransform { protected expectedType: Type; protected exceptionFactory: (errors: ValidationError[]) => any; protected validateCustomDecorators: boolean; + protected isInGatewayMode = false; + protected target: unknown; + isTargetAware = true as const; constructor(@Optional() options?: ValidationPipeOptions) { options = options || {}; @@ -82,6 +89,10 @@ export class ValidationPipe implements PipeTransform { classTransformer = this.loadTransformer(options.transformerPackage); } + public setTarget(target: unknown) { + this.target = target; + } + protected loadValidator( validatorPackage?: ValidatorPackage, ): ValidatorPackage { @@ -105,6 +116,15 @@ export class ValidationPipe implements PipeTransform { } public async transform(value: any, metadata: ArgumentMetadata) { + if ( + !this.isInGatewayMode && + this.target && + Reflect.getMetadata(GATEWAY_METADATA, this.target) + ) { + this.isInGatewayMode = true; + this.exceptionFactory = this.createExceptionFactory(); + } + if (this.expectedType) { metadata = { ...metadata, metatype: this.expectedType }; } @@ -165,12 +185,27 @@ export class ValidationPipe implements PipeTransform { } public createExceptionFactory() { + let errorConstructorWrapper: (error: unknown) => unknown | WsException = ( + error: unknown, + ) => error; + if (this.isInGatewayMode) { + errorConstructorWrapper = (error: unknown) => { + if (error instanceof HttpException) { + return new WsException(error); + } + }; + } + return (validationErrors: ValidationError[] = []) => { if (this.isDetailedOutputDisabled) { - return new HttpErrorByCode[this.errorHttpStatusCode](); + return errorConstructorWrapper( + new HttpErrorByCode[this.errorHttpStatusCode](), + ); } const errors = this.flattenValidationErrors(validationErrors); - return new HttpErrorByCode[this.errorHttpStatusCode](errors); + return errorConstructorWrapper( + new HttpErrorByCode[this.errorHttpStatusCode](errors), + ); }; } diff --git a/packages/websockets/context/ws-context-creator.ts b/packages/websockets/context/ws-context-creator.ts index c699612b6f7..8391565c854 100644 --- a/packages/websockets/context/ws-context-creator.ts +++ b/packages/websockets/context/ws-context-creator.ts @@ -24,7 +24,7 @@ import { } from '@nestjs/core/interceptors'; import { PipesConsumer, PipesContextCreator } from '@nestjs/core/pipes'; import { MESSAGE_METADATA, PARAM_ARGS_METADATA } from '../constants'; -import { WsException } from '../errors/ws-exception'; +import { WsException } from '@nestjs/common'; import { WsParamsFactory } from '../factories/ws-params-factory'; import { ExceptionFiltersContext } from './exception-filters-context'; import { DEFAULT_CALLBACK_METADATA } from './ws-metadata-constants'; diff --git a/packages/websockets/errors/index.ts b/packages/websockets/errors/index.ts index 83185f3a6ef..9c33e75e6e6 100644 --- a/packages/websockets/errors/index.ts +++ b/packages/websockets/errors/index.ts @@ -1 +1 @@ -export * from './ws-exception'; +export { WsException } from '@nestjs/common/exceptions/ws-exception'; diff --git a/packages/websockets/errors/ws-exception.ts b/packages/websockets/errors/ws-exception.ts index cf7e639a2cc..b4995808077 100644 --- a/packages/websockets/errors/ws-exception.ts +++ b/packages/websockets/errors/ws-exception.ts @@ -1,27 +1,8 @@ -import { isObject, isString } from '@nestjs/common/utils/shared.utils'; +import { WsException as WSE } from '@nestjs/common/exceptions/ws-exception'; -export class WsException extends Error { - constructor(private readonly error: string | object) { - super(); - this.initMessage(); - } +/** + * @deprecated WsException has been moved to @nestjs/common + */ +const WsException = WSE; - public initMessage() { - if (isString(this.error)) { - this.message = this.error; - } else if ( - isObject(this.error) && - isString((this.error as Record).message) - ) { - this.message = (this.error as Record).message; - } else if (this.constructor) { - this.message = this.constructor.name - .match(/[A-Z][a-z]+|[0-9]+/g) - .join(' '); - } - } - - public getError(): string | object { - return this.error; - } -} +export { WsException }; diff --git a/packages/websockets/exceptions/base-ws-exception-filter.ts b/packages/websockets/exceptions/base-ws-exception-filter.ts index bcc48b0a68b..7e1f9f8df94 100644 --- a/packages/websockets/exceptions/base-ws-exception-filter.ts +++ b/packages/websockets/exceptions/base-ws-exception-filter.ts @@ -1,7 +1,7 @@ import { ArgumentsHost, Logger, WsExceptionFilter } from '@nestjs/common'; import { isObject } from '@nestjs/common/utils/shared.utils'; import { MESSAGES } from '@nestjs/core/constants'; -import { WsException } from '../errors/ws-exception'; +import { WsException } from '@nestjs/common'; /** * @publicApi diff --git a/packages/websockets/exceptions/ws-exceptions-handler.ts b/packages/websockets/exceptions/ws-exceptions-handler.ts index feed7d3979c..550bef76161 100644 --- a/packages/websockets/exceptions/ws-exceptions-handler.ts +++ b/packages/websockets/exceptions/ws-exceptions-handler.ts @@ -3,7 +3,7 @@ import { ArgumentsHost } from '@nestjs/common'; import { ExceptionFilterMetadata } from '@nestjs/common/interfaces/exceptions/exception-filter-metadata.interface'; import { selectExceptionFilterMetadata } from '@nestjs/common/utils/select-exception-filter-metadata.util'; import { InvalidExceptionFilterException } from '@nestjs/core/errors/exceptions/invalid-exception-filter.exception'; -import { WsException } from '../errors/ws-exception'; +import { WsException } from '@nestjs/common'; import { BaseWsExceptionFilter } from './base-ws-exception-filter'; /**