From 2cbd538a3b67ef6bdd9ab7b92bebc3c8c7a1bea1 Mon Sep 17 00:00:00 2001 From: luisredondo <46001223+luisredondo@users.noreply.github.com> Date: Mon, 12 Feb 2024 15:40:40 -0600 Subject: [PATCH] feat: Add streaming support to googleai_dart client (#299) Co-authored-by: David Miguel --- packages/googleai_dart/README.md | 33 ++- .../example/googleai_dart_example.dart | 29 +++ packages/googleai_dart/lib/src/client.dart | 50 ++++ .../lib/src/generated/client.dart | 27 --- .../oas/googleai_openapi_curated.yaml | 27 --- .../googleai_dart/test/count_tokens_test.dart | 1 - .../test/embed_content_test.dart | 1 - .../test/generate_content_test.dart | 1 - .../googleai_dart/test/model_info_test.dart | 1 - .../test/stream_generate_content_test.dart | 228 ++++++++++++++++++ 10 files changed, 336 insertions(+), 62 deletions(-) create mode 100644 packages/googleai_dart/test/stream_generate_content_test.dart diff --git a/packages/googleai_dart/README.md b/packages/googleai_dart/README.md index c4665b86..825bcc7c 100644 --- a/packages/googleai_dart/README.md +++ b/packages/googleai_dart/README.md @@ -15,12 +15,9 @@ Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Ge - Custom base URL, headers and query params support (e.g. HTTP proxies) - Custom HTTP client support (e.g. SOCKS5 proxies or advanced use cases) -> [!NOTE] -> Streaming support coming soon. - **Supported endpoints:** -- Generate content +- Generate content (with streaming support) - Count tokens - Embed content - Models info @@ -35,6 +32,7 @@ Unofficial Dart client for [Google AI](https://ai.google.dev) for Developers (Ge + [Text-only input](#text-only-input) + [Text-and-image input](#text-and-image-input) + [Multi-turn conversations (chat)](#multi-turn-conversations-chat) + + [Streaming generated content](#streaming-generated-content) * [Count tokens](#count-tokens) * [Embedding](#embedding) * [Model info](#model-info) @@ -167,6 +165,33 @@ print(res.candidates?.first.content?.parts?.first.text); // In the heart of a tranquil village nestled amidst the rolling hills of 17th century France... ``` +#### Streaming generated content + +By default, `generateContent` returns a response after completing the entire generation process. You can achieve faster interactions by not waiting for the entire result, and instead use `streamGenerateContent` to handle partial results as they become available. + +```dart +final stream = await client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + parts: [ + Part(text: 'Write a story about a magic backpack.'), + ], + ), + ], + generationConfig: GenerationConfig( + temperature: 0.8, + ), + ), +); + +stream.listen((final res) P + print(res.candidates?.first.content?.parts?.first.text); + // In a quaint little town nestled amidst rolling hills, there lived a... +) +``` + ### Count tokens When using long prompts, it might be useful to count tokens before sending any content to the model. diff --git a/packages/googleai_dart/example/googleai_dart_example.dart b/packages/googleai_dart/example/googleai_dart_example.dart index fd54821e..6087e0f2 100644 --- a/packages/googleai_dart/example/googleai_dart_example.dart +++ b/packages/googleai_dart/example/googleai_dart_example.dart @@ -1,4 +1,5 @@ // ignore_for_file: avoid_print, avoid_redundant_argument_values +import 'dart:async'; import 'dart:convert'; import 'dart:io'; @@ -14,6 +15,9 @@ Future main() async { await _generateContentTextAndImageInput(client); await _generateContentMultiTurnConversations(client); + // Stream generate content + await _streamGenerateContentTextInput(client); + // Count tokens await _countTokens(client); @@ -114,6 +118,31 @@ Future _generateContentMultiTurnConversations( // In the heart of a tranquil village nestled amidst the rolling hills of 17th century France... } +Future _streamGenerateContentTextInput( + final GoogleAIClient client, +) async { + final stream = client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + parts: [ + Part(text: 'Write a story about a magic backpack.'), + ], + ), + ], + generationConfig: GenerationConfig( + temperature: 0.8, + ), + ), + ); + + await for (final res in stream) { + print(res.candidates?.first.content?.parts?.first.text); + // In a quaint little town nestled amidst rolling hills, there lived a... + } +} + Future _countTokens(final GoogleAIClient client) async { final res = await client.countTokens( modelId: 'gemini-pro', diff --git a/packages/googleai_dart/lib/src/client.dart b/packages/googleai_dart/lib/src/client.dart index ef55e569..605aec83 100644 --- a/packages/googleai_dart/lib/src/client.dart +++ b/packages/googleai_dart/lib/src/client.dart @@ -1,7 +1,11 @@ // ignore_for_file: use_super_parameters +import 'dart:async'; +import 'dart:convert'; + import 'package:http/http.dart' as http; import 'generated/client.dart' as g; +import 'generated/schema/schema.dart'; import 'http_client/http_client.dart'; /// Client for Google AI API (Gemini API). @@ -39,8 +43,54 @@ class GoogleAIClient extends g.GoogleAIClient { client: client ?? createDefaultHttpClient(), ); + // ------------------------------------------ + // METHOD: streamGenerateContent + // ------------------------------------------ + + /// Generates a streamed response from the model given an input `GenerateContentRequest`. + /// + /// `modelId`: The id of the model to use. + /// + /// `request`: Request to generate a completion from the model. + /// + /// `POST` `https://generativelanguage.googleapis.com/v1/models/{modelId}:streamGenerateContent` + Stream streamGenerateContent({ + final String modelId = 'gemini-pro', + final GenerateContentRequest? request, + }) async* { + final streamedResponse = await makeRequestStream( + baseUrl: 'https://generativelanguage.googleapis.com/v1', + queryParams: { + 'alt': 'sse', + }, + path: '/models/$modelId:streamGenerateContent', + method: g.HttpMethod.post, + requestType: 'application/json', + responseType: 'application/json', + body: request, + ); + + yield* streamedResponse.stream + .transform(const _GoogleAIStreamTransformer()) + .map((final d) => GenerateContentResponse.fromJson(json.decode(d))); + } + @override Future onRequest(final http.BaseRequest request) { return onRequestHandler(request); } } + +class _GoogleAIStreamTransformer + extends StreamTransformerBase, String> { + const _GoogleAIStreamTransformer(); + + @override + Stream bind(final Stream> stream) { + return stream + .transform(utf8.decoder) + .transform(const LineSplitter()) + .where((final s) => s.isNotEmpty) + .map((final s) => s.substring(6)); + } +} diff --git a/packages/googleai_dart/lib/src/generated/client.dart b/packages/googleai_dart/lib/src/generated/client.dart index b39bf9f7..ebf48164 100644 --- a/packages/googleai_dart/lib/src/generated/client.dart +++ b/packages/googleai_dart/lib/src/generated/client.dart @@ -583,33 +583,6 @@ class GoogleAIClient { return BatchEmbedContentsResponse.fromJson(_jsonDecode(r)); } - // ------------------------------------------ - // METHOD: streamGenerateContent - // ------------------------------------------ - - /// Generates a streamed response from the model given an input `GenerateContentRequest`. - /// - /// `modelId`: The id of the model to use. - /// - /// `request`: Request to generate a completion from the model. - /// - /// `POST` `https://generativelanguage.googleapis.com/v1/models/{modelId}:streamGenerateContent` - Future streamGenerateContent({ - String modelId = 'gemini-pro', - GenerateContentRequest? request, - }) async { - final r = await makeRequest( - baseUrl: 'https://generativelanguage.googleapis.com/v1', - path: '/models/$modelId:streamGenerateContent', - method: HttpMethod.post, - isMultipart: false, - requestType: 'application/json', - responseType: 'application/json', - body: request, - ); - return GenerateContentResponse.fromJson(_jsonDecode(r)); - } - // ------------------------------------------ // METHOD: embedContent // ------------------------------------------ diff --git a/packages/googleai_dart/oas/googleai_openapi_curated.yaml b/packages/googleai_dart/oas/googleai_openapi_curated.yaml index 385131d9..783cde47 100644 --- a/packages/googleai_dart/oas/googleai_openapi_curated.yaml +++ b/packages/googleai_dart/oas/googleai_openapi_curated.yaml @@ -1020,33 +1020,6 @@ paths: schema: type: string default: embedding-001 - /models/{modelId}:streamGenerateContent: - post: - description: >- - Generates a streamed response from the model given an input - `GenerateContentRequest`. - operationId: streamGenerateContent - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/GenerateContentRequest' - security: [ ] - responses: - '200': - description: Successful response - content: - application/json: - schema: - $ref: '#/components/schemas/GenerateContentResponse' - parameters: - - in: path - name: modelId - description: The id of the model to use. - required: false - schema: - type: string - default: gemini-pro /models/{modelId}:embedContent: post: description: Generates an embedding from the model given an input `Content`. diff --git a/packages/googleai_dart/test/count_tokens_test.dart b/packages/googleai_dart/test/count_tokens_test.dart index 2f3265d6..38fea1d7 100644 --- a/packages/googleai_dart/test/count_tokens_test.dart +++ b/packages/googleai_dart/test/count_tokens_test.dart @@ -1,4 +1,3 @@ -// ignore_for_file: avoid_redundant_argument_values @TestOn('vm') library; // Uses dart:io diff --git a/packages/googleai_dart/test/embed_content_test.dart b/packages/googleai_dart/test/embed_content_test.dart index f262cd42..449da57d 100644 --- a/packages/googleai_dart/test/embed_content_test.dart +++ b/packages/googleai_dart/test/embed_content_test.dart @@ -1,4 +1,3 @@ -// ignore_for_file: avoid_redundant_argument_values @TestOn('vm') library; // Uses dart:io diff --git a/packages/googleai_dart/test/generate_content_test.dart b/packages/googleai_dart/test/generate_content_test.dart index 6b0e7b85..d0515a56 100644 --- a/packages/googleai_dart/test/generate_content_test.dart +++ b/packages/googleai_dart/test/generate_content_test.dart @@ -1,4 +1,3 @@ -// ignore_for_file: avoid_redundant_argument_values @TestOn('vm') library; // Uses dart:io diff --git a/packages/googleai_dart/test/model_info_test.dart b/packages/googleai_dart/test/model_info_test.dart index 791d2cdc..89aa00f0 100644 --- a/packages/googleai_dart/test/model_info_test.dart +++ b/packages/googleai_dart/test/model_info_test.dart @@ -1,4 +1,3 @@ -// ignore_for_file: avoid_redundant_argument_values @TestOn('vm') library; // Uses dart:io diff --git a/packages/googleai_dart/test/stream_generate_content_test.dart b/packages/googleai_dart/test/stream_generate_content_test.dart new file mode 100644 index 00000000..d361b23c --- /dev/null +++ b/packages/googleai_dart/test/stream_generate_content_test.dart @@ -0,0 +1,228 @@ +@TestOn('vm') +library; // Uses dart:io + +import 'dart:convert'; +import 'dart:io'; + +import 'package:googleai_dart/googleai_dart.dart'; +import 'package:test/test.dart'; + +void main() { + group('Google AI stream generate content API tests', () { + late GoogleAIClient client; + + setUp(() async { + client = GoogleAIClient( + apiKey: Platform.environment['GOOGLEAI_API_KEY'], + ); + }); + + tearDown(() { + client.endSession(); + }); + + test('Test Text-only input with gemini-pro', () async { + final stream = client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + parts: [ + Part( + text: 'List the numbers from 1 to 100 in order ' + 'without any spaces, commas or additional explanations.', + ), + ], + ), + ], + generationConfig: GenerationConfig( + temperature: 0, + ), + ), + ); + + var text = ''; + await for (final res in stream) { + expect(res.promptFeedback?.blockReason, isNull); + expect(res.candidates, isNotEmpty); + final candidate = res.candidates!.first; + expect(candidate.index, 0); + expect(candidate.finishReason, CandidateFinishReason.stop); + expect(candidate.content, isNotNull); + final content = candidate.content!; + expect(content.role, 'model'); + expect(content.parts, hasLength(1)); + text += content.parts!.first.text ?? ''; + } + + expect( + text.replaceAll(RegExp(r'[\s\n]'), ''), + contains('123456789'), + ); + }); + + test('Text-and-image input with gemini-pro-vision', () async { + final stream = client.streamGenerateContent( + modelId: 'gemini-pro-vision', + request: GenerateContentRequest( + contents: [ + Content( + parts: [ + const Part( + text: 'What is this picture? Be detailed. ' + 'List all the elements that you see.', + ), + Part( + inlineData: Blob( + mimeType: 'image/png', + data: base64.encode( + await File('./test/assets/1.png').readAsBytes(), + ), + ), + ), + ], + ), + ], + ), + ); + + var text = ''; + await for (final res in stream) { + expect(res.promptFeedback?.blockReason, isNull); + expect(res.candidates, isNotEmpty); + final candidate = res.candidates!.first; + expect(candidate.index, 0); + expect(candidate.finishReason, CandidateFinishReason.stop); + expect(candidate.content, isNotNull); + final content = candidate.content!; + expect(content.role, 'model'); + expect(content.parts, hasLength(1)); + final part = content.parts!.first; + text += ' ${part.text!}'; + } + + expect( + text, + anyOf( + contains('coffee'), + contains('blueberries'), + contains('cookies'), + ), + ); + }); + + test('Test stop sequence', () async { + final stream = client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + parts: [ + Part( + text: 'List the numbers from 1 to 9 in order ' + 'without any spaces, commas or additional explanations.', + ), + ], + ), + ], + generationConfig: GenerationConfig( + stopSequences: ['4'], + ), + ), + ); + + var text = ''; + await for (final res in stream) { + expect(res.candidates, isNotEmpty); + final candidate = res.candidates!.first; + expect(candidate.content, isNotNull); + final content = candidate.content!; + text += + content.parts!.first.text?.replaceAll(RegExp(r'[\s\n]'), '') ?? ''; + } + + expect(text, contains('123')); + expect(text, isNot(contains('456789'))); + }); + + test('Test max tokens', () async { + final res = client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + parts: [ + Part(text: 'Tell me a joke'), + ], + ), + ], + generationConfig: GenerationConfig( + maxOutputTokens: 2, + ), + ), + ); + + await for (final res in res) { + expect(res.candidates, isNotEmpty); + final candidate = res.candidates!.first; + expect(candidate.finishReason, CandidateFinishReason.maxTokens); + } + }); + + test('Test Multi-turn conversations with gemini-pro', () async { + final stream = client.streamGenerateContent( + modelId: 'gemini-pro', + request: const GenerateContentRequest( + contents: [ + Content( + role: 'user', + parts: [ + Part( + text: 'List the numbers from 1 to 9 in order ' + 'without any spaces, commas or additional explanations.', + ), + ], + ), + Content( + role: 'model', + parts: [ + Part( + text: '123456789', + ), + ], + ), + Content( + role: 'user', + parts: [ + Part( + text: 'Remove the number 4 from the list', + ), + ], + ), + ], + ), + ); + + var text = ''; + + await for (final res in stream) { + expect(res.promptFeedback?.blockReason, isNull); + expect(res.candidates, isNotEmpty); + final candidate = res.candidates!.first; + expect(candidate.index, 0); + expect(candidate.finishReason, CandidateFinishReason.stop); + expect(candidate.content, isNotNull); + final content = candidate.content!; + expect(content.role, 'model'); + expect(content.parts, hasLength(1)); + final part = content.parts!.first; + text += ' ${part.text!}'; + } + + expect( + text, + contains('12356789'), + ); + }); + }); +}