-
Notifications
You must be signed in to change notification settings - Fork 76
/
constrained-output.tsx
355 lines (334 loc) · 11.8 KB
/
constrained-output.tsx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
/**
* This module defines affordances for constraining the output of the model
* into specific formats, such as JSON, YAML, or Markdown.
* @packageDocumentation
*/
import yaml from 'js-yaml';
import { Jsonifiable } from 'type-fest';
import z from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import {
AssistantMessage,
ChatCompletion,
ModelPropsWithChildren,
SystemMessage,
UserMessage,
} from '../core/completion.js';
import { AIJSXError, ErrorBlame, ErrorCode } from '../core/errors.js';
import * as AI from '../index.js';
import { patchedUntruncateJson } from '../lib/util.js';
export type ObjectCompletion = ModelPropsWithChildren & {
/** Validators are used to ensure that the final object looks as expected. */
validators?: ((obj: object) => void)[];
/**
* The object Schema that is required. This is a type of validator with the exception
* that the schema description is also provided to the model.
*
* @note To match OpenAI function definition specs, the schema must be a Zod object.
* Arrays and other types should be wrapped in a top-level object in order to be used.
*
* For example, to describe a list of strings, the following is not accepted:
* `const schema: z.Schema = z.array(z.string())`
*
* Instead, you can wrap it in an object like so:
* `const schema: z.ZodObject = z.object({ arr: z.array(z.string()) })`
*/
schema?: z.ZodObject<any>;
/** Any output example to be shown to the model. */
example?: string;
// TODO (@farzad): better name/framing for example.
};
export type TypedObjectCompletion = ObjectCompletion & {
/** Human-readable name of the type, e.g. JSON or YAML. */
typeName: string;
/**
* Object parser: creates the object to be evaluated given the string.
* Note: the parser is only used for validation. The final output is still a string.
*/
parser: (str: string) => object;
/**
* The intermediate results that get yielded might be incomplete and as such
* a cleaning/healing/untruncating step might be required.
* For example, see the `untruncate-json` package.
*/
partialResultCleaner?: (str: string) => string;
};
export type TypedObjectCompletionWithRetry = TypedObjectCompletion & { retries?: number };
/**
* A {@link ChatCompletion} component that constrains the output to be a valid JSON string.
* It uses a combination of prompt engineering and validation with retries to ensure that the output is valid.
*
* Though not required, you can provide a Zod schema to validate the output against. This is useful for
* ensuring that the output is of the expected type.
*
* @example
* ```tsx
* const FamilyTree: z.Schema = z.array(
* z.object({
* name: z.string(),
* children: z.lazy(() => FamilyTree).optional(),
* })
* );
*
* const RootFamilyTree: z.ZodObject<any> = z.object({
* tree: FamilyTree,
* });
*
* return (
* <JsonChatCompletion schema={RootFamilyTree}>
* <UserMessage>
* Create a nested family tree with names and ages.
* It should include a total of 5 people.
* </UserMessage>
* </JsonChatCompletion>
* );
* ```
* @returns A string that is a valid JSON or throws an error after `retries` attempts.
* Intermediate results that are valid, are also yielded.
*/
export async function* JsonChatCompletion(
{ schema, ...props }: Omit<TypedObjectCompletionWithRetry, 'typeName' | 'parser' | 'partialResultCleaner'>,
{ render }: AI.ComponentContext
) {
try {
return yield* render(<JsonChatCompletionFunctionCall schema={schema ?? z.object({}).nonstrict()} {...props} />);
} catch (e: any) {
if (e.code !== ErrorCode.ChatModelDoesNotSupportFunctions) {
throw e;
}
}
return yield* render(
<ObjectCompletionWithRetry
{...props}
typeName="JSON"
parser={JSON.parse}
partialResultCleaner={patchedUntruncateJson}
/>
);
}
/**
* A {@link ChatCompletion} component that constrains the output to be a valid YAML string.
* It uses a combination of prompt engineering and validation with retries to ensure that the output is valid.
*
* Though not required, you can provide a Zod schema to validate the output against. This is useful for
* ensuring that the output is of the expected type.
*
* @example
* ```tsx
* const FamilyTree: z.Schema = z.array(
* z.object({
* name: z.string(),
* children: z.lazy(() => FamilyTree).optional(),
* })
* );
* const RootFamilyTree: z.ZodObject<any> = z.object({
* tree: FamilyTree,
* });
*
* return (
* <YamlChatCompletion schema={RootFamilyTree}>
* <UserMessage>
* Create a nested family tree with names and ages.
* It should include a total of 5 people.
* </UserMessage>
* </YamlChatCompletion>
* );
* ```
* @returns A string that is a valid YAML or throws an error after `retries` attempts.
* Intermediate results that are valid, are also yielded.
*/
export async function* YamlChatCompletion(
props: Omit<TypedObjectCompletionWithRetry, 'typeName' | 'parser'>,
{ render }: AI.ComponentContext
) {
return yield* render(
<ObjectCompletionWithRetry {...props} typeName="YAML" parser={yaml.load as (str: string) => object} />
);
}
export class CompletionError extends AIJSXError {
constructor(
message: string,
public readonly blame: ErrorBlame,
public readonly metadata: Jsonifiable & { output: string; validationError: string }
) {
super(message, ErrorCode.ModelOutputDidNotMatchConstraint, blame, metadata);
}
}
/**
* A {@link ChatCompletion} component that constrains the output to be a valid object format (e.g. JSON/YAML).
*
* Though not required, you can provide a Zod schema to validate the output against. This is useful for
* ensuring that the output is of the expected type.
*
* @returns A string that validates as the given type or throws an error.
* Intermediate results that are valid, are also yielded.
*/
async function* OneShotObjectCompletion(
{ children, typeName, validators, example, schema, parser, partialResultCleaner, ...props }: TypedObjectCompletion,
{ render }: AI.ComponentContext
) {
// If a schema is provided, it is added to the list of validators as well as the prompt.
const validatorsAndSchema = schema ? [schema.parse, ...(validators ?? [])] : validators ?? [];
const childrenWithCompletion = (
<ChatCompletion {...props}>
{children}
<SystemMessage>
Respond with a {typeName} object that encodes your response.
{schema
? `The ${typeName} object should match this JSON Schema: ${JSON.stringify(zodToJsonSchema(schema))}\n`
: ''}
{example ? `For example: ${example}\n` : ''}
Respond with only the {typeName} object. Do not include any explanatory prose. Do not include ```
{typeName.toLowerCase()} ``` code blocks.
</SystemMessage>
</ChatCompletion>
);
const renderGenerator = render(childrenWithCompletion)[Symbol.asyncIterator]();
let lastYieldedLen = 0;
while (true) {
const partial = await renderGenerator.next();
const str = partialResultCleaner ? partialResultCleaner(partial.value) : partial.value;
try {
const object = parser(str);
for (const validator of validatorsAndSchema) {
validator(object);
}
} catch (e: any) {
if (partial.done) {
throw new CompletionError(`The model did not produce a valid ${typeName} object`, 'runtime', {
typeName,
output: partial.value,
validationError: e.message,
});
}
continue;
}
if (partial.done) {
return str;
}
if (str.length > lastYieldedLen) {
lastYieldedLen = str.length;
yield str;
}
}
// TODO: return an AIJSXError instead? The issue is I want to get the original (unmodified) error message somehow.
}
/**
* A {@link ChatCompletion} component that constrains the output to be a valid object format (e.g. JSON/YAML).
* If the first attempt fails, it will retry with a new prompt up to `retries` times.
*
* @returns A string that validates as the given type or throws an error after `retries` attempts
* Intermediate results that are valid, are also yielded.
*/
async function* ObjectCompletionWithRetry(
{ children, retries = 3, ...props }: TypedObjectCompletionWithRetry,
{ render, logger }: AI.ComponentContext
) {
const childrenWithCompletion = <OneShotObjectCompletion {...props}>{children}</OneShotObjectCompletion>;
let output;
let validationError: string;
try {
output = yield* render(childrenWithCompletion);
return output;
} catch (e: any) {
validationError = e.metadata.validationError;
output = e.metadata.output;
}
logger.debug({ atempt: 1, expectedFormat: props.typeName, output }, `Output did not validate to ${props.typeName}.`);
for (let retryIndex = 1; retryIndex < retries; retryIndex++) {
const completionRetry = (
<OneShotObjectCompletion {...props}>
<SystemMessage>
You are a {props.typeName} object generator. Create a {props.typeName} object (context redacted).
</SystemMessage>
<AssistantMessage>{output}</AssistantMessage>
<UserMessage>
Try again. Here's the validation error when trying to parse the output as {props.typeName}:{'\n'}
```log filename="error.log"{'\n'}
{validationError}
{'\n```\n'}
You must reformat your previous output to be a valid {props.typeName} object, but you must keep the same data.
</UserMessage>
</OneShotObjectCompletion>
);
try {
output = yield* render(completionRetry);
return output;
} catch (e: any) {
validationError = e.metadata.validationError;
output = e.metadata.output;
}
logger.debug(
{ attempt: retryIndex + 1, expectedFormat: props.typeName, output },
`Output did not validate to ${props.typeName}.`
);
}
throw new CompletionError(
`The model did not produce a valid ${props.typeName} object, even after ${retries} attempts.`,
'runtime',
{
typeName: props.typeName,
retries,
output,
validationError,
}
);
}
/**
* A {@link ChatCompletion} component that constrains the output to be a valid JSON string.
* It (ab)uses OpenAI function calls to generate the JSON string.
*
* @returns A string that is a valid JSON or throws an error.
*
* @hidden
*/
export async function* JsonChatCompletionFunctionCall(
{ schema, validators, children, ...props }: ObjectCompletion,
{ render }: AI.ComponentContext
) {
// If a schema is provided, it is added to the list of validators as well as the prompt.
const validatorsAndSchema = schema ? [schema.parse, ...(validators ?? [])] : validators ?? [];
const childrenWithCompletion = (
<ChatCompletion
experimental_streamFunctionCallOnly
{...props}
functionDefinitions={{
print: {
description: 'Prints the response in a human readable format.',
parameters: schema,
},
}}
forcedFunction="print"
>
{children}
<SystemMessage>
Your response must use the `print` function that is provided. No other explanation needed. do not respond with
an assistant message. Just call the function.
</SystemMessage>
</ChatCompletion>
);
const frames = render(childrenWithCompletion);
for await (const frame of frames) {
const object = JSON.parse(frame).arguments;
try {
for (const validator of validatorsAndSchema) {
validator(object);
}
} catch (e: any) {
continue;
}
yield JSON.stringify(object);
}
const object = JSON.parse(await frames).arguments;
try {
for (const validator of validatorsAndSchema) {
validator(object);
}
} catch (e: any) {
throw new CompletionError('The model did not produce a valid JSON object', 'runtime', {
output: JSON.stringify(object),
validationError: e.message,
});
}
return JSON.stringify(object);
}