Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/refactor-extract-prepare-helpers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@funkai/agents": patch
---

Extract shared setup into prepareGeneration() and prepareFlowAgent() helpers to deduplicate generate/stream methods
266 changes: 149 additions & 117 deletions packages/agents/src/core/agents/base/agent.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { generateText, streamText, stepCountIs } from "ai";
import type { AsyncIterableStream } from "ai";
import type { AsyncIterableStream, LanguageModel } from "ai";

import { resolveOutput } from "@/core/agents/base/output.js";
import type { OutputSpec } from "@/core/agents/base/output.js";
import type {
Agent,
AgentConfig,
Expand All @@ -20,6 +21,7 @@ import {
toTokenUsage,
} from "@/core/agents/base/utils.js";
import { createDefaultLogger } from "@/core/logger.js";
import type { Logger } from "@/core/logger.js";
import type { Tool } from "@/core/tool.js";
import { fireHooks, wrapHook } from "@/lib/hooks.js";
import { withModelMiddleware } from "@/lib/middleware.js";
Expand Down Expand Up @@ -223,6 +225,120 @@ export function agent<
return { ok: true, input: parsed.data as TInput };
}

/**
* Resolved values shared by both `generate()` and `stream()`.
*
* Returned by `prepareGeneration()` so each method only contains
* the logic that differs (the AI SDK call and result handling).
*
* @private
*/
interface PreparedGeneration {
readonly input: TInput;
readonly model: LanguageModel;
readonly aiTools: ReturnType<typeof buildAITools>;
readonly system: string | undefined;
readonly promptParams: { prompt: string } | { messages: Message[] };
readonly output: OutputSpec | undefined;
readonly maxSteps: number;
readonly signal: AbortSignal | undefined;
readonly onStepFinish: (step: {
toolCalls?: ReadonlyArray<{ toolName: string } & Record<string, unknown>>;
toolResults?: ReadonlyArray<{ toolName: string } & Record<string, unknown>>;
usage?: { inputTokens?: number; outputTokens?: number; totalTokens?: number };
}) => Promise<void>;
}

/**
* Perform the shared setup for `generate()` and `stream()`.
*
* Resolves the model/tools/system/prompt/output, fires onStart hooks,
* and builds the `onStepFinish` handler. Input validation and logger
* resolution are handled by the caller so that validation errors
* return early while model/tool errors propagate through the caller's
* try/catch.
*
* @private
*/
async function prepareGeneration(
input: TInput,
log: Logger,
overrides: AgentOverrides<TTools, TSubAgents> | undefined,
): Promise<PreparedGeneration> {
const overrideModel = readOverride(overrides, "model");
const modelRef = overrideModel ?? config.model;
const baseModel = resolveModel(modelRef, config.resolver);
const model = await withModelMiddleware({ model: baseModel });

const overrideTools = readOverride(overrides, "tools");
const overrideAgents = readOverride(overrides, "agents");
const mergedTools = { ...config.tools, ...overrideTools } as Record<string, Tool>;
const mergedAgents = { ...config.agents, ...overrideAgents } as SubAgents;
const hasTools = Object.keys(mergedTools).length > 0;
const hasAgents = Object.keys(mergedAgents).length > 0;

const aiTools = buildAITools(
valueOrUndefined(hasTools, mergedTools),
valueOrUndefined(hasAgents, mergedAgents),
);

const overrideSystem = readOverride(overrides, "system");
const systemConfig = overrideSystem ?? config.system;
const system = resolveSystem(systemConfig, input);

const promptParams = buildPrompt(input, config);

const overrideOutput = readOverride(overrides, "output");
const outputParam = overrideOutput ?? config.output;
const output = resolveOptionalOutput(outputParam);

const overrideMaxSteps = readOverride(overrides, "maxSteps");
const maxSteps = overrideMaxSteps ?? config.maxSteps ?? 20;
const signal = readOverride(overrides, "signal");

await fireHooks(
log,
wrapHook(config.onStart, { input }),
wrapHook(readOverride(overrides, "onStart"), { input }),
);

const stepCounter = { value: 0 };
const onStepFinish = async (step: {
toolCalls?: ReadonlyArray<{ toolName: string } & Record<string, unknown>>;
toolResults?: ReadonlyArray<{ toolName: string } & Record<string, unknown>>;
usage?: { inputTokens?: number; outputTokens?: number; totalTokens?: number };
}) => {
const stepId = `${config.name}:${stepCounter.value++}`;
const toolCalls = (step.toolCalls ?? []).map((tc) => {
const args = extractProperty(tc, "args");
return { toolName: tc.toolName, argsTextLength: safeSerializedLength(args) };
});
const toolResults = (step.toolResults ?? []).map((tr) => {
const result = extractProperty(tr, "result");
return { toolName: tr.toolName, resultTextLength: safeSerializedLength(result) };
});
const usage = extractUsage(step.usage);
const event = { stepId, toolCalls, toolResults, usage };
await fireHooks(
log,
wrapHook(config.onStepFinish, event),
wrapHook(readOverride(overrides, "onStepFinish"), event),
);
};

return {
input,
model,
aiTools,
system,
promptParams,
output,
maxSteps,
signal,
onStepFinish,
};
}

async function generate(
rawInput: TInput,
overrides?: AgentOverrides<TTools, TSubAgents>,
Expand All @@ -231,78 +347,36 @@ export function agent<
if (!validated.ok) {
return { ok: false, error: validated.error };
}
const input = validated.input;

const overrideLogger = readOverride(overrides, "logger");
const log = (overrideLogger ?? baseLogger).child({ agentId: config.name });
const startedAt = Date.now();

try {
const overrideModel = readOverride(overrides, "model");
const modelRef = overrideModel ?? config.model;
const baseModel = resolveModel(modelRef, config.resolver);
const model = await withModelMiddleware({ model: baseModel });

const overrideTools = readOverride(overrides, "tools");
const overrideAgents = readOverride(overrides, "agents");
const mergedTools = { ...config.tools, ...overrideTools } as Record<string, Tool>;
const mergedAgents = { ...config.agents, ...overrideAgents } as SubAgents;
const hasTools = Object.keys(mergedTools).length > 0;
const hasAgents = Object.keys(mergedAgents).length > 0;

const aiTools = buildAITools(
valueOrUndefined(hasTools, mergedTools),
valueOrUndefined(hasAgents, mergedAgents),
);

const overrideSystem = readOverride(overrides, "system");
const systemConfig = overrideSystem ?? config.system;
const system = resolveSystem(systemConfig, input);

const promptParams = buildPrompt(input, config);

const overrideOutput = readOverride(overrides, "output");
const outputParam = overrideOutput ?? config.output;
const output = resolveOptionalOutput(outputParam);

await fireHooks(
log,
wrapHook(config.onStart, { input }),
wrapHook(readOverride(overrides, "onStart"), { input }),
);
const prepared = await prepareGeneration(validated.input, log, overrides);
const {
input,
model,
aiTools,
system,
promptParams,
output,
maxSteps,
signal,
onStepFinish,
} = prepared;

log.debug("agent.generate start", { name: config.name });

const overrideMaxSteps = readOverride(overrides, "maxSteps");
const maxSteps = overrideMaxSteps ?? config.maxSteps ?? 20;
const overrideSignal = readOverride(overrides, "signal");
const stepCounter = { value: 0 };
const aiResult = await generateText({
model,
system,
...promptParams,
tools: aiTools,
output,
stopWhen: stepCountIs(maxSteps),
abortSignal: overrideSignal,
onStepFinish: async (step) => {
const stepId = `${config.name}:${stepCounter.value++}`;
const toolCalls = (step.toolCalls ?? []).map((tc) => {
const args = extractProperty(tc, "args");
return { toolName: tc.toolName, argsTextLength: safeSerializedLength(args) };
});
const toolResults = (step.toolResults ?? []).map((tr) => {
const result = extractProperty(tr, "result");
return { toolName: tr.toolName, resultTextLength: safeSerializedLength(result) };
});
const usage = extractUsage(step.usage);
const event = { stepId, toolCalls, toolResults, usage };
await fireHooks(
log,
wrapHook(config.onStepFinish, event),
wrapHook(readOverride(overrides, "onStepFinish"), event),
);
},
abortSignal: signal,
onStepFinish,
});

const duration = Date.now() - startedAt;
Expand Down Expand Up @@ -335,8 +409,8 @@ export function agent<

await fireHooks(
log,
wrapHook(config.onError, { input, error }),
wrapHook(readOverride(overrides, "onError"), { input, error }),
wrapHook(config.onError, { input: validated.input, error }),
wrapHook(readOverride(overrides, "onError"), { input: validated.input, error }),
);

return {
Expand All @@ -358,78 +432,36 @@ export function agent<
if (!validated.ok) {
return { ok: false, error: validated.error };
}
const input = validated.input;

const overrideLogger = readOverride(overrides, "logger");
const log = (overrideLogger ?? baseLogger).child({ agentId: config.name });
const startedAt = Date.now();

try {
const overrideModel = readOverride(overrides, "model");
const modelRef = overrideModel ?? config.model;
const baseModel = resolveModel(modelRef, config.resolver);
const model = await withModelMiddleware({ model: baseModel });

const overrideTools = readOverride(overrides, "tools");
const overrideAgents = readOverride(overrides, "agents");
const mergedTools = { ...config.tools, ...overrideTools } as Record<string, Tool>;
const mergedAgents = { ...config.agents, ...overrideAgents } as SubAgents;
const hasTools = Object.keys(mergedTools).length > 0;
const hasAgents = Object.keys(mergedAgents).length > 0;

const aiTools = buildAITools(
valueOrUndefined(hasTools, mergedTools),
valueOrUndefined(hasAgents, mergedAgents),
);

const overrideSystem = readOverride(overrides, "system");
const systemConfig = overrideSystem ?? config.system;
const system = resolveSystem(systemConfig, input);

const promptParams = buildPrompt(input, config);

const overrideOutput = readOverride(overrides, "output");
const outputParam = overrideOutput ?? config.output;
const output = resolveOptionalOutput(outputParam);

await fireHooks(
log,
wrapHook(config.onStart, { input }),
wrapHook(readOverride(overrides, "onStart"), { input }),
);
const prepared = await prepareGeneration(validated.input, log, overrides);
const {
input,
model,
aiTools,
system,
promptParams,
output,
maxSteps,
signal,
onStepFinish,
} = prepared;

log.debug("agent.stream start", { name: config.name });

const overrideMaxSteps = readOverride(overrides, "maxSteps");
const maxSteps = overrideMaxSteps ?? config.maxSteps ?? 20;
const overrideSignal = readOverride(overrides, "signal");
const stepCounter = { value: 0 };
const aiResult = streamText({
model,
system,
...promptParams,
tools: aiTools,
output,
stopWhen: stepCountIs(maxSteps),
abortSignal: overrideSignal,
onStepFinish: async (step) => {
const stepId = `${config.name}:${stepCounter.value++}`;
const toolCalls = (step.toolCalls ?? []).map((tc) => {
const args = extractProperty(tc, "args");
return { toolName: tc.toolName, argsTextLength: safeSerializedLength(args) };
});
const toolResults = (step.toolResults ?? []).map((tr) => {
const result = extractProperty(tr, "result");
return { toolName: tr.toolName, resultTextLength: safeSerializedLength(result) };
});
const usage = extractUsage(step.usage);
const event = { stepId, toolCalls, toolResults, usage };
await fireHooks(
log,
wrapHook(config.onStepFinish, event),
wrapHook(readOverride(overrides, "onStepFinish"), event),
);
},
abortSignal: signal,
onStepFinish,
});

const { readable, writable } = new TransformStream<StreamPart, StreamPart>();
Expand Down Expand Up @@ -521,8 +553,8 @@ export function agent<

await fireHooks(
log,
wrapHook(config.onError, { input, error }),
wrapHook(readOverride(overrides, "onError"), { input, error }),
wrapHook(config.onError, { input: validated.input, error }),
wrapHook(readOverride(overrides, "onError"), { input: validated.input, error }),
);

return {
Expand Down
Loading