Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 96 additions & 19 deletions js/src/wrappers/ai-sdk/ai-sdk.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,20 @@ describe("ai sdk client unit tests", TEST_SUITE_OPTIONS, () => {
expect(metrics.start).toBeLessThanOrEqual(metrics.end);
expect(metrics.end).toBeLessThanOrEqual(end);

expect(metrics.tokens).toBeGreaterThan(0);
expect(metrics.prompt_tokens).toBeGreaterThan(0);
expect(metrics.completion_tokens).toBeGreaterThan(0);
// Token/cost metrics live on the child doGenerate span to avoid
// double-counting. Parent span should NOT have them.
expect(metrics.tokens).toBeUndefined();
expect(metrics.prompt_tokens).toBeUndefined();
expect(metrics.completion_tokens).toBeUndefined();

// Verify child doGenerate span carries the metrics
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenSpan = spans.find(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any;
expect(doGenSpan).toBeDefined();
expect(doGenSpan.metrics.prompt_tokens).toBeGreaterThan(0);
expect(doGenSpan.metrics.completion_tokens).toBeGreaterThan(0);

// Check that output is present and not omitted
expect(span.output).toBeDefined();
Expand Down Expand Up @@ -268,9 +279,20 @@ describe("ai sdk client unit tests", TEST_SUITE_OPTIONS, () => {
expect(metrics.start).toBeLessThanOrEqual(metrics.end);
expect(metrics.end).toBeLessThanOrEqual(end);

expect(metrics.tokens).toBeGreaterThan(0);
expect(metrics.prompt_tokens).toBeGreaterThan(0);
expect(metrics.completion_tokens).toBeGreaterThan(0);
// Token/cost metrics live on the child doGenerate span to avoid
// double-counting. Parent span should NOT have them.
expect(metrics.tokens).toBeUndefined();
expect(metrics.prompt_tokens).toBeUndefined();
expect(metrics.completion_tokens).toBeUndefined();

// Verify child doGenerate span carries the metrics
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenSpan = spans.find(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any;
expect(doGenSpan).toBeDefined();
expect(doGenSpan.metrics.prompt_tokens).toBeGreaterThan(0);
expect(doGenSpan.metrics.completion_tokens).toBeGreaterThan(0);

// Verify image content is properly handled as attachment
const messageContent = span.input.messages[0].content;
Expand Down Expand Up @@ -373,9 +395,20 @@ describe("ai sdk client unit tests", TEST_SUITE_OPTIONS, () => {
expect(metrics.start).toBeLessThanOrEqual(metrics.end);
expect(metrics.end).toBeLessThanOrEqual(end);

expect(metrics.tokens).toBeGreaterThan(0);
expect(metrics.prompt_tokens).toBeGreaterThan(0);
expect(metrics.completion_tokens).toBeGreaterThan(0);
// Token/cost metrics live on the child doGenerate span to avoid
// double-counting. Parent span should NOT have them.
expect(metrics.tokens).toBeUndefined();
expect(metrics.prompt_tokens).toBeUndefined();
expect(metrics.completion_tokens).toBeUndefined();

// Verify child doGenerate span carries the metrics
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenSpan = spans.find(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any;
expect(doGenSpan).toBeDefined();
expect(doGenSpan.metrics.prompt_tokens).toBeGreaterThan(0);
expect(doGenSpan.metrics.completion_tokens).toBeGreaterThan(0);

// Verify file content is properly handled as attachment
const messageContent = span.input.messages[0].content;
Expand Down Expand Up @@ -453,9 +486,20 @@ describe("ai sdk client unit tests", TEST_SUITE_OPTIONS, () => {
expect(ttft).toBeGreaterThanOrEqual(metrics.time_to_first_token);
}

expect(metrics.tokens).toBeGreaterThan(0);
expect(metrics.prompt_tokens).toBeGreaterThan(0);
expect(metrics.completion_tokens).toBeGreaterThan(0);
// Token/cost metrics live on the child doStream span to avoid
// double-counting. Parent span should NOT have them.
expect(metrics.tokens).toBeUndefined();
expect(metrics.prompt_tokens).toBeUndefined();
expect(metrics.completion_tokens).toBeUndefined();

// Verify child doStream span carries the metrics
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doStreamSpan = spans.find(
(s: any) => s.span_attributes?.name === "doStream",
) as any;
expect(doStreamSpan).toBeDefined();
expect(doStreamSpan.metrics.prompt_tokens).toBeGreaterThan(0);
expect(doStreamSpan.metrics.completion_tokens).toBeGreaterThan(0);
});

test("ai sdk multi-turn conversation", async () => {
Expand Down Expand Up @@ -506,9 +550,20 @@ describe("ai sdk client unit tests", TEST_SUITE_OPTIONS, () => {
expect(metrics.start).toBeLessThanOrEqual(metrics.end);
expect(metrics.end).toBeLessThanOrEqual(end);

expect(metrics.tokens).toBeGreaterThan(0);
expect(metrics.prompt_tokens).toBeGreaterThan(0);
expect(metrics.completion_tokens).toBeGreaterThan(0);
// Token/cost metrics live on the child doGenerate span to avoid
// double-counting. Parent span should NOT have them.
expect(metrics.tokens).toBeUndefined();
expect(metrics.prompt_tokens).toBeUndefined();
expect(metrics.completion_tokens).toBeUndefined();

// Verify child doGenerate span carries the metrics
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenSpan = spans.find(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any;
expect(doGenSpan).toBeDefined();
expect(doGenSpan.metrics.prompt_tokens).toBeGreaterThan(0);
expect(doGenSpan.metrics.completion_tokens).toBeGreaterThan(0);
});

test("ai sdk system prompt", async () => {
Expand Down Expand Up @@ -2375,8 +2430,16 @@ describe.skipIf(!AI_GATEWAY_API_KEY)(
expect(generateTextSpan.metadata.model).toBe("gpt-4o-mini");
expect(generateTextSpan.metadata.provider).toBe("openai");

// Verify cost is extracted from gateway marketCost
expect(generateTextSpan.metrics.estimated_cost).toBeGreaterThan(0);
// Cost should NOT be on the parent span (to avoid double-counting with
// child doGenerate spans). It should be on the doGenerate child span.
expect(generateTextSpan.metrics.estimated_cost).toBeUndefined();

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenerateSpan = spans.find(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any;
expect(doGenerateSpan).toBeDefined();
expect(doGenerateSpan.metrics.estimated_cost).toBeGreaterThan(0);
});

test("multi-step tool use extracts total cost", async () => {
Expand Down Expand Up @@ -2407,8 +2470,22 @@ describe.skipIf(!AI_GATEWAY_API_KEY)(

expect(generateTextSpan).toBeDefined();

// Cost should be sum of all steps
expect(generateTextSpan.metrics.estimated_cost).toBeGreaterThan(0);
// Cost should NOT be on the parent span (to avoid double-counting).
// Individual doGenerate child spans carry per-step costs.
expect(generateTextSpan.metrics.estimated_cost).toBeUndefined();

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const doGenerateSpans = spans.filter(
(s: any) => s.span_attributes?.name === "doGenerate",
) as any[];
expect(doGenerateSpans.length).toBeGreaterThan(0);

// At least one doGenerate span should have cost
const totalCost = doGenerateSpans.reduce(
(sum: number, s: any) => sum + (s.metrics.estimated_cost ?? 0),
0,
);
expect(totalCost).toBeGreaterThan(0);

// Verify model/provider in metadata
expect(generateTextSpan.metadata.model).toBe("gpt-4o-mini");
Expand Down
57 changes: 49 additions & 8 deletions js/src/wrappers/ai-sdk/ai-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,14 @@ const makeGenerateTextWrapper = (

return traced(
async (span) => {
const { wrappedModel, getMetrics } = wrapModelAndGetMetrics(
params.model,
aiSDK,
);

const result = await generateText({
...params,
model: wrapModel(params.model, aiSDK),
model: wrappedModel,
tools: wrapTools(params.tools),
});

Expand All @@ -230,7 +235,7 @@ const makeGenerateTextWrapper = (

span.log({
output: await processOutput(result, options.denyOutputPaths),
metrics: extractTokenMetrics(result),
metrics: getMetrics(result),
...(Object.keys(resolvedMetadata).length > 0
? { metadata: resolvedMetadata }
: {}),
Expand Down Expand Up @@ -521,6 +526,28 @@ const wrapGenerateText = (
return makeGenerateTextWrapper("generateText", options, generateText, aiSDK);
};

/**
* Wraps the model and returns a metrics extractor for the parent span.
* When the model is wrapped, child doGenerate/doStream spans already carry
* token/cost metrics, so the extractor returns undefined to prevent
* double-counting on the parent.
*/
const wrapModelAndGetMetrics = (
model: any,
aiSDK?: any,
): {
wrappedModel: any;
getMetrics: (result: any) => Record<string, number> | undefined;
} => {
const wrappedModel = wrapModel(model, aiSDK);
const modelIsWrapped = wrappedModel?._braintrustWrapped === true;
return {
wrappedModel,
getMetrics: (result: any) =>
modelIsWrapped ? undefined : extractTokenMetrics(result),
};
};

const wrapGenerateObject = (
generateObject: any,
options: WrapAISDKOptions = {},
Expand All @@ -542,9 +569,14 @@ const wrapGenerateObject = (

return traced(
async (span) => {
const { wrappedModel, getMetrics } = wrapModelAndGetMetrics(
params.model,
aiSDK,
);

const result = await generateObject({
...params,
model: wrapModel(params.model, aiSDK),
model: wrappedModel,
tools: wrapTools(params.tools),
});

Expand All @@ -562,7 +594,7 @@ const wrapGenerateObject = (

span.log({
output,
metrics: extractTokenMetrics(result),
metrics: getMetrics(result),
...(Object.keys(resolvedMetadata).length > 0
? { metadata: resolvedMetadata }
: {}),
Expand Down Expand Up @@ -654,10 +686,15 @@ const makeStreamTextWrapper = (
try {
const startTime = Date.now();
let receivedFirst = false;
const { wrappedModel, getMetrics } = wrapModelAndGetMetrics(
params.model,
aiSDK,
);

const result = withCurrent(span, () =>
streamText({
...params,
model: wrapModel(params.model, aiSDK),
model: wrappedModel,
tools: wrapTools(params.tools),
onChunk: (chunk: any) => {
if (!receivedFirst) {
Expand Down Expand Up @@ -686,7 +723,7 @@ const makeStreamTextWrapper = (

span.log({
output: await processOutput(event, options.denyOutputPaths),
metrics: extractTokenMetrics(event),
metrics: getMetrics(event),
...(Object.keys(resolvedMetadata).length > 0
? { metadata: resolvedMetadata }
: {}),
Expand Down Expand Up @@ -824,11 +861,15 @@ const wrapStreamObject = (
try {
const startTime = Date.now();
let receivedFirst = false;
const { wrappedModel, getMetrics } = wrapModelAndGetMetrics(
params.model,
aiSDK,
);

const result = withCurrent(span, () =>
streamObject({
...params,
model: wrapModel(params.model, aiSDK),
model: wrappedModel,
tools: wrapTools(params.tools),
onChunk: (chunk: any) => {
if (!receivedFirst) {
Expand Down Expand Up @@ -856,7 +897,7 @@ const wrapStreamObject = (

span.log({
output: await processOutput(event, options.denyOutputPaths),
metrics: extractTokenMetrics(event),
metrics: getMetrics(event),
...(Object.keys(resolvedMetadata).length > 0
? { metadata: resolvedMetadata }
: {}),
Expand Down
Loading