diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 43ffc4664c..4c5de60765 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -175,6 +175,8 @@ export class Dotprompt implements PromptMetadata { jsonSchema: options.output?.jsonSchema || this.output?.jsonSchema, }, tools: (options.tools || []).concat(this.tools || []), + streamingCallback: options.streamingCallback, + returnToolRequests: options.returnToolRequests, }; } diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ab5534ac7f..3c7a66761a 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -90,6 +90,20 @@ describe('Prompt', () => { await invalidSchemaPrompt.render({ input: { foo: 'baz' } }); }, ValidationError); }); + + it('should render with overrided fields', async () => { + const prompt = testPrompt(`Hello {{name}}, how are you?`); + + const streamingCallback = (c) => console.log(c); + + const rendered = await prompt.render({ + input: { name: 'Michael' }, + streamingCallback, + returnToolRequests: true, + }); + assert.strictEqual(rendered.streamingCallback, streamingCallback); + assert.strictEqual(rendered.returnToolRequests, true); + }); }); describe('#generate', () => {