diff --git a/readme.md b/readme.md index b923f8a..aa3f7c1 100644 --- a/readme.md +++ b/readme.md @@ -49,6 +49,19 @@ client.createCompletion(params: CompletionParams): Promise<{ }> ``` +To get a streaming response, use the `streamCompletion` method. + +```ts +client.streamCompletion(params: CompletionParams): Promise< + ReadableStream<{ + /** The completion string. */ + completion: string; + /** The raw response from the API. */ + response: CompletionResponse; + }> + > +``` + ### Create Chat Completion See: [OpenAI docs](https://beta.openai.com/docs/api-reference/chat) | [Type definitions](/src/schemas/chat-completion.ts) diff --git a/src/fetch-api.ts b/src/fetch-api.ts index 086df4c..c1344a4 100644 --- a/src/fetch-api.ts +++ b/src/fetch-api.ts @@ -4,7 +4,7 @@ import { OpenAIApiError } from './errors'; const DEFAULT_BASE_URL = 'https://api.openai.com/v1'; -export interface FetchOptions extends Options { +export interface FetchOptions extends Omit { credentials?: string; } diff --git a/src/openai-client.ts b/src/openai-client.ts index 13dc964..4f45e69 100644 --- a/src/openai-client.ts +++ b/src/openai-client.ts @@ -12,10 +12,10 @@ import type { FetchOptions } from './fetch-api'; import type { ChatCompletionParams, ChatCompletionResponse, - ChatResponseMessage} from './schemas/chat-completion'; -import { - ChatCompletionParamsSchema + ChatResponseMessage, } from './schemas/chat-completion'; +import { ChatCompletionParamsSchema } from './schemas/chat-completion'; +import { StreamCompletionChunker } from './streaming'; export type ConfigOpts = { /** @@ -91,6 +91,49 @@ export class OpenAIClient { return { completion, response }; } + /** + * Create a completion for a single prompt string and stream back partial progress. + * @param params typipcal standard OpenAI completion parameters + * @returns A stream of completion chunks. + * + * @example + * + * ```ts + * const client = new OpenAIClient(process.env.OPENAI_API_KEY); + * const stream = await client.streamCompletion({ + * model: "text-davinci-003", + * prompt: "Give me some lyrics, make it up.", + * max_tokens: 256, + * temperature: 0, + * }); + * + * for await (const chunk of stream) { + * process.stdout.write(chunk.completion); + * } + * ``` + */ + async streamCompletion(params: CompletionParams): Promise< + ReadableStream<{ + /** The completion string. */ + completion: string; + /** The raw response from the API. */ + response: CompletionResponse; + }> + > { + const reqBody = CompletionParamsSchema.parse(params); + const response = await this.api.post('completions', { + json: { ...reqBody, stream: true }, + onDownloadProgress: () => {}, // trick ky to return ReadableStream. + }); + const stream = response.body as ReadableStream; + return stream.pipeThrough( + new StreamCompletionChunker((response: CompletionResponse) => { + const completion = response.choices[0].text || ''; + return { completion, response }; + }) + ); + } + /** * Create a completion for a chat message. */ @@ -111,6 +154,31 @@ export class OpenAIClient { return { message, response }; } + async streamChatCompletion(params: ChatCompletionParams): Promise< + ReadableStream<{ + /** The completion message. */ + message: ChatResponseMessage; + /** The raw response from the API. */ + response: ChatCompletionResponse; + }> + > { + const reqBody = ChatCompletionParamsSchema.parse(params); + const response = await this.api.post('chat/completions', { + json: { ...reqBody, stream: true }, + onDownloadProgress: () => {}, // trick ky to return ReadableStream. + }); + const stream = response.body as ReadableStream; + return stream.pipeThrough( + new StreamCompletionChunker((response: ChatCompletionResponse) => { + const message = response.choices[0].delta || { + role: 'assistant', + content: '', + }; + return { message, response }; + }) + ); + } + /** * Create an edit for a single input string. */ diff --git a/src/schemas/chat-completion.ts b/src/schemas/chat-completion.ts index 43517c1..dbae332 100644 --- a/src/schemas/chat-completion.ts +++ b/src/schemas/chat-completion.ts @@ -76,4 +76,6 @@ export type ChatCompletionResponseChoices = { index?: number; finish_reason?: string; message?: ChatResponseMessage; + /** Used instead of `message` when streaming */ + delta?: ChatResponseMessage; }[]; diff --git a/src/streaming.ts b/src/streaming.ts new file mode 100644 index 0000000..7eaf023 --- /dev/null +++ b/src/streaming.ts @@ -0,0 +1,71 @@ +/** A function that converts from raw Completion response from OpenAI + * into a nicer object which includes the first choice in response from OpenAI. + */ +type ResponseFactory = (response: Raw) => Nice; + +/** + * A parser for the streaming responses from the OpenAI API. + * + * Conveniently shaped like an argument for WritableStream constructor. + */ +class OpenAIStreamParser { + private responseFactory: ResponseFactory; + onchunk?: (chunk: Nice) => void; + onend?: () => void; + + constructor(responseFactory: ResponseFactory) { + this.responseFactory = responseFactory; + } + + /** + * Takes the ReadableStream chunks, produced by `fetch` and turns them into + * `CompletionResponse` objects. + * @param chunk The chunk of data from the stream. + */ + write(chunk: Uint8Array): void { + const decoder = new TextDecoder(); + const s = decoder.decode(chunk); + s.split('\n') + .map((line) => line.trim()) + .filter((line) => line.length > 0) + .forEach((line) => { + const pos = line.indexOf(':'); + const name = line.substring(0, pos); + if (name !== 'data') return; + const content = line.substring(pos + 1).trim(); + if (content.length == 0) return; + if (content === '[DONE]') { + this.onend?.(); + return; + } + try { + const parsed = JSON.parse(content); + this.onchunk?.(this.responseFactory(parsed)); + } catch (e) { + console.error('Failed parsing streamed JSON chunk', e); + } + }); + } +} + +/** + * A transform stream that takes the streaming responses from the OpenAI API + * and turns them into useful response objects. + */ +export class StreamCompletionChunker + implements TransformStream +{ + writable: WritableStream; + readable: ReadableStream; + + constructor(responseFactory: ResponseFactory) { + const parser = new OpenAIStreamParser(responseFactory); + this.writable = new WritableStream(parser); + this.readable = new ReadableStream({ + start(controller) { + parser.onchunk = (chunk: Nice) => controller.enqueue(chunk); + parser.onend = () => controller.close(); + }, + }); + } +} diff --git a/tsconfig.json b/tsconfig.json index 37d6397..040a9bc 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -10,7 +10,7 @@ "esModuleInterop": true, "forceConsistentCasingInFileNames": true, "isolatedModules": true, - "lib": ["es2021"], + "lib": ["es2021", "DOM"], "module": "commonjs", "moduleResolution": "node", "outDir": "dist",