Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/bold-dancers-see.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@openai/agents-core': patch
---

feat: #679 Add runInParallel option to input guardrail initialization
5 changes: 3 additions & 2 deletions packages/agents-core/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,9 @@ export interface AgentConfiguration<
mcpServers: MCPServer[];

/**
* A list of checks that run in parallel to the agent's execution, before generating a response.
* Runs only if the agent is the first agent in the chain.
* A list of checks that run in parallel to the agent by default; set `runInParallel` to false to
* block LLM/tool calls until the guardrail completes. Runs only if the agent is the first agent
* in the chain.
*/
inputGuardrails: InputGuardrail[];

Expand Down
10 changes: 10 additions & 0 deletions packages/agents-core/src/guardrail.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ export interface InputGuardrail {
* The function that performs the guardrail check
*/
execute: InputGuardrailFunction;

/**
* Whether the guardrail should execute alongside the agent (true, default) or block the
* agent until it completes (false).
*/
runInParallel?: boolean;
}

/**
Expand Down Expand Up @@ -105,6 +111,7 @@ export interface InputGuardrailMetadata {
*/
export interface InputGuardrailDefinition extends InputGuardrailMetadata {
guardrailFunction: InputGuardrailFunction;
runInParallel: boolean;
run(args: InputGuardrailFunctionArgs): Promise<InputGuardrailResult>;
}

Expand All @@ -114,6 +121,7 @@ export interface InputGuardrailDefinition extends InputGuardrailMetadata {
export interface DefineInputGuardrailArgs {
name: string;
execute: InputGuardrailFunction;
runInParallel?: boolean;
}

/**
Expand All @@ -122,10 +130,12 @@ export interface DefineInputGuardrailArgs {
export function defineInputGuardrail({
name,
execute,
runInParallel = true,
}: DefineInputGuardrailArgs): InputGuardrailDefinition {
return {
type: 'input',
name,
runInParallel,
guardrailFunction: execute,
async run(args: InputGuardrailFunctionArgs): Promise<InputGuardrailResult> {
return {
Expand Down
94 changes: 88 additions & 6 deletions packages/agents-core/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
defineOutputGuardrail,
InputGuardrail,
InputGuardrailDefinition,
InputGuardrailResult,
OutputGuardrail,
OutputGuardrailDefinition,
OutputGuardrailFunctionArgs,
Expand Down Expand Up @@ -603,6 +604,34 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
AgentOutputType<unknown>
>[];

#getInputGuardrailDefinitions<
TContext,
TAgent extends Agent<TContext, AgentOutputType>,
>(state: RunState<TContext, TAgent>): InputGuardrailDefinition[] {
return this.inputGuardrailDefs.concat(
state._currentAgent.inputGuardrails.map(defineInputGuardrail),
);
}

#splitInputGuardrails<
TContext,
TAgent extends Agent<TContext, AgentOutputType>,
>(state: RunState<TContext, TAgent>) {
const guardrails = this.#getInputGuardrailDefinitions(state);
const blocking: InputGuardrailDefinition[] = [];
const parallel: InputGuardrailDefinition[] = [];

for (const guardrail of guardrails) {
if (guardrail.runInParallel === false) {
blocking.push(guardrail);
} else {
parallel.push(guardrail);
}
}

return { blocking, parallel };
}

/**
* @internal
* Resolves the effective model once so both run loops obey the same precedence rules.
Expand Down Expand Up @@ -738,8 +767,21 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
`Running agent ${state._currentAgent.name} (turn ${state._currentTurn})`,
);

let parallelGuardrailPromise:
| Promise<InputGuardrailResult[]>
| undefined;
if (state._currentTurn === 1) {
await this.#runInputGuardrails(state);
const guardrails = this.#splitInputGuardrails(state);
if (guardrails.blocking.length > 0) {
await this.#runInputGuardrails(state, guardrails.blocking);
}
if (guardrails.parallel.length > 0) {
parallelGuardrailPromise = this.#runInputGuardrails(
state,
guardrails.parallel,
);
parallelGuardrailPromise.catch(() => {});
}
}

const turnInput = serverConversationTracker
Expand Down Expand Up @@ -829,6 +871,10 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
state._currentTurnPersistedItemCount = 0;
}
state._currentStep = turnResult.nextStep;

if (parallelGuardrailPromise) {
await parallelGuardrailPromise;
}
}

if (
Expand Down Expand Up @@ -1007,8 +1053,25 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
`Running agent ${currentAgent.name} (turn ${result.state._currentTurn})`,
);

let guardrailError: unknown;
let parallelGuardrailPromise:
| Promise<InputGuardrailResult[]>
| undefined;
if (result.state._currentTurn === 1) {
await this.#runInputGuardrails(result.state);
const guardrails = this.#splitInputGuardrails(result.state);
if (guardrails.blocking.length > 0) {
await this.#runInputGuardrails(result.state, guardrails.blocking);
}
if (guardrails.parallel.length > 0) {
const promise = this.#runInputGuardrails(
result.state,
guardrails.parallel,
);
parallelGuardrailPromise = promise.catch((err) => {
guardrailError = err;
return [];
});
}
}

const turnInput = serverConversationTracker
Expand Down Expand Up @@ -1038,6 +1101,10 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
sessionInputUpdate,
);

if (guardrailError) {
throw guardrailError;
}

handedInputToModel = true;
await persistStreamInputIfNeeded();

Expand All @@ -1064,6 +1131,9 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
),
signal: options.signal,
})) {
if (guardrailError) {
throw guardrailError;
}
if (event.type === 'response_done') {
const parsed = StreamEventResponseCompleted.parse(event);
finalResponse = {
Expand All @@ -1080,6 +1150,13 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
result._addItem(new RunRawModelStreamEvent(event));
}

if (parallelGuardrailPromise) {
await parallelGuardrailPromise;
if (guardrailError) {
throw guardrailError;
}
}

result.state._noActiveAgentRun = false;

if (!finalResponse) {
Expand Down Expand Up @@ -1276,10 +1353,12 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
async #runInputGuardrails<
TContext,
TAgent extends Agent<TContext, AgentOutputType>,
>(state: RunState<TContext, TAgent>) {
const guardrails = this.inputGuardrailDefs.concat(
state._currentAgent.inputGuardrails.map(defineInputGuardrail),
);
>(
state: RunState<TContext, TAgent>,
guardrailsOverride?: InputGuardrailDefinition[],
): Promise<InputGuardrailResult[]> {
const guardrails =
guardrailsOverride ?? this.#getInputGuardrailDefinitions(state);
if (guardrails.length > 0) {
const guardrailArgs = {
agent: state._currentAgent,
Expand All @@ -1300,6 +1379,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
);
}),
);
state._inputGuardrailResults.push(...results);
for (const result of results) {
if (result.output.tripwireTriggered) {
if (state._currentAgentSpan) {
Expand All @@ -1315,6 +1395,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
);
}
}
return results;
} catch (e) {
if (e instanceof InputGuardrailTripwireTriggered) {
throw e;
Expand All @@ -1328,6 +1409,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
);
}
}
return [];
}

async #runOutputGuardrails<
Expand Down
25 changes: 25 additions & 0 deletions packages/agents-core/test/guardrail.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ describe('guardrail helpers', () => {
expect(agent.outputGuardrails[0].name).toEqual('og');
});

it('defaults input guardrails to run in parallel', () => {
const guardrail = defineInputGuardrail({
name: 'ig',
execute: async (_args) => ({
outputInfo: { ok: true },
tripwireTriggered: false,
}),
});

expect(guardrail.runInParallel).toBe(true);
});

it('uses configured runInParallel value for input guardrails', () => {
const guardrail = defineInputGuardrail({
name: 'blocking',
execute: async (_args) => ({
outputInfo: { ok: true },
tripwireTriggered: false,
}),
runInParallel: false,
});

expect(guardrail.runInParallel).toBe(false);
});

it('executes input guardrail and returns expected result', async () => {
const guardrailFn = vi.fn(async (_args: InputGuardrailFunctionArgs) => ({
outputInfo: { ok: true },
Expand Down
61 changes: 61 additions & 0 deletions packages/agents-core/test/run.stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ describe('Runner.run (streaming)', () => {

const guardrail = {
name: 'block',
runInParallel: false,
execute: vi.fn().mockResolvedValue({
tripwireTriggered: true,
outputInfo: { reason: 'blocked' },
Expand Down Expand Up @@ -973,6 +974,66 @@ describe('Runner.run (streaming)', () => {

expect(saveInputSpy).not.toHaveBeenCalled();
});

it('runs blocking input guardrails before streaming starts', async () => {
let guardrailFinished = false;

const guardrail = {
name: 'blocking',
runInParallel: false,
execute: vi.fn(async () => {
await Promise.resolve();
guardrailFinished = true;
return {
tripwireTriggered: false,
outputInfo: { ok: true },
};
}),
};

class ExpectGuardrailBeforeStreamModel implements Model {
getResponse(_request: ModelRequest): Promise<ModelResponse> {
throw new Error('Unexpected call to getResponse');
}

async *getStreamedResponse(
_request: ModelRequest,
): AsyncIterable<StreamEvent> {
expect(guardrailFinished).toBe(true);
yield {
type: 'response_done',
response: {
id: 'stream1',
usage: {
requests: 1,
inputTokens: 0,
outputTokens: 0,
totalTokens: 0,
},
output: [fakeModelMessage('ok')],
},
} satisfies StreamEvent;
}
}

const agent = new Agent({
name: 'BlockingStreamAgent',
model: new ExpectGuardrailBeforeStreamModel(),
inputGuardrails: [guardrail],
});

const runner = new Runner();
const result = await runner.run(agent, 'hi', { stream: true });

for await (const _ of result.toStream()) {
// consume
}
await result.completed;

expect(result.finalOutput).toBe('ok');
expect(result.inputGuardrailResults).toHaveLength(1);
expect(guardrail.execute).toHaveBeenCalledTimes(1);
});
});

class ImmediateStreamingModel implements Model {
Expand Down
43 changes: 43 additions & 0 deletions packages/agents-core/test/run.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,49 @@ describe('Runner.run', () => {
expect(guardrailFn).toHaveBeenCalledTimes(1);
});

it('waits for blocking input guardrails before calling the model', async () => {
let guardrailCompleted = false;
const blockingGuardrail = {
name: 'blocking-ig',
runInParallel: false,
execute: vi.fn(async () => {
await Promise.resolve();
guardrailCompleted = true;
return { tripwireTriggered: false, outputInfo: {} };
}),
};

class ExpectGuardrailFirstModel implements Model {
calls = 0;

async getResponse(_request: ModelRequest): Promise<ModelResponse> {
this.calls++;
expect(guardrailCompleted).toBe(true);
return {
output: [fakeModelMessage('done')],
usage: new Usage(),
};
}

/* eslint-disable require-yield */
async *getStreamedResponse(_request: ModelRequest) {
throw new Error('not implemented');
}
/* eslint-enable require-yield */
}

const agent = new Agent({
name: 'BlockingGuard',
model: new ExpectGuardrailFirstModel(),
inputGuardrails: [blockingGuardrail],
});

const result = await run(agent, 'hello');
expect(result.finalOutput).toBe('done');
expect(result.inputGuardrailResults).toHaveLength(1);
expect(blockingGuardrail.execute).toHaveBeenCalledTimes(1);
});

it('output guardrail success', async () => {
const guardrailFn = vi.fn(async () => ({
tripwireTriggered: false,
Expand Down