diff --git a/packages/server/src/builder.test.ts b/packages/server/src/builder.test.ts index eb9ba0f00..9bbb964ed 100644 --- a/packages/server/src/builder.test.ts +++ b/packages/server/src/builder.test.ts @@ -130,7 +130,7 @@ describe('builder', () => { expect(applied).toBe(decorateMiddlewareSpy.mock.results[0]?.value) expect(decorateMiddlewareSpy).toBeCalledTimes(1) - expect(decorateMiddlewareSpy).toBeCalledWith(mid) + expect(decorateMiddlewareSpy).toBeCalledWith(mid, def.errorMap) }) it('.errors', () => { diff --git a/packages/server/src/builder.ts b/packages/server/src/builder.ts index c1afe03fe..e094fd816 100644 --- a/packages/server/src/builder.ts +++ b/packages/server/src/builder.ts @@ -6,7 +6,7 @@ import type { Context, MergedCurrentContext, MergedInitialContext } from './cont import type { ORPCErrorConstructorMap } from './error' import type { Lazy } from './lazy' import type { AnyMiddleware, MapInputMiddleware, Middleware } from './middleware' -import type { DecoratedMiddleware } from './middleware-decorated' +import type { AnyDecoratedMiddleware, DecoratedMiddleware } from './middleware-decorated' import type { ProcedureHandler } from './procedure' import type { Router } from './router' import type { EnhancedRouter, EnhanceRouterOptions } from './router-utils' @@ -146,7 +146,7 @@ export class Builder< middleware, TInput, TOutput = any>( // = any here is important to make middleware can be used in any output by default middleware: Middleware, TMeta>, ): DecoratedMiddleware { // any ensures middleware can used in any procedure - return decorateMiddleware(middleware) + return decorateMiddleware(middleware, this['~orpc'].errorMap) } /** @@ -190,16 +190,24 @@ export class Builder< > use( - middleware: AnyMiddleware, + middleware: AnyMiddleware | AnyDecoratedMiddleware, mapInput?: MapInputMiddleware, ): BuilderWithMiddlewares { const mapped = mapInput ? decorateMiddleware(middleware).mapInput(mapInput) : middleware + if (!('errorMap' in middleware) || !middleware.errorMap) { + return new Builder({ + ...this['~orpc'], + middlewares: addMiddleware(this['~orpc'].middlewares, mapped), + }) as any + } + return new Builder({ ...this['~orpc'], middlewares: addMiddleware(this['~orpc'].middlewares, mapped), + errorMap: mergeErrorMap(this['~orpc'].errorMap, middleware.errorMap), }) as any } diff --git a/packages/server/src/middleware-decorated.ts b/packages/server/src/middleware-decorated.ts index 592500ab2..c38cefe62 100644 --- a/packages/server/src/middleware-decorated.ts +++ b/packages/server/src/middleware-decorated.ts @@ -1,4 +1,4 @@ -import type { Meta } from '@orpc/contract' +import type { ErrorMap, Meta } from '@orpc/contract' import type { IntersectPick } from '@orpc/shared' import type { Context, MergedCurrentContext, MergedInitialContext } from './context' import type { ORPCErrorConstructorMap } from './error' @@ -12,6 +12,11 @@ export interface DecoratedMiddleware< TErrorConstructorMap extends ORPCErrorConstructorMap, TMeta extends Meta, > extends Middleware { + /** + * Error map associated with this middleware (if any) + * @internal + */ + errorMap?: ErrorMap /** * Change the expected input type by providing a map function. */ @@ -78,6 +83,8 @@ export interface DecoratedMiddleware< > } +export type AnyDecoratedMiddleware = DecoratedMiddleware + export function decorateMiddleware< TInContext extends Context, TOutContext extends Context, @@ -86,23 +93,39 @@ export function decorateMiddleware< TErrorConstructorMap extends ORPCErrorConstructorMap, TMeta extends Meta, >( - middleware: Middleware, + middleware: Middleware + | DecoratedMiddleware, + errorMap?: ErrorMap, ): DecoratedMiddleware { const decorated = ((...args) => middleware(...args)) as DecoratedMiddleware + // Attach error map if provided + if (errorMap) { + decorated.errorMap = errorMap + } + if ('errorMap' in middleware) { + decorated.errorMap = middleware.errorMap + } + decorated.mapInput = (mapInput) => { const mapped = decorateMiddleware( (options, input, ...rest) => middleware(options as any, mapInput(input as any), ...rest as [any]), + decorated.errorMap, // Preserve error map ) return mapped as any } - decorated.concat = (concatMiddleware: AnyMiddleware, mapInput?: MapInputMiddleware) => { + decorated.concat = (concatMiddleware: AnyMiddleware | AnyDecoratedMiddleware, mapInput?: MapInputMiddleware) => { const mapped = mapInput ? decorateMiddleware(concatMiddleware).mapInput(mapInput) : concatMiddleware + const combinedErrorMap = { + ...decorated.errorMap, + ...('errorMap' in concatMiddleware ? concatMiddleware.errorMap : undefined), + } + const concatted = decorateMiddleware((options, input, output, ...rest) => { const merged = middleware({ ...options, @@ -114,7 +137,7 @@ export function decorateMiddleware< } as any, input as any, output as any, ...rest) return merged - }) + }, combinedErrorMap) return concatted as any } diff --git a/packages/server/src/procedure-decorated.ts b/packages/server/src/procedure-decorated.ts index 5a6fb4c8e..41f0e9302 100644 --- a/packages/server/src/procedure-decorated.ts +++ b/packages/server/src/procedure-decorated.ts @@ -12,6 +12,7 @@ import type { IntersectPick, MaybeOptionalOptions } from '@orpc/shared' import type { Context, MergedCurrentContext, MergedInitialContext } from './context' import type { ORPCErrorConstructorMap } from './error' import type { AnyMiddleware, MapInputMiddleware, Middleware } from './middleware' +import type { AnyDecoratedMiddleware } from './middleware-decorated' import type { ProcedureActionableClient } from './procedure-action' import type { CreateProcedureClientOptions, ProcedureClient } from './procedure-client' import { mergeErrorMap, mergeMeta, mergeRoute } from '@orpc/contract' @@ -136,14 +137,22 @@ export class DecoratedProcedure< TMeta > - use(middleware: AnyMiddleware, mapInput?: MapInputMiddleware): DecoratedProcedure { + use(middleware: AnyMiddleware | AnyDecoratedMiddleware, mapInput?: MapInputMiddleware): DecoratedProcedure { const mapped = mapInput ? decorateMiddleware(middleware).mapInput(mapInput) : middleware + if (!('errorMap' in middleware) || !middleware.errorMap) { + return new DecoratedProcedure({ + ...this['~orpc'], + middlewares: addMiddleware(this['~orpc'].middlewares, mapped), + }) + } + return new DecoratedProcedure({ ...this['~orpc'], middlewares: addMiddleware(this['~orpc'].middlewares, mapped), + errorMap: mergeErrorMap(this['~orpc'].errorMap, middleware.errorMap), }) } diff --git a/packages/server/tests/error-middleware-pattern.test.ts b/packages/server/tests/error-middleware-pattern.test.ts new file mode 100644 index 000000000..d3bc956e2 --- /dev/null +++ b/packages/server/tests/error-middleware-pattern.test.ts @@ -0,0 +1,80 @@ +import { createSafeClient } from '@orpc/client' +import { createRouterClient, os } from '../src' + +describe('error middleware patterns', () => { + it('should have defined=true when using original supported pattern (same base for middleware and procedure)', async () => { + // ✅ CORRECT: Define errors on base, use same base for both middleware and procedure + const base = os.errors({ + UNAUTHORIZED: {}, + }) + + const middleware = base.middleware(async ({ next, context, errors }) => { + throw errors.UNAUTHORIZED() + }) + + const router = base.use(middleware).handler(async () => {}) + + const client = createSafeClient(createRouterClient(router)) + const [error,, isDefined] = await client() + + // Should have defined=true because error was thrown from defined error map + expect((error as any).defined).toBe(true) + expect(isDefined).toBe(true) + expect((error as any).code).toBe('UNAUTHORIZED') + }) + + it('should have defined=true with automatic error map merging (new behavior)', async () => { + // ✅ NEW BEHAVIOR: Middleware error maps are automatically merged + const middleware = os.errors({ + UNAUTHORIZED: {}, + }).middleware(async ({ next, context, errors }) => { + throw errors.UNAUTHORIZED() + }) + + // Using base os (no errors) but middleware error map should be automatically merged + const router = os.use(middleware).handler(async () => {}) + + const client = createSafeClient(createRouterClient(router)) + const [error,, isDefined] = await client() + + // Should now have defined=true because middleware error map gets merged automatically + expect((error as any).defined).toBe(true) + expect(isDefined).toBe(true) + expect((error as any).code).toBe('UNAUTHORIZED') + // Verify the error map was merged + expect(router['~orpc'].errorMap).toHaveProperty('UNAUTHORIZED') + }) + + it('should merge errors from different sources correctly', async () => { + const middleware = os.errors({ + UNAUTHORIZED: {}, + }).middleware(async ({ next, context, errors }) => { + // Don't throw, just continue to test error map merging + return next({ context }) + }) + const router = os + .use(middleware) + .errors({ NOT_FOUND: {} }) + .handler(async ({ errors }) => { + // Should have access to both UNAUTHORIZED and NOT_FOUND + expect('UNAUTHORIZED' in errors).toBe(true) + expect('NOT_FOUND' in errors).toBe(true) + + // @ts-expect-error TODO: Currently, errors defined in middleware is not inferred into the procedure + const unauthorizedError = errors.UNAUTHORIZED() + const notFoundError = errors.NOT_FOUND() + + expect(unauthorizedError.defined).toBe(true) + expect(notFoundError.defined).toBe(true) + + throw notFoundError + }) + + const client = createSafeClient(createRouterClient(router)) + const [error,, isDefined] = await client() + + expect((error as any).defined).toBe(true) + expect(isDefined).toBe(true) + expect((error as any).code).toBe('NOT_FOUND') + }) +})