Skip to content

Commit 25532a8

Browse files
authored
fix: .next broken when calling multiple times (#337)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Streamlined internal processing logic to ensure more consistent and reliable handling of sequential operations. - **Tests** - Expanded test coverage by adding scenarios that validate the system's behavior when executing multiple sequential steps. - Improved test setup for more accurate state management between test runs. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 6b06c96 commit 25532a8

4 files changed

Lines changed: 90 additions & 34 deletions

File tree

packages/server/src/procedure-client.test.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ const procedureCases = [
4545
] as const
4646

4747
beforeEach(() => {
48-
vi.clearAllMocks()
48+
vi.resetAllMocks()
4949
})
5050

5151
describe.each(procedureCases)('createProcedureClient - case %s', async (_, procedure) => {
@@ -479,6 +479,37 @@ describe.each(procedureCases)('createProcedureClient - case %s', async (_, proce
479479
expect(context).toBeCalledTimes(1)
480480
expect(context).toBeCalledWith({ cache: true })
481481
})
482+
483+
it('can multiple .next calls', async () => {
484+
const client = createProcedureClient(procedure)
485+
486+
preMid1.mockImplementationOnce(async ({ next }, input, output) => output([(await next()).output, (await next()).output, (await next()).output, (await next()).output]))
487+
488+
let index = 0
489+
490+
preMid2.mockImplementation(({ next }) => next({ context: { preMid2: index++ } }))
491+
postMid1.mockImplementation(({ next }) => next({ context: { postMid1: index++ } }))
492+
493+
await expect(client({ val: '123' })).resolves.toEqual([{ val: 123 }, { val: 123 }, { val: 123 }, { val: 123 }])
494+
495+
expect(preMid1).toBeCalledTimes(1)
496+
expect(preMid2).toBeCalledTimes(4)
497+
expect(postMid1).toBeCalledTimes(4)
498+
expect(postMid2).toBeCalledTimes(4)
499+
expect(handler).toBeCalledTimes(4)
500+
501+
expect((handler as any).mock.calls[0][0].context.preMid2).toBe(0)
502+
expect((handler as any).mock.calls[0][0].context.postMid1).toBe(1)
503+
504+
expect((handler as any).mock.calls[1][0].context.preMid2).toBe(2)
505+
expect((handler as any).mock.calls[1][0].context.postMid1).toBe(3)
506+
507+
expect((handler as any).mock.calls[2][0].context.preMid2).toBe(4)
508+
expect((handler as any).mock.calls[2][0].context.postMid1).toBe(5)
509+
510+
expect((handler as any).mock.calls[3][0].context.preMid2).toBe(6)
511+
expect((handler as any).mock.calls[3][0].context.postMid1).toBe(7)
512+
})
482513
})
483514

484515
it('still work without InputSchema', async () => {

packages/server/src/procedure-client.ts

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import type { Client, ClientContext } from '@orpc/client'
22
import type { Interceptor, MaybeOptionalOptions, Value } from '@orpc/shared'
33
import type { Lazyable } from './lazy'
4-
import type { MiddlewareNextFn } from './middleware'
54
import type { AnyProcedure, Procedure, ProcedureHandlerOptions } from './procedure'
65
import { ORPCError } from '@orpc/client'
76
import { type AnySchema, type ErrorFromErrorMap, type ErrorMap, type InferSchemaInput, type InferSchemaOutput, type Meta, ValidationError } from '@orpc/contract'
@@ -162,44 +161,37 @@ async function executeProcedureInternal(procedure: AnyProcedure, options: Proced
162161
const middlewares = procedure['~orpc'].middlewares
163162
const inputValidationIndex = Math.min(Math.max(0, procedure['~orpc'].inputValidationIndex), middlewares.length)
164163
const outputValidationIndex = Math.min(Math.max(0, procedure['~orpc'].outputValidationIndex), middlewares.length)
165-
let currentIndex = 0
166-
let currentContext = options.context
167-
let currentInput = options.input
168164

169-
const next: MiddlewareNextFn<any> = async (...[nextOptions]) => {
170-
const index = currentIndex
171-
const midContext = nextOptions?.context ?? {} as any
172-
173-
currentIndex += 1
174-
currentContext = mergeCurrentContext(currentContext, midContext)
165+
const next = async (index: number, context: Context, input: unknown): Promise<unknown> => {
166+
let currentInput = input
175167

176168
if (index === inputValidationIndex) {
177169
currentInput = await validateInput(procedure, currentInput)
178170
}
179171

180172
const mid = middlewares[index]
181173

182-
const result = mid
183-
? {
184-
context: midContext,
185-
output: (await mid({ ...options, context: currentContext, next }, currentInput, middlewareOutputFn)).output,
186-
}
187-
: {
188-
context: midContext,
189-
output: await procedure['~orpc'].handler({ ...options, context: currentContext, input: currentInput }),
190-
}
174+
const output = mid
175+
? (await mid({
176+
...options,
177+
context,
178+
next: async (...[nextOptions]) => {
179+
const nextContext: Context = nextOptions?.context ?? {}
191180

192-
if (index === outputValidationIndex) {
193-
const validatedOutput = await validateOutput(procedure, result.output)
181+
return {
182+
output: await next(index + 1, mergeCurrentContext(context, nextContext), currentInput),
183+
context: nextContext,
184+
}
185+
},
186+
}, currentInput, middlewareOutputFn)).output
187+
: await procedure['~orpc'].handler({ ...options, context, input: currentInput })
194188

195-
return {
196-
context: result.context,
197-
output: validatedOutput,
198-
}
189+
if (index === outputValidationIndex) {
190+
return await validateOutput(procedure, output)
199191
}
200192

201-
return result
193+
return output
202194
}
203195

204-
return (await next()).output
196+
return next(0, options.context, options.input)
205197
}

packages/shared/src/interceptor.test.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,41 @@ describe('intercept', () => {
115115
next: 'hello2',
116116
})
117117
})
118+
119+
it('can multiple .next calls', async () => {
120+
interceptor2.mockReturnValueOnce(Promise.resolve('__interceptor2__'))
121+
122+
const result = await intercept(
123+
[
124+
async ({ next }) => [await next(), await next(), await next()],
125+
interceptor1,
126+
interceptor2,
127+
],
128+
{
129+
foo: 'bar',
130+
},
131+
main,
132+
)
133+
134+
expect(result).toEqual(['__interceptor2__', '__main__', '__main__'])
135+
136+
expect(interceptor1).toHaveBeenCalledTimes(3)
137+
expect(interceptor1).toHaveBeenCalledWith({
138+
foo: 'bar',
139+
next: expect.any(Function),
140+
})
141+
142+
expect(interceptor2).toHaveBeenCalledTimes(3)
143+
expect(interceptor2).toHaveBeenCalledWith({
144+
foo: 'bar',
145+
next: expect.any(Function),
146+
})
147+
148+
expect(main).toHaveBeenCalledTimes(2)
149+
expect(main).toHaveBeenCalledWith({
150+
foo: 'bar',
151+
})
152+
})
118153
})
119154

120155
describe('onStart / onSuccess / onError / onFinish', () => {

packages/shared/src/interceptor.ts

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,18 @@ export async function intercept<TOptions extends InterceptableOptions, TResult,
103103
options: NoInfer<TOptions>,
104104
main: NoInfer<(options: TOptions) => Promisable<TResult>>,
105105
): Promise<TResult> {
106-
let index = 0
107-
108-
const next = async (options: TOptions): Promise<TResult> => {
109-
const interceptor = interceptors[index++]
106+
const next = async (options: TOptions, index: number): Promise<TResult> => {
107+
const interceptor = interceptors[index]
110108

111109
if (!interceptor) {
112110
return await main(options)
113111
}
114112

115113
return await interceptor({
116114
...options,
117-
next: (newOptions: TOptions = options) => next(newOptions),
115+
next: (newOptions: TOptions = options) => next(newOptions, index + 1),
118116
})
119117
}
120118

121-
return await next(options)
119+
return next(options, 0)
122120
}

0 commit comments

Comments
 (0)