From fa4b5c376a191fea50c3f8b1d6b07cef0480a74e Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Thu, 9 May 2024 17:31:46 +0200 Subject: [PATCH] feat!: Migrate internal client from googleai_dart to google_generative_ai (#407) --- .../chat_models/integrations/googleai.md | 72 ++++--- .../text_embedding/integrations/google_ai.md | 6 +- examples/browser_summarizer/pubspec.lock | 4 +- .../chat_models/integrations/googleai.dart | 48 +++-- examples/docs_examples/pubspec.lock | 19 +- examples/docs_examples/pubspec_overrides.yaml | 4 +- examples/hello_world_backend/pubspec.lock | 4 +- examples/hello_world_cli/pubspec.lock | 4 +- examples/hello_world_flutter/pubspec.lock | 4 +- melos.yaml | 1 + packages/googleai_dart/README.md | 6 +- packages/googleai_dart/lib/src/client.dart | 2 +- .../google_ai/chat_google_generative_ai.dart | 203 ++++++++++++------ .../src/chat_models/google_ai/mappers.dart | 191 ++++++++-------- .../lib/src/chat_models/google_ai/types.dart | 30 +-- .../google_ai/google_ai_embeddings.dart | 139 ++++++++---- .../src/utils/https_client/http_client.dart | 61 ++++++ .../utils/https_client/http_client_html.dart | 21 ++ .../utils/https_client/http_client_io.dart | 19 ++ .../utils/https_client/http_client_stub.dart | 14 ++ packages/langchain_google/pubspec.yaml | 3 +- .../langchain_google/pubspec_overrides.yaml | 4 +- .../chat_google_generative_ai_test.dart | 52 +++-- .../google_ai/google_ai_embeddings_test.dart | 18 +- 24 files changed, 596 insertions(+), 333 deletions(-) create mode 100644 packages/langchain_google/lib/src/utils/https_client/http_client.dart create mode 100644 packages/langchain_google/lib/src/utils/https_client/http_client_html.dart create mode 100644 packages/langchain_google/lib/src/utils/https_client/http_client_io.dart create mode 100644 packages/langchain_google/lib/src/utils/https_client/http_client_stub.dart diff --git a/docs/modules/model_io/models/chat_models/integrations/googleai.md b/docs/modules/model_io/models/chat_models/integrations/googleai.md index be3cd37b..f3c0eeaf 100644 --- a/docs/modules/model_io/models/chat_models/integrations/googleai.md +++ b/docs/modules/model_io/models/chat_models/integrations/googleai.md @@ -4,11 +4,21 @@ Wrapper around [Google AI for Developers](https://ai.google.dev/) API (aka Gemin ## Setup -To use `ChatGoogleGenerativeAI` you need to have an API key. You can get one [here](https://makersuite.google.com/app/apikey). - -The following models are available at the moment: -- `gemini-pro`: text -> text model -- `gemini-pro-vision`: text / image -> text model +To use `ChatGoogleGenerativeAI` you need to have an API key. You can get one [here](https://aistudio.google.com/app/apikey). + +The following models are available: +- `gemini-1.0-pro` (or `gemini-pro`): + * text -> text model + * Max input token: 30720 + * Max output tokens: 2048 +- `gemini-pro-vision`: + * text / image -> text model + * Max input token: 12288 + * Max output tokens: 4096 +- `gemini-1.5-pro-latest`: text / image -> text model + * text / image -> text model + * Max input token: 1048576 + * Max output tokens: 8192 Mind that this list may not be up-to-date. Refer to the [documentation](https://ai.google.dev/models) for the updated list. @@ -20,17 +30,15 @@ final apiKey = Platform.environment['GOOGLEAI_API_KEY']; final chatModel = ChatGoogleGenerativeAI( apiKey: apiKey, defaultOptions: ChatGoogleGenerativeAIOptions( + model: 'gemini-1.5-pro-latest', temperature: 0, ), ); -const template = ''' -You are a helpful assistant that translates {input_language} to {output_language}. - -Text to translate: -{text}'''; -final humanMessagePrompt = HumanChatMessagePromptTemplate.fromTemplate(template); -final chatPrompt = ChatPromptTemplate.fromPromptMessages([humanMessagePrompt]); +final chatPrompt = ChatPromptTemplate.fromTemplates([ + (ChatMessageType.system, 'You are a helpful assistant that translates {input_language} to {output_language}.'), + (ChatMessageType.human, 'Text to translate:\n{text}'), +]); final chain = chatPrompt | chatModel | StringOutputParser(); @@ -40,7 +48,7 @@ final res = await chain.invoke({ 'text': 'I love programming.', }); print(res); -// -> 'J'adore la programmation.'final +// -> 'J'adore programmer.' ``` ## Multimodal support @@ -51,7 +59,7 @@ final apiKey = Platform.environment['GOOGLEAI_API_KEY']; final chatModel = ChatGoogleGenerativeAI( apiKey: apiKey, defaultOptions: ChatGoogleGenerativeAIOptions( - model: 'gemini-pro-vision', + model: 'gemini-1.5-pro-latest', temperature: 0, ), ); @@ -71,7 +79,7 @@ final res = await chatModel.invoke( ]), ); print(res.output.content); -// -> 'A Red and Green Apple' +// -> 'That is an apple.' ``` ## Streaming @@ -79,27 +87,27 @@ print(res.output.content); ```dart final apiKey = Platform.environment['GOOGLEAI_API_KEY']; -final promptTemplate = ChatPromptTemplate.fromTemplate( - 'You are a helpful assistant that replies only with numbers ' - 'in order without any spaces or commas ' - 'List the numbers from 1 to {max_num}'); +final promptTemplate = ChatPromptTemplate.fromTemplates(const [ + (ChatMessageType.system, 'You are a helpful assistant that replies only with numbers in order without any spaces or commas.'), + (ChatMessageType.human, 'List the numbers from 1 to {max_num}'), +]); -final chatModel = ChatGoogleGenerativeAI(apiKey: apiKey); +final chatModel = ChatGoogleGenerativeAI( + apiKey: apiKey, + defaultOptions: const ChatGoogleGenerativeAIOptions( + model: 'gemini-1.5-pro-latest', + temperature: 0, + ), +); final chain = promptTemplate.pipe(chatModel).pipe(StringOutputParser()); final stream = chain.stream({'max_num': '30'}); await stream.forEach(print); -// 1234567891011121 -// 31415161718192021222324252627282 -// 930 -``` - -## Limitations +// 1 +// 2345678910111213 +// 1415161718192021 +// 222324252627282930 -As of the time this doc was written (15/12/23), Gemini has some restrictions on the types and structure of prompts it accepts. Specifically: - -1. When providing multimodal (image) inputs, you are restricted to at most 1 message of “human” (user) type. You cannot pass multiple messages (though the single human message may have multiple content entries). -2. System messages are not accepted. -3. For regular chat conversations, messages must follow the human/ai/human/ai alternating pattern. You may not provide 2 AI or human messages in sequence. -4. Message may be blocked if they violate the safety checks of the LLM. In this case, the model will return an empty response. +chatModel.close(); +``` diff --git a/docs/modules/retrieval/text_embedding/integrations/google_ai.md b/docs/modules/retrieval/text_embedding/integrations/google_ai.md index 41ae2c59..6d84e8a1 100644 --- a/docs/modules/retrieval/text_embedding/integrations/google_ai.md +++ b/docs/modules/retrieval/text_embedding/integrations/google_ai.md @@ -4,8 +4,10 @@ The embedding service in the [Gemini API](https://ai.google.dev/docs/embeddings_ ## Available models -- `embedding-001` (default) - * Optimized for creating embeddings for text of up to 2048 tokens +- `text-embedding-004` + * Dimensions: 768 (with support for reduced dimensionality) +- `embedding-001` + * Dimensions: 768 The previous list of models may not be exhaustive or up-to-date. Check out the [Google AI documentation](https://ai.google.dev/models/gemini) for the latest list of available models. diff --git a/examples/browser_summarizer/pubspec.lock b/examples/browser_summarizer/pubspec.lock index 8b71a1a8..11e5271d 100644 --- a/examples/browser_summarizer/pubspec.lock +++ b/examples/browser_summarizer/pubspec.lock @@ -246,7 +246,7 @@ packages: path: "../../packages/langchain_openai" relative: true source: path - version: "0.6.0+1" + version: "0.6.0+2" langchain_tiktoken: dependency: transitive description: @@ -309,7 +309,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.3.0" + version: "0.3.1" path: dependency: transitive description: diff --git a/examples/docs_examples/bin/modules/model_io/models/chat_models/integrations/googleai.dart b/examples/docs_examples/bin/modules/model_io/models/chat_models/integrations/googleai.dart index 6e61d1c0..9bb99c43 100644 --- a/examples/docs_examples/bin/modules/model_io/models/chat_models/integrations/googleai.dart +++ b/examples/docs_examples/bin/modules/model_io/models/chat_models/integrations/googleai.dart @@ -17,19 +17,18 @@ Future _chatGoogleGenerativeAI() async { final chatModel = ChatGoogleGenerativeAI( apiKey: apiKey, defaultOptions: const ChatGoogleGenerativeAIOptions( + model: 'gemini-1.5-pro-latest', temperature: 0, ), ); - const template = ''' -You are a helpful assistant that translates {input_language} to {output_language}. - -Text to translate: -{text}'''; - final humanMessagePrompt = - HumanChatMessagePromptTemplate.fromTemplate(template); - final chatPrompt = - ChatPromptTemplate.fromPromptMessages([humanMessagePrompt]); + final chatPrompt = ChatPromptTemplate.fromTemplates(const [ + ( + ChatMessageType.system, + 'You are a helpful assistant that translates {input_language} to {output_language}.' + ), + (ChatMessageType.human, 'Text to translate:\n{text}'), + ]); final chain = chatPrompt | chatModel | const StringOutputParser(); @@ -39,7 +38,7 @@ Text to translate: 'text': 'I love programming.', }); print(res); - // -> 'J'adore la programmation.' + // -> 'J'adore programmer.' chatModel.close(); } @@ -50,7 +49,7 @@ Future _chatGoogleGenerativeAIMultiModal() async { final chatModel = ChatGoogleGenerativeAI( apiKey: apiKey, defaultOptions: const ChatGoogleGenerativeAIOptions( - model: 'gemini-pro-vision', + model: 'gemini-1.5-pro-latest', temperature: 0, ), ); @@ -70,7 +69,7 @@ Future _chatGoogleGenerativeAIMultiModal() async { ]), ); print(res.output.content); - // -> 'A Red and Green Apple' + // -> 'That is an apple.' chatModel.close(); } @@ -78,20 +77,31 @@ Future _chatGoogleGenerativeAIMultiModal() async { Future _chatOpenAIStreaming() async { final apiKey = Platform.environment['GOOGLEAI_API_KEY']; - final promptTemplate = ChatPromptTemplate.fromTemplate( + final promptTemplate = ChatPromptTemplate.fromTemplates(const [ + ( + ChatMessageType.system, 'You are a helpful assistant that replies only with numbers ' - 'in order without any spaces or commas ' - 'List the numbers from 1 to {max_num}'); + 'in order without any spaces or commas.', + ), + (ChatMessageType.human, 'List the numbers from 1 to {max_num}'), + ]); - final chatModel = ChatGoogleGenerativeAI(apiKey: apiKey); + final chatModel = ChatGoogleGenerativeAI( + apiKey: apiKey, + defaultOptions: const ChatGoogleGenerativeAIOptions( + model: 'gemini-1.5-pro-latest', + temperature: 0, + ), + ); final chain = promptTemplate.pipe(chatModel).pipe(const StringOutputParser()); final stream = chain.stream({'max_num': '30'}); await stream.forEach(print); - // 1234567891011121 - // 31415161718192021222324252627282 - // 930 + // 1 + // 2345678910111213 + // 1415161718192021 + // 222324252627282930 chatModel.close(); } diff --git a/examples/docs_examples/pubspec.lock b/examples/docs_examples/pubspec.lock index 65989569..3df17c8a 100644 --- a/examples/docs_examples/pubspec.lock +++ b/examples/docs_examples/pubspec.lock @@ -128,6 +128,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.8.12" + google_generative_ai: + dependency: transitive + description: + name: google_generative_ai + sha256: "3a51ed314e596ddee9654ab9f1372fff0b8f3917be937655bcadebd5fa063df0" + url: "https://pub.dev" + source: hosted + version: "0.3.3" google_identity_services_web: dependency: transitive description: @@ -136,13 +144,6 @@ packages: url: "https://pub.dev" source: hosted version: "0.3.1+1" - googleai_dart: - dependency: "direct overridden" - description: - path: "../../packages/googleai_dart" - relative: true - source: path - version: "0.0.4" googleapis: dependency: transitive description: @@ -270,7 +271,7 @@ packages: path: "../../packages/langchain_openai" relative: true source: path - version: "0.6.0+1" + version: "0.6.0+2" langchain_tiktoken: dependency: transitive description: @@ -323,7 +324,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.3.0" + version: "0.3.1" path: dependency: transitive description: diff --git a/examples/docs_examples/pubspec_overrides.yaml b/examples/docs_examples/pubspec_overrides.yaml index 94d3894b..e02da308 100644 --- a/examples/docs_examples/pubspec_overrides.yaml +++ b/examples/docs_examples/pubspec_overrides.yaml @@ -1,9 +1,7 @@ -# melos_managed_dependency_overrides: chromadb,googleai_dart,langchain,langchain_chroma,langchain_google,langchain_mistralai,langchain_ollama,langchain_openai,mistralai_dart,ollama_dart,openai_dart,vertex_ai,langchain_core,langchain_community +# melos_managed_dependency_overrides: chromadb,langchain,langchain_chroma,langchain_google,langchain_mistralai,langchain_ollama,langchain_openai,mistralai_dart,ollama_dart,openai_dart,vertex_ai,langchain_core,langchain_community dependency_overrides: chromadb: path: ../../packages/chromadb - googleai_dart: - path: ../../packages/googleai_dart langchain: path: ../../packages/langchain langchain_chroma: diff --git a/examples/hello_world_backend/pubspec.lock b/examples/hello_world_backend/pubspec.lock index f170887a..a6ab9a9d 100644 --- a/examples/hello_world_backend/pubspec.lock +++ b/examples/hello_world_backend/pubspec.lock @@ -133,7 +133,7 @@ packages: path: "../../packages/langchain_openai" relative: true source: path - version: "0.6.0+1" + version: "0.6.0+2" langchain_tiktoken: dependency: transitive description: @@ -156,7 +156,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.3.0" + version: "0.3.1" path: dependency: transitive description: diff --git a/examples/hello_world_cli/pubspec.lock b/examples/hello_world_cli/pubspec.lock index f00ce58e..4f3a1669 100644 --- a/examples/hello_world_cli/pubspec.lock +++ b/examples/hello_world_cli/pubspec.lock @@ -125,7 +125,7 @@ packages: path: "../../packages/langchain_openai" relative: true source: path - version: "0.6.0+1" + version: "0.6.0+2" langchain_tiktoken: dependency: transitive description: @@ -148,7 +148,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.3.0" + version: "0.3.1" path: dependency: transitive description: diff --git a/examples/hello_world_flutter/pubspec.lock b/examples/hello_world_flutter/pubspec.lock index 4e703b73..c9698539 100644 --- a/examples/hello_world_flutter/pubspec.lock +++ b/examples/hello_world_flutter/pubspec.lock @@ -154,7 +154,7 @@ packages: path: "../../packages/langchain_openai" relative: true source: path - version: "0.6.0+1" + version: "0.6.0+2" langchain_tiktoken: dependency: transitive description: @@ -193,7 +193,7 @@ packages: path: "../../packages/openai_dart" relative: true source: path - version: "0.3.0" + version: "0.3.1" path: dependency: transitive description: diff --git a/melos.yaml b/melos.yaml index e47b8950..a6684972 100644 --- a/melos.yaml +++ b/melos.yaml @@ -35,6 +35,7 @@ command: flutter_markdown: ^0.6.22 freezed_annotation: ^2.4.1 gcloud: ^0.8.12 + google_generative_ai: 0.3.3 googleapis: ^12.0.0 googleapis_auth: ^1.5.1 http: ^1.1.0 diff --git a/packages/googleai_dart/README.md b/packages/googleai_dart/README.md index 5765c7ff..cf4aa823 100644 --- a/packages/googleai_dart/README.md +++ b/packages/googleai_dart/README.md @@ -5,7 +5,9 @@ [![](https://dcbadge.vercel.app/api/server/x4qbhqecVR?style=flat)](https://discord.gg/x4qbhqecVR) [![MIT](https://img.shields.io/badge/license-MIT-purple.svg)](https://github.com/davidmigloz/langchain_dart/blob/main/LICENSE) -Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Gemini API). +Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Gemini API v1). + +> Note: The official [`google_generative_ai`](https://pub.dev/packages/google_generative_ai) now has feature parity with this package (except for the [Model info](https://github.com/google-gemini/generative-ai-dart/issues/93) endpoints). We plan to deprecate this package in the near future. ## Features @@ -53,7 +55,7 @@ Refer to the [documentation](https://ai.google.dev/docs) for more information ab ### Authentication -The Google AI API uses API keys for authentication. Visit [Google AI Studio dashboard](https://makersuite.google.com/app/apikey) page to retrieve the API key you'll use in your requests. +The Google AI API uses API keys for authentication. Visit [Google AI Studio dashboard](https://aistudio.google.com/app/apikey) page to retrieve the API key you'll use in your requests. > **Remember that your API key is a secret!** > Do not share it with others or expose it in any client-side code (browsers, apps). Production requests must be routed through your own backend server where your API key can be securely loaded from an environment variable or key management service. diff --git a/packages/googleai_dart/lib/src/client.dart b/packages/googleai_dart/lib/src/client.dart index c6abc41a..39138156 100644 --- a/packages/googleai_dart/lib/src/client.dart +++ b/packages/googleai_dart/lib/src/client.dart @@ -16,7 +16,7 @@ class GoogleAIClient extends g.GoogleAIClient { /// /// Main configuration options: /// - `apiKey`: your Google AI API key. You can find your API key in the - /// [Google AI Studio dashboard](https://makersuite.google.com/app/apikey). + /// [Google AI Studio dashboard](https://aistudio.google.com/app/apikey). /// /// Advance configuration options: /// - `baseUrl`: the base URL to use.You can override this to use a different diff --git a/packages/langchain_google/lib/src/chat_models/google_ai/chat_google_generative_ai.dart b/packages/langchain_google/lib/src/chat_models/google_ai/chat_google_generative_ai.dart index c6ccf07d..c36c970d 100644 --- a/packages/langchain_google/lib/src/chat_models/google_ai/chat_google_generative_ai.dart +++ b/packages/langchain_google/lib/src/chat_models/google_ai/chat_google_generative_ai.dart @@ -1,9 +1,10 @@ -import 'package:googleai_dart/googleai_dart.dart'; +import 'package:google_generative_ai/google_generative_ai.dart'; import 'package:http/http.dart' as http; import 'package:langchain_core/chat_models.dart'; import 'package:langchain_core/prompts.dart'; import 'package:uuid/uuid.dart'; +import '../../utils/https_client/http_client.dart'; import 'mappers.dart'; import 'types.dart'; @@ -25,13 +26,23 @@ import 'types.dart'; /// ### Setup /// /// To use `ChatGoogleGenerativeAI` you need to have an API key. -/// You can get one [here](https://makersuite.google.com/app/apikey). +/// You can get one [here](https://aistudio.google.com/app/apikey). /// /// ### Available models /// -/// The following models are available at the moment: -/// - `gemini-pro`: text -> text model -/// - `gemini-pro-vision`: text / image -> text model +/// The following models are available: +/// - `gemini-1.0-pro` (or `gemini-pro`): +/// * text -> text model +/// * Max input token: 30720 +/// * Max output tokens: 2048 +/// - `gemini-pro-vision`: +/// * text / image -> text model +/// * Max input token: 12288 +/// * Max output tokens: 4096 +/// - `gemini-1.5-pro-latest`: text / image -> text model +/// * text / image -> text model +/// * Max input token: 1048576 +/// * Max output tokens: 8192 /// /// Mind that this list may not be up-to-date. /// Refer to the [documentation](https://ai.google.dev/models) for the updated list. @@ -142,7 +153,7 @@ class ChatGoogleGenerativeAI /// /// Main configuration options: /// - `apiKey`: your Google AI API key. You can find your API key in the - /// [Google AI Studio dashboard](https://makersuite.google.com/app/apikey). + /// [Google AI Studio dashboard](https://aistudio.google.com/app/apikey). /// - [ChatGoogleGenerativeAI.defaultOptions] /// /// Advance configuration options: @@ -163,49 +174,64 @@ class ChatGoogleGenerativeAI super.defaultOptions = const ChatGoogleGenerativeAIOptions( model: 'gemini-pro', ), - }) : _client = GoogleAIClient( - apiKey: apiKey, - baseUrl: baseUrl, - headers: headers, - queryParams: queryParams, - client: client, - ); + }) : _currentModel = defaultOptions.model ?? '', + _httpClient = createDefaultHttpClient( + baseHttpClient: client, + baseUrl: + baseUrl ?? 'https://generativelanguage.googleapis.com/v1beta', + headers: { + if (apiKey != null) 'x-goog-api-key': apiKey, + ...?headers, + }, + queryParams: queryParams ?? const {}, + ) { + _googleAiClient = _createGoogleAiClient( + _currentModel, + apiKey: apiKey ?? '', + httpClient: _httpClient, + ); + } + + /// The HTTP client to use. + final CustomHttpClient _httpClient; /// A client for interacting with Google AI API. - final GoogleAIClient _client; + late GenerativeModel _googleAiClient; /// A UUID generator. late final Uuid _uuid = const Uuid(); /// Set or replace the API key. - set apiKey(final String value) => _client.apiKey = value; + set apiKey(final String value) => + _httpClient.headers['x-goog-api-key'] = value; /// Get the API key. - String get apiKey => _client.apiKey; + String get apiKey => _httpClient.headers['x-goog-api-key'] ?? ''; @override String get modelType => 'chat-google-generative-ai'; + /// The current model set in [_googleAiClient]; + String _currentModel; + + /// The current system instruction set in [_googleAiClient]; + String? _currentSystemInstruction; + @override Future invoke( final PromptValue input, { final ChatGoogleGenerativeAIOptions? options, }) async { final id = _uuid.v4(); - final (model, isTuned) = _getNormalizedModel(options); - final request = _generateCompletionRequest( - input.toChatMessages(), - options: options, + final (model, prompt, safetySettings, generationConfig, tools, toolConfig) = + _generateCompletionRequest(input.toChatMessages(), options: options); + final completion = await _googleAiClient.generateContent( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, ); - final completion = await (isTuned - ? _client.generateContentTunedModel( - tunedModelId: model, - request: request, - ) - : _client.generateContent( - modelId: model, - request: request, - )); return completion.toChatResult(id, model); } @@ -215,37 +241,51 @@ class ChatGoogleGenerativeAI final ChatGoogleGenerativeAIOptions? options, }) { final id = _uuid.v4(); - final (model, isTuned) = _getNormalizedModel(options); - assert(!isTuned, 'Tuned models are not supported for streaming.'); - final request = _generateCompletionRequest( - input.toChatMessages(), - options: options, - ); - - return _client - .streamGenerateContent(modelId: model, request: request) + final (model, prompt, safetySettings, generationConfig, tools, toolConfig) = + _generateCompletionRequest(input.toChatMessages(), options: options); + return _googleAiClient + .generateContentStream( + prompt, + safetySettings: safetySettings, + generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, + ) .map((final completion) => completion.toChatResult(id, model)); } /// Creates a [GenerateContentRequest] from the given input. - GenerateContentRequest _generateCompletionRequest( + ( + String model, + Iterable prompt, + List? safetySettings, + GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, + ) _generateCompletionRequest( final List messages, { final ChatGoogleGenerativeAIOptions? options, }) { - return GenerateContentRequest( - contents: messages.toContentList(), - generationConfig: GenerationConfig( - topP: options?.topP ?? defaultOptions.topP, - topK: options?.topK ?? defaultOptions.topK, + _updateClientIfNeeded(messages, options); + + return ( + _currentModel, + messages.toContentList(), + (options?.safetySettings ?? defaultOptions.safetySettings) + ?.toSafetySettings(), + GenerationConfig( candidateCount: options?.candidateCount ?? defaultOptions.candidateCount, + stopSequences: + options?.stopSequences ?? defaultOptions.stopSequences ?? const [], maxOutputTokens: options?.maxOutputTokens ?? defaultOptions.maxOutputTokens, temperature: options?.temperature ?? defaultOptions.temperature, - stopSequences: options?.stopSequences ?? defaultOptions.stopSequences, + topP: options?.topP ?? defaultOptions.topP, + topK: options?.topK ?? defaultOptions.topK, ), - safetySettings: (options?.safetySettings ?? defaultOptions.safetySettings) - ?.toSafetySettings(), + null, // options?.tools?.toTools(), + null, // options?.toolConfig?.toToolConfig(), ); } @@ -264,36 +304,65 @@ class ChatGoogleGenerativeAI final PromptValue promptValue, { final ChatGoogleGenerativeAIOptions? options, }) async { - final (model, _) = _getNormalizedModel(options); - final tokens = await _client.countTokens( - modelId: model, - request: CountTokensRequest( - contents: promptValue.toChatMessages().toContentList(), - ), - ); - return tokens.totalTokens ?? 0; + final messages = promptValue.toChatMessages(); + _updateClientIfNeeded(messages, options); + final tokens = await _googleAiClient.countTokens(messages.toContentList()); + return tokens.totalTokens; } /// Closes the client and cleans up any resources associated with it. void close() { - _client.endSession(); + _httpClient.close(); + } + + /// Create a new [GenerativeModel] instance. + GenerativeModel _createGoogleAiClient( + final String model, { + final String? apiKey, + final CustomHttpClient? httpClient, + final String? systemInstruction, + }) { + return GenerativeModel( + model: model, + apiKey: apiKey ?? this.apiKey, + httpClient: httpClient ?? _httpClient, + systemInstruction: + systemInstruction != null ? Content.system(systemInstruction) : null, + ); } - /// Returns the model code to use and whether it is a tuned model. - (String model, bool isTuned) _getNormalizedModel( + /// Recreate the [GenerativeModel] instance. + void _recreateGoogleAiClient( + final String model, + final String? systemInstruction, + ) { + _googleAiClient = + _createGoogleAiClient(model, systemInstruction: systemInstruction); + } + + /// Updates the model in [_googleAiClient] if needed. + void _updateClientIfNeeded( + final List messages, final ChatGoogleGenerativeAIOptions? options, ) { - final rawModel = + final model = options?.model ?? defaultOptions.model ?? throwNullModelError(); - if (!rawModel.contains('/')) { - return (rawModel, false); + final systemInstruction = messages.firstOrNull is SystemChatMessage + ? messages.firstOrNull?.contentAsString + : null; + + bool recreate = false; + if (model != _currentModel) { + _currentModel = model; + recreate = true; + } + if (systemInstruction != _currentSystemInstruction) { + _currentSystemInstruction = systemInstruction; + recreate = true; + } + if (recreate) { + _recreateGoogleAiClient(model, systemInstruction); } - final parts = rawModel.split('/'); - return switch (parts.first) { - 'tunedModels' => (parts.skip(1).join('/'), true), - 'models' => (parts.skip(1).join('/'), false), - _ => (rawModel, false), - }; } } diff --git a/packages/langchain_google/lib/src/chat_models/google_ai/mappers.dart b/packages/langchain_google/lib/src/chat_models/google_ai/mappers.dart index 0b03bf05..945b1d46 100644 --- a/packages/langchain_google/lib/src/chat_models/google_ai/mappers.dart +++ b/packages/langchain_google/lib/src/chat_models/google_ai/mappers.dart @@ -1,60 +1,52 @@ // ignore_for_file: public_member_api_docs +import 'dart:convert'; + import 'package:collection/collection.dart'; -import 'package:googleai_dart/googleai_dart.dart'; +import 'package:google_generative_ai/google_generative_ai.dart' as g; import 'package:langchain_core/chat_models.dart'; import 'package:langchain_core/language_models.dart'; import 'types.dart'; -const _authorUser = 'user'; -const _authorAI = 'model'; - extension ChatMessagesMapper on List { - List toContentList() { - return map( - (final message) => switch (message) { - SystemChatMessage() => throw UnsupportedError( - 'Google AI does not support system messages at the moment. ' - 'Attach your system message in the human message.', - ), - final HumanChatMessage msg => Content( - role: _authorUser, - parts: _mapHumanChatMessageContentParts(msg.content), - ), - final AIChatMessage aiChatMessage => Content( - role: _authorAI, - parts: [ - Part(text: aiChatMessage.content), - ], - ), - final CustomChatMessage customChatMessage => Content( - role: customChatMessage.role, - parts: [ - Part(text: customChatMessage.content), - ], - ), - ToolChatMessage() => throw UnsupportedError( - 'Google AI does not support tool messages', - ), - }, - ).toList(growable: false); + List toContentList() { + return where((msg) => msg is! SystemChatMessage) + .map( + (final message) => switch (message) { + SystemChatMessage() => + throw AssertionError('System messages should be filtered out'), + final HumanChatMessage msg => + g.Content.multi(_mapHumanChatMessageContentParts(msg.content)), + final AIChatMessage msg => + g.Content.model([g.TextPart(msg.content)]), + final CustomChatMessage msg => + g.Content(msg.role, [g.TextPart(msg.content)]), + ToolChatMessage() => throw UnsupportedError( + 'Google AI does not support tool messages', + ), + }, + ) + .toList(growable: false); } - List _mapHumanChatMessageContentParts( + List _mapHumanChatMessageContentParts( final ChatMessageContent content, ) { return switch (content) { - final ChatMessageContentText c => [Part(text: c.text)], + final ChatMessageContentText c => [g.TextPart(c.text)], final ChatMessageContentImage c => [ - Part(inlineData: Blob(mimeType: c.mimeType, data: c.data)), + if (c.data.startsWith('http')) + g.FilePart(Uri.parse(c.data)) + else + g.DataPart(c.mimeType ?? '', base64Decode(c.data)), ], final ChatMessageContentMultiModal c => c.parts .map( (final p) => switch (p) { - final ChatMessageContentText c => Part(text: c.text), - final ChatMessageContentImage c => Part( - inlineData: Blob(mimeType: c.mimeType, data: c.data), - ), + final ChatMessageContentText c => g.TextPart(c.text), + final ChatMessageContentImage c => c.data.startsWith('http') + ? g.FilePart(Uri.parse(c.data)) + : g.DataPart(c.mimeType ?? '', base64Decode(c.data)), ChatMessageContentMultiModal() => throw UnsupportedError( 'Cannot have multimodal content in multimodal content', ), @@ -65,86 +57,103 @@ extension ChatMessagesMapper on List { } } -extension GenerateContentResponseMapper on GenerateContentResponse { +extension GenerateContentResponseMapper on g.GenerateContentResponse { ChatResult toChatResult(final String id, final String model) { - final candidate = candidates?.first; + final candidate = candidates.first; return ChatResult( id: id, output: AIChatMessage( - content: candidate?.content?.parts - ?.map((final p) => p.text) - .whereNotNull() - .join('\n') ?? - '', + content: candidate.content.parts + .map( + (p) => switch (p) { + final g.TextPart p => p.text, + final g.DataPart p => base64Encode(p.bytes), + final g.FilePart p => p.uri.toString(), + g.FunctionResponse() => throw UnimplementedError( + 'FunctionResponse part not yet supported', + ), + g.FunctionCall() => throw UnimplementedError( + 'FunctionResponse part not yet supported', + ), + }, + ) + .whereNotNull() + .join('\n'), ), - finishReason: _mapFinishReason(candidate?.finishReason), + finishReason: _mapFinishReason(candidate.finishReason), metadata: { 'model': model, 'block_reason': promptFeedback?.blockReason?.name, + 'block_reason_message': promptFeedback?.blockReasonMessage, + 'safety_ratings': candidate.safetyRatings + ?.map( + (r) => { + 'category': r.category.name, + 'probability': r.probability.name, + }, + ) + .toList(growable: false), + 'citation_metadata': candidate.citationMetadata?.citationSources + .map( + (s) => { + 'start_index': s.startIndex, + 'end_index': s.endIndex, + 'uri': s.uri.toString(), + 'license': s.license, + }, + ) + .toList(growable: false), + 'finish_message': candidate.finishMessage, }, usage: LanguageModelUsage( - totalTokens: candidates?.map((final c) => c.tokenCount ?? 0).sum ?? 0, + promptTokens: usageMetadata?.promptTokenCount, + responseTokens: usageMetadata?.candidatesTokenCount, + totalTokens: usageMetadata?.totalTokenCount, ), ); } FinishReason _mapFinishReason( - final CandidateFinishReason? reason, + final g.FinishReason? reason, ) => switch (reason) { - CandidateFinishReason.finishReasonUnspecified => - FinishReason.unspecified, - CandidateFinishReason.stop => FinishReason.stop, - CandidateFinishReason.maxTokens => FinishReason.length, - CandidateFinishReason.safety => FinishReason.contentFilter, - CandidateFinishReason.recitation => FinishReason.recitation, - CandidateFinishReason.other => FinishReason.unspecified, + g.FinishReason.unspecified => FinishReason.unspecified, + g.FinishReason.stop => FinishReason.stop, + g.FinishReason.maxTokens => FinishReason.length, + g.FinishReason.safety => FinishReason.contentFilter, + g.FinishReason.recitation => FinishReason.recitation, + g.FinishReason.other => FinishReason.unspecified, null => FinishReason.unspecified, }; } extension SafetySettingsMapper on List { - List toSafetySettings() { + List toSafetySettings() { return map( - (final setting) => SafetySetting( - category: switch (setting.category) { - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryUnspecified => - SafetySettingCategory.harmCategoryUnspecified, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryDerogatory => - SafetySettingCategory.harmCategoryDerogatory, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryToxicity => - SafetySettingCategory.harmCategoryToxicity, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryViolence => - SafetySettingCategory.harmCategoryViolence, - ChatGoogleGenerativeAISafetySettingCategory.harmCategorySexual => - SafetySettingCategory.harmCategorySexual, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryMedical => - SafetySettingCategory.harmCategoryMedical, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryDangerous => - SafetySettingCategory.harmCategoryDangerous, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryHarassment => - SafetySettingCategory.harmCategoryHarassment, - ChatGoogleGenerativeAISafetySettingCategory.harmCategoryHateSpeech => - SafetySettingCategory.harmCategoryHateSpeech, - ChatGoogleGenerativeAISafetySettingCategory - .harmCategorySexuallyExplicit => - SafetySettingCategory.harmCategorySexuallyExplicit, - ChatGoogleGenerativeAISafetySettingCategory - .harmCategoryDangerousContent => - SafetySettingCategory.harmCategoryDangerousContent, + (final setting) => g.SafetySetting( + switch (setting.category) { + ChatGoogleGenerativeAISafetySettingCategory.unspecified => + g.HarmCategory.unspecified, + ChatGoogleGenerativeAISafetySettingCategory.harassment => + g.HarmCategory.harassment, + ChatGoogleGenerativeAISafetySettingCategory.hateSpeech => + g.HarmCategory.hateSpeech, + ChatGoogleGenerativeAISafetySettingCategory.sexuallyExplicit => + g.HarmCategory.sexuallyExplicit, + ChatGoogleGenerativeAISafetySettingCategory.dangerousContent => + g.HarmCategory.dangerousContent, }, - threshold: switch (setting.threshold) { - ChatGoogleGenerativeAISafetySettingThreshold - .harmBlockThresholdUnspecified => - SafetySettingThreshold.harmBlockThresholdUnspecified, + switch (setting.threshold) { + ChatGoogleGenerativeAISafetySettingThreshold.unspecified => + g.HarmBlockThreshold.unspecified, ChatGoogleGenerativeAISafetySettingThreshold.blockLowAndAbove => - SafetySettingThreshold.blockLowAndAbove, + g.HarmBlockThreshold.low, ChatGoogleGenerativeAISafetySettingThreshold.blockMediumAndAbove => - SafetySettingThreshold.blockMediumAndAbove, + g.HarmBlockThreshold.medium, ChatGoogleGenerativeAISafetySettingThreshold.blockOnlyHigh => - SafetySettingThreshold.blockOnlyHigh, + g.HarmBlockThreshold.high, ChatGoogleGenerativeAISafetySettingThreshold.blockNone => - SafetySettingThreshold.blockNone, + g.HarmBlockThreshold.none, }, ), ).toList(growable: false); diff --git a/packages/langchain_google/lib/src/chat_models/google_ai/types.dart b/packages/langchain_google/lib/src/chat_models/google_ai/types.dart index 996db447..12d95987 100644 --- a/packages/langchain_google/lib/src/chat_models/google_ai/types.dart +++ b/packages/langchain_google/lib/src/chat_models/google_ai/types.dart @@ -127,37 +127,19 @@ class ChatGoogleGenerativeAISafetySetting { /// Docs: https://ai.google.dev/docs/safety_setting_gemini enum ChatGoogleGenerativeAISafetySettingCategory { /// The harm category is unspecified. - harmCategoryUnspecified, - - /// The harm category is identity attack. - harmCategoryDerogatory, - - /// The harm category is profanity. - harmCategoryToxicity, - - /// The harm category is violence. - harmCategoryViolence, - - /// The harm category is sexual content. - harmCategorySexual, - - /// The harm category is medical. - harmCategoryMedical, - - /// The harm category is illegal activities. - harmCategoryDangerous, + unspecified, /// The harm category is harassment. - harmCategoryHarassment, + harassment, /// The harm category is hate speech. - harmCategoryHateSpeech, + hateSpeech, /// The harm category is sexually explicit content. - harmCategorySexuallyExplicit, + sexuallyExplicit, /// The harm category is dangerous content. - harmCategoryDangerousContent, + dangerousContent, } /// Controls the probability threshold at which harm is blocked. @@ -165,7 +147,7 @@ enum ChatGoogleGenerativeAISafetySettingCategory { /// Docs: https://ai.google.dev/docs/safety_setting_gemini enum ChatGoogleGenerativeAISafetySettingThreshold { /// Threshold is unspecified, block using default threshold. - harmBlockThresholdUnspecified, + unspecified, /// Block when low, medium or high probability of unsafe content. blockLowAndAbove, diff --git a/packages/langchain_google/lib/src/embeddings/google_ai/google_ai_embeddings.dart b/packages/langchain_google/lib/src/embeddings/google_ai/google_ai_embeddings.dart index 9e601fac..b5996abd 100644 --- a/packages/langchain_google/lib/src/embeddings/google_ai/google_ai_embeddings.dart +++ b/packages/langchain_google/lib/src/embeddings/google_ai/google_ai_embeddings.dart @@ -1,10 +1,13 @@ -import 'package:collection/collection.dart'; -import 'package:googleai_dart/googleai_dart.dart'; +import 'package:collection/collection.dart' show IterableNullableExtension; +import 'package:google_generative_ai/google_generative_ai.dart' + show Content, EmbedContentRequest, GenerativeModel, TaskType; import 'package:http/http.dart' as http; import 'package:langchain_core/documents.dart'; import 'package:langchain_core/embeddings.dart'; import 'package:langchain_core/utils.dart'; +import '../../utils/https_client/http_client.dart'; + /// {@template google_generative_ai_embeddings} /// Wrapper around Google AI embedding models API /// @@ -20,8 +23,10 @@ import 'package:langchain_core/utils.dart'; /// /// ### Available models /// +/// - `text-embedding-004` +/// * Dimensions: 768 (with support for reduced dimensionality) /// - `embedding-001` -/// * Optimized for creating embeddings for text of up to 2048 tokens +/// * Dimensions: 768 /// /// The previous list of models may not be exhaustive or up-to-date. Check out /// the [Google AI documentation](https://ai.google.dev/models/gemini) @@ -34,8 +39,18 @@ import 'package:langchain_core/utils.dart'; /// embeddings. /// /// This class uses the specifies the following task type: -/// - `retrievalDocument`: for embedding documents -/// - `retrievalQuery`: for embedding queries +/// - `retrievalDocument`: for [embedDocuments] +/// - `retrievalQuery`: for [embedQuery] +/// +/// ### Reduced dimensionality +/// +/// Some embedding models support specifying a smaller number of dimensions +/// for the resulting embeddings. This can be useful when you want to save +/// computing and storage costs with minor performance loss. Use the +/// [dimensions] parameter to specify the number of dimensions. +/// +/// You can also use this feature to reduce the dimensions to 2D or 3D for +/// visualization purposes. /// /// ### Title /// @@ -64,8 +79,10 @@ class GoogleGenerativeAIEmbeddings implements Embeddings { /// /// Main configuration options: /// - `apiKey`: your Google AI API key. You can find your API key in the - /// [Google AI Studio dashboard](https://makersuite.google.com/app/apikey). - /// - [GoogleGenerativeAIEmbeddings.model] + /// [Google AI Studio dashboard](https://aistudio.google.com/app/apikey). + /// - `model`: the embeddings model to use. You can find a list of available + /// embedding models here: https://ai.google.dev/models/gemini + /// - [GoogleGenerativeAIEmbeddings.dimensions] /// - [GoogleGenerativeAIEmbeddings.batchSize] /// - [GoogleGenerativeAIEmbeddings.docTitleKey] /// @@ -84,29 +101,45 @@ class GoogleGenerativeAIEmbeddings implements Embeddings { final Map? headers, final Map? queryParams, final http.Client? client, - this.model = 'embedding-001', + String model = 'text-embedding-004', this.dimensions, this.batchSize = 100, this.docTitleKey = 'title', - }) : _client = GoogleAIClient( - apiKey: apiKey, - baseUrl: baseUrl, - headers: headers, - queryParams: queryParams, - client: client, - ); + }) : _model = model, + _httpClient = createDefaultHttpClient( + baseHttpClient: client, + baseUrl: + baseUrl ?? 'https://generativelanguage.googleapis.com/v1beta', + headers: { + if (apiKey != null) 'x-goog-api-key': apiKey, + ...?headers, + }, + queryParams: queryParams ?? const {}, + ) { + _googleAiClient = _createGoogleAiClient(model, apiKey ?? '', _httpClient); + } + + /// The HTTP client to use. + final CustomHttpClient _httpClient; /// A client for interacting with Google AI API. - final GoogleAIClient _client; + late GenerativeModel _googleAiClient; /// The embeddings model to use. - /// - /// You can find a list of available embedding models here: - /// https://ai.google.dev/models/gemini - String model; + String _model; + + /// Set the embeddings model to use. + set model(final String model) { + _recreateGoogleAiClient(model); + _model = model; + } + + /// Get the embeddings model to use. + String get model => _model; /// The number of dimensions the resulting output embeddings should have. /// Only supported in `text-embedding-004` and later models. + /// TODO https://github.com/google-gemini/generative-ai-dart/pull/149 int? dimensions; /// The maximum number of documents to embed in a single request. @@ -116,10 +149,11 @@ class GoogleGenerativeAIEmbeddings implements Embeddings { String docTitleKey; /// Set or replace the API key. - set apiKey(final String value) => _client.apiKey = value; + set apiKey(final String value) => + _httpClient.headers['x-goog-api-key'] = value; /// Get the API key. - String get apiKey => _client.apiKey; + String get apiKey => _httpClient.headers['x-goog-api-key'] ?? ''; @override Future>> embedDocuments( @@ -129,25 +163,20 @@ class GoogleGenerativeAIEmbeddings implements Embeddings { final List>> embeddings = await Future.wait( batches.map((final batch) async { - final data = await _client.batchEmbedContents( - modelId: model, - request: BatchEmbedContentsRequest( - requests: batch.map((final doc) { - return EmbedContentRequest( - title: doc.metadata[docTitleKey], - content: Content(parts: [Part(text: doc.pageContent)]), - taskType: EmbedContentRequestTaskType.retrievalDocument, - model: 'models/$model', - outputDimensionality: dimensions, - ); - }).toList(growable: false), - ), + final data = await _googleAiClient.batchEmbedContents( + batch.map((final doc) { + return EmbedContentRequest( + Content.text(doc.pageContent), + taskType: TaskType.retrievalDocument, + title: doc.metadata[docTitleKey], + // outputDimensionality: dimensions, TODO + ); + }).toList(growable: false), ); return data.embeddings - ?.map((final p) => p.values) - .whereNotNull() - .toList(growable: false) ?? - const []; + .map((final p) => p.values) + .whereNotNull() + .toList(growable: false); }), ); @@ -156,18 +185,34 @@ class GoogleGenerativeAIEmbeddings implements Embeddings { @override Future> embedQuery(final String query) async { - final data = await _client.embedContent( - modelId: model, - request: EmbedContentRequest( - content: Content(parts: [Part(text: query)]), - taskType: EmbedContentRequestTaskType.retrievalQuery, - ), + final data = await _googleAiClient.embedContent( + Content.text(query), + taskType: TaskType.retrievalQuery, + // outputDimensionality: dimensions, TODO ); - return data.embedding?.values ?? const []; + return data.embedding.values; } /// Closes the client and cleans up any resources associated with it. void close() { - _client.endSession(); + _httpClient.close(); + } + + /// Create a new [GenerativeModel] instance. + GenerativeModel _createGoogleAiClient( + final String model, [ + final String? apiKey, + final CustomHttpClient? httpClient, + ]) { + return GenerativeModel( + model: model, + apiKey: apiKey ?? this.apiKey, + httpClient: httpClient ?? _httpClient, + ); + } + + /// Recreate the [GenerativeModel] instance. + void _recreateGoogleAiClient(final String model) { + _googleAiClient = _createGoogleAiClient(model); } } diff --git a/packages/langchain_google/lib/src/utils/https_client/http_client.dart b/packages/langchain_google/lib/src/utils/https_client/http_client.dart new file mode 100644 index 00000000..479d2164 --- /dev/null +++ b/packages/langchain_google/lib/src/utils/https_client/http_client.dart @@ -0,0 +1,61 @@ +import 'package:http/http.dart' as http; + +export 'http_client_stub.dart' + if (dart.library.io) 'http_client_io.dart' + if (dart.library.js) 'http_client_html.dart' + if (dart.library.html) 'http_client_html.dart'; + +/// {@template custom_http_client} +/// Custom HTTP client that wraps the base HTTP client and allows to override +/// the base URL, headers, and query parameters. +/// {@endtemplate} +class CustomHttpClient extends http.BaseClient { + /// {@macro custom_http_client} + CustomHttpClient({ + required this.baseHttpClient, + required this.baseUrl, + required this.headers, + required this.queryParams, + }); + + /// Base HTTP client to use. + final http.Client baseHttpClient; + + /// Base URL to use. + final Uri baseUrl; + + /// Headers to send with every request. + final Map headers; + + /// Query parameters to send with every request. + final Map queryParams; + + @override + Future send(http.BaseRequest request) { + final newUrl = baseUrl.resolveUri(request.url).replace( + queryParameters: { + ...request.url.queryParameters, + ...queryParams, + }, + ); + + http.BaseRequest newRequest; + if (request is http.Request) { + newRequest = http.Request(request.method, newUrl) + ..headers.addAll({ + ...request.headers, + ...headers, + }) + ..persistentConnection = request.persistentConnection + ..followRedirects = request.followRedirects + ..maxRedirects = request.maxRedirects + ..bodyBytes = request.bodyBytes; + } else { + throw UnsupportedError( + 'Request type not supported (${request.runtimeType})', + ); + } + + return baseHttpClient.send(newRequest); + } +} diff --git a/packages/langchain_google/lib/src/utils/https_client/http_client_html.dart b/packages/langchain_google/lib/src/utils/https_client/http_client_html.dart new file mode 100644 index 00000000..becf0515 --- /dev/null +++ b/packages/langchain_google/lib/src/utils/https_client/http_client_html.dart @@ -0,0 +1,21 @@ +import 'package:fetch_client/fetch_client.dart' as fetch; +import 'package:http/http.dart' as http; +import 'package:http/retry.dart'; + +import 'http_client.dart'; + +/// Creates an IOClient. +CustomHttpClient createDefaultHttpClient({ + http.Client? baseHttpClient, + required String baseUrl, + required Map headers, + required Map queryParams, +}) { + return CustomHttpClient( + baseHttpClient: baseHttpClient ?? + RetryClient(fetch.FetchClient(mode: fetch.RequestMode.cors)), + baseUrl: Uri.parse(baseUrl), + headers: headers, + queryParams: queryParams, + ); +} diff --git a/packages/langchain_google/lib/src/utils/https_client/http_client_io.dart b/packages/langchain_google/lib/src/utils/https_client/http_client_io.dart new file mode 100644 index 00000000..608b703d --- /dev/null +++ b/packages/langchain_google/lib/src/utils/https_client/http_client_io.dart @@ -0,0 +1,19 @@ +import 'package:http/http.dart' as http; +import 'package:http/retry.dart'; + +import 'http_client.dart'; + +/// Creates an IOClient. +CustomHttpClient createDefaultHttpClient({ + http.Client? baseHttpClient, + required String baseUrl, + required Map headers, + required Map queryParams, +}) { + return CustomHttpClient( + baseHttpClient: baseHttpClient ?? RetryClient(http.Client()), + baseUrl: Uri.parse(baseUrl), + headers: headers, + queryParams: queryParams, + ); +} diff --git a/packages/langchain_google/lib/src/utils/https_client/http_client_stub.dart b/packages/langchain_google/lib/src/utils/https_client/http_client_stub.dart new file mode 100644 index 00000000..8f49aa3b --- /dev/null +++ b/packages/langchain_google/lib/src/utils/https_client/http_client_stub.dart @@ -0,0 +1,14 @@ +import 'package:http/http.dart' as http; + +import 'http_client.dart'; + +/// Creates a default HTTP client for the current platform. +CustomHttpClient createDefaultHttpClient({ + http.Client? baseHttpClient, + required String baseUrl, + required Map headers, + required Map queryParams, +}) => + throw UnsupportedError( + 'Cannot create a client without dart:html or dart:io.', + ); diff --git a/packages/langchain_google/pubspec.yaml b/packages/langchain_google/pubspec.yaml index c277f162..51553414 100644 --- a/packages/langchain_google/pubspec.yaml +++ b/packages/langchain_google/pubspec.yaml @@ -18,8 +18,9 @@ environment: dependencies: collection: '>=1.17.0 <1.19.0' + fetch_client: ^1.0.2 gcloud: ^0.8.12 - googleai_dart: ^0.0.4 + google_generative_ai: 0.3.3 googleapis: ^12.0.0 googleapis_auth: ^1.5.1 http: ^1.1.0 diff --git a/packages/langchain_google/pubspec_overrides.yaml b/packages/langchain_google/pubspec_overrides.yaml index b0513992..50319fbe 100644 --- a/packages/langchain_google/pubspec_overrides.yaml +++ b/packages/langchain_google/pubspec_overrides.yaml @@ -1,7 +1,5 @@ -# melos_managed_dependency_overrides: googleai_dart,vertex_ai,langchain_core +# melos_managed_dependency_overrides: vertex_ai,langchain_core dependency_overrides: - googleai_dart: - path: ../googleai_dart langchain_core: path: ../langchain_core vertex_ai: diff --git a/packages/langchain_google/test/chat_models/google_ai/chat_google_generative_ai_test.dart b/packages/langchain_google/test/chat_models/google_ai/chat_google_generative_ai_test.dart index 88d421fd..d39d8c9e 100644 --- a/packages/langchain_google/test/chat_models/google_ai/chat_google_generative_ai_test.dart +++ b/packages/langchain_google/test/chat_models/google_ai/chat_google_generative_ai_test.dart @@ -13,11 +13,16 @@ import 'package:test/test.dart'; void main() { group('ChatGoogleGenerativeAI tests', () { + const defaultModel = 'gemini-pro'; + late ChatGoogleGenerativeAI chatModel; setUp(() async { chatModel = ChatGoogleGenerativeAI( apiKey: Platform.environment['GOOGLEAI_API_KEY'], + defaultOptions: const ChatGoogleGenerativeAIOptions( + model: defaultModel, + ), ); }); @@ -26,24 +31,27 @@ void main() { }); test('Test Text-only input with gemini-pro', () async { - final res = await chatModel.invoke( - PromptValue.string( - 'List the numbers from 1 to 9 in order ' - 'without any spaces, commas or additional explanations.', - ), - options: const ChatGoogleGenerativeAIOptions( - model: 'gemini-pro', - temperature: 0, - ), - ); - expect(res.id, isNotEmpty); - expect(res.finishReason, isNot(FinishReason.unspecified)); - expect(res.metadata['model'], 'gemini-pro'); - expect(res.metadata['block_reason'], isNull); - expect( - res.output.content.replaceAll(RegExp(r'[\s\n]'), ''), - contains('123456789'), - ); + const models = ['gemini-1.0-pro', 'gemini-1.5-pro-latest']; + for (final model in models) { + final res = await chatModel.invoke( + PromptValue.string( + 'List the numbers from 1 to 9 in order ' + 'without any spaces, commas or additional explanations.', + ), + options: ChatGoogleGenerativeAIOptions( + model: model, + temperature: 0, + ), + ); + expect(res.id, isNotEmpty); + expect(res.finishReason, isNot(FinishReason.unspecified)); + expect(res.metadata['model'], startsWith(model)); + expect(res.metadata['block_reason'], isNull); + expect( + res.output.content.replaceAll(RegExp(r'[\s\n]'), ''), + contains('123456789'), + ); + } }); test('Test models prefix', () async { @@ -53,7 +61,7 @@ void main() { 'without any spaces, commas or additional explanations.', ), options: const ChatGoogleGenerativeAIOptions( - model: 'models/gemini-pro', + model: defaultModel, temperature: 0, ), ); @@ -91,7 +99,7 @@ void main() { 'without any spaces, commas or additional explanations.', ), options: const ChatGoogleGenerativeAIOptions( - model: 'gemini-pro', + model: defaultModel, stopSequences: ['4'], ), ); @@ -108,7 +116,7 @@ void main() { maxOutputTokens: 2, ), ); - expect(res.output.content, lessThan(20)); + expect(res.output.content.length, lessThan(20)); expect(res.finishReason, FinishReason.length); }); @@ -126,7 +134,7 @@ void main() { final res = await chatModel.invoke( prompt, options: const ChatGoogleGenerativeAIOptions( - model: 'gemini-pro', + model: defaultModel, temperature: 0, ), ); diff --git a/packages/langchain_google/test/embeddings/google_ai/google_ai_embeddings_test.dart b/packages/langchain_google/test/embeddings/google_ai/google_ai_embeddings_test.dart index 4e2be55e..bc942e51 100644 --- a/packages/langchain_google/test/embeddings/google_ai/google_ai_embeddings_test.dart +++ b/packages/langchain_google/test/embeddings/google_ai/google_ai_embeddings_test.dart @@ -22,8 +22,15 @@ void main() { }); test('Test GoogleGenerativeAIEmbeddings.embedQuery', () async { - final res = await embeddings.embedQuery('Hello world'); - expect(res.length, 768); + const models = ['text-embedding-004', 'embedding-001']; + for (final model in models) { + embeddings.model = model; + final res = await embeddings.embedQuery( + 'Hello world', + ); + expect(res.length, 768); + embeddings.close(); + } }); test('Test GoogleGenerativeAIEmbeddings.embedDocuments', () async { @@ -41,5 +48,12 @@ void main() { expect(res[0].length, 768); expect(res[1].length, 768); }); + + // TODO https://github.com/google-gemini/generative-ai-dart/pull/149 + test('Test shortening embeddings', skip: true, () async { + embeddings.dimensions = 256; + final res = await embeddings.embedQuery('Hello world'); + expect(res.length, 256); + }); }); }