From b9e37470b4158ce01fffbbf1b3cf68d0fdc34ed6 Mon Sep 17 00:00:00 2001 From: Riley Tomasek Date: Sun, 28 May 2023 22:52:21 -0400 Subject: [PATCH] Add bulk embedding and completion methods This is helpful when you run into the requests/second rate limit and is generally more efficient than sending more requests. --- .gitignore | 1 + package.json | 1 + src/openai-client.ts | 60 +++++++++++++++++++++++++++++++++++++-- src/schemas/completion.ts | 9 ++++++ src/schemas/embedding.ts | 17 +++++++++++ yarn.lock | 5 ++++ 6 files changed, 90 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index f06235c..9c97bbd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ node_modules dist +.env diff --git a/package.json b/package.json index 34d3302..8282d84 100644 --- a/package.json +++ b/package.json @@ -30,6 +30,7 @@ }, "devDependencies": { "@types/node": "^18.14.6", + "dotenv": "^16.0.3", "esbuild": "^0.17.11", "eslint": "^8.35.0", "eslint-config-hckrs": "^0.0.3", diff --git a/src/openai-client.ts b/src/openai-client.ts index b9335ca..6b28a34 100644 --- a/src/openai-client.ts +++ b/src/openai-client.ts @@ -1,13 +1,24 @@ import { createApiInstance } from './fetch-api'; -import { CompletionParamsSchema } from './schemas/completion'; +import { + BulkCompletionParamsSchema, + CompletionParamsSchema, +} from './schemas/completion'; import { EditParamsSchema } from './schemas/edit'; -import { EmbeddingParamsSchema } from './schemas/embedding'; +import { + BulkEmbeddingParamsSchema, + EmbeddingParamsSchema, +} from './schemas/embedding'; import type { + BulkCompletionParams, CompletionParams, CompletionResponse, } from './schemas/completion'; import type { EditParams, EditResponse } from './schemas/edit'; -import type { EmbeddingParams, EmbeddingResponse } from './schemas/embedding'; +import type { + EmbeddingParams, + EmbeddingResponse, + BulkEmbeddingParams, +} from './schemas/embedding'; import type { FetchOptions } from './fetch-api'; import type { ChatCompletionParams, @@ -81,6 +92,28 @@ export class OpenAIClient { return { embedding, response }; } + /** + * Create embeddings for an array of input strings. + * @param params.input The strings to embed. + * @param params.model The model to use for the embedding. + * @param params.user A unique identifier representing the end-user. + */ + async createEmbeddings(params: BulkEmbeddingParams): Promise<{ + /** The embeddings for the input strings. */ + embeddings: number[][]; + /** The raw response from the API. */ + response: EmbeddingResponse; + }> { + const reqBody = BulkEmbeddingParamsSchema.parse(params); + const response: EmbeddingResponse = await this.api + .post('embeddings', { json: reqBody }) + .json(); + // Sort ascending by index to be safe. + const items = response.data.sort((a, b) => a.index - b.index); + const embeddings = items.map((item) => item.embedding); + return { embeddings, response }; + } + /** * Create a completion for a single prompt string. */ @@ -98,6 +131,27 @@ export class OpenAIClient { return { completion, response }; } + /** + * Create completions for an array of prompt strings. + */ + async createCompletions(params: BulkCompletionParams): Promise<{ + /** The completion strings. */ + completions: string[]; + /** The raw response from the API. */ + response: CompletionResponse; + }> { + const reqBody = BulkCompletionParamsSchema.parse(params); + const response: CompletionResponse = await this.api + .post('completions', { json: reqBody }) + .json(); + // Sort ascending by index to be safe. + const choices = response.choices.sort( + (a, b) => (a.index ?? 0) - (b.index ?? 0) + ); + const completions = choices.map((choice) => choice.text || ''); + return { completions, response }; + } + /** * Create a completion for a single prompt string and stream back partial progress. * @param params typipcal standard OpenAI completion parameters diff --git a/src/schemas/completion.ts b/src/schemas/completion.ts index f190f04..3cbc2cf 100644 --- a/src/schemas/completion.ts +++ b/src/schemas/completion.ts @@ -85,6 +85,15 @@ export const CompletionParamsSchema = z.object({ export type CompletionParams = z.input; +export const BulkCompletionParamsSchema = CompletionParamsSchema.extend({ + /** + * The array of string prompts to generate completions for. + */ + prompt: z.array(z.string()), +}); + +export type BulkCompletionParams = z.input; + export type CompletionResponse = { id: string; object: string; diff --git a/src/schemas/embedding.ts b/src/schemas/embedding.ts index 92d9dc1..85d592b 100644 --- a/src/schemas/embedding.ts +++ b/src/schemas/embedding.ts @@ -38,6 +38,23 @@ export const EmbeddingParamsSchema = z.object({ export type EmbeddingParams = z.input; +export const BulkEmbeddingParamsSchema = z.object({ + /** + * ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them. + */ + model: EmbeddingModel, + /** + * The strings to embed. + */ + input: z.array(z.string()), + /** + * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](/docs/usage-policies/end-user-ids). + */ + user: z.string().optional(), +}); + +export type BulkEmbeddingParams = z.input; + export type EmbeddingResponse = { data: { embedding: number[]; diff --git a/yarn.lock b/yarn.lock index f9f630b..18feff1 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1450,6 +1450,11 @@ dot-prop@^6.0.1: dependencies: is-obj "^2.0.0" +dotenv@^16.0.3: + version "16.0.3" + resolved "https://registry.yarnpkg.com/dotenv/-/dotenv-16.0.3.tgz#115aec42bac5053db3c456db30cc243a5a836a07" + integrity sha512-7GO6HghkA5fYG9TYnNxi14/7K9f5occMlp3zXAuSxn7CKCxt9xbNWG7yF8hTCSUchlfWSe3uLmlPfigevRItzQ== + duplexer3@^0.1.4: version "0.1.5" resolved "https://registry.yarnpkg.com/duplexer3/-/duplexer3-0.1.5.tgz#0b5e4d7bad5de8901ea4440624c8e1d20099217e"