diff --git a/.changeset/bold-dancers-see.md b/.changeset/bold-dancers-see.md new file mode 100644 index 00000000..f332b915 --- /dev/null +++ b/.changeset/bold-dancers-see.md @@ -0,0 +1,5 @@ +--- +'@openai/agents-core': patch +--- + +feat: #679 Add runInParallel option to input guardrail initialization diff --git a/packages/agents-core/src/agent.ts b/packages/agents-core/src/agent.ts index 2139b635..fbd73608 100644 --- a/packages/agents-core/src/agent.ts +++ b/packages/agents-core/src/agent.ts @@ -235,8 +235,9 @@ export interface AgentConfiguration< mcpServers: MCPServer[]; /** - * A list of checks that run in parallel to the agent's execution, before generating a response. - * Runs only if the agent is the first agent in the chain. + * A list of checks that run in parallel to the agent by default; set `runInParallel` to false to + * block LLM/tool calls until the guardrail completes. Runs only if the agent is the first agent + * in the chain. */ inputGuardrails: InputGuardrail[]; diff --git a/packages/agents-core/src/guardrail.ts b/packages/agents-core/src/guardrail.ts index fa6fdd4c..1b27a115 100644 --- a/packages/agents-core/src/guardrail.ts +++ b/packages/agents-core/src/guardrail.ts @@ -65,6 +65,12 @@ export interface InputGuardrail { * The function that performs the guardrail check */ execute: InputGuardrailFunction; + + /** + * Whether the guardrail should execute alongside the agent (true, default) or block the + * agent until it completes (false). + */ + runInParallel?: boolean; } /** @@ -105,6 +111,7 @@ export interface InputGuardrailMetadata { */ export interface InputGuardrailDefinition extends InputGuardrailMetadata { guardrailFunction: InputGuardrailFunction; + runInParallel: boolean; run(args: InputGuardrailFunctionArgs): Promise; } @@ -114,6 +121,7 @@ export interface InputGuardrailDefinition extends InputGuardrailMetadata { export interface DefineInputGuardrailArgs { name: string; execute: InputGuardrailFunction; + runInParallel?: boolean; } /** @@ -122,10 +130,12 @@ export interface DefineInputGuardrailArgs { export function defineInputGuardrail({ name, execute, + runInParallel = true, }: DefineInputGuardrailArgs): InputGuardrailDefinition { return { type: 'input', name, + runInParallel, guardrailFunction: execute, async run(args: InputGuardrailFunctionArgs): Promise { return { diff --git a/packages/agents-core/src/run.ts b/packages/agents-core/src/run.ts index fdc5645b..91cc25b3 100644 --- a/packages/agents-core/src/run.ts +++ b/packages/agents-core/src/run.ts @@ -4,6 +4,7 @@ import { defineOutputGuardrail, InputGuardrail, InputGuardrailDefinition, + InputGuardrailResult, OutputGuardrail, OutputGuardrailDefinition, OutputGuardrailFunctionArgs, @@ -603,6 +604,34 @@ export class Runner extends RunHooks> { AgentOutputType >[]; + #getInputGuardrailDefinitions< + TContext, + TAgent extends Agent, + >(state: RunState): InputGuardrailDefinition[] { + return this.inputGuardrailDefs.concat( + state._currentAgent.inputGuardrails.map(defineInputGuardrail), + ); + } + + #splitInputGuardrails< + TContext, + TAgent extends Agent, + >(state: RunState) { + const guardrails = this.#getInputGuardrailDefinitions(state); + const blocking: InputGuardrailDefinition[] = []; + const parallel: InputGuardrailDefinition[] = []; + + for (const guardrail of guardrails) { + if (guardrail.runInParallel === false) { + blocking.push(guardrail); + } else { + parallel.push(guardrail); + } + } + + return { blocking, parallel }; + } + /** * @internal * Resolves the effective model once so both run loops obey the same precedence rules. @@ -738,8 +767,21 @@ export class Runner extends RunHooks> { `Running agent ${state._currentAgent.name} (turn ${state._currentTurn})`, ); + let parallelGuardrailPromise: + | Promise + | undefined; if (state._currentTurn === 1) { - await this.#runInputGuardrails(state); + const guardrails = this.#splitInputGuardrails(state); + if (guardrails.blocking.length > 0) { + await this.#runInputGuardrails(state, guardrails.blocking); + } + if (guardrails.parallel.length > 0) { + parallelGuardrailPromise = this.#runInputGuardrails( + state, + guardrails.parallel, + ); + parallelGuardrailPromise.catch(() => {}); + } } const turnInput = serverConversationTracker @@ -829,6 +871,10 @@ export class Runner extends RunHooks> { state._currentTurnPersistedItemCount = 0; } state._currentStep = turnResult.nextStep; + + if (parallelGuardrailPromise) { + await parallelGuardrailPromise; + } } if ( @@ -1007,8 +1053,25 @@ export class Runner extends RunHooks> { `Running agent ${currentAgent.name} (turn ${result.state._currentTurn})`, ); + let guardrailError: unknown; + let parallelGuardrailPromise: + | Promise + | undefined; if (result.state._currentTurn === 1) { - await this.#runInputGuardrails(result.state); + const guardrails = this.#splitInputGuardrails(result.state); + if (guardrails.blocking.length > 0) { + await this.#runInputGuardrails(result.state, guardrails.blocking); + } + if (guardrails.parallel.length > 0) { + const promise = this.#runInputGuardrails( + result.state, + guardrails.parallel, + ); + parallelGuardrailPromise = promise.catch((err) => { + guardrailError = err; + return []; + }); + } } const turnInput = serverConversationTracker @@ -1038,6 +1101,10 @@ export class Runner extends RunHooks> { sessionInputUpdate, ); + if (guardrailError) { + throw guardrailError; + } + handedInputToModel = true; await persistStreamInputIfNeeded(); @@ -1064,6 +1131,9 @@ export class Runner extends RunHooks> { ), signal: options.signal, })) { + if (guardrailError) { + throw guardrailError; + } if (event.type === 'response_done') { const parsed = StreamEventResponseCompleted.parse(event); finalResponse = { @@ -1080,6 +1150,13 @@ export class Runner extends RunHooks> { result._addItem(new RunRawModelStreamEvent(event)); } + if (parallelGuardrailPromise) { + await parallelGuardrailPromise; + if (guardrailError) { + throw guardrailError; + } + } + result.state._noActiveAgentRun = false; if (!finalResponse) { @@ -1276,10 +1353,12 @@ export class Runner extends RunHooks> { async #runInputGuardrails< TContext, TAgent extends Agent, - >(state: RunState) { - const guardrails = this.inputGuardrailDefs.concat( - state._currentAgent.inputGuardrails.map(defineInputGuardrail), - ); + >( + state: RunState, + guardrailsOverride?: InputGuardrailDefinition[], + ): Promise { + const guardrails = + guardrailsOverride ?? this.#getInputGuardrailDefinitions(state); if (guardrails.length > 0) { const guardrailArgs = { agent: state._currentAgent, @@ -1300,6 +1379,7 @@ export class Runner extends RunHooks> { ); }), ); + state._inputGuardrailResults.push(...results); for (const result of results) { if (result.output.tripwireTriggered) { if (state._currentAgentSpan) { @@ -1315,6 +1395,7 @@ export class Runner extends RunHooks> { ); } } + return results; } catch (e) { if (e instanceof InputGuardrailTripwireTriggered) { throw e; @@ -1328,6 +1409,7 @@ export class Runner extends RunHooks> { ); } } + return []; } async #runOutputGuardrails< diff --git a/packages/agents-core/test/guardrail.test.ts b/packages/agents-core/test/guardrail.test.ts index 262b28c4..80e46977 100644 --- a/packages/agents-core/test/guardrail.test.ts +++ b/packages/agents-core/test/guardrail.test.ts @@ -44,6 +44,31 @@ describe('guardrail helpers', () => { expect(agent.outputGuardrails[0].name).toEqual('og'); }); + it('defaults input guardrails to run in parallel', () => { + const guardrail = defineInputGuardrail({ + name: 'ig', + execute: async (_args) => ({ + outputInfo: { ok: true }, + tripwireTriggered: false, + }), + }); + + expect(guardrail.runInParallel).toBe(true); + }); + + it('uses configured runInParallel value for input guardrails', () => { + const guardrail = defineInputGuardrail({ + name: 'blocking', + execute: async (_args) => ({ + outputInfo: { ok: true }, + tripwireTriggered: false, + }), + runInParallel: false, + }); + + expect(guardrail.runInParallel).toBe(false); + }); + it('executes input guardrail and returns expected result', async () => { const guardrailFn = vi.fn(async (_args: InputGuardrailFunctionArgs) => ({ outputInfo: { ok: true }, diff --git a/packages/agents-core/test/run.stream.test.ts b/packages/agents-core/test/run.stream.test.ts index 13dc4963..ad2d9c0e 100644 --- a/packages/agents-core/test/run.stream.test.ts +++ b/packages/agents-core/test/run.stream.test.ts @@ -944,6 +944,7 @@ describe('Runner.run (streaming)', () => { const guardrail = { name: 'block', + runInParallel: false, execute: vi.fn().mockResolvedValue({ tripwireTriggered: true, outputInfo: { reason: 'blocked' }, @@ -973,6 +974,66 @@ describe('Runner.run (streaming)', () => { expect(saveInputSpy).not.toHaveBeenCalled(); }); + + it('runs blocking input guardrails before streaming starts', async () => { + let guardrailFinished = false; + + const guardrail = { + name: 'blocking', + runInParallel: false, + execute: vi.fn(async () => { + await Promise.resolve(); + guardrailFinished = true; + return { + tripwireTriggered: false, + outputInfo: { ok: true }, + }; + }), + }; + + class ExpectGuardrailBeforeStreamModel implements Model { + getResponse(_request: ModelRequest): Promise { + throw new Error('Unexpected call to getResponse'); + } + + async *getStreamedResponse( + _request: ModelRequest, + ): AsyncIterable { + expect(guardrailFinished).toBe(true); + yield { + type: 'response_done', + response: { + id: 'stream1', + usage: { + requests: 1, + inputTokens: 0, + outputTokens: 0, + totalTokens: 0, + }, + output: [fakeModelMessage('ok')], + }, + } satisfies StreamEvent; + } + } + + const agent = new Agent({ + name: 'BlockingStreamAgent', + model: new ExpectGuardrailBeforeStreamModel(), + inputGuardrails: [guardrail], + }); + + const runner = new Runner(); + const result = await runner.run(agent, 'hi', { stream: true }); + + for await (const _ of result.toStream()) { + // consume + } + await result.completed; + + expect(result.finalOutput).toBe('ok'); + expect(result.inputGuardrailResults).toHaveLength(1); + expect(guardrail.execute).toHaveBeenCalledTimes(1); + }); }); class ImmediateStreamingModel implements Model { diff --git a/packages/agents-core/test/run.test.ts b/packages/agents-core/test/run.test.ts index 51eb6d2e..21c083c1 100644 --- a/packages/agents-core/test/run.test.ts +++ b/packages/agents-core/test/run.test.ts @@ -321,6 +321,49 @@ describe('Runner.run', () => { expect(guardrailFn).toHaveBeenCalledTimes(1); }); + it('waits for blocking input guardrails before calling the model', async () => { + let guardrailCompleted = false; + const blockingGuardrail = { + name: 'blocking-ig', + runInParallel: false, + execute: vi.fn(async () => { + await Promise.resolve(); + guardrailCompleted = true; + return { tripwireTriggered: false, outputInfo: {} }; + }), + }; + + class ExpectGuardrailFirstModel implements Model { + calls = 0; + + async getResponse(_request: ModelRequest): Promise { + this.calls++; + expect(guardrailCompleted).toBe(true); + return { + output: [fakeModelMessage('done')], + usage: new Usage(), + }; + } + + /* eslint-disable require-yield */ + async *getStreamedResponse(_request: ModelRequest) { + throw new Error('not implemented'); + } + /* eslint-enable require-yield */ + } + + const agent = new Agent({ + name: 'BlockingGuard', + model: new ExpectGuardrailFirstModel(), + inputGuardrails: [blockingGuardrail], + }); + + const result = await run(agent, 'hello'); + expect(result.finalOutput).toBe('done'); + expect(result.inputGuardrailResults).toHaveLength(1); + expect(blockingGuardrail.execute).toHaveBeenCalledTimes(1); + }); + it('output guardrail success', async () => { const guardrailFn = vi.fn(async () => ({ tripwireTriggered: false,