Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ import {
export { Status, StatusCodes, StatusSchema } from './statusTypes.js';
export { JSONSchema7 };

/**
* Action metadata.
*/
export interface ActionMetadata<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -43,16 +46,32 @@ export interface ActionMetadata<
metadata?: M;
}

/**
* Results of an action run. Includes telemetry.
*/
export interface ActionResult<O> {
result: O;
telemetry: {
traceId: string;
spanId: string;
};
}

/**
* Self-describing, validating, observable, locally and remotely callable function.
*/
export type Action<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
M extends Record<string, any> = Record<string, any>,
> = ((input: z.infer<I>) => Promise<z.infer<O>>) & {
__action: ActionMetadata<I, O, M>;
run(input: z.infer<I>): Promise<ActionResult<z.infer<O>>>;
};

export type SideChannelData = Record<string, any>;

/**
* Action factory params.
*/
type ActionParams<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -73,10 +92,16 @@ type ActionParams<
use?: Middleware<z.infer<I>, z.infer<O>>[];
};

/**
* Middleware function for actions.
*/
export interface Middleware<I = any, O = any> {
(req: I, next: (req?: I) => Promise<O>): Promise<O>;
}

/**
* Creates an action with provided middleware.
*/
export function actionWithMiddleware<
I extends z.ZodTypeAny,
O extends z.ZodTypeAny,
Expand All @@ -86,10 +111,17 @@ export function actionWithMiddleware<
middleware: Middleware<z.infer<I>, z.infer<O>>[]
): Action<I, O, M> {
const wrapped = (async (req: z.infer<I>) => {
return (await wrapped.run(req)).result;
}) as Action<I, O, M>;
wrapped.__action = action.__action;
wrapped.run = async (req: z.infer<I>): Promise<ActionResult<z.infer<O>>> => {
let telemetry;
const dispatch = async (index: number, req: z.infer<I>) => {
if (index === middleware.length) {
// end of the chain, call the original model action
return await action(req);
const result = await action.run(req);
telemetry = result.telemetry;
return result.result;
}

const currentMiddleware = middleware[index];
Expand All @@ -98,9 +130,8 @@ export function actionWithMiddleware<
);
};

return await dispatch(0, req);
}) as Action<I, O, M>;
wrapped.__action = action.__action;
return { result: await dispatch(0, req), telemetry };
};
return wrapped;
}

Expand All @@ -120,18 +151,36 @@ export function action<
? config.name
: `${config.name.pluginId}/${config.name.actionId}`;
const actionFn = async (input: I) => {
return (await actionFn.run(input)).result;
};
actionFn.__action = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;
actionFn.run = async (
input: z.infer<I>
): Promise<ActionResult<z.infer<O>>> => {
input = parseSchema(input, {
schema: config.inputSchema,
jsonSchema: config.inputJsonSchema,
});
let traceId;
let spanId;
let output = await newTrace(
{
name: actionName,
labels: {
[SPAN_TYPE_ATTR]: 'action',
},
},
async (metadata) => {
async (metadata, span) => {
traceId = span.spanContext().traceId;
spanId = span.spanContext().spanId;
metadata.name = actionName;
metadata.input = input;

Expand All @@ -145,17 +194,14 @@ export function action<
schema: config.outputSchema,
jsonSchema: config.outputJsonSchema,
});
return output;
return {
result: output,
telemetry: {
traceId,
spanId,
},
};
};
actionFn.__action = {
name: actionName,
description: config.description,
inputSchema: config.inputSchema,
inputJsonSchema: config.inputJsonSchema,
outputSchema: config.outputSchema,
outputJsonSchema: config.outputJsonSchema,
metadata: config.metadata,
} as ActionMetadata<I, O, M>;

if (config.use) {
return actionWithMiddleware(actionFn, config.use);
Expand Down
53 changes: 15 additions & 38 deletions js/core/src/reflection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ import { GENKIT_VERSION } from './index.js';
import { logger } from './logging.js';
import { Registry } from './registry.js';
import { toJsonSchema } from './schema.js';
import {
flushTracing,
newTrace,
setCustomMetadataAttribute,
setTelemetryServerUrl,
} from './tracing.js';
import { flushTracing, setTelemetryServerUrl } from './tracing.js';

// TODO: Move this to common location for schemas.
export const RunActionResponseSchema = z.object({
Expand Down Expand Up @@ -169,48 +164,30 @@ export class ReflectionServer {
return;
}
if (stream === 'true') {
const result = await newTrace(
{ name: 'dev-run-action-wrapper' },
async (_, span) => {
setCustomMetadataAttribute('genkit-dev-internal', 'true');
traceId = span.spanContext().traceId;
return await runWithStreamingCallback(
(chunk) => {
response.write(JSON.stringify(chunk) + '\n');
},
async () => await action(input)
);
}
const result = await runWithStreamingCallback(
(chunk) => {
response.write(JSON.stringify(chunk) + '\n');
},
async () => await action.run(input)
);
await flushTracing();
response.write(
JSON.stringify({
result,
telemetry: traceId
? {
traceId,
}
: undefined,
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
} as RunActionResponse)
);
response.end();
} else {
const result = await newTrace(
{ name: 'dev-run-action-wrapper' },
async (_, span) => {
setCustomMetadataAttribute('genkit-dev-internal', 'true');
traceId = span.spanContext().traceId;
return await action(input);
}
);
const result = await action.run(input);
await flushTracing();
response.send({
result,
telemetry: traceId
? {
traceId,
}
: undefined,
result: result.result,
telemetry: {
traceId: result.telemetry.traceId,
},
} as RunActionResponse);
}
} catch (err) {
Expand Down
32 changes: 32 additions & 0 deletions js/core/tests/action_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,36 @@ describe('action', () => {
20 // "foomiddle1middle2".length + 1 + 2
);
});

it('returns telemetry info', async () => {
const act = action(
{
name: 'foo',
inputSchema: z.string(),
outputSchema: z.number(),
use: [
async (input, next) => (await next(input + 'middle1')) + 1,
async (input, next) => (await next(input + 'middle2')) + 2,
],
},
async (input) => {
return input.length;
}
);

const result = await act.run('foo');
assert.strictEqual(
result.result,
20 // "foomiddle1middle2".length + 1 + 2
);
assert.strictEqual(result.telemetry !== null, true);
assert.strictEqual(
result.telemetry.traceId !== null && result.telemetry.traceId.length > 0,
true
);
assert.strictEqual(
result.telemetry.spanId !== null && result.telemetry.spanId.length > 0,
true
);
});
});
1 change: 0 additions & 1 deletion js/genkit/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ export {
type Middleware,
type ReflectionServerOptions,
type RunActionResponse,
type SideChannelData,
type Status,
type StreamableFlow,
type StreamingCallback,
Expand Down
Loading