Skip to content

Commit

Permalink
feat: Add support for shortening embeddings in OpenAIEmbeddings (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Jan 26, 2024
1 parent c725db0 commit 5f5eb54
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
21 changes: 20 additions & 1 deletion packages/langchain_openai/lib/src/chat_models/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,26 @@ class ChatOpenAIOptions extends ChatModelOptions {

/// ID of the model to use (e.g. 'gpt-3.5-turbo').
///
/// See https://platform.openai.com/docs/api-reference/chat/create#chat-create-model
/// Available models:
/// - `gpt-4`
/// - `gpt-4-0314`
/// - `gpt-4-0613`
/// - `gpt-4-32k`
/// - `gpt-4-32k-0314`
/// - `gpt-4-32k-0613`
/// - `gpt-4-turbo-preview`
/// - `gpt-4-1106-preview`
/// - `gpt-4-0125-preview`
/// - `gpt-4-vision-preview`
/// - `gpt-3.5-turbo`
/// - `gpt-3.5-turbo-16k`
/// - `gpt-3.5-turbo-0301`
/// - `gpt-3.5-turbo-0613`
/// - `gpt-3.5-turbo-1106`
/// - `gpt-3.5-turbo-16k-0613`
///
/// Mind that the list may be outdated.
/// See https://platform.openai.com/docs/models for the latest list.
final String? model;

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on
Expand Down
18 changes: 16 additions & 2 deletions packages/langchain_openai/lib/src/embeddings/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class OpenAIEmbeddings implements Embeddings {
/// [OpenAI dashboard](https://platform.openai.com/account/api-keys).
/// - `organization`: your OpenAI organization ID (if applicable).
/// - [OpenAIEmbeddings.model]
/// - [OpenAIEmbeddings.dimensions]
/// - [OpenAIEmbeddings.batchSize]
/// - [OpenAIEmbeddings.user]
///
Expand All @@ -130,6 +131,7 @@ class OpenAIEmbeddings implements Embeddings {
final Map<String, dynamic>? queryParams,
final http.Client? client,
this.model = 'text-embedding-ada-002',
this.dimensions,
this.batchSize = 512,
this.user,
}) : _client = OpenAIClient(
Expand All @@ -144,11 +146,21 @@ class OpenAIEmbeddings implements Embeddings {
/// A client for interacting with OpenAI API.
final OpenAIClient _client;

/// ID of the model to use (e.g. 'text-embedding-ada-002').
/// ID of the model to use (e.g. 'text-embedding-3-small').
///
/// See https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-model
/// Available models:
/// - `text-embedding-3-small`
/// - `text-embedding-3-large`
/// - `text-embedding-ada-002`
///
/// Mind that the list may be outdated.
/// See https://platform.openai.com/docs/models for the latest list.
String model;

/// The number of dimensions the resulting output embeddings should have.
/// Only supported in `text-embedding-3` and later models.
int? dimensions;

/// The maximum number of documents to embed in a single request.
/// This is limited by max input tokens for the model
/// (e.g. 8191 tokens for text-embedding-ada-002).
Expand Down Expand Up @@ -181,6 +193,7 @@ class OpenAIEmbeddings implements Embeddings {
input: EmbeddingInput.listString(
batch.map((final doc) => doc.pageContent).toList(growable: false),
),
dimensions: dimensions,
user: user,
),
);
Expand All @@ -197,6 +210,7 @@ class OpenAIEmbeddings implements Embeddings {
request: CreateEmbeddingRequest(
model: EmbeddingModel.modelId(model),
input: EmbeddingInput.string(query),
dimensions: dimensions,
user: user,
),
);
Expand Down
8 changes: 7 additions & 1 deletion packages/langchain_openai/lib/src/llms/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ class OpenAIOptions extends LLMOptions {

/// ID of the model to use (e.g. 'gpt-3.5-turbo-instruct').
///
/// See https://platform.openai.com/docs/api-reference/completions/create#completions-create-model
/// Available models:
/// - `gpt-3.5-turbo-instruct`
/// - `davinci-002`
/// - `babbage-002`
///
/// Mind that the list may be outdated.
/// See https://platform.openai.com/docs/models for the latest list.
final String? model;

/// Generates best_of completions server-side and returns the "best"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,20 @@ void main() {
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];

test('Test OpenAIEmbeddings.embedQuery', () async {
final embeddings = OpenAIEmbeddings(apiKey: openaiApiKey);
final res = await embeddings.embedQuery('Hello world');
expect(res.length, 1536);
final models = [
('text-embedding-ada-002', 1536),
('text-embedding-3-small', 1536),
('text-embedding-3-large', 3072),
];

for (final (modelId, modelDim) in models) {
final embeddings = OpenAIEmbeddings(
apiKey: openaiApiKey,
model: modelId,
);
final res = await embeddings.embedQuery('Hello world');
expect(res.length, modelDim, reason: modelId);
}
});

test('Test OpenAIEmbeddings.embedDocuments', () async {
Expand All @@ -33,5 +44,15 @@ void main() {
expect(res[0].length, 1536);
expect(res[1].length, 1536);
});

test('Test shortening embeddings', () async {
final embeddings = OpenAIEmbeddings(
apiKey: openaiApiKey,
model: 'text-embedding-3-small',
dimensions: 256,
);
final res = await embeddings.embedQuery('Hello world');
expect(res.length, 256);
});
});
}

0 comments on commit 5f5eb54

Please sign in to comment.