diff --git a/packages/ai-jsx/package.json b/packages/ai-jsx/package.json index 95bdca43e..907ff29d1 100644 --- a/packages/ai-jsx/package.json +++ b/packages/ai-jsx/package.json @@ -4,7 +4,7 @@ "repository": "fixie-ai/ai-jsx", "bugs": "https://github.com/fixie-ai/ai-jsx/issues", "homepage": "https://ai-jsx.com", - "version": "0.9.2", + "version": "0.10.0", "volta": { "extends": "../../package.json" }, diff --git a/packages/ai-jsx/src/core/conversation.tsx b/packages/ai-jsx/src/core/conversation.tsx index 63eab7fa2..59347c0b9 100644 --- a/packages/ai-jsx/src/core/conversation.tsx +++ b/packages/ai-jsx/src/core/conversation.tsx @@ -279,7 +279,7 @@ export async function* Converse( yield AI.AppendOnlyStream; const fullConversation = [] as ConversationMessage[]; - let next = memo(children); + let next = children; while (true) { const newMessages = await renderToConversation(next, render, logger); if (newMessages.length === 0) { @@ -319,7 +319,7 @@ export async function* ShowConversation( present?: (message: ConversationMessage) => AI.Node; onComplete?: (conversation: ConversationMessage[], render: AI.RenderContext['render']) => Promise | void; }, - { render, isAppendOnlyRender, memo }: AI.ComponentContext + { render, isAppendOnlyRender }: AI.ComponentContext ): AI.RenderableStream { // If we're in an append-only render, do the transformation in an append-only manner so as not to block. if (isAppendOnlyRender) { @@ -341,8 +341,7 @@ export async function* ShowConversation( return toConversationMessages(frame).map(present ?? ((m) => m.element)); } - // Memoize before rendering so that the all the conversational components get memoized as well. - const finalFrame = yield* render(memo(children), { + const finalFrame = yield* render(children, { map: handleFrame, stop: isConversationalComponent, appendOnly: isAppendOnlyRender, @@ -399,7 +398,7 @@ export async function ShrinkConversation( budget: number; children: Node; }, - { render, memo, logger }: AI.ComponentContext + { render, logger }: AI.ComponentContext ) { /** * We construct a tree of immutable and shrinkable nodes such that shrinkable nodes @@ -513,9 +512,7 @@ export async function ShrinkConversation( return roots.map((root) => (root.type === 'immutable' ? root.element : treeRootsToNode(root.children))); } - const memoized = memo(children); - - const rendered = await render(memoized, { + const rendered = await render(children, { stop: (e) => isConversationalComponent(e) || e.tag === InternalShrinkable, }); diff --git a/packages/ai-jsx/src/core/memoize.tsx b/packages/ai-jsx/src/core/memoize.tsx index d06cb0b53..3ab19df4c 100644 --- a/packages/ai-jsx/src/core/memoize.tsx +++ b/packages/ai-jsx/src/core/memoize.tsx @@ -1,9 +1,17 @@ -import { Renderable, RenderContext, AppendOnlyStream, RenderableStream } from './render.js'; -import { Node, getReferencedNode, isIndirectNode, makeIndirectNode, isElement } from './node.js'; +import { + Renderable, + RenderContext, + AppendOnlyStream, + RenderableStream, + AppendOnlyStreamValue, + isAppendOnlyStreamValue, + valueToAppend, +} from './render.js'; +import { Node, Element, getReferencedNode, isIndirectNode, makeIndirectNode, isElement } from './node.js'; import { Logger } from './log.js'; import { bindAsyncGeneratorToActiveContext } from './opentelemetry.js'; +import _ from 'lodash'; -let lastMemoizedId = 0; /** @hidden */ export const memoizedIdSymbol = Symbol('memoizedId'); @@ -12,10 +20,10 @@ export const memoizedIdSymbol = Symbol('memoizedId'); * "Partially" memoizes a renderable such that it will only be rendered once in any * single `RenderContext`. */ -export function partialMemo(node: Node, existingId?: number): Node; -export function partialMemo(renderable: Renderable, existingId?: number): Renderable; -export function partialMemo(renderable: Renderable, existingId?: number): Node | Renderable { - const id = existingId ?? ++lastMemoizedId; +export function partialMemo(element: Element, id: number): Element; +export function partialMemo(node: Node, id: number): Node; +export function partialMemo(renderable: Renderable, id: number): Renderable; +export function partialMemo(renderable: Renderable, id: number): Node | Renderable { if (typeof renderable !== 'object' || renderable === null) { return renderable; } @@ -76,32 +84,54 @@ export function partialMemo(renderable: Renderable, existingId?: number): Node | // N.B. Async context doesn't get bound to the generator, so we need to do that manually. const generator = bindAsyncGeneratorToActiveContext(unboundGenerator); - const sink: (Renderable | typeof AppendOnlyStream)[] = []; - let finalResult: Renderable | typeof AppendOnlyStream = null; + const sink: (Node | AppendOnlyStreamValue)[] = []; + let completed = false; let nextPromise: Promise | null = null; return { [memoizedIdSymbol]: id, - async *[Symbol.asyncIterator](): AsyncGenerator< - Renderable | typeof AppendOnlyStream, - Renderable | typeof AppendOnlyStream - > { + async *[Symbol.asyncIterator](): AsyncGenerator { let index = 0; + let isAppendOnly = false; + while (true) { if (index < sink.length) { - yield sink[index++]; + // There's something we can yield/return right away. + let concatenatedNodes = [] as Node[]; + while (index < sink.length) { + let value = sink[index++]; + if (isAppendOnlyStreamValue(value)) { + isAppendOnly = true; + value = valueToAppend(value); + } + + if (isAppendOnly) { + concatenatedNodes.push(value); + } else { + // In case the stream changes to append-only, reset the concatenated nodes. + concatenatedNodes = [value]; + } + } + + const valueToYield = isAppendOnly ? AppendOnlyStream(concatenatedNodes) : _.last(sink); + if (completed) { + return valueToYield; + } + + yield valueToYield; continue; - } else if (completed) { - return finalResult; - } else if (nextPromise == null) { + } + + if (nextPromise == null) { nextPromise = generator.next().then((result) => { - const memoized = result.value === AppendOnlyStream ? result.value : partialMemo(result.value, id); + const memoized = isAppendOnlyStreamValue(result.value) + ? AppendOnlyStream(partialMemo(valueToAppend(result.value), id)) + : partialMemo(result.value, id); + + sink.push(memoized); if (result.done) { completed = true; - finalResult = memoized; - } else { - sink.push(memoized); } nextPromise = null; }); diff --git a/packages/ai-jsx/src/core/render.ts b/packages/ai-jsx/src/core/render.ts index c7e4e2e1c..d0a531dd5 100644 --- a/packages/ai-jsx/src/core/render.ts +++ b/packages/ai-jsx/src/core/render.ts @@ -33,20 +33,33 @@ import { import { openTelemetryStreamRenderer } from './opentelemetry.js'; import { getEnvVar } from '../lib/util.js'; +const appendOnlyStreamSymbol = Symbol('AI.appendOnlyStream'); + /** * A value that can be yielded by a component to indicate that each yielded value should * be appended to, rather than replace, the previously yielded values. */ -export const AppendOnlyStream = Symbol('AI.appendOnlyStream'); +export function AppendOnlyStream(node?: Node) { + return { [appendOnlyStreamSymbol]: node }; +} + +/** @hidden */ +export type AppendOnlyStreamValue = typeof AppendOnlyStream | { [appendOnlyStreamSymbol]: Node }; + +/** @hidden */ +export function isAppendOnlyStreamValue(value: unknown): value is AppendOnlyStreamValue { + return value === AppendOnlyStream || (typeof value === 'object' && value !== null && appendOnlyStreamSymbol in value); +} +/** @hidden */ +export function valueToAppend(value: AppendOnlyStreamValue): Node { + return typeof value === 'object' ? value[appendOnlyStreamSymbol] : undefined; +} /** * A RenderableStream represents an async iterable that yields {@link Renderable}s. */ export interface RenderableStream { - [Symbol.asyncIterator]: () => AsyncGenerator< - Renderable | typeof AppendOnlyStream, - Renderable | typeof AppendOnlyStream - >; + [Symbol.asyncIterator]: () => AsyncGenerator; } /** @@ -173,6 +186,7 @@ export interface RenderContext { * * The memoization is fully recursive. */ + memo(element: Element): Element; memo(renderable: Renderable): Node; /** @@ -295,13 +309,8 @@ async function* renderStream( } if (isElement(renderable)) { if (shouldStop(renderable)) { - // If the renderable already has a context bound to it, leave it as-is because that context would've - // taken precedence over the current one. But, if it does _not_ have a bound context, we bind - // the current context so that if/when it is rendered, rendering will "continue on" as-is. - if (!attachedContext(renderable)) { - return [withContext(renderable, context)]; - } - return [renderable]; + // Don't render it, but memoize it so that rendering picks up where we left off. + return [context.memo(renderable)]; } const renderingContext = attachedContext(renderable) ?? context; if (renderingContext !== context) { @@ -338,12 +347,16 @@ async function* renderStream( let isAppendOnlyStream = false; while (true) { const next = await iterator.next(); - if (next.value === AppendOnlyStream) { + let valueToRender = next.value; + if (isAppendOnlyStreamValue(valueToRender)) { // TODO: I'd like to emit a log here indicating that an element has chosen to AppendOnlyStream, // but I'm not sure what the best way is to know which element/renderId produced `renderable`. isAppendOnlyStream = true; - } else if (isAppendOnlyStream) { - const renderResult = context.render(next.value, recursiveRenderOpts); + valueToRender = valueToAppend(valueToRender); + } + + if (isAppendOnlyStream) { + const renderResult = context.render(valueToRender, recursiveRenderOpts); for await (const frame of renderResult) { yield lastValue.concat(frame); } @@ -352,9 +365,9 @@ async function* renderStream( // Subsequently yielded values might not be append-only, so we can't yield them. (But // if this iterator is `done` then we rely on the recursive call to decide when it's safe // to yield.) - lastValue = await context.render(next.value, recursiveRenderOpts); + lastValue = await context.render(valueToRender, recursiveRenderOpts); } else { - lastValue = yield* context.render(next.value, recursiveRenderOpts); + lastValue = yield* context.render(valueToRender, recursiveRenderOpts); } if (next.done) { @@ -393,12 +406,20 @@ export function createRenderContext(opts?: { logger?: LogImplementation; enableO renderFn = openTelemetryStreamRenderer(renderFn); logger = new CombinedLogger([logger, new OpenTelemetryLogger()]); } - return createRenderContextInternal(renderFn, { - [LoggerContext[contextKey].userContextSymbol]: logger, - }); + return createRenderContextInternal( + renderFn, + { + [LoggerContext[contextKey].userContextSymbol]: logger, + }, + { id: 0 } + ); } -function createRenderContextInternal(renderStream: StreamRenderer, userContext: Record): RenderContext { +function createRenderContextInternal( + renderStream: StreamRenderer, + userContext: Record, + memoizedIdHolder: { id: number } +): RenderContext { const context: RenderContext = { render: ( renderable: Renderable, @@ -490,15 +511,20 @@ function createRenderContextInternal(renderStream: StreamRenderer, userContext: return defaultValue; }, - memo: (renderable) => withContext(partialMemo(renderable), context), + memo: (renderable: Renderable) => withContext(partialMemo(renderable, ++memoizedIdHolder.id), context) as any, - wrapRender: (getRenderStream) => createRenderContextInternal(getRenderStream(renderStream), userContext), + wrapRender: (getRenderStream) => + createRenderContextInternal(getRenderStream(renderStream), userContext, memoizedIdHolder), [pushContextSymbol]: (contextReference, value) => - createRenderContextInternal(renderStream, { - ...userContext, - [contextReference[contextKey].userContextSymbol]: value, - }), + createRenderContextInternal( + renderStream, + { + ...userContext, + [contextReference[contextKey].userContextSymbol]: value, + }, + memoizedIdHolder + ), }; return context; diff --git a/packages/docs/docs/changelog.md b/packages/docs/docs/changelog.md index 019f68ee8..aa92ff3d5 100644 --- a/packages/docs/docs/changelog.md +++ b/packages/docs/docs/changelog.md @@ -1,6 +1,13 @@ # Changelog -## 0.9.2 +## 0.10.0 + +- Memoized streaming elements no longer replay their entire stream with every render. Instead, they start with the last rendered frame. +- Elements returned by partial rendering are automatically memoized to ensure they only render once. +- Streaming components can no longer yield promises or generators. Only `Node`s or `AI.AppendOnlyStream` values can be yielded. +- The `AI.AppendOnlyStream` value is now a function that can be called with a non-empty value to append. + +## [0.9.2](https://github.com/fixie-ai/ai-jsx/commit/219aebeb5e062bf3470a239443626915e0503ad9) - In the [OpenTelemetry integration](./guides/observability.md#opentelemetry-integration): - Add prompt/completion attributes with token counts for ``. This replaces the `tokenCount` attribute added in 0.9.1. diff --git a/packages/docs/docs/guides/rendering.md b/packages/docs/docs/guides/rendering.md index 14257b08f..042d1847c 100644 --- a/packages/docs/docs/guides/rendering.md +++ b/packages/docs/docs/guides/rendering.md @@ -362,3 +362,5 @@ function MyUserMessages() { ; ``` + +Elements returned by partial rendering will be [memoized](./rules-of-jsx.md#memoization) so that they render only once. diff --git a/packages/docs/docs/guides/rules-of-jsx.md b/packages/docs/docs/guides/rules-of-jsx.md index 211bd60d1..6eb8ec5f7 100644 --- a/packages/docs/docs/guides/rules-of-jsx.md +++ b/packages/docs/docs/guides/rules-of-jsx.md @@ -111,7 +111,7 @@ function* GenerateImage() { AI.JSX will interpret each `yield`ed value as a new value which should totally overwrite the previously-yielded values, so the caller would see a progression of increasingly high-quality images. -However, sometimes your data source will give you deltas, so replacing the previous contents doesn't make much sense. In this case, `yield` the [`AppendOnlyStream`](../api/modules/core_render.md#appendonlystream) symbol to indicate that `yield`ed results should be interpreted as deltas: +However, sometimes your data source will give you deltas, so replacing the previous contents doesn't make much sense. In this case, `yield` the [`AppendOnlyStream`](../api/modules/core_render.md#appendonlystream) value to indicate that `yield`ed results should be interpreted as deltas: ```tsx import * as AI from 'ai-jsx'; @@ -227,6 +227,12 @@ const catName = memo( Now, `catName` will result in a single model call, and its value will be reused everywhere that component appears in the tree. +:::note Memoized Streams + +If a streaming element is memoized, rendering will start with the last rendered frame rather than replaying every frame. + +::: + # See Also - [Rendering](./rendering.md) diff --git a/packages/examples/test/core/completion.tsx b/packages/examples/test/core/completion.tsx index 5c1b7c368..1b477a162 100644 --- a/packages/examples/test/core/completion.tsx +++ b/packages/examples/test/core/completion.tsx @@ -75,8 +75,6 @@ describe('OpenTelemetry', () => { const spans = memoryExporter.getFinishedSpans(); const minimalSpans = _.map(spans, 'attributes'); - // Unfortunately, the @memoizedId will be sensitive how many tests in this file ran before it. - // To avoid issues with that, we put this test first. expect(minimalSpans).toMatchInlineSnapshot(` [ { @@ -106,7 +104,7 @@ describe('OpenTelemetry', () => { "ai.jsx.tree": ""opentel response from OpenAI"", }, { - "ai.jsx.completion": "[{"element":"\\n {\\"opentel response from OpenAI\\"}\\n","cost":10}]", + "ai.jsx.completion": "[{"element":"\\n {\\"opentel response from OpenAI\\"}\\n","cost":10}]", "ai.jsx.prompt": "[{"element":"\\n {\\"hello\\"}\\n","cost":4}]", "ai.jsx.result": "opentel response from OpenAI", "ai.jsx.tag": "OpenAIChatModel", @@ -140,25 +138,25 @@ describe('OpenTelemetry', () => { expect(_.map(memoryExporter.getFinishedSpans(), 'attributes')).toMatchInlineSnapshot(` [ { - "ai.jsx.result": "[ + "ai.jsx.result": "[ {"hello"} ]", "ai.jsx.tag": "UserMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"hello"} ", }, { - "ai.jsx.result": "[ + "ai.jsx.result": "[ {"hello"} ]", "ai.jsx.tag": "UserMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"hello"} ", }, { - "ai.jsx.result": "[ + "ai.jsx.result": "[ {"hello"} ]", "ai.jsx.tag": "ShrinkConversation", @@ -171,14 +169,14 @@ describe('OpenTelemetry', () => { { "ai.jsx.result": "hello", "ai.jsx.tag": "UserMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"hello"} ", }, { "ai.jsx.result": "hello", "ai.jsx.tag": "UserMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"hello"} ", }, @@ -190,7 +188,7 @@ describe('OpenTelemetry', () => { { "ai.jsx.result": "opentel response from OpenAI", "ai.jsx.tag": "AssistantMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"▮"} ", }, @@ -202,16 +200,16 @@ describe('OpenTelemetry', () => { { "ai.jsx.result": "opentel response from OpenAI", "ai.jsx.tag": "AssistantMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"opentel response from OpenAI"} ", }, { - "ai.jsx.result": "[ + "ai.jsx.result": "[ {"opentel response from OpenAI"} ]", "ai.jsx.tag": "AssistantMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"opentel response from OpenAI"} ", }, @@ -223,13 +221,13 @@ describe('OpenTelemetry', () => { { "ai.jsx.result": "opentel response from OpenAI", "ai.jsx.tag": "AssistantMessage", - "ai.jsx.tree": " + "ai.jsx.tree": " {"opentel response from OpenAI"} ", }, { - "ai.jsx.completion": "[{"element":"\\n {\\"opentel response from OpenAI\\"}\\n","cost":10}]", - "ai.jsx.prompt": "[{"element":"\\n {\\"hello\\"}\\n","cost":4}]", + "ai.jsx.completion": "[{"element":"\\n {\\"opentel response from OpenAI\\"}\\n","cost":10}]", + "ai.jsx.prompt": "[{"element":"\\n {\\"hello\\"}\\n","cost":4}]", "ai.jsx.result": "opentel response from OpenAI", "ai.jsx.tag": "OpenAIChatModel", "ai.jsx.tree": " diff --git a/packages/examples/test/core/memoize.tsx b/packages/examples/test/core/memoize.tsx new file mode 100644 index 000000000..a3298ef6b --- /dev/null +++ b/packages/examples/test/core/memoize.tsx @@ -0,0 +1,186 @@ +import * as AI from 'ai-jsx'; + +it('ensures that elements are only rendered once', async () => { + let didRender = false; + function Component() { + if (didRender) { + return 'FAIL'; + } + + didRender = true; + return 'PASS'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + expect(await ctx.render(element)).toBe('PASS'); + expect(await ctx.render(element)).toBe('PASS'); +}); + +it('works with nested components', async () => { + let didRender = false; + function Component() { + if (didRender) { + return 'FAIL'; + } + + didRender = true; + return 'PASS'; + } + + function Parent() { + return ( + <> + + + ); + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + expect(await ctx.render(element)).toBe('PASS'); + expect(await ctx.render(element)).toBe('PASS'); +}); + +it('works with nested/async components', async () => { + let didRender = false; + function Component() { + if (didRender) { + return 'FAIL'; + } + + didRender = true; + return 'PASS'; + } + + function AsyncParent() { + return Promise.resolve( + <> + + + ); + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + expect(await ctx.render(element)).toBe('PASS'); + expect(await ctx.render(element)).toBe('PASS'); +}); + +it('works for streams', async () => { + async function* Component() { + yield 3; + yield 2; + yield 1; + return 'LIFTOFF'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + + const frames = [] as string[]; + const renderResult = ctx.render(element); + for await (const frame of renderResult) { + frames.push(frame); + } + expect(frames).toEqual(['3', '2', '1']); + expect(await renderResult).toBe('LIFTOFF'); + expect(await ctx.render(element)).toBe('LIFTOFF'); +}); + +it('works for append-only streams', async () => { + async function* Component() { + yield AI.AppendOnlyStream; + yield 3; + yield 2; + yield 1; + return 'LIFTOFF'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + + const frames = [] as string[]; + const renderResult = ctx.render(element); + for await (const frame of renderResult) { + frames.push(frame); + } + expect(frames).toEqual(['', '3', '32', '321']); + expect(await renderResult).toEqual('321LIFTOFF'); + expect(await ctx.render(element)).toBe('321LIFTOFF'); +}); + +it('works for streams that become append-only', async () => { + async function* Component() { + yield 4; + yield 3; + yield AI.AppendOnlyStream; + yield 2; + yield 1; + return 'LIFTOFF'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + + const frames = [] as string[]; + const renderResult = ctx.render(element); + for await (const frame of renderResult) { + frames.push(frame); + } + expect(frames).toEqual(['4', '3', '3', '32', '321']); + expect(await renderResult).toEqual('321LIFTOFF'); + expect(await ctx.render(element)).toBe('321LIFTOFF'); +}); + +it('works for streams that become append-only using a value', async () => { + async function* Component() { + yield 4; + yield 3; + yield AI.AppendOnlyStream(2); + yield 1; + return 'LIFTOFF'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + + const frames = [] as string[]; + const renderResult = ctx.render(element); + for await (const frame of renderResult) { + frames.push(frame); + } + expect(frames).toEqual(['4', '3', '32', '321']); + expect(await renderResult).toEqual('321LIFTOFF'); + expect(await ctx.render(element)).toBe('321LIFTOFF'); +}); + +it('coalesces frames when there are multiple concurrent renders', async () => { + async function* Component() { + yield AI.AppendOnlyStream; + yield 3; + yield 2; + yield 1; + return 'LIFTOFF'; + } + + const ctx = AI.createRenderContext(); + const element = ctx.memo(); + + const iterator1 = ctx.render(element)[Symbol.asyncIterator](); + const iterator2 = ctx.render(element)[Symbol.asyncIterator](); + + expect((await iterator1.next()).value).toBe(''); + expect((await iterator2.next()).value).toBe(''); + + expect((await iterator1.next()).value).toBe('3'); + expect((await iterator1.next()).value).toBe('32'); + expect((await iterator1.next()).value).toBe('321'); + expect((await iterator2.next()).value).toBe('321'); + + expect(await iterator1.next()).toEqual({ value: '321LIFTOFF', done: true }); + expect(await iterator2.next()).toEqual({ value: '321LIFTOFF', done: true }); + + const iterator3 = ctx.render(element)[Symbol.asyncIterator](); + expect(await iterator3.next()).toEqual({ value: '321LIFTOFF', done: true }); +});