Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions js/ai/src/formats/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { GenkitError } from '@genkit-ai/core';
import { extractItems } from '../extract';
import type { Formatter } from './types';

export const arrayFormatter: Formatter<unknown[], unknown[], number> = {
export const arrayFormatter: Formatter<unknown[], unknown[]> = {
name: 'array',
config: {
contentType: 'application/json',
Expand All @@ -43,16 +43,15 @@ export const arrayFormatter: Formatter<unknown[], unknown[], number> = {
}

return {
parseChunk: (chunk, cursor = 0) => {
const { items, cursor: newCursor } = extractItems(
chunk.accumulatedText,
cursor
);
parseChunk: (chunk) => {
// first, determine the cursor position from the previous chunks
const cursor = chunk.previousChunks?.length
? extractItems(chunk.previousText).cursor
: 0;
// then, extract the items starting at that cursor
const { items } = extractItems(chunk.accumulatedText, cursor);

return {
output: items,
cursor: newCursor,
};
return items;
},

parseResponse: (response) => {
Expand Down
2 changes: 1 addition & 1 deletion js/ai/src/formats/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ export async function resolveFormat(
return arg as Formatter;
}

export const DEFAULT_FORMATS: Formatter<any, any, any>[] = [
export const DEFAULT_FORMATS: Formatter<any, any>[] = [
jsonFormatter,
arrayFormatter,
textFormatter,
Expand Down
9 changes: 3 additions & 6 deletions js/ai/src/formats/json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import { extractJson } from '../extract';
import type { Formatter } from './types';

export const jsonFormatter: Formatter<unknown, unknown, string> = {
export const jsonFormatter: Formatter<unknown, unknown> = {
name: 'json',
config: {
contentType: 'application/json',
Expand All @@ -36,11 +36,8 @@ ${JSON.stringify(request.output!.schema!)}
}

return {
parseChunk: (chunk, cursor = '') => {
return {
output: extractJson(chunk.accumulatedText),
cursor: chunk.accumulatedText,
};
parseChunk: (chunk) => {
return extractJson(chunk.accumulatedText);
},

parseResponse: (response) => {
Expand Down
41 changes: 25 additions & 16 deletions js/ai/src/formats/jsonl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function objectLines(text: string): string[] {
.filter((line) => line.startsWith('{'));
}

export const jsonlFormatter: Formatter<unknown[], unknown[], number> = {
export const jsonlFormatter: Formatter<unknown[], unknown[]> = {
name: 'jsonl',
config: {
contentType: 'application/jsonl',
Expand Down Expand Up @@ -54,27 +54,36 @@ ${JSON.stringify(request.output.schema.items)}
}

return {
parseChunk: (chunk, cursor = 0) => {
const jsonLines = objectLines(chunk.accumulatedText);
parseChunk: (chunk) => {
const results: unknown[] = [];
let newCursor = cursor;

for (let i = cursor; i < jsonLines.length; i++) {
try {
const result = JSON5.parse(jsonLines[i]);
if (result) {
results.push(result);
const text = chunk.accumulatedText;

let startIndex = 0;
if (chunk.previousChunks?.length) {
const lastNewline = chunk.previousText.lastIndexOf('\n');
if (lastNewline !== -1) {
startIndex = lastNewline + 1;
}
}

const lines = text.slice(startIndex).split('\n');

for (const line of lines) {
const trimmed = line.trim();
if (trimmed.startsWith('{')) {
try {
const result = JSON5.parse(trimmed);
if (result) {
results.push(result);
}
} catch (e) {
break;
}
newCursor = i + 1;
} catch (e) {
break;
}
}

return {
output: results,
cursor: newCursor,
};
return results;
},

parseResponse: (response) => {
Expand Down
4 changes: 1 addition & 3 deletions js/ai/src/formats/text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ export const textFormatter: Formatter<string, string> = {
handler: () => {
return {
parseChunk: (chunk) => {
return {
output: chunk.text,
};
return chunk.text;
},

parseResponse: (response) => {
Expand Down
17 changes: 2 additions & 15 deletions js/ai/src/formats/types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,17 @@
import { GenerateResponse, GenerateResponseChunk } from '../generate.js';
import { ModelRequest, Part } from '../model.js';

export interface ParsedChunk<CO = unknown, CC = unknown> {
output: CO;
/**
* The cursor of a parsed chunk response holds context that is relevant to continue parsing.
* The returned cursor will be passed into the next iteration of the chunk parser. Cursors
* are not exposed to external consumers of the formatter.
*/
cursor?: CC;
}

type OutputContentTypes =
| 'application/json'
| 'text/plain'
| 'application/jsonl';

export interface Formatter<O = unknown, CO = unknown, CC = unknown> {
export interface Formatter<O = unknown, CO = unknown> {
name: string;
config: ModelRequest['output'];
handler: (req: ModelRequest) => {
parseResponse(response: GenerateResponse): O;
parseChunk?: (
chunk: GenerateResponseChunk,
cursor?: CC
) => ParsedChunk<CO, CC>;
parseChunk?: (chunk: GenerateResponseChunk, cursor?: CC) => CO;
instructions?: string | Part[];
};
}
6 changes: 4 additions & 2 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ async function generate(
streamingCallback
? (chunk: GenerateResponseChunkData) => {
// Store accumulated chunk data
accumulatedChunks.push(chunk);
if (streamingCallback) {
streamingCallback!(
new GenerateResponseChunk(chunk, accumulatedChunks)
new GenerateResponseChunk(chunk, {
previousChunks: accumulatedChunks,
})
);
}
accumulatedChunks.push(chunk);
}
: undefined,
async () => {
Expand Down
60 changes: 43 additions & 17 deletions js/ai/src/generate/chunk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,50 @@

import { GenkitError } from '@genkit-ai/core';
import { extractJson } from '../extract.js';
import { GenerateResponseChunkData, Part, ToolRequestPart } from '../model.js';
import {
GenerateResponseChunkData,
Part,
Role,
ToolRequestPart,
} from '../model.js';

export interface ChunkParser<T = unknown> {
(chunk: GenerateResponseChunk<T>): T;
}

export class GenerateResponseChunk<T = unknown>
implements GenerateResponseChunkData
{
/** The index of the candidate this chunk corresponds to. */
/** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */
index?: number;
/** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */
role?: Role;
/** The content generated in this chunk. */
content: Part[];
/** Custom model-specific data for this chunk. */
custom?: unknown;
/** Accumulated chunks for partial output extraction. */
accumulatedChunks?: GenerateResponseChunkData[];
previousChunks?: GenerateResponseChunkData[];
/** The parser to be used to parse `output` from this chunk. */
parser?: ChunkParser<T>;

constructor(
data: GenerateResponseChunkData,
accumulatedChunks?: GenerateResponseChunkData[]
options?: {
previousChunks?: GenerateResponseChunkData[];
role?: Role;
index?: number;
parser?: ChunkParser<T>;
}
) {
this.index = data.index;
this.content = data.content || [];
this.custom = data.custom;
this.accumulatedChunks = accumulatedChunks;
this.previousChunks = options?.previousChunks
? [...options.previousChunks]
: undefined;
this.index = options?.index;
this.role = options?.role;
this.parser = options?.parser;
}

/**
Expand All @@ -53,13 +75,20 @@ export class GenerateResponseChunk<T = unknown>
* @returns A string of all concatenated chunk text content.
*/
get accumulatedText(): string {
if (!this.accumulatedChunks)
return this.previousText + this.text;
}

/**
* Concatenates all `text` parts of all preceding chunks.
*/
get previousText(): string {
if (!this.previousChunks)
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message: 'Cannot compose accumulated text without accumulated chunks.',
message: 'Cannot compose accumulated text without previous chunks.',
});

return this.accumulatedChunks
return this.previousChunks
?.map((c) => c.content.map((p) => p.text || '').join(''))
.join('');
}
Expand Down Expand Up @@ -92,18 +121,15 @@ export class GenerateResponseChunk<T = unknown>
}

/**
* Attempts to extract the longest valid JSON substring from the accumulated chunks.
* @returns The longest valid JSON substring found in the accumulated chunks.
* Parses the chunk into the desired output format using the parser associated
* with the generate request, or falls back to naive JSON parsing otherwise.
*/
get output(): T | null {
if (!this.accumulatedChunks) return null;
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractJson<T>(accumulatedText, false);
if (this.parser) return this.parser(this);
return this.data || extractJson(this.accumulatedText);
}

toJSON(): GenerateResponseChunkData {
return { index: this.index, content: this.content, custom: this.custom };
return { content: this.content, custom: this.custom };
}
}
8 changes: 3 additions & 5 deletions js/ai/tests/formats/array_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,13 @@ describe('arrayFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
lastCursor
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
7 changes: 3 additions & 4 deletions js/ai/tests/formats/json_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,14 @@ describe('jsonFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] }),
lastCursor
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
9 changes: 3 additions & 6 deletions js/ai/tests/formats/jsonl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,18 @@ describe('jsonlFormat', () => {
it(st.desc, () => {
const parser = jsonlFormatter.handler({ messages: [] });
const chunks: GenerateResponseChunkData[] = [];
let lastCursor = 0;

for (const chunk of st.chunks) {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks),
lastCursor
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.deepStrictEqual(result.output, chunk.want);
lastCursor = result.cursor!;
assert.deepStrictEqual(result, chunk.want);
}
});
}
Expand Down
6 changes: 3 additions & 3 deletions js/ai/tests/formats/text_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ describe('textFormat', () => {
const newChunk: GenerateResponseChunkData = {
content: [{ text: chunk.text }],
};
chunks.push(newChunk);

const result = parser.parseChunk!(
new GenerateResponseChunk(newChunk, chunks)
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
);
chunks.push(newChunk);

assert.strictEqual(result.output, chunk.want);
assert.strictEqual(result, chunk.want);
}
});
}
Expand Down
Loading
Loading